Skip to content

Commit ea56c11

Browse files
authored
Merge branch 'main' into feat/openai-block-v4
2 parents 23e3094 + b3aeb4d commit ea56c11

File tree

3 files changed

+233
-93
lines changed

3 files changed

+233
-93
lines changed

inference/core/workflows/core_steps/models/foundation/google_vision_ocr/v1.py

Lines changed: 165 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from pydantic import ConfigDict, Field
88
from supervision.config import CLASS_NAME_DATA_FIELD
99

10+
from inference.core.roboflow_api import post_to_roboflow_api
1011
from inference.core.workflows.core_steps.common.utils import (
1112
attach_parents_coordinates_to_sv_detections,
1213
)
@@ -22,6 +23,7 @@
2223
from inference.core.workflows.execution_engine.entities.types import (
2324
IMAGE_KIND,
2425
OBJECT_DETECTION_PREDICTION_KIND,
26+
ROBOFLOW_MANAGED_KEY,
2527
SECRET_KIND,
2628
STRING_KIND,
2729
Selector,
@@ -40,7 +42,8 @@
4042
- `text_detection`: optimized for areas of text within a larger image.
4143
- `ocr_text_detection`: optimized for dense text documents.
4244
43-
You need to provide your Google Vision API key to use this block.
45+
Provide your Google Vision API key or set the value to ``rf_key:account`` (or
46+
``rf_key:user:<id>``) to proxy requests through Roboflow's API.
4447
"""
4548

4649

@@ -80,17 +83,20 @@ class BlockManifest(WorkflowBlockManifest):
8083
},
8184
},
8285
)
86+
api_key: Union[
87+
Selector(kind=[STRING_KIND, SECRET_KIND, ROBOFLOW_MANAGED_KEY]), str
88+
] = Field(
89+
default="rf_key:account",
90+
description="Your Google Vision API key",
91+
examples=["xxx-xxx", "$inputs.google_api_key"],
92+
private=True,
93+
)
8394
language_hints: Optional[List[str]] = Field(
8495
default=None,
8596
description="Optional list of language codes to pass to the OCR API. If not provided, the API will attempt to detect the language automatically."
8697
"If provided, language codes must be supported by the OCR API, visit https://cloud.google.com/vision/docs/languages for list of supported language codes.",
8798
examples=[["en", "fr"], ["de"]],
8899
)
89-
api_key: Union[Selector(kind=[STRING_KIND, SECRET_KIND]), str] = Field(
90-
description="Your Google Vision API key",
91-
examples=["xxx-xxx", "$inputs.google_api_key"],
92-
private=True,
93-
)
94100

95101
@classmethod
96102
def describe_outputs(cls) -> List[OutputDefinition]:
@@ -109,6 +115,16 @@ def get_execution_engine_compatibility(cls) -> Optional[str]:
109115

110116
class GoogleVisionOCRBlockV1(WorkflowBlock):
111117

118+
def __init__(
119+
self,
120+
api_key: Optional[str],
121+
):
122+
self._roboflow_api_key = api_key
123+
124+
@classmethod
125+
def get_init_parameters(cls) -> List[str]:
126+
return ["api_key"]
127+
112128
@classmethod
113129
def get_manifest(cls) -> Type[WorkflowBlockManifest]:
114130
return BlockManifest
@@ -118,105 +134,163 @@ def run(
118134
image: WorkflowImageData,
119135
ocr_type: Literal["text_detection", "ocr_text_detection"],
120136
language_hints: Optional[List[str]],
121-
api_key: str,
137+
api_key: str = "rf_key:account",
122138
) -> BlockResult:
123-
# Decide which type of OCR to use and make the request to Google Vision API
139+
# Decide which type of OCR to use
124140
if ocr_type == "text_detection":
125-
type = "TEXT_DETECTION"
141+
detection_type = "TEXT_DETECTION"
126142
elif ocr_type == "ocr_text_detection":
127-
type = "DOCUMENT_TEXT_DETECTION"
143+
detection_type = "DOCUMENT_TEXT_DETECTION"
128144
else:
129145
raise ValueError(f"Invalid ocr_type: {ocr_type}")
130146

131-
request_json = {
132-
"requests": [
133-
{
134-
"image": {"content": image.base64_image},
135-
"features": [{"type": type}],
136-
}
137-
]
138-
}
139-
140-
if language_hints is not None:
141-
for r in request_json["requests"]:
142-
r["imageContext"] = {"languageHints": language_hints}
143-
response = requests.post(
144-
"https://vision.googleapis.com/v1/images:annotate",
145-
params={"key": api_key},
146-
json=request_json,
147+
request_json = _build_request_json(
148+
image=image,
149+
detection_type=detection_type,
150+
language_hints=language_hints,
147151
)
148152

149-
if response.status_code != 200:
150-
raise RuntimeError(
151-
f"Request to Google Cloud Vision API failed: {str(response.json())}"
153+
# Route to proxy or direct API based on api_key format
154+
if api_key.startswith(("rf_key:account", "rf_key:user:")):
155+
result = _execute_proxied_google_vision_request(
156+
roboflow_api_key=self._roboflow_api_key,
157+
google_vision_api_key=api_key,
158+
request_json=request_json,
159+
)
160+
else:
161+
result = _execute_google_vision_request(
162+
api_key=api_key,
163+
request_json=request_json,
152164
)
153165

154-
result = response.json()["responses"][0]
155-
156-
# Check for image without text
157-
if "textAnnotations" not in result or not result["textAnnotations"]:
158-
return {
159-
"text": "",
160-
"language": "",
161-
"predictions": sv.Detections.empty(),
162-
}
163-
164-
# Extract predictions from the response
165-
text = result["textAnnotations"][0]["description"]
166-
language = result["textAnnotations"][0]["locale"]
167-
168-
xyxy = []
169-
confidence = []
170-
classes = []
171-
detections_id = []
172-
173-
for page in result["fullTextAnnotation"]["pages"]:
174-
for block in page["blocks"]:
175-
# Get bounding box coordinates
176-
box = block["boundingBox"]["vertices"]
177-
x_min = min(v.get("x", 0) for v in box)
178-
y_min = min(v.get("y", 0) for v in box)
179-
x_max = max(v.get("x", 0) for v in box)
180-
y_max = max(v.get("y", 0) for v in box)
181-
xyxy.append([x_min, y_min, x_max, y_max])
182-
183-
# Only DOCUMENT_TEXT_DETECTION provides confidence score, use 1.0 otherwise
184-
confidence.append(block.get("confidence", 1.0))
185-
186-
# Get block text
187-
block_text = []
188-
for paragraph in block["paragraphs"]:
189-
for word in paragraph["words"]:
190-
word_text = "".join(
191-
symbol["text"] for symbol in word["symbols"]
192-
)
193-
block_text.append(word_text)
194-
classes.append(" ".join(block_text))
195-
196-
# Create unique detection id for each block
197-
detections_id.append(uuid4())
198-
199-
predictions = sv.Detections(
200-
xyxy=np.array(xyxy),
201-
confidence=np.array(confidence),
202-
class_id=np.arange(len(classes)),
203-
data={CLASS_NAME_DATA_FIELD: np.array(classes)},
204-
)
166+
return _parse_google_vision_response(result=result, image=image)
167+
168+
169+
def _build_request_json(
170+
image: WorkflowImageData,
171+
detection_type: str,
172+
language_hints: Optional[List[str]],
173+
) -> dict:
174+
ocr_request = {
175+
"image": {"content": image.base64_image},
176+
"features": [{"type": detection_type}],
177+
}
178+
179+
if language_hints is not None:
180+
ocr_request["imageContext"] = {"languageHints": language_hints}
181+
182+
return {"requests": [ocr_request]}
205183

206-
predictions[DETECTION_ID_KEY] = np.array(detections_id)
207-
predictions[PREDICTION_TYPE_KEY] = np.array(["ocr"] * len(predictions))
208-
image_height, image_width = image.numpy_image.shape[:2]
209-
predictions[IMAGE_DIMENSIONS_KEY] = np.array(
210-
[[image_height, image_width]] * len(predictions)
184+
185+
def _execute_proxied_google_vision_request(
186+
roboflow_api_key: str,
187+
google_vision_api_key: str,
188+
request_json: dict,
189+
) -> dict:
190+
payload = {
191+
"google_vision_api_key": google_vision_api_key,
192+
"request_json": request_json,
193+
}
194+
195+
try:
196+
response_data = post_to_roboflow_api(
197+
endpoint="apiproxy/google_vision_ocr",
198+
api_key=roboflow_api_key,
199+
payload=payload,
211200
)
201+
return response_data["responses"][0]
202+
except requests.exceptions.RequestException as e:
203+
raise RuntimeError(f"Failed to connect to Roboflow proxy: {e}") from e
204+
except (KeyError, IndexError) as e:
205+
raise RuntimeError(
206+
f"Invalid response structure from Roboflow proxy: {e}"
207+
) from e
212208

213-
predictions = attach_parents_coordinates_to_sv_detections(
214-
detections=predictions,
215-
image=image,
209+
210+
def _execute_google_vision_request(
211+
api_key: str,
212+
request_json: dict,
213+
) -> dict:
214+
response = requests.post(
215+
"https://vision.googleapis.com/v1/images:annotate",
216+
params={"key": api_key},
217+
json=request_json,
218+
)
219+
220+
if response.status_code != 200:
221+
raise RuntimeError(
222+
f"Request to Google Cloud Vision API failed: {str(response.json())}"
216223
)
217224

225+
return response.json()["responses"][0]
226+
227+
228+
def _parse_google_vision_response(
229+
result: dict,
230+
image: WorkflowImageData,
231+
) -> BlockResult:
232+
# Check for image without text
233+
if "textAnnotations" not in result or not result["textAnnotations"]:
218234
return {
219-
"text": text,
220-
"language": language,
221-
"predictions": predictions,
235+
"text": "",
236+
"language": "",
237+
"predictions": sv.Detections.empty(),
222238
}
239+
240+
# Extract predictions from the response
241+
text = result["textAnnotations"][0]["description"]
242+
language = result["textAnnotations"][0]["locale"]
243+
244+
xyxy = []
245+
confidence = []
246+
classes = []
247+
detections_id = []
248+
249+
for page in result["fullTextAnnotation"]["pages"]:
250+
for block in page["blocks"]:
251+
# Get bounding box coordinates
252+
box = block["boundingBox"]["vertices"]
253+
x_min = min(v.get("x", 0) for v in box)
254+
y_min = min(v.get("y", 0) for v in box)
255+
x_max = max(v.get("x", 0) for v in box)
256+
y_max = max(v.get("y", 0) for v in box)
257+
xyxy.append([x_min, y_min, x_max, y_max])
258+
259+
# Only DOCUMENT_TEXT_DETECTION provides confidence score, use 1.0 otherwise
260+
confidence.append(block.get("confidence", 1.0))
261+
262+
# Get block text
263+
block_text = []
264+
for paragraph in block["paragraphs"]:
265+
for word in paragraph["words"]:
266+
word_text = "".join(symbol["text"] for symbol in word["symbols"])
267+
block_text.append(word_text)
268+
classes.append(" ".join(block_text))
269+
270+
# Create unique detection id for each block
271+
detections_id.append(uuid4())
272+
273+
predictions = sv.Detections(
274+
xyxy=np.array(xyxy),
275+
confidence=np.array(confidence),
276+
class_id=np.arange(len(classes)),
277+
data={CLASS_NAME_DATA_FIELD: np.array(classes)},
278+
)
279+
280+
predictions[DETECTION_ID_KEY] = np.array(detections_id)
281+
predictions[PREDICTION_TYPE_KEY] = np.array(["ocr"] * len(predictions))
282+
image_height, image_width = image.numpy_image.shape[:2]
283+
predictions[IMAGE_DIMENSIONS_KEY] = np.array(
284+
[[image_height, image_width]] * len(predictions)
285+
)
286+
287+
predictions = attach_parents_coordinates_to_sv_detections(
288+
detections=predictions,
289+
image=image,
290+
)
291+
292+
return {
293+
"text": text,
294+
"language": language,
295+
"predictions": predictions,
296+
}

tests/inference/hosted_platform_tests/workflows_examples/test_workflow_with_google_ocr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
}
5555

5656

57-
@pytest.mark.skipif(GOOGLE_VISION_API_KEY is None, reason="No OpenAI API key provided")
57+
@pytest.mark.skipif(GOOGLE_VISION_API_KEY is None, reason="No Google API key provided")
5858
@pytest.mark.flaky(retries=4, delay=1)
5959
def test_workflow_with_google_api_ocr(
6060
object_detection_service_url: str,

0 commit comments

Comments
 (0)