Skip to content

Commit 23e3094

Browse files
authored
Merge branch 'main' into feat/openai-block-v4
2 parents e8a51ae + d7abe49 commit 23e3094

File tree

9 files changed

+566
-2
lines changed

9 files changed

+566
-2
lines changed

inference_experimental/inference_exp/models/auto_loaders/core.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -653,7 +653,9 @@ def create_symlinks_to_shared_blobs(
653653
link_name = os.path.join(model_dir, file_handle)
654654
target_path = shared_files_mapping[file_handle]
655655
result[file_handle] = link_name
656-
if os.path.exists(link_name):
656+
if os.path.exists(link_name) and (
657+
not os.path.islink(link_name) or os.path.realpath(link_name) == target_path
658+
):
657659
continue
658660
handle_symlink_creation(
659661
target_path=target_path,

inference_experimental/inference_exp/models/auto_loaders/models_registry.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,10 @@ class RegistryEntry:
362362
("doctr", STRUCTURED_OCR_TASK, BackendType.TORCH): LazyClass(
363363
module_name="inference_exp.models.doctr.doctr_torch", class_name="DocTR"
364364
),
365+
("easy-ocr", STRUCTURED_OCR_TASK, BackendType.TORCH): LazyClass(
366+
module_name="inference_exp.models.easy_ocr.easy_ocr_torch",
367+
class_name="EasyOCRTorch",
368+
),
365369
}
366370

367371

inference_experimental/inference_exp/models/easy_ocr/__init__.py

Whitespace-only changes.
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
from typing import List, Optional, Tuple, Union
2+
3+
import easyocr
4+
import numpy as np
5+
import torch
6+
from inference_exp import Detections, StructuredOCRModel
7+
from inference_exp.configuration import DEFAULT_DEVICE
8+
from inference_exp.entities import ColorFormat, ImageDimensions
9+
from inference_exp.errors import CorruptedModelPackageError, ModelRuntimeError
10+
from inference_exp.models.common.model_packages import get_model_package_contents
11+
from inference_exp.utils.file_system import read_json
12+
from pydantic import BaseModel
13+
14+
Point = Tuple[int, int]
15+
Coordinates = Tuple[Point, Point, Point, Point]
16+
DetectedText = str
17+
Confidence = float
18+
EasyOCRRawPrediction = Tuple[Coordinates, DetectedText, Confidence]
19+
20+
21+
RECOGNIZED_DETECTORS = {"craft", "dbnet18", "dbnet50"}
22+
23+
24+
class EasyOcrConfig(BaseModel):
25+
lang_list: List[str]
26+
detector_model_file_name: str
27+
recognition_model_file_name: str
28+
detect_network: str
29+
recognition_network: str
30+
31+
32+
class EasyOCRTorch(
33+
StructuredOCRModel[List[np.ndarray], ImageDimensions, EasyOCRRawPrediction]
34+
):
35+
36+
@classmethod
37+
def from_pretrained(
38+
cls,
39+
model_name_or_path: str,
40+
device: torch.device = DEFAULT_DEVICE,
41+
**kwargs,
42+
) -> "StructuredOCRModel":
43+
package_contents = get_model_package_contents(
44+
model_package_dir=model_name_or_path, elements=["easy-ocr-config.json"]
45+
)
46+
config = parse_easy_ocr_config(
47+
config_path=package_contents["easy-ocr-config.json"]
48+
)
49+
device_string = device.type
50+
if device.type == "cuda" and device.index:
51+
device_string = f"{device_string}:{device.index}"
52+
try:
53+
model = easyocr.Reader(
54+
config.lang_list,
55+
download_enabled=False,
56+
model_storage_directory=model_name_or_path,
57+
user_network_directory=model_name_or_path,
58+
detect_network=config.detect_network,
59+
recog_network=config.recognition_network,
60+
detector=True,
61+
recognizer=True,
62+
gpu=device_string,
63+
)
64+
except Exception as error:
65+
raise CorruptedModelPackageError(
66+
message=f"EasyOCR model package is broken - could not parse model config file. Error: {error}"
67+
f"If you attempt to run `inference-exp` locally - inspect the contents of local directory to check "
68+
f"model package - config file is corrupted. If you run the model on Roboflow platform - "
69+
f"contact us.",
70+
help_url="https://todo",
71+
) from error
72+
return cls(model=model, device=device)
73+
74+
def __init__(
75+
self,
76+
model: easyocr.Reader,
77+
device: torch.device,
78+
):
79+
self._model = model
80+
self._device = device
81+
82+
@property
83+
def class_names(self) -> List[str]:
84+
return ["text-region"]
85+
86+
def pre_process(
87+
self,
88+
images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
89+
input_color_format: Optional[ColorFormat] = None,
90+
**kwargs,
91+
) -> Tuple[List[np.ndarray], List[ImageDimensions]]:
92+
if isinstance(images, np.ndarray):
93+
input_color_format = input_color_format or "bgr"
94+
if input_color_format != "bgr":
95+
images = images[:, :, ::-1]
96+
h, w = images.shape[:2]
97+
return [images], [ImageDimensions(height=h, width=w)]
98+
if isinstance(images, torch.Tensor):
99+
input_color_format = input_color_format or "rgb"
100+
if len(images.shape) == 3:
101+
images = torch.unsqueeze(images, dim=0)
102+
if input_color_format != "bgr":
103+
images = images[:, [2, 1, 0], :, :]
104+
result = []
105+
dimensions = []
106+
for image in images:
107+
np_image = image.permute(1, 2, 0).cpu().numpy()
108+
result.append(np_image)
109+
dimensions.append(
110+
ImageDimensions(height=np_image.shape[0], width=np_image.shape[1])
111+
)
112+
return result, dimensions
113+
if not isinstance(images, list):
114+
raise ModelRuntimeError(
115+
message="Pre-processing supports only np.array or torch.Tensor or list of above.",
116+
help_url="https://todo",
117+
)
118+
if not len(images):
119+
raise ModelRuntimeError(
120+
message="Detected empty input to the model", help_url="https://todo"
121+
)
122+
if isinstance(images[0], np.ndarray):
123+
input_color_format = input_color_format or "bgr"
124+
if input_color_format != "bgr":
125+
images = [i[:, :, ::-1] for i in images]
126+
dimensions = [
127+
ImageDimensions(height=i.shape[0], width=i.shape[1]) for i in images
128+
]
129+
return images, dimensions
130+
if isinstance(images[0], torch.Tensor):
131+
result = []
132+
dimensions = []
133+
input_color_format = input_color_format or "rgb"
134+
for image in images:
135+
if input_color_format != "bgr":
136+
image = image[[2, 1, 0], :, :]
137+
np_image = image.permute(1, 2, 0).cpu().numpy()
138+
result.append(np_image)
139+
dimensions.append(
140+
ImageDimensions(height=np_image.shape[0], width=np_image.shape[1])
141+
)
142+
return result, dimensions
143+
raise ModelRuntimeError(
144+
message=f"Detected unknown input batch element: {type(images[0])}",
145+
help_url="https://todo",
146+
)
147+
148+
def forward(
149+
self, pre_processed_images: List[np.ndarray], **kwargs
150+
) -> List[EasyOCRRawPrediction]:
151+
all_results = []
152+
for image in pre_processed_images:
153+
image_results_raw = self._model.readtext(image)
154+
image_results_parsed = [
155+
(
156+
[
157+
[x.item() if not isinstance(x, (int, float)) else x for x in c]
158+
for c in res[0]
159+
],
160+
res[1],
161+
res[2].item() if not isinstance(res[2], (int, float)) else res[2],
162+
)
163+
for res in image_results_raw
164+
]
165+
all_results.append(image_results_parsed)
166+
return all_results
167+
168+
def post_process(
169+
self,
170+
model_results: List[EasyOCRRawPrediction],
171+
pre_processing_meta: List[ImageDimensions],
172+
confidence_threshold: float = 0.3,
173+
text_regions_separator: str = " ",
174+
**kwargs,
175+
) -> Tuple[List[str], List[Detections]]:
176+
rendered_texts, all_detections = [], []
177+
for single_image_result, original_dimensions in zip(
178+
model_results, pre_processing_meta
179+
):
180+
whole_image_text = []
181+
xyxy = []
182+
confidence = []
183+
class_id = []
184+
for box, text, text_confidence in single_image_result:
185+
if text_confidence < confidence_threshold:
186+
continue
187+
whole_image_text.append(text)
188+
min_x = min(p[0] for p in box)
189+
min_y = min(p[1] for p in box)
190+
max_x = max(p[0] for p in box)
191+
max_y = max(p[1] for p in box)
192+
box_xyxy = [min_x, min_y, max_x, max_y]
193+
xyxy.append(box_xyxy)
194+
confidence.append(float(text_confidence))
195+
class_id.append(0)
196+
while_image_text_joined = text_regions_separator.join(whole_image_text)
197+
rendered_texts.append(while_image_text_joined)
198+
data = [{"text": text} for text in whole_image_text]
199+
all_detections.append(
200+
Detections(
201+
xyxy=torch.tensor(xyxy, device=self._device),
202+
class_id=torch.tensor(class_id, device=self._device),
203+
confidence=torch.tensor(confidence, device=self._device),
204+
bboxes_metadata=data,
205+
)
206+
)
207+
return rendered_texts, all_detections
208+
209+
210+
def parse_easy_ocr_config(config_path: str) -> EasyOcrConfig:
211+
try:
212+
raw_config = read_json(config_path)
213+
return EasyOcrConfig.model_validate(raw_config)
214+
except Exception as error:
215+
raise CorruptedModelPackageError(
216+
message=f"EasyOCR model package is broken - could not parse model config file. Error: {error}"
217+
f"If you attempt to run `inference-exp` locally - inspect the contents of local directory to check "
218+
f"model package - config file is corrupted. If you run the model on Roboflow platform - "
219+
f"contact us.",
220+
help_url="https://todo",
221+
) from error

inference_experimental/pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ dependencies = [
2828
"filelock>=3.12.0,<4.0.0",
2929
"rich>=14.1.0,<15.0.0",
3030
"segmentation-models-pytorch>=0.5.0,<1.0.0",
31-
"scikit-image>=0.24.0,<0.26.0"
31+
"scikit-image>=0.24.0,<0.26.0",
32+
"easyocr~=1.7.2",
3233
]
3334

3435
[project.optional-dependencies]
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import numpy as np
2+
import pytest
3+
from inference_exp import AutoModel, Detections
4+
5+
6+
@pytest.mark.e2e_model_inference
7+
def test_easyocr_english(
8+
ocr_test_image_numpy: np.ndarray, roboflow_api_key: str
9+
) -> None:
10+
# given
11+
model = AutoModel.from_pretrained("easy-ocr-english")
12+
13+
# when
14+
result = model(ocr_test_image_numpy)
15+
16+
# then
17+
assert len(result) == 2
18+
assert result[0][0].startswith("This is a test image for OCR")
19+
assert isinstance(result[1][0], Detections)
20+
21+
22+
@pytest.mark.e2e_model_inference
23+
def test_easyocr_latin(ocr_test_image_numpy: np.ndarray, roboflow_api_key: str) -> None:
24+
# given
25+
model = AutoModel.from_pretrained("easy-ocr-latin")
26+
27+
# when
28+
result = model(ocr_test_image_numpy)
29+
30+
# then
31+
assert len(result) == 2
32+
assert result[0][0].startswith("This is a test image for OCR")
33+
assert isinstance(result[1][0], Detections)
34+
35+
36+
@pytest.mark.e2e_model_inference
37+
def test_easyocr_japanese(
38+
ocr_test_image_numpy: np.ndarray, roboflow_api_key: str
39+
) -> None:
40+
# given
41+
model = AutoModel.from_pretrained("easy-ocr-japanese")
42+
43+
# when
44+
result = model(ocr_test_image_numpy)
45+
46+
# then
47+
assert len(result) == 2
48+
assert isinstance(result[1][0], Detections)

inference_experimental/tests/integration_tests/models/conftest.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@
142142

143143
DEPTH_ANYTHING_V2_SMALL_PACKAGE_URL = "https://storage.googleapis.com/roboflow-tests-assets/rf-platform-models/depth-anything-v2.zip"
144144
DOCTR_PACKAGE_URL = "https://storage.googleapis.com/roboflow-tests-assets/rf-platform-models/doctr-dbnet-rn50-crnn-vgg16.zip"
145+
EASY_OCR_PACKAGE_URL = "https://storage.googleapis.com/roboflow-tests-assets/rf-platform-models/easy-ocr-english.zip"
145146

146147

147148
@pytest.fixture(scope="module")
@@ -1142,3 +1143,10 @@ def doctr_package() -> str:
11421143
return download_model_package(
11431144
model_package_zip_url=DOCTR_PACKAGE_URL, package_name="doctr"
11441145
)
1146+
1147+
1148+
@pytest.fixture(scope="module")
1149+
def easy_ocr_package() -> str:
1150+
return download_model_package(
1151+
model_package_zip_url=EASY_OCR_PACKAGE_URL, package_name="easy-ocr"
1152+
)

0 commit comments

Comments
 (0)