Skip to content

Commit 27e9a15

Browse files
authored
Merge pull request #1657 from roboflow/revert-1656-revert-1642-exp-rfdetr-2
Revert "Revert "USE_INFERENCE_EXP_MODELS""
2 parents b16f4fc + 5c22472 commit 27e9a15

File tree

20 files changed

+18810
-1510
lines changed

20 files changed

+18810
-1510
lines changed

inference/core/env.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,9 @@
196196
os.getenv("CORE_MODEL_YOLO_WORLD_ENABLED", True)
197197
)
198198

199+
# Enable experimental RFDETR backend (inference_exp) rollout, default is True
200+
USE_INFERENCE_EXP_MODELS = str2bool(os.getenv("USE_INFERENCE_EXP_MODELS", "False"))
201+
199202
# ID of host device, default is None
200203
DEVICE_ID = os.getenv("DEVICE_ID", None)
201204

inference/core/models/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def infer_from_request(
134134
responses = self.infer(**request.dict(), return_image_dims=False)
135135
for response in responses:
136136
response.time = perf_counter() - t1
137+
logger.debug(f"model infer time: {response.time * 1000.0} ms")
137138
if request.id:
138139
response.inference_id = request.id
139140

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
from threading import Lock
2+
from time import perf_counter
3+
from typing import Any, Generic, List, Optional, Tuple, Union
4+
5+
import numpy as np
6+
from inference_exp.models.base.object_detection import Detections, ObjectDetectionModel
7+
from inference_exp.models.base.types import (
8+
PreprocessedInputs,
9+
PreprocessingMetadata,
10+
RawPrediction,
11+
)
12+
13+
from inference.core.entities.responses.inference import (
14+
InferenceResponseImage,
15+
ObjectDetectionInferenceResponse,
16+
ObjectDetectionPrediction,
17+
)
18+
from inference.core.env import API_KEY
19+
from inference.core.logger import logger
20+
from inference.core.models.base import Model
21+
from inference.core.utils.image_utils import load_image_rgb
22+
from inference.models.aliases import resolve_roboflow_model_alias
23+
24+
25+
class InferenceExpObjectDetectionModelAdapter(Model):
26+
def __init__(self, model_id: str, api_key: str = None, **kwargs):
27+
super().__init__()
28+
29+
self.metrics = {"num_inferences": 0, "avg_inference_time": 0.0}
30+
31+
self.api_key = api_key if api_key else API_KEY
32+
model_id = resolve_roboflow_model_alias(model_id=model_id)
33+
34+
self.task_type = "object-detection"
35+
36+
# Lazy import to avoid hard dependency if flag disabled
37+
from inference_exp import AutoModel # type: ignore
38+
39+
self._exp_model: ObjectDetectionModel = AutoModel.from_pretrained(
40+
model_id_or_path=model_id, api_key=self.api_key
41+
)
42+
if hasattr(self._exp_model, "optimize_for_inference"):
43+
self._exp_model.optimize_for_inference()
44+
45+
self.class_names = list(self._exp_model.class_names)
46+
47+
def map_inference_kwargs(self, kwargs: dict) -> dict:
48+
return kwargs
49+
50+
def preprocess(self, image: Any, **kwargs):
51+
is_batch = isinstance(image, list)
52+
images = image if is_batch else [image]
53+
np_images: List[np.ndarray] = [
54+
load_image_rgb(
55+
v,
56+
disable_preproc_auto_orient=kwargs.get(
57+
"disable_preproc_auto_orient", False
58+
),
59+
)
60+
for v in images
61+
]
62+
mapped_kwargs = self.map_inference_kwargs(kwargs)
63+
return self._exp_model.pre_process(np_images, **mapped_kwargs)
64+
65+
def predict(self, img_in, **kwargs):
66+
mapped_kwargs = self.map_inference_kwargs(kwargs)
67+
return self._exp_model.forward(img_in, **mapped_kwargs)
68+
69+
def postprocess(
70+
self,
71+
predictions: Tuple[np.ndarray, ...],
72+
preprocess_return_metadata: PreprocessingMetadata,
73+
**kwargs,
74+
) -> List[Detections]:
75+
mapped_kwargs = self.map_inference_kwargs(kwargs)
76+
detections_list = self._exp_model.post_process(
77+
predictions, preprocess_return_metadata, **mapped_kwargs
78+
)
79+
80+
responses: List[ObjectDetectionInferenceResponse] = []
81+
for preproc_metadata, det in zip(preprocess_return_metadata, detections_list):
82+
H = preproc_metadata.original_size.height
83+
W = preproc_metadata.original_size.width
84+
85+
xyxy = det.xyxy.detach().cpu().numpy()
86+
confs = det.confidence.detach().cpu().numpy()
87+
class_ids = det.class_id.detach().cpu().numpy()
88+
89+
predictions: List[ObjectDetectionPrediction] = []
90+
91+
for (x1, y1, x2, y2), conf, class_id in zip(xyxy, confs, class_ids):
92+
cx = (float(x1) + float(x2)) / 2.0
93+
cy = (float(y1) + float(y2)) / 2.0
94+
w = float(x2) - float(x1)
95+
h = float(y2) - float(y1)
96+
class_id_int = int(class_id)
97+
class_name = (
98+
self.class_names[class_id_int]
99+
if 0 <= class_id_int < len(self.class_names)
100+
else str(class_id_int)
101+
)
102+
predictions.append(
103+
ObjectDetectionPrediction(
104+
x=cx,
105+
y=cy,
106+
width=w,
107+
height=h,
108+
confidence=float(conf),
109+
**{"class": class_name},
110+
class_id=class_id_int,
111+
)
112+
)
113+
114+
responses.append(
115+
ObjectDetectionInferenceResponse(
116+
predictions=predictions,
117+
image=InferenceResponseImage(width=W, height=H),
118+
)
119+
)
120+
121+
return responses
122+
123+
def clear_cache(self, delete_from_disk: bool = True) -> None:
124+
"""Clears any cache if necessary. TODO: Implement this to delete the cache from the experimental model.
125+
126+
Args:
127+
delete_from_disk (bool, optional): Whether to delete cached files from disk. Defaults to True.
128+
"""
129+
pass

inference/models/aliases.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,15 @@
1717
"paligemma-3b-ft-docvqa-448": "paligemma-pretrains/18",
1818
"paligemma-3b-ft-ocrvqa-448": "paligemma-pretrains/19",
1919
}
20+
# FLORENCE_ALIASES = {
21+
# "florence-2-base": "florence-pretrains/1",
22+
# "florence-2-large": "florence-pretrains/2",
23+
# }
24+
# since transformers 0.53.3 need newer version of florence2 weights
2025
FLORENCE_ALIASES = {
21-
"florence-2-base": "florence-pretrains/1",
22-
"florence-2-large": "florence-pretrains/2",
26+
"florence-2-base": "florence-pretrains/3",
27+
"florence-2-large": "florence-pretrains/4",
2328
}
24-
2529
QWEN_ALIASES = {
2630
"qwen25-vl-7b": "qwen-pretrains/1",
2731
}

