Skip to content

Commit 230ec7a

Browse files
Merge branch 'main' into fix/usage-collector-model-id
2 parents 2ddc72a + c2ce156 commit 230ec7a

File tree

18 files changed

+499
-98
lines changed

18 files changed

+499
-98
lines changed

inference/core/interfaces/webrtc_worker/webrtc.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -162,10 +162,11 @@ def send_chunked_data(
162162
f"Sending response for frame {frame_id}: {total_chunks} chunk(s), {len(payload_bytes)} bytes"
163163
)
164164

165+
view = memoryview(payload_bytes)
165166
for chunk_index in range(total_chunks):
166167
start = chunk_index * chunk_size
167168
end = min(start + chunk_size, len(payload_bytes))
168-
chunk_data = payload_bytes[start:end]
169+
chunk_data = view[start:end]
169170

170171
message = create_chunked_binary_message(
171172
frame_id, chunk_index, total_chunks, chunk_data
@@ -305,7 +306,9 @@ async def _send_data_output(
305306

306307
if self._data_mode == DataOutputMode.NONE:
307308
# Even empty responses use binary protocol
308-
json_bytes = json.dumps(webrtc_output.model_dump()).encode("utf-8")
309+
json_bytes = await asyncio.to_thread(
310+
lambda: json.dumps(webrtc_output.model_dump()).encode("utf-8")
311+
)
309312
send_chunked_data(self.data_channel, self._received_frames, json_bytes)
310313
return
311314

@@ -371,15 +374,22 @@ async def _handle_data_channel_frame(self, message: bytes) -> None:
371374
f"Received frame {frame_id}: {total_chunks} chunk(s), {len(jpeg_bytes)} bytes JPEG"
372375
)
373376

374-
nparr = np.frombuffer(jpeg_bytes, np.uint8)
375-
np_image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
377+
def _decode_to_frame(jpeg_bytes: bytes) -> VideoFrame:
378+
nparr = np.frombuffer(jpeg_bytes, np.uint8)
379+
np_image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
380+
381+
if np_image is None:
382+
raise ValueError("cv2.imdecode returned None")
383+
384+
return VideoFrame.from_ndarray(np_image, format="bgr24")
376385

377-
if np_image is None:
378-
logger.error(f"Failed to decode JPEG for frame {frame_id}")
386+
try:
387+
video_frame = await asyncio.to_thread(_decode_to_frame, jpeg_bytes)
388+
except Exception as e:
389+
logger.error(f"Failed to decode JPEG for frame {frame_id}: {e}")
379390
return
380391

381-
video_frame = VideoFrame.from_ndarray(np_image, format="bgr24")
382-
await self._data_frame_queue.put((frame_id, video_frame))
392+
self._data_frame_queue.put_nowait((frame_id, video_frame))
383393

384394
if frame_id % 100 == 1:
385395
logger.info(f"Queued frame {frame_id}")
@@ -440,8 +450,10 @@ async def process_frames_data_only(self):
440450
)
441451

442452
# Send data via data channel
443-
await self._send_data_output(
444-
workflow_output, frame_timestamp, frame, errors
453+
asyncio.create_task(
454+
self._send_data_output(
455+
workflow_output, frame_timestamp, frame, errors
456+
)
445457
)
446458

447459
except asyncio.CancelledError:

inference_experimental/inference_exp/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
MultiLabelClassificationPrediction,
1313
)
1414
from inference_exp.models.base.depth_estimation import DepthEstimationModel
15-
from inference_exp.models.base.documents_parsing import DocumentParsingModel
15+
from inference_exp.models.base.documents_parsing import StructuredOCRModel
1616
from inference_exp.models.base.embeddings import TextImageEmbeddingModel
1717
from inference_exp.models.base.instance_segmentation import (
1818
InstanceDetections,

inference_experimental/inference_exp/models/auto_loaders/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
MultiLabelClassificationModel,
5858
)
5959
from inference_exp.models.base.depth_estimation import DepthEstimationModel
60-
from inference_exp.models.base.documents_parsing import DocumentParsingModel
60+
from inference_exp.models.base.documents_parsing import StructuredOCRModel
6161
from inference_exp.models.base.embeddings import TextImageEmbeddingModel
6262
from inference_exp.models.base.instance_segmentation import InstanceSegmentationModel
6363
from inference_exp.models.base.keypoints_detection import KeyPointsDetectionModel
@@ -79,7 +79,7 @@
7979
ClassificationModel,
8080
MultiLabelClassificationModel,
8181
DepthEstimationModel,
82-
DocumentParsingModel,
82+
StructuredOCRModel,
8383
TextImageEmbeddingModel,
8484
InstanceSegmentationModel,
8585
KeyPointsDetectionModel,

