diff --git a/inference/core/workflows/core_steps/formatters/vlm_as_detector/v2.py b/inference/core/workflows/core_steps/formatters/vlm_as_detector/v2.py index 23fad45a8a..2697d7dae5 100644 --- a/inference/core/workflows/core_steps/formatters/vlm_as_detector/v2.py +++ b/inference/core/workflows/core_steps/formatters/vlm_as_detector/v2.py @@ -125,9 +125,11 @@ class BlockManifest(WorkflowBlockManifest): } }, ) - model_type: Literal["google-gemini", "anthropic-claude", "florence-2"] = Field( - description="Type of the model that generated prediction", - examples=[["google-gemini", "anthropic-claude", "florence-2"]], + model_type: Literal["openai", "google-gemini", "anthropic-claude", "florence-2"] = ( + Field( + description="Type of the model that generated prediction", + examples=[["google-gemini", "anthropic-claude", "florence-2"]], + ) ) task_type: Literal[tuple(SUPPORTED_TASKS)] = Field( description="Task type to performed by model.", @@ -234,7 +236,7 @@ def try_parse_json(content: str) -> Tuple[bool, dict]: return True, {} -def parse_gemini_object_detection_response( +def parse_llm_object_detection_response( image: WorkflowImageData, parsed_data: dict, classes: List[str], @@ -353,8 +355,11 @@ def get_4digit_from_md5(input_string): REGISTERED_PARSERS = { - ("google-gemini", "object-detection"): parse_gemini_object_detection_response, - ("anthropic-claude", "object-detection"): parse_gemini_object_detection_response, + # LLMs + ("openai", "object-detection"): parse_llm_object_detection_response, + ("google-gemini", "object-detection"): parse_llm_object_detection_response, + ("anthropic-claude", "object-detection"): parse_llm_object_detection_response, + # Florence 2 ("florence-2", "object-detection"): partial( parse_florence2_object_detection_response, florence_task_type="" ), diff --git a/inference/core/workflows/core_steps/loader.py b/inference/core/workflows/core_steps/loader.py index f877069c4a..9700755b71 100644 --- a/inference/core/workflows/core_steps/loader.py +++ b/inference/core/workflows/core_steps/loader.py @@ -221,6 +221,9 @@ from inference.core.workflows.core_steps.models.foundation.openai.v3 import ( OpenAIBlockV3, ) +from inference.core.workflows.core_steps.models.foundation.openai.v4 import ( + OpenAIBlockV4, +) from inference.core.workflows.core_steps.models.foundation.perception_encoder.v1 import ( PerceptionEncoderModelBlockV1, ) @@ -642,6 +645,7 @@ def load_blocks() -> List[Type[WorkflowBlock]]: OpenAIBlockV1, OpenAIBlockV2, OpenAIBlockV3, + OpenAIBlockV4, PathDeviationAnalyticsBlockV1, PathDeviationAnalyticsBlockV2, PixelateVisualizationBlockV1, diff --git a/inference/core/workflows/core_steps/models/foundation/openai/v4.py b/inference/core/workflows/core_steps/models/foundation/openai/v4.py new file mode 100644 index 0000000000..b0cd3000a6 --- /dev/null +++ b/inference/core/workflows/core_steps/models/foundation/openai/v4.py @@ -0,0 +1,935 @@ +import base64 +import json +from functools import partial +from typing import Any, Dict, List, Literal, Optional, Type, Union + +import requests +from openai import OpenAI +from pydantic import ConfigDict, Field, model_validator + +from inference.core.env import WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS +from inference.core.managers.base import ModelManager +from inference.core.roboflow_api import post_to_roboflow_api +from inference.core.utils.image_utils import encode_image_to_jpeg_bytes, load_image +from inference.core.workflows.core_steps.common.utils import run_in_parallel +from inference.core.workflows.core_steps.common.vlms import VLM_TASKS_METADATA +from inference.core.workflows.execution_engine.entities.base import ( + Batch, + OutputDefinition, + WorkflowImageData, +) +from inference.core.workflows.execution_engine.entities.types import ( + FLOAT_KIND, + IMAGE_KIND, + LANGUAGE_MODEL_OUTPUT_KIND, + LIST_OF_VALUES_KIND, + ROBOFLOW_MANAGED_KEY, + SECRET_KIND, + STRING_KIND, + ImageInputField, + Selector, +) +from inference.core.workflows.prototypes.block import ( + BlockResult, + WorkflowBlock, + WorkflowBlockManifest, +) + +OPENAI_MODELS = [ + { + "id": "gpt-5.1", + "name": "GPT-5.1", + "reasoning_effort_values": ["none", "low", "medium", "high"], + }, + { + "id": "gpt-5", + "name": "GPT-5", + "reasoning_effort_values": ["minimal", "low", "medium", "high"], + }, + { + "id": "gpt-5-mini", + "name": "GPT-5 mini", + "reasoning_effort_values": ["minimal", "low", "medium", "high"], + }, + { + "id": "gpt-5-nano", + "name": "GPT-5 nano", + "reasoning_effort_values": ["minimal", "low", "medium", "high"], + }, + { + "id": "gpt-4.1", + "name": "GPT-4.1", + "reasoning_effort_values": [], + }, + { + "id": "gpt-4.1-mini", + "name": "GPT-4.1 mini", + "reasoning_effort_values": [], + }, + { + "id": "gpt-4.1-nano", + "name": "GPT-4.1 nano", + "reasoning_effort_values": [], + }, + { + "id": "gpt-4o", + "name": "GPT-4o", + "reasoning_effort_values": [], + }, + { + "id": "gpt-4o-mini", + "name": "GPT-4o mini", + "reasoning_effort_values": [], + }, +] + +MODEL_VERSION_IDS = [model["id"] for model in OPENAI_MODELS] + +MODEL_VERSION_METADATA = { + model["id"]: {"name": model["name"]} for model in OPENAI_MODELS +} + +MODELS_SUPPORTING_REASONING_EFFORT = [ + model["id"] for model in OPENAI_MODELS if model["reasoning_effort_values"] +] + +MODELS_NOT_SUPPORTING_REASONING_EFFORT = [ + model["id"] for model in OPENAI_MODELS if not model["reasoning_effort_values"] +] + +MODEL_REASONING_EFFORT_VALUES = { + model["id"]: model["reasoning_effort_values"] for model in OPENAI_MODELS +} + +SUPPORTED_TASK_TYPES_LIST = [ + "unconstrained", + "ocr", + "structured-answering", + "classification", + "multi-label-classification", + "visual-question-answering", + "caption", + "detailed-caption", + "object-detection", +] +SUPPORTED_TASK_TYPES = set(SUPPORTED_TASK_TYPES_LIST) + +RELEVANT_TASKS_METADATA = { + k: v for k, v in VLM_TASKS_METADATA.items() if k in SUPPORTED_TASK_TYPES +} +RELEVANT_TASKS_DOCS_DESCRIPTION = "\n\n".join( + f"* **{v['name']}** (`{k}`) - {v['description']}" + for k, v in RELEVANT_TASKS_METADATA.items() +) + +LONG_DESCRIPTION = f""" +Ask a question to OpenAI's GPT models with vision capabilities (including GPT-5 and GPT-4o). + +You can specify arbitrary text prompts or predefined ones, the block supports the following types of prompt: + +{RELEVANT_TASKS_DOCS_DESCRIPTION} + +Provide your OpenAI API key or set the value to ``rf_key:account`` (or +``rf_key:user:``) to proxy requests through Roboflow's API. +""" + +TaskType = Literal[tuple(SUPPORTED_TASK_TYPES_LIST)] + +TASKS_REQUIRING_PROMPT = { + "unconstrained", + "visual-question-answering", +} + +TASKS_REQUIRING_CLASSES = { + "classification", + "multi-label-classification", + "object-detection", +} + +TASKS_REQUIRING_OUTPUT_STRUCTURE = { + "structured-answering", +} + + +class BlockManifest(WorkflowBlockManifest): + model_config = ConfigDict( + json_schema_extra={ + "name": "OpenAI", + "version": "v4", + "short_description": "Run OpenAI's GPT models with vision capabilities.", + "long_description": LONG_DESCRIPTION, + "license": "Apache-2.0", + "block_type": "model", + "search_keywords": ["LMM", "VLM", "ChatGPT", "GPT", "OpenAI"], + "is_vlm_block": True, + "task_type_property": "task_type", + "ui_manifest": { + "section": "model", + "icon": "fal fa-atom", + "blockPriority": 5, + "popular": True, + }, + }, + protected_namespaces=(), + ) + type: Literal["roboflow_core/open_ai@v4"] + images: Selector(kind=[IMAGE_KIND]) = ImageInputField + task_type: TaskType = Field( + default="unconstrained", + description="Task type to be performed by model. Value determines required parameters and output response.", + json_schema_extra={ + "values_metadata": RELEVANT_TASKS_METADATA, + "recommended_parsers": { + "structured-answering": "roboflow_core/json_parser@v1", + "classification": "roboflow_core/vlm_as_classifier@v2", + "multi-label-classification": "roboflow_core/vlm_as_classifier@v2", + "object-detection": "roboflow_core/vlm_as_detector@v2", + }, + "always_visible": True, + }, + ) + prompt: Optional[Union[Selector(kind=[STRING_KIND]), str]] = Field( + default=None, + description="Text prompt to the OpenAI model", + examples=["my prompt", "$inputs.prompt"], + json_schema_extra={ + "relevant_for": { + "task_type": {"values": TASKS_REQUIRING_PROMPT, "required": True}, + }, + "multiline": True, + }, + ) + output_structure: Optional[Dict[str, str]] = Field( + default=None, + description="Dictionary with structure of expected JSON response", + examples=[{"my_key": "description"}, "$inputs.output_structure"], + json_schema_extra={ + "relevant_for": { + "task_type": { + "values": TASKS_REQUIRING_OUTPUT_STRUCTURE, + "required": True, + }, + }, + }, + ) + classes: Optional[Union[Selector(kind=[LIST_OF_VALUES_KIND]), List[str]]] = Field( + default=None, + description="List of classes to be used", + examples=[["class-a", "class-b"], "$inputs.classes"], + json_schema_extra={ + "relevant_for": { + "task_type": { + "values": TASKS_REQUIRING_CLASSES, + "required": True, + }, + }, + }, + ) + api_key: Union[ + Selector(kind=[STRING_KIND, SECRET_KIND, ROBOFLOW_MANAGED_KEY]), str + ] = Field( + default="rf_key:account", + description="Your OpenAI API key", + examples=["xxx-xxx", "$inputs.openai_api_key"], + private=True, + ) + model_version: Union[ + Selector(kind=[STRING_KIND]), + Literal[tuple(MODEL_VERSION_IDS)], + ] = Field( + default="gpt-5.1", + description="Model to be used", + examples=["gpt-5.1", "$inputs.openai_model"], + json_schema_extra={ + "values_metadata": MODEL_VERSION_METADATA, + }, + ) + reasoning_effort: Optional[ + Union[ + Selector(kind=[STRING_KIND]), + Literal["none", "minimal", "low", "medium", "high"], + ] + ] = Field( + default=None, + description="Control the effort on reasoning. " + "Reducing reasoning effort can result in faster responses and fewer tokens used. " + "GPT-5.1 defaults to 'none' (no reasoning) and supports 'none', 'low', 'medium', 'high'. " + "GPT-5 models default to 'medium' and support 'minimal', 'low', 'medium', 'high'.", + json_schema_extra={ + "relevant_for": { + "model_version": { + "values": MODELS_SUPPORTING_REASONING_EFFORT, + "required": False, + }, + }, + }, + ) + image_detail: Union[ + Selector(kind=[STRING_KIND]), Literal["auto", "high", "low"] + ] = Field( + default="auto", + description="Indicates the image's quality, with 'high' suggesting it is of high resolution and should be processed or displayed with high fidelity.", + examples=["auto", "high", "low"], + ) + max_tokens: Optional[int] = Field( + default=None, + description="Maximum number of tokens the model can generate in its response. " + "If not specified, the model will use its default limit. Minimum value is 16.", + ge=16, + ) + temperature: Optional[Union[float, Selector(kind=[FLOAT_KIND])]] = Field( + default=None, + description="Temperature to sample from the model - value in range 0.0-2.0, the higher - the more " + 'random / "creative" the generations are.', + ge=0.0, + le=2.0, + ) + max_concurrent_requests: Optional[int] = Field( + default=None, + description="Number of concurrent requests that can be executed by block when batch of input images provided. " + "If not given - block defaults to value configured globally in Workflows Execution Engine. " + "Please restrict if you hit OpenAI limits.", + ) + + @model_validator(mode="after") + def validate(self) -> "BlockManifest": + if self.task_type in TASKS_REQUIRING_PROMPT and self.prompt is None: + raise ValueError( + f"`prompt` parameter required to be set for task `{self.task_type}`" + ) + if self.task_type in TASKS_REQUIRING_CLASSES and self.classes is None: + raise ValueError( + f"`classes` parameter required to be set for task `{self.task_type}`" + ) + if ( + self.task_type in TASKS_REQUIRING_OUTPUT_STRUCTURE + and self.output_structure is None + ): + raise ValueError( + f"`output_structure` parameter required to be set for task `{self.task_type}`" + ) + return self + + @classmethod + def get_parameters_accepting_batches(cls) -> List[str]: + return ["images"] + + @classmethod + def describe_outputs(cls) -> List[OutputDefinition]: + return [ + OutputDefinition( + name="output", kind=[STRING_KIND, LANGUAGE_MODEL_OUTPUT_KIND] + ), + OutputDefinition(name="classes", kind=[LIST_OF_VALUES_KIND]), + ] + + @classmethod + def get_execution_engine_compatibility(cls) -> Optional[str]: + return ">=1.4.0,<2.0.0" + + +class OpenAIBlockV4(WorkflowBlock): + + def __init__( + self, + model_manager: ModelManager, + api_key: Optional[str], + ): + self._model_manager = model_manager + self._api_key = api_key + + @classmethod + def get_init_parameters(cls) -> List[str]: + return ["model_manager", "api_key"] + + @classmethod + def get_manifest(cls) -> Type[WorkflowBlockManifest]: + return BlockManifest + + @classmethod + def get_execution_engine_compatibility(cls) -> Optional[str]: + return ">=1.3.0,<2.0.0" + + def run( + self, + images: Batch[WorkflowImageData], + task_type: TaskType, + prompt: Optional[str], + output_structure: Optional[Dict[str, str]], + classes: Optional[List[str]], + model_version: str, + reasoning_effort: Optional[str], + image_detail: Literal["low", "high", "auto"], + max_tokens: Optional[int], + temperature: Optional[float], + max_concurrent_requests: Optional[int], + api_key: str = "rf_key:account", + ) -> BlockResult: + inference_images = [i.to_inference_format() for i in images] + raw_outputs = run_openai_prompting( + roboflow_api_key=self._api_key, + images=inference_images, + task_type=task_type, + prompt=prompt, + output_structure=output_structure, + classes=classes, + openai_api_key=api_key, + model_version=model_version, + reasoning_effort=reasoning_effort, + image_detail=image_detail, + max_tokens=max_tokens, + temperature=temperature, + max_concurrent_requests=max_concurrent_requests, + ) + return [ + {"output": raw_output, "classes": classes} for raw_output in raw_outputs + ] + + +def run_openai_prompting( + roboflow_api_key: Optional[str], + images: List[Dict[str, Any]], + task_type: TaskType, + prompt: Optional[str], + output_structure: Optional[Dict[str, str]], + classes: Optional[List[str]], + openai_api_key: str, + model_version: str, + reasoning_effort: Optional[str], + image_detail: Literal["auto", "high", "low"], + max_tokens: Optional[int], + temperature: Optional[float], + max_concurrent_requests: Optional[int], +) -> List[str]: + if task_type not in PROMPT_BUILDERS: + raise ValueError(f"Task type: {task_type} not supported.") + openai_prompts = [] + for image in images: + loaded_image, _ = load_image(image) + base64_image = base64.b64encode( + encode_image_to_jpeg_bytes(loaded_image) + ).decode("ascii") + generated_prompt = PROMPT_BUILDERS[task_type]( + base64_image=base64_image, + prompt=prompt, + output_structure=output_structure, + classes=classes, + image_detail=image_detail, + ) + openai_prompts.append(generated_prompt) + return execute_openai_requests( + roboflow_api_key=roboflow_api_key, + openai_api_key=openai_api_key, + openai_prompts=openai_prompts, + model_version=model_version, + reasoning_effort=reasoning_effort, + max_tokens=max_tokens, + temperature=temperature, + max_concurrent_requests=max_concurrent_requests, + ) + + +def execute_openai_requests( + roboflow_api_key: Optional[str], + openai_api_key: str, + openai_prompts: List[dict], + model_version: str, + reasoning_effort: Optional[str], + max_tokens: Optional[int], + temperature: Optional[float], + max_concurrent_requests: Optional[int], +) -> List[str]: + tasks = [ + partial( + execute_openai_request, + roboflow_api_key=roboflow_api_key, + openai_api_key=openai_api_key, + instructions=prompt.get("instructions"), + input_content=prompt["input"], + model_version=model_version, + reasoning_effort=reasoning_effort, + max_tokens=max_tokens, + temperature=temperature, + ) + for prompt in openai_prompts + ] + max_workers = ( + max_concurrent_requests + or WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS + ) + return run_in_parallel( + tasks=tasks, + max_workers=max_workers, + ) + + +def _execute_proxied_openai_request( + roboflow_api_key: str, + openai_api_key: str, + instructions: Optional[str], + input_content: List[dict], + model_version: str, + reasoning_effort: Optional[str], + max_tokens: Optional[int], + temperature: Optional[float], +) -> str: + """Executes OpenAI request via Roboflow proxy.""" + payload = { + "model": model_version, + "input": input_content, + "openai_api_key": openai_api_key, + } + + if instructions is not None: + payload["instructions"] = instructions + + if max_tokens is not None: + payload["max_output_tokens"] = max_tokens + + if temperature is not None: + payload["temperature"] = temperature + + if ( + reasoning_effort is not None + and model_version in MODELS_SUPPORTING_REASONING_EFFORT + ): + effort_values = MODEL_REASONING_EFFORT_VALUES.get(model_version, []) + if reasoning_effort not in effort_values: + raise ValueError( + f'Model {model_version} does not support reasoning effort "{reasoning_effort}"' + ) + payload["reasoning"] = {"effort": reasoning_effort} + + endpoint = "apiproxy/openai/v2" + + try: + # Use the Roboflow API post function (this ensures proper auth headers used based on invocation context) + response_data = post_to_roboflow_api( + endpoint=endpoint, + api_key=roboflow_api_key, + payload=payload, + ) + return _extract_output_text(response_data) + except requests.exceptions.RequestException as e: + raise RuntimeError(f"Failed to connect to Roboflow proxy: {e}") from e + except (KeyError, IndexError) as e: + raise RuntimeError( + f"Invalid response structure from Roboflow proxy: {e}" + ) from e + + +def _extract_output_text(response_data: dict) -> str: + """Extract output text from OpenAI Responses API response.""" + status = response_data.get("status") + + if status == "failed": + error = response_data.get("error", {}) + error_message = ( + f"{error.get('code', 'Unknown')}: {error.get('message', 'Unknown error')}" + ) + raise ValueError(f"OpenAI API request failed: {error_message}") + + if status == "cancelled": + raise ValueError("OpenAI API request was cancelled.") + + if status == "incomplete": + incomplete_details = response_data.get("incomplete_details", {}) + reason = incomplete_details.get("reason", "Unknown reason") + if reason == "max_output_tokens": + raise ValueError( + "OpenAI API stopped generation because the max_tokens limit was reached. " + "Please increase the max_tokens parameter to allow for a complete response." + ) + raise ValueError( + f"OpenAI API returned an incomplete response. Reason: {reason}" + ) + + if status not in ["completed", "in_progress", "queued", None]: + raise ValueError(f"OpenAI API returned unexpected status: {status}") + + # Extract text from output items + output_items = response_data.get("output", []) + texts = [] + for item in output_items: + if item.get("type") == "message": + for content in item.get("content", []): + if content.get("type") == "output_text": + texts.append(content.get("text", "")) + + output_text = "".join(texts) + if not output_text: + raise ValueError("OpenAI API returned no text content in response.") + + return output_text + + +def _execute_direct_openai_request( + openai_api_key: str, + instructions: Optional[str], + input_content: List[dict], + model_version: str, + reasoning_effort: Optional[str], + max_tokens: Optional[int], + temperature: Optional[float], +) -> str: + """Executes OpenAI request directly.""" + client = _get_openai_client(openai_api_key) + + request_params = { + "model": model_version, + "input": input_content, + } + + if instructions is not None: + request_params["instructions"] = instructions + + if max_tokens is not None: + request_params["max_output_tokens"] = max_tokens + + if temperature is not None: + request_params["temperature"] = temperature + + if ( + reasoning_effort is not None + and model_version in MODELS_SUPPORTING_REASONING_EFFORT + ): + effort_values = MODEL_REASONING_EFFORT_VALUES.get(model_version, []) + if reasoning_effort not in effort_values: + raise ValueError( + f'Model {model_version} does not support reasoning effort "{reasoning_effort}"' + ) + request_params["reasoning"] = {"effort": reasoning_effort} + + response = client.responses.create(**request_params) + + status = response.status + if status == "failed": + error_message = "Unknown error" + if response.error: + error_message = f"{response.error.code}: {response.error.message}" + raise ValueError(f"OpenAI API request failed: {error_message}") + + if status == "cancelled": + raise ValueError("OpenAI API request was cancelled.") + + if status == "incomplete": + reason = "Unknown reason" + if response.incomplete_details: + reason = response.incomplete_details.reason + if reason == "max_output_tokens": + raise ValueError( + "OpenAI API stopped generation because the max_tokens limit was reached. " + "Please increase the max_tokens parameter to allow for a complete response." + ) + raise ValueError( + f"OpenAI API returned an incomplete response. Reason: {reason}" + ) + + if status not in ["completed", "in_progress", "queued"]: + raise ValueError(f"OpenAI API returned unexpected status: {status}") + + output_text = response.output_text + if not output_text: + raise ValueError("OpenAI API returned no text content in response.") + + return output_text + + +def execute_openai_request( + roboflow_api_key: Optional[str], + openai_api_key: str, + instructions: Optional[str], + input_content: List[dict], + model_version: str, + reasoning_effort: Optional[str], + max_tokens: Optional[int], + temperature: Optional[float], +) -> str: + if openai_api_key.startswith(("rf_key:account", "rf_key:user:")): + return _execute_proxied_openai_request( + roboflow_api_key=roboflow_api_key, + openai_api_key=openai_api_key, + instructions=instructions, + input_content=input_content, + model_version=model_version, + reasoning_effort=reasoning_effort, + max_tokens=max_tokens, + temperature=temperature, + ) + else: + return _execute_direct_openai_request( + openai_api_key=openai_api_key, + instructions=instructions, + input_content=input_content, + model_version=model_version, + reasoning_effort=reasoning_effort, + max_tokens=max_tokens, + temperature=temperature, + ) + + +def prepare_unconstrained_prompt( + base64_image: str, + prompt: str, + image_detail: str, + **kwargs, +) -> dict: + return { + "input": [ + { + "role": "user", + "content": [ + {"type": "input_text", "text": prompt}, + { + "type": "input_image", + "image_url": f"data:image/jpeg;base64,{base64_image}", + "detail": image_detail, + }, + ], + } + ], + } + + +def prepare_classification_prompt( + base64_image: str, + classes: List[str], + image_detail: str, + **kwargs, +) -> dict: + serialised_classes = ", ".join(classes) + return { + "instructions": ( + "You act as single-class classification model. You must provide reasonable predictions. " + "You are only allowed to produce JSON document in Markdown ```json [...]``` markers. " + 'Expected structure of json: {"class_name": "class-name", "confidence": 0.4}. ' + "`class-name` must be one of the class names defined by user. You are only allowed to return " + "single JSON document, even if there are potentially multiple classes. You are not allowed to return list." + ), + "input": [ + { + "role": "user", + "content": [ + { + "type": "input_text", + "text": f"List of all classes to be recognised by model: {serialised_classes}", + }, + { + "type": "input_image", + "image_url": f"data:image/jpeg;base64,{base64_image}", + "detail": image_detail, + }, + ], + } + ], + } + + +def prepare_multi_label_classification_prompt( + base64_image: str, + classes: List[str], + image_detail: str, + **kwargs, +) -> dict: + serialised_classes = ", ".join(classes) + return { + "instructions": ( + "You act as multi-label classification model. You must provide reasonable predictions. " + "You are only allowed to produce JSON document in Markdown ```json``` markers. " + 'Expected structure of json: {"predicted_classes": [{"class": "class-name-1", "confidence": 0.9}, ' + '{"class": "class-name-2", "confidence": 0.7}]}. ' + "`class-name-X` must be one of the class names defined by user and `confidence` is a float value in range " + "0.0-1.0 that represent how sure you are that the class is present in the image. Only return class names " + "that are visible." + ), + "input": [ + { + "role": "user", + "content": [ + { + "type": "input_text", + "text": f"List of all classes to be recognised by model: {serialised_classes}", + }, + { + "type": "input_image", + "image_url": f"data:image/jpeg;base64,{base64_image}", + "detail": image_detail, + }, + ], + } + ], + } + + +def prepare_vqa_prompt( + base64_image: str, + prompt: str, + image_detail: str, + **kwargs, +) -> dict: + return { + "instructions": ( + "You act as Visual Question Answering model. Your task is to provide answer to question " + "submitted by user. If this is open-question - answer with few sentences, for ABCD question, " + "return only the indicator of the answer." + ), + "input": [ + { + "role": "user", + "content": [ + {"type": "input_text", "text": f"Question: {prompt}"}, + { + "type": "input_image", + "image_url": f"data:image/jpeg;base64,{base64_image}", + "detail": image_detail, + }, + ], + } + ], + } + + +def prepare_ocr_prompt( + base64_image: str, + image_detail: str, + **kwargs, +) -> dict: + return { + "instructions": ( + "You act as OCR model. Your task is to read text from the image and return it in " + "paragraphs representing the structure of texts in the image. You should only return " + "recognised text, nothing else." + ), + "input": [ + { + "role": "user", + "content": [ + { + "type": "input_image", + "image_url": f"data:image/jpeg;base64,{base64_image}", + "detail": image_detail, + }, + ], + } + ], + } + + +def prepare_caption_prompt( + base64_image: str, + image_detail: str, + short_description: bool, + **kwargs, +) -> dict: + caption_detail_level = "Caption should be short." + if not short_description: + caption_detail_level = "Caption should be extensive." + return { + "instructions": ( + f"You act as image caption model. Your task is to provide description of the image. " + f"{caption_detail_level}" + ), + "input": [ + { + "role": "user", + "content": [ + { + "type": "input_image", + "image_url": f"data:image/jpeg;base64,{base64_image}", + "detail": image_detail, + }, + ], + } + ], + } + + +def prepare_structured_answering_prompt( + base64_image: str, + output_structure: Dict[str, str], + image_detail: str, + **kwargs, +) -> dict: + output_structure_serialised = json.dumps(output_structure, indent=4) + return { + "instructions": ( + "You are supposed to produce responses in JSON wrapped in Markdown markers: " + "```json\nyour-response\n```. User is to provide you dictionary with keys and values. " + "Each key must be present in your response. Values in user dictionary represent " + "descriptions for JSON fields to be generated. Provide only JSON Markdown in response." + ), + "input": [ + { + "role": "user", + "content": [ + { + "type": "input_text", + "text": f"Specification of requirements regarding output fields: \n" + f"{output_structure_serialised}", + }, + { + "type": "input_image", + "image_url": f"data:image/jpeg;base64,{base64_image}", + "detail": image_detail, + }, + ], + } + ], + } + + +def prepare_object_detection_prompt( + base64_image: str, + classes: List[str], + image_detail: str, + **kwargs, +) -> dict: + serialised_classes = ", ".join(classes) + return { + "instructions": ( + "You act as object-detection model. You must provide reasonable predictions. " + "You are only allowed to produce JSON document. " + 'Expected structure of json: {"detections": [{"x_min": 0.1, "y_min": 0.2, "x_max": 0.3, "y_max": 0.4, "class_name": "my-class-X", "confidence": 0.7}]}. ' + "`my-class-X` must be one of the class names defined by user. All coordinates must be in range 0.0-1.0, representing percentage of image dimensions. " + "`confidence` is a value in range 0.0-1.0 representing your confidence in prediction. You should detect all instances of classes provided by user." + ), + "input": [ + { + "role": "user", + "content": [ + { + "type": "input_text", + "text": f"List of all classes to be recognised by model: {serialised_classes}", + }, + { + "type": "input_image", + "image_url": f"data:image/jpeg;base64,{base64_image}", + "detail": image_detail, + }, + ], + } + ], + } + + +def _get_openai_client(api_key: str): + client = _openai_client_cache.get(api_key) + if client is None: + client = OpenAI(api_key=api_key) + _openai_client_cache[api_key] = client + return client + + +PROMPT_BUILDERS = { + "unconstrained": prepare_unconstrained_prompt, + "ocr": prepare_ocr_prompt, + "visual-question-answering": prepare_vqa_prompt, + "caption": partial(prepare_caption_prompt, short_description=True), + "detailed-caption": partial(prepare_caption_prompt, short_description=False), + "classification": prepare_classification_prompt, + "multi-label-classification": prepare_multi_label_classification_prompt, + "structured-answering": prepare_structured_answering_prompt, + "object-detection": prepare_object_detection_prompt, +} + +_openai_client_cache = {} diff --git a/tests/workflows/unit_tests/core_steps/models/foundation/test_openai.py b/tests/workflows/unit_tests/core_steps/models/foundation/test_openai_v1.py similarity index 100% rename from tests/workflows/unit_tests/core_steps/models/foundation/test_openai.py rename to tests/workflows/unit_tests/core_steps/models/foundation/test_openai_v1.py diff --git a/tests/workflows/unit_tests/core_steps/models/foundation/test_openai_v4.py b/tests/workflows/unit_tests/core_steps/models/foundation/test_openai_v4.py new file mode 100644 index 0000000000..40d0ee12d1 --- /dev/null +++ b/tests/workflows/unit_tests/core_steps/models/foundation/test_openai_v4.py @@ -0,0 +1,739 @@ +from typing import Any +from unittest.mock import MagicMock, Mock, patch + +import pytest +from pydantic import ValidationError + +from inference.core.workflows.core_steps.models.foundation.openai.v4 import ( + MODEL_REASONING_EFFORT_VALUES, + MODELS_NOT_SUPPORTING_REASONING_EFFORT, + MODELS_SUPPORTING_REASONING_EFFORT, + BlockManifest, + _execute_direct_openai_request, + _execute_proxied_openai_request, + _extract_output_text, + execute_openai_request, + prepare_classification_prompt, + prepare_multi_label_classification_prompt, + prepare_object_detection_prompt, + prepare_ocr_prompt, + prepare_structured_answering_prompt, + prepare_unconstrained_prompt, + prepare_vqa_prompt, +) + + +def test_openai_step_validation_when_input_is_valid() -> None: + # given + specification = { + "type": "roboflow_core/open_ai@v4", + "name": "step_1", + "images": "$inputs.image", + "task_type": "unconstrained", + "prompt": "$inputs.prompt", + "api_key": "$inputs.openai_api_key", + } + + # when + result = BlockManifest.model_validate(specification) + + # then + assert result.type == "roboflow_core/open_ai@v4" + assert result.name == "step_1" + assert result.images == "$inputs.image" + assert result.task_type == "unconstrained" + assert result.prompt == "$inputs.prompt" + assert result.api_key == "$inputs.openai_api_key" + + +def test_openai_step_validation_with_default_api_key() -> None: + # given + specification = { + "type": "roboflow_core/open_ai@v4", + "name": "step_1", + "images": "$inputs.image", + "task_type": "caption", + } + + # when + result = BlockManifest.model_validate(specification) + + # then + assert result.api_key == "rf_key:account" + + +@pytest.mark.parametrize("value", [None, 1, "a", True]) +def test_openai_step_validation_when_image_is_invalid(value: Any) -> None: + # given + specification = { + "type": "roboflow_core/open_ai@v4", + "name": "step_1", + "images": value, + "task_type": "unconstrained", + "prompt": "$inputs.prompt", + "api_key": "$inputs.openai_api_key", + } + + # when + with pytest.raises(ValidationError): + _ = BlockManifest.model_validate(specification) + + +def test_openai_step_validation_when_prompt_is_given_directly() -> None: + # given + specification = { + "type": "roboflow_core/open_ai@v4", + "name": "step_1", + "images": "$inputs.image", + "task_type": "unconstrained", + "prompt": "This is my prompt", + "api_key": "$inputs.openai_api_key", + } + + # when + result = BlockManifest.model_validate(specification) + + # then + assert result.prompt == "This is my prompt" + + +@pytest.mark.parametrize( + "model_version", + [ + "gpt-5.1", + "gpt-5", + "gpt-5-mini", + "gpt-5-nano", + "gpt-4.1", + "gpt-4o", + "$inputs.model", + ], +) +def test_openai_step_validation_when_model_version_valid(model_version: str) -> None: + # given + specification = { + "type": "roboflow_core/open_ai@v4", + "name": "step_1", + "images": "$inputs.image", + "task_type": "caption", + "api_key": "$inputs.openai_api_key", + "model_version": model_version, + } + + # when + result = BlockManifest.model_validate(specification) + + # then + assert result.model_version == model_version + + +@pytest.mark.parametrize("value", ["invalid-model", 123]) +def test_openai_step_validation_when_model_version_invalid(value: Any) -> None: + # given + specification = { + "type": "roboflow_core/open_ai@v4", + "name": "step_1", + "images": "$inputs.image", + "task_type": "caption", + "api_key": "$inputs.openai_api_key", + "model_version": value, + } + + # when + with pytest.raises(ValidationError): + _ = BlockManifest.model_validate(specification) + + +@pytest.mark.parametrize( + "reasoning_effort", ["none", "minimal", "low", "medium", "high", "$inputs.effort"] +) +def test_openai_step_validation_with_reasoning_effort(reasoning_effort: str) -> None: + # given + specification = { + "type": "roboflow_core/open_ai@v4", + "name": "step_1", + "images": "$inputs.image", + "task_type": "caption", + "api_key": "$inputs.openai_api_key", + "model_version": "gpt-5.1", + "reasoning_effort": reasoning_effort, + } + + # when + result = BlockManifest.model_validate(specification) + + # then + assert result.reasoning_effort == reasoning_effort + + +@pytest.mark.parametrize("value", ["invalid", 123, "very_high"]) +def test_openai_step_validation_when_reasoning_effort_invalid(value: Any) -> None: + # given + specification = { + "type": "roboflow_core/open_ai@v4", + "name": "step_1", + "images": "$inputs.image", + "task_type": "caption", + "api_key": "$inputs.openai_api_key", + "reasoning_effort": value, + } + + # when + with pytest.raises(ValidationError): + _ = BlockManifest.model_validate(specification) + + +def test_openai_step_validation_with_temperature() -> None: + # given + specification = { + "type": "roboflow_core/open_ai@v4", + "name": "step_1", + "images": "$inputs.image", + "task_type": "caption", + "api_key": "$inputs.openai_api_key", + "temperature": 0.7, + } + + # when + result = BlockManifest.model_validate(specification) + + # then + assert result.temperature == 0.7 + + +@pytest.mark.parametrize("value", [-0.1, 2.1, "invalid"]) +def test_openai_step_validation_when_temperature_invalid(value: Any) -> None: + # given + specification = { + "type": "roboflow_core/open_ai@v4", + "name": "step_1", + "images": "$inputs.image", + "task_type": "caption", + "api_key": "$inputs.openai_api_key", + "temperature": value, + } + + # when + with pytest.raises(ValidationError): + _ = BlockManifest.model_validate(specification) + + +def test_openai_step_validation_with_max_tokens() -> None: + # given + specification = { + "type": "roboflow_core/open_ai@v4", + "name": "step_1", + "images": "$inputs.image", + "task_type": "caption", + "api_key": "$inputs.openai_api_key", + "max_tokens": 100, + } + + # when + result = BlockManifest.model_validate(specification) + + # then + assert result.max_tokens == 100 + + +def test_openai_step_validation_with_max_tokens_minimum_value() -> None: + # given + specification = { + "type": "roboflow_core/open_ai@v4", + "name": "step_1", + "images": "$inputs.image", + "task_type": "caption", + "api_key": "$inputs.openai_api_key", + "max_tokens": 16, + } + + # when + result = BlockManifest.model_validate(specification) + + # then + assert result.max_tokens == 16 + + +@pytest.mark.parametrize("value", [15, 10, 0, -1]) +def test_openai_step_validation_when_max_tokens_below_minimum(value: int) -> None: + # given + specification = { + "type": "roboflow_core/open_ai@v4", + "name": "step_1", + "images": "$inputs.image", + "task_type": "caption", + "api_key": "$inputs.openai_api_key", + "max_tokens": value, + } + + # when + with pytest.raises(ValidationError): + _ = BlockManifest.model_validate(specification) + + +def test_openai_step_validation_without_required_prompt() -> None: + # given - unconstrained requires prompt + specification = { + "type": "roboflow_core/open_ai@v4", + "name": "step_1", + "images": "$inputs.image", + "task_type": "unconstrained", + "api_key": "$inputs.openai_api_key", + } + + # when + with pytest.raises(ValidationError): + _ = BlockManifest.model_validate(specification) + + +def test_openai_step_validation_without_required_classes() -> None: + # given - classification requires classes + specification = { + "type": "roboflow_core/open_ai@v4", + "name": "step_1", + "images": "$inputs.image", + "task_type": "classification", + "api_key": "$inputs.openai_api_key", + } + + # when + with pytest.raises(ValidationError): + _ = BlockManifest.model_validate(specification) + + +def test_openai_step_validation_without_required_output_structure() -> None: + # given - structured-answering requires output_structure + specification = { + "type": "roboflow_core/open_ai@v4", + "name": "step_1", + "images": "$inputs.image", + "task_type": "structured-answering", + "api_key": "$inputs.openai_api_key", + } + + # when + with pytest.raises(ValidationError): + _ = BlockManifest.model_validate(specification) + + +def test_openai_step_validation_with_classification_and_classes() -> None: + # given + specification = { + "type": "roboflow_core/open_ai@v4", + "name": "step_1", + "images": "$inputs.image", + "task_type": "classification", + "classes": ["cat", "dog"], + "api_key": "$inputs.openai_api_key", + } + + # when + result = BlockManifest.model_validate(specification) + + # then + assert result.task_type == "classification" + assert result.classes == ["cat", "dog"] + + +def test_openai_step_validation_with_object_detection_and_classes() -> None: + # given + specification = { + "type": "roboflow_core/open_ai@v4", + "name": "step_1", + "images": "$inputs.image", + "task_type": "object-detection", + "classes": ["person", "car"], + "api_key": "$inputs.openai_api_key", + } + + # when + result = BlockManifest.model_validate(specification) + + # then + assert result.task_type == "object-detection" + assert result.classes == ["person", "car"] + + +def test_extract_output_text_success() -> None: + # given + response_data = { + "status": "completed", + "output": [ + { + "type": "message", + "content": [{"type": "output_text", "text": "This is the response"}], + } + ], + } + + # when + result = _extract_output_text(response_data) + + # then + assert result == "This is the response" + + +def test_extract_output_text_with_multiple_text_blocks() -> None: + # given + response_data = { + "status": "completed", + "output": [ + { + "type": "message", + "content": [ + {"type": "output_text", "text": "Part 1"}, + {"type": "output_text", "text": " Part 2"}, + ], + } + ], + } + + # when + result = _extract_output_text(response_data) + + # then + assert result == "Part 1 Part 2" + + +def test_extract_output_text_failed_status() -> None: + # given + response_data = { + "status": "failed", + "error": {"code": "invalid_request", "message": "Bad request"}, + } + + # when/then + with pytest.raises(ValueError) as exc_info: + _extract_output_text(response_data) + + assert "OpenAI API request failed" in str(exc_info.value) + assert "invalid_request" in str(exc_info.value) + + +def test_extract_output_text_cancelled_status() -> None: + # given + response_data = {"status": "cancelled"} + + # when/then + with pytest.raises(ValueError) as exc_info: + _extract_output_text(response_data) + + assert "cancelled" in str(exc_info.value) + + +def test_extract_output_text_incomplete_max_tokens() -> None: + # given + response_data = { + "status": "incomplete", + "incomplete_details": {"reason": "max_output_tokens"}, + } + + # when/then + with pytest.raises(ValueError) as exc_info: + _extract_output_text(response_data) + + assert "max_tokens limit was reached" in str(exc_info.value) + assert "increase the max_tokens parameter" in str(exc_info.value) + + +def test_extract_output_text_incomplete_other_reason() -> None: + # given + response_data = { + "status": "incomplete", + "incomplete_details": {"reason": "content_filter"}, + } + + # when/then + with pytest.raises(ValueError) as exc_info: + _extract_output_text(response_data) + + assert "incomplete response" in str(exc_info.value) + assert "content_filter" in str(exc_info.value) + + +def test_extract_output_text_unexpected_status() -> None: + # given + response_data = {"status": "unknown_status"} + + # when/then + with pytest.raises(ValueError) as exc_info: + _extract_output_text(response_data) + + assert "unexpected status" in str(exc_info.value) + + +def test_extract_output_text_no_text_content() -> None: + # given + response_data = { + "status": "completed", + "output": [], + } + + # when/then + with pytest.raises(ValueError) as exc_info: + _extract_output_text(response_data) + + assert "no text content" in str(exc_info.value) + + +def test_execute_openai_request_routes_to_proxy_for_rf_key_account() -> None: + # given + with patch( + "inference.core.workflows.core_steps.models.foundation.openai.v4._execute_proxied_openai_request" + ) as mock_proxy: + mock_proxy.return_value = "proxied response" + + # when + result = execute_openai_request( + roboflow_api_key="rf_api_key", + openai_api_key="rf_key:account", + instructions="test", + input_content=[], + model_version="gpt-5.1", + reasoning_effort=None, + max_tokens=None, + temperature=None, + ) + + # then + assert result == "proxied response" + mock_proxy.assert_called_once() + + +def test_execute_openai_request_routes_to_proxy_for_rf_key_user() -> None: + # given + with patch( + "inference.core.workflows.core_steps.models.foundation.openai.v4._execute_proxied_openai_request" + ) as mock_proxy: + mock_proxy.return_value = "proxied response" + + # when + result = execute_openai_request( + roboflow_api_key="rf_api_key", + openai_api_key="rf_key:user:12345", + instructions="test", + input_content=[], + model_version="gpt-5.1", + reasoning_effort=None, + max_tokens=None, + temperature=None, + ) + + # then + assert result == "proxied response" + mock_proxy.assert_called_once() + + +def test_execute_openai_request_routes_to_direct_for_regular_api_key() -> None: + # given + with patch( + "inference.core.workflows.core_steps.models.foundation.openai.v4._execute_direct_openai_request" + ) as mock_direct: + mock_direct.return_value = "direct response" + + # when + result = execute_openai_request( + roboflow_api_key="rf_api_key", + openai_api_key="sk-test-key", + instructions="test", + input_content=[], + model_version="gpt-5.1", + reasoning_effort=None, + max_tokens=None, + temperature=None, + ) + + # then + assert result == "direct response" + mock_direct.assert_called_once() + + +@patch( + "inference.core.workflows.core_steps.models.foundation.openai.v4._get_openai_client" +) +def test_direct_request_with_valid_reasoning_effort_for_gpt_5_1( + mock_get_client: Mock, +) -> None: + # given + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.status = "completed" + mock_response.output_text = "response" + mock_client.responses.create.return_value = mock_response + mock_get_client.return_value = mock_client + + # when + result = _execute_direct_openai_request( + openai_api_key="sk-test", + instructions="test", + input_content=[{"role": "user", "content": []}], + model_version="gpt-5.1", + reasoning_effort="high", + max_tokens=None, + temperature=None, + ) + + # then + assert result == "response" + call_kwargs = mock_client.responses.create.call_args[1] + assert call_kwargs["reasoning"] == {"effort": "high"} + + +@patch( + "inference.core.workflows.core_steps.models.foundation.openai.v4._get_openai_client" +) +def test_direct_request_with_invalid_reasoning_effort_for_gpt_5_1_raises_error( + mock_get_client: Mock, +) -> None: + # given + mock_client = MagicMock() + mock_get_client.return_value = mock_client + + # when/then + with pytest.raises(ValueError) as exc_info: + _execute_direct_openai_request( + openai_api_key="sk-test", + instructions="test", + input_content=[{"role": "user", "content": []}], + model_version="gpt-5.1", + reasoning_effort="minimal", # not supported by gpt-5.1 + max_tokens=None, + temperature=None, + ) + + assert 'does not support reasoning effort "minimal"' in str(exc_info.value) + + +@patch( + "inference.core.workflows.core_steps.models.foundation.openai.v4.post_to_roboflow_api" +) +def test_proxied_request_with_invalid_reasoning_effort_for_gpt_5_raises_error( + mock_post: Mock, +) -> None: + # when/then + with pytest.raises(ValueError) as exc_info: + _execute_proxied_openai_request( + roboflow_api_key="rf_api_key", + openai_api_key="rf_key:account", + instructions="test", + input_content=[{"role": "user", "content": []}], + model_version="gpt-5", + reasoning_effort="none", # not supported by gpt-5 + max_tokens=None, + temperature=None, + ) + + assert 'does not support reasoning effort "none"' in str(exc_info.value) + + +def test_prepare_unconstrained_prompt() -> None: + # when + result = prepare_unconstrained_prompt( + base64_image="test_image_data", + prompt="Describe this image", + image_detail="high", + ) + + # then + assert "input" in result + assert len(result["input"]) == 1 + user_message = result["input"][0] + assert user_message["role"] == "user" + assert len(user_message["content"]) == 2 + assert user_message["content"][0]["type"] == "input_text" + assert user_message["content"][0]["text"] == "Describe this image" + assert user_message["content"][1]["type"] == "input_image" + assert user_message["content"][1]["detail"] == "high" + + +def test_prepare_classification_prompt() -> None: + # when + result = prepare_classification_prompt( + base64_image="test_image_data", + classes=["cat", "dog", "bird"], + image_detail="auto", + ) + + # then + assert "instructions" in result + assert "classification model" in result["instructions"] + assert "JSON document" in result["instructions"] + user_content = result["input"][0]["content"] + assert "cat, dog, bird" in user_content[0]["text"] + + +def test_prepare_multi_label_classification_prompt() -> None: + # when + result = prepare_multi_label_classification_prompt( + base64_image="test_image_data", + classes=["sunny", "cloudy"], + image_detail="low", + ) + + # then + assert "instructions" in result + assert "multi-label classification" in result["instructions"] + assert "predicted_classes" in result["instructions"] + + +def test_prepare_vqa_prompt() -> None: + # when + result = prepare_vqa_prompt( + base64_image="test_image_data", + prompt="What color is the car?", + image_detail="auto", + ) + + # then + assert "instructions" in result + assert "Visual Question Answering" in result["instructions"] + user_content = result["input"][0]["content"] + assert "Question: What color is the car?" in user_content[0]["text"] + + +def test_prepare_ocr_prompt() -> None: + # when + result = prepare_ocr_prompt( + base64_image="test_image_data", + image_detail="high", + ) + + # then + assert "instructions" in result + assert "OCR model" in result["instructions"] + user_content = result["input"][0]["content"] + assert len(user_content) == 1 + assert user_content[0]["type"] == "input_image" + + +def test_prepare_structured_answering_prompt() -> None: + # when + result = prepare_structured_answering_prompt( + base64_image="test_image_data", + output_structure={"name": "person name", "age": "estimated age"}, + image_detail="auto", + ) + + # then + assert "instructions" in result + assert "JSON" in result["instructions"] + user_content = result["input"][0]["content"] + assert "name" in user_content[0]["text"] + assert "age" in user_content[0]["text"] + + +def test_prepare_object_detection_prompt() -> None: + # when + result = prepare_object_detection_prompt( + base64_image="test_image_data", + classes=["person", "car"], + image_detail="high", + ) + + # then + assert "instructions" in result + assert "object-detection model" in result["instructions"] + assert "detections" in result["instructions"] + assert "x_min" in result["instructions"] + user_content = result["input"][0]["content"] + assert "person, car" in user_content[0]["text"]