inference/models/florence2/florence2.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def initialize_model(self, **kwargs):
7171
lora_config = LoraConfig.from_pretrained(self.cache_dir, device_map=DEVICE)
7272
model_id = lora_config.base_model_name_or_path
7373
revision = lora_config.revision
74+
original_revision_pre_mapping = revision
7475
if revision is not None:
7576
try:
7677
self.dtype = getattr(torch, revision)
@@ -135,11 +136,19 @@ def initialize_model(self, **kwargs):
135136
adapter_missing_keys.append(key)
136137
load_result.missing_keys.clear()
137138
load_result.missing_keys.extend(adapter_missing_keys)
138-
if len(load_result.missing_keys) > 0:
139-
raise RuntimeError(
140-
"Could not load LoRA weights for the model - found missing checkpoint keys "
141-
f"({len(load_result.missing_keys)}): {load_result.missing_keys}",
142-
)
139+
if original_revision_pre_mapping == "refs/pr/6":
140+
if len(load_result.missing_keys) > 2:
141+
raise RuntimeError(
142+
"Could not load LoRA weights for the model - found missing checkpoint keys "
143+
f"({len(load_result.missing_keys)}): {load_result.missing_keys}",
144+
)
145+
146+
else:
147+
if len(load_result.missing_keys) > 0:
148+
raise RuntimeError(
149+
"Could not load LoRA weights for the model - found missing checkpoint keys "
150+
f"({len(load_result.missing_keys)}): {load_result.missing_keys}",
151+
)
143152

144153
self.model = model
145154
except ImportError:
@@ -166,6 +175,7 @@ def get_lora_base_from_roboflow(self, model_id, revision):
166175
)
167176