inference_experimental/inference_exp/models/auto_loaders/models_registry.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
CLASSIFICATION_TASK = "classification"
1616
MULTI_LABEL_CLASSIFICATION_TASK = "multi-label-classification"
1717
DEPTH_ESTIMATION_TASK = "depth-estimation"
18+
STRUCTURED_OCR_TASK = "structured-ocr"
1819

1920

2021
@dataclass(frozen=True)
@@ -356,8 +357,11 @@ class RegistryEntry:
356357
),
357358
("depth-anything-v2", DEPTH_ESTIMATION_TASK, BackendType.HF): LazyClass(
358359
module_name="inference_exp.models.depth_anything_v2.depth_anything_v2_hf",
359-
class_name="DepthAnythingV2HF"
360-
)
360+
class_name="DepthAnythingV2HF",
361+
),
362+
("doctr", STRUCTURED_OCR_TASK, BackendType.TORCH): LazyClass(
363+
module_name="inference_exp.models.doctr.doctr_torch", class_name="DocTR"
364+
),
361365
}
362366

363367

inference_experimental/inference_exp/models/base/documents_parsing.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,13 @@
1111
)
1212

1313

14-
class DocumentParsingModel(
14+
class StructuredOCRModel(
1515
ABC, Generic[PreprocessedInputs, PreprocessingMetadata, RawPrediction]
1616
):
1717

1818
@classmethod
1919
@abstractmethod
20-
def from_pretrained(
21-
cls, model_name_or_path: str, **kwargs
22-
) -> "DocumentParsingModel":
20+
def from_pretrained(cls, model_name_or_path: str, **kwargs) -> "StructuredOCRModel":
2321
pass
2422

2523
@property

inference_experimental/inference_exp/models/depth_anything_v2/depth_anything_v2_hf.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,7 @@ def from_pretrained(
2828
local_files_only=local_files_only,
2929
).to(device)
3030
processor = AutoImageProcessor.from_pretrained(
31-
model_name_or_path,
32-
local_files_only=local_files_only,
33-
use_fast=True
31+
model_name_or_path, local_files_only=local_files_only, use_fast=True
3432
)
3533
return cls(model=model, processor=processor, device=device)
3634

inference_experimental/inference_exp/models/doctr/doctr_torch.py

Lines changed: 58 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,88 +1,98 @@
1-
import os
21
from dataclasses import dataclass
32
from typing import Callable, List, Optional, Tuple, Union
43

54
import numpy as np
65
import torch
76
from doctr.io import Document
8-
from doctr.models import ocr_predictor
7+
from doctr.models import detection_predictor, ocr_predictor, recognition_predictor
98
from inference_exp import Detections
109
from inference_exp.configuration import DEFAULT_DEVICE
1110
from inference_exp.entities import ColorFormat, ImageDimensions
1211
from inference_exp.errors import CorruptedModelPackageError, ModelRuntimeError
13-
from inference_exp.models.base.documents_parsing import DocumentParsingModel
12+
from inference_exp.models.base.documents_parsing import StructuredOCRModel
1413
from inference_exp.models.common.model_packages import get_model_package_contents
1514
from inference_exp.utils.file_system import read_json
1615

17-
WEIGHTS_NAMES_MAPPING = {
18-
"db_resnet50": "db_resnet50-79bd7d70.pt",
19-
"db_resnet34": "db_resnet34-cb6aed9e.pt",
20-
"db_mobilenet_v3_large": "db_mobilenet_v3_large-21748dd0.pt",
21-
"crnn_vgg16_bn": "crnn_vgg16_bn-9762b0b0.pt",
22-
"crnn_mobilenet_v3_small": "crnn_mobilenet_v3_small_pt-3b919a02.pt",
23-
"crnn_mobilenet_v3_large": "crnn_mobilenet_v3_large_pt-f5259ec2.pt",
16+
SUPPORTED_DETECTION_MODELS = {
17+
"fast_base",
18+
"fast_small",
19+
"fast_tiny",
20+
"db_resnet50",
21+
"db_resnet34",
22+
"db_mobilenet_v3_large",
23+
"linknet_resnet18",
24+
"linknet_resnet34",
25+
"linknet_resnet50",
26+
}
27+
SUPPORTED_RECOGNITION_MODELS = {
28+
"crnn_vgg16_bn",
29+
"crnn_mobilenet_v3_small",
30+
"crnn_mobilenet_v3_large",
31+
"master",
32+
"sar_resnet31",
33+
"vitstr_small",
34+
"vitstr_base",
35+
"parseq",
2436
}
2537

2638

27-
class DocTR(DocumentParsingModel[List[np.ndarray], ImageDimensions, Document]):
39+
class DocTR(StructuredOCRModel[List[np.ndarray], ImageDimensions, Document]):
2840

2941
@classmethod
3042
def from_pretrained(
3143
cls,
3244
model_name_or_path: str,
3345
device: torch.device = DEFAULT_DEVICE,
46+
assume_straight_pages: bool = True,
47+
preserve_aspect_ratio: bool = True,
48+
detection_max_batch_size: int = 2,
49+
recognition_max_batch_size: int = 128,
3450
**kwargs,
35-
) -> "DocumentParsingModel":
36-
os.environ["DOCTR_CACHE_DIR"] = model_name_or_path
51+
) -> "StructuredOCRModel":
3752
model_package_content = get_model_package_contents(
3853
model_package_dir=model_name_or_path,
39-
elements=["doctr_det", "doctr_rec", "config.json"],
54+
elements=["detection_weights.pt", "recognition_weights.pt", "config.json"],
4055
)
4156
config = parse_model_config(config_path=model_package_content["config.json"])
42-
os.makedirs(f"{model_name_or_path}/doctr_det/models/", exist_ok=True)
43-
os.makedirs(f"{model_name_or_path}/doctr_rec/models/", exist_ok=True)
44-
det_model_source_path = os.path.join(
45-
model_name_or_path, "doctr_det", config.det_model, "model.pt"
46-
)
47-
rec_model_source_path = os.path.join(
48-
model_name_or_path, "doctr_rec", config.rec_model, "model.pt"
49-
)
50-
if not os.path.exists(det_model_source_path):
51-
raise CorruptedModelPackageError(
52-
message="Could not initialize DocTR model - could not find detection model weights.",
53-
help_url="https://todo",
54-
)
55-
if not os.path.exists(rec_model_source_path):
56-
raise CorruptedModelPackageError(
57-
message="Could not initialize DocTR model - could not find recognition model weights.",
58-
help_url="https://todo",
59-
)
60-
if config.det_model not in WEIGHTS_NAMES_MAPPING:
57+
if config.det_model not in SUPPORTED_DETECTION_MODELS:
6158
raise CorruptedModelPackageError(
6259
message=f"{config.det_model} model denoted in configuration not supported as DocTR detection model.",
6360
help_url="https://todo",
6461
)
65-
if config.rec_model not in WEIGHTS_NAMES_MAPPING:
62+
if config.rec_model not in SUPPORTED_RECOGNITION_MODELS:
6663
raise CorruptedModelPackageError(
67-
message=f"{config.det_model} model denoted in configuration not supported as DocTR recognition model.",
64+
message=f"{config.rec_model} model denoted in configuration not supported as DocTR recognition model.",
6865
help_url="https://todo",
6966
)
70-
det_model_target_path = os.path.join(
71-
model_name_or_path, "models", WEIGHTS_NAMES_MAPPING[config.det_model]
67+
det_model = detection_predictor(
68+
arch=config.det_model,
69+
pretrained=False,
70+
assume_straight_pages=assume_straight_pages,
71+
preserve_aspect_ratio=preserve_aspect_ratio,
72+
batch_size=detection_max_batch_size,
73+
)
74+
det_model.model.to(device)
75+
detector_weights = torch.load(
76+
model_package_content["detection_weights.pt"],
77+
weights_only=True,
78+
map_location=device,
79+
)
80+
det_model.model.load_state_dict(detector_weights)
81+
rec_model = recognition_predictor(
82+
arch=config.rec_model,
83+
pretrained=False,
84+
batch_size=recognition_max_batch_size,
7285
)
73-
rec_model_target_path = os.path.join(
74-
model_name_or_path, "models", WEIGHTS_NAMES_MAPPING[config.rec_model]
86+
rec_model.model.to(device)
87+
rec_weights = torch.load(
88+
model_package_content["recognition_weights.pt"],
89+
weights_only=True,
90+
map_location=device,
7591
)
76-
if os.path.exists(det_model_target_path):
77-
os.remove(det_model_target_path)
78-
os.symlink(det_model_source_path, det_model_target_path)
79-
if os.path.exists(rec_model_target_path):
80-
os.remove(rec_model_target_path)
81-
os.symlink(rec_model_source_path, rec_model_target_path)
92+
rec_model.model.load_state_dict(rec_weights)
8293
model = ocr_predictor(
83-
det_arch=config.det_model,
84-
reco_arch=config.rec_model,
85-
pretrained=True,
94+
det_arch=det_model.model,
95+
reco_arch=rec_model.model,
8696
).to(device=device)
8797
return cls(model=model, device=device)
8898

inference_experimental/inference_exp/models/moondream2/moondream2_hf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ def from_pretrained(
3737
if torch.mps.is_available():
3838
raise ModelRuntimeError(
3939
message=f"This model cannot run on Apple device with MPS unit - original implementation contains bug "
40-
f"preventing proper allocation of tensors which causes runtime error. Run this model on the "
41-
f"machine with Nvidia GPU or x86 CPU.",
40+
f"preventing proper allocation of tensors which causes runtime error. Run this model on the "
41+
f"machine with Nvidia GPU or x86 CPU.",
4242
help_url="https://todo",
4343
)
4444
model_package_content = get_model_package_contents(

inference_experimental/inference_exp/models/rfdetr/rfdetr_instance_segmentation_pytorch.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,9 @@ def from_pretrained(
124124
model_config = CONFIG_FOR_MODEL_TYPE[model_type](device=device)
125125
checkpoint_num_classes = weights_dict["class_embed.bias"].shape[0]
126126
model_config.num_classes = checkpoint_num_classes - 1
127-
model_config.resolution = inference_config.network_input.training_input_size.height
127+
model_config.resolution = (
128+
inference_config.network_input.training_input_size.height
129+
)
128130
model = build_model(config=model_config)
129131
model.load_state_dict(weights_dict)
130132
model = model.eval().to(device)

inference_experimental/inference_exp/models/rfdetr/rfdetr_object_detection_pytorch.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,9 @@ def from_pretrained(
130130
model_config = CONFIG_FOR_MODEL_TYPE[model_type](device=device)
131131
checkpoint_num_classes = weights_dict["class_embed.bias"].shape[0]
132132
model_config.num_classes = checkpoint_num_classes - 1
133-
model_config.resolution = inference_config.network_input.training_input_size.height
133+
model_config.resolution = (
134+
inference_config.network_input.training_input_size.height
135+
)
134136
model = build_model(config=model_config)
135137
model.load_state_dict(weights_dict)
136138
model = model.eval().to(device)

0 commit comments

Comments
 (0)