168177
revision_mapping = {
178+
("microsoft/Florence-2-base-ft", "refs/pr/6"): "refs/pr/29-converted",
169179
("microsoft/Florence-2-base-ft", "refs/pr/22"): "refs/pr/29-converted",
170180
("microsoft/Florence-2-large-ft", "refs/pr/20"): "refs/pr/38-converted",
171181
}

inference/models/florence2/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,15 @@ def import_class_from_file(file_path, class_name, alias_name=None):
1515

1616
sys.path.insert(0, parent_dir)
1717

18+
previous_module = sys.modules.get(module_name)
19+
injected = False
1820
try:
1921
spec = importlib.util.spec_from_file_location(module_name, file_path)
2022
module = importlib.util.module_from_spec(spec)
2123

24+
sys.modules[module_name] = module
25+
injected = True
26+
2227
# Manually set the __package__ attribute to the parent package
2328
module.__package__ = os.path.basename(module_dir)
2429

@@ -27,5 +32,12 @@ def import_class_from_file(file_path, class_name, alias_name=None):
2732
if alias_name:
2833
globals()[alias_name] = cls
2934
return cls
35+
except Exception:
36+
if injected:
37+
if previous_module is not None:
38+
sys.modules[module_name] = previous_module
39+
else:
40+
sys.modules.pop(module_name, None)
41+
raise
3042
finally:
3143
sys.path.pop(0)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from inference.core.models.exp_adapter import InferenceExpObjectDetectionModelAdapter
2+
3+
4+
class RFDetrExperimentalModel(InferenceExpObjectDetectionModelAdapter):
5+
"""Adapter for RF-DETR using inference_exp AutoModel backend.
6+
7+
This class wraps an inference_exp AutoModel to present the same interface
8+
as legacy models in the inference server.
9+
"""
10+
11+
def map_inference_kwargs(self, kwargs: dict) -> dict:
12+
return {
13+
"threshold": kwargs.get("confidence"),
14+
}

inference/models/utils.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
PALIGEMMA_ENABLED,
2020
QWEN_2_5_ENABLED,
2121
SMOLVLM2_ENABLED,
22+
USE_INFERENCE_EXP_MODELS,
2223
)
2324
from inference.core.models.base import Model
2425
from inference.core.models.stubs import (
@@ -550,3 +551,30 @@ def get_model(model_id, api_key=API_KEY, **kwargs) -> Model:
550551

551552
def get_roboflow_model(*args, **kwargs):
552553
return get_model(*args, **kwargs)
554+
555+
556+
# Prefer inference_exp backend for RF-DETR variants when enabled and available
557+
try:
558+
if USE_INFERENCE_EXP_MODELS:
559+
# Ensure experimental package is importable before swapping
560+
__import__("inference_exp")
561+
from inference.models.rfdetr.rfdetr_exp import RFDetrExperimentalModel
562+
from inference.models.yolov8.yolov8_object_detection_exp import (
563+
Yolo8ODExperimentalModel,
564+
)
565+
566+
for task, variant in ROBOFLOW_MODEL_TYPES.keys():
567+
if task == "object-detection" and variant.startswith("rfdetr-"):
568+
ROBOFLOW_MODEL_TYPES[(task, variant)] = RFDetrExperimentalModel
569+
570+
# iterate over ROBOFLOW_MODEL_TYPES and replace all valuses where the model variatn starts with yolov8 with the experimental model
571+
for task, variant in ROBOFLOW_MODEL_TYPES.keys():
572+
if task == "object-detection" and variant.startswith("yolov8"):
573+
ROBOFLOW_MODEL_TYPES[(task, variant)] = Yolo8ODExperimentalModel
574+
575+
576+
except Exception:
577+
# Fallback silently to legacy ONNX RFDETR when experimental stack is unavailable
578+
warnings.warn(
579+
"Inference experimental stack is unavailable, falling back to regular model inference stack"
580+
)
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from inference.core.models.exp_adapter import InferenceExpObjectDetectionModelAdapter
2+
3+
4+
class Yolo8ODExperimentalModel(InferenceExpObjectDetectionModelAdapter):
5+
def map_inference_kwargs(self, kwargs: dict) -> dict:
6+
return {
7+
"conf_thresh": kwargs.get("confidence"),
8+
"iou_thresh": kwargs.get("iou_threshold"),
9+
"max_detections": kwargs.get("max_detections"),
10+
"class_agnostic": kwargs.get("class_agnostic"),
11+
}

0 commit comments

Comments
 (0)