Skip to content

Commit 246cada

Browse files
Merge pull request #1666 from roboflow/feature/trt-inplementation-for-more-models
TRT inplementation for more models
2 parents 1f6cbc4 + ea02738 commit 246cada

File tree

9 files changed

+1419
-11
lines changed

9 files changed

+1419
-11
lines changed

inference_experimental/inference_exp/models/auto_loaders/models_registry.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,14 @@ class RegistryEntry:
297297
module_name="inference_exp.models.vit.vit_classification_huggingface",
298298
class_name="VITForMultiLabelClassificationHF",
299299
),
300+
("vit", CLASSIFICATION_TASK, BackendType.TRT): LazyClass(
301+
module_name="inference_exp.models.vit.vit_classification_trt",
302+
class_name="VITForClassificationTRT",
303+
),
304+
("vit", MULTI_LABEL_CLASSIFICATION_TASK, BackendType.TRT): LazyClass(
305+
module_name="inference_exp.models.vit.vit_classification_trt",
306+
class_name="VITForMultiLabelClassificationTRT",
307+
),
300308
("resnet", CLASSIFICATION_TASK, BackendType.ONNX): LazyClass(
301309
module_name="inference_exp.models.resnet.resnet_classification_onnx",
302310
class_name="ResNetForClassificationOnnx",
@@ -313,6 +321,14 @@ class RegistryEntry:
313321
module_name="inference_exp.models.resnet.resnet_classification_torch",
314322
class_name="ResNetForMultiLabelClassificationTorch",
315323
),
324+
("resnet", CLASSIFICATION_TASK, BackendType.TRT): LazyClass(
325+
module_name="inference_exp.models.resnet.resnet_classification_trt",
326+
class_name="ResNetForClassificationTRT",
327+
),
328+
("resnet", MULTI_LABEL_CLASSIFICATION_TASK, BackendType.TRT): LazyClass(
329+
module_name="inference_exp.models.resnet.resnet_classification_trt",
330+
class_name="ResNetForMultiLabelClassificationTRT",
331+
),
316332
("segment-anything-2-rt", INSTANCE_SEGMENTATION_TASK, BackendType.TORCH): LazyClass(
317333
module_name="inference_exp.models.sam2_rt.sam2_pytorch",
318334
class_name="SAM2ForStream",
@@ -325,10 +341,18 @@ class RegistryEntry:
325341
module_name="inference_exp.models.deep_lab_v3_plus.deep_lab_v3_plus_segmentation_onnx",
326342
class_name="DeepLabV3PlusForSemanticSegmentationOnnx",
327343
),
344+
("deep-lab-v3-plus", SEMANTIC_SEGMENTATION_TASK, BackendType.TRT): LazyClass(
345+
module_name="inference_exp.models.deep_lab_v3_plus.deep_lab_v3_plus_segmentation_trt",
346+
class_name="DeepLabV3PlusForSemanticSegmentationTRT",
347+
),
328348
("yolact", INSTANCE_SEGMENTATION_TASK, BackendType.ONNX): LazyClass(
329349
module_name="inference_exp.models.yolact.yolact_instance_segmentation_onnx",
330350
class_name="YOLOACTForInstanceSegmentationOnnx",
331351
),
352+
("yolact", INSTANCE_SEGMENTATION_TASK, BackendType.TRT): LazyClass(
353+
module_name="inference_exp.models.yolact.yolact_instance_segmentation_trt",
354+
class_name="YOLOACTForInstanceSegmentationTRT",
355+
),
332356
}
333357

334358

Lines changed: 307 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,307 @@
1+
from threading import Lock
2+
from typing import List, Optional, Tuple, Union
3+
4+
import torch
5+
from inference_exp import ColorFormat, SemanticSegmentationModel
6+
from inference_exp.configuration import DEFAULT_DEVICE
7+
from inference_exp.errors import (
8+
CorruptedModelPackageError,
9+
MissingDependencyError,
10+
ModelRuntimeError,
11+
)
12+
from inference_exp.models.base.semantic_segmentation import SemanticSegmentationResult
13+
from inference_exp.models.base.types import PreprocessedInputs, PreprocessingMetadata
14+
from inference_exp.models.common.cuda import use_cuda_context, use_primary_cuda_context
15+
from inference_exp.models.common.model_packages import get_model_package_contents
16+
from inference_exp.models.common.roboflow.model_packages import (
17+
InferenceConfig,
18+
PreProcessingMetadata,
19+
ResizeMode,
20+
TRTConfig,
21+
parse_class_names_file,
22+
parse_inference_config,
23+
parse_trt_config,
24+
)
25+
from inference_exp.models.common.roboflow.pre_processing import (
26+
pre_process_network_input,
27+
)
28+
from inference_exp.models.common.trt import (
29+
get_engine_inputs_and_outputs,
30+
infer_from_trt_engine,
31+
load_model,
32+
)
33+
from torchvision.transforms import functional
34+
35+
try:
36+
import tensorrt as trt
37+
except ImportError as import_error:
38+
raise MissingDependencyError(
39+
message=f"Could not import YOLOv8 model with TRT backend - this error means that some additional dependencies "
40+
f"are not installed in the environment. If you run the `inference-exp` library directly in your Python "
41+
f"program, make sure the following extras of the package are installed: `trt10` - installation can only "
42+
f"succeed for Linux and Windows machines with Cuda 12 installed. Jetson devices, should have TRT 10.x "
43+
f"installed for all builds with Jetpack 6. "
44+
f"If you see this error using Roboflow infrastructure, make sure the service you use does support the model. "
45+
f"You can also contact Roboflow to get support.",
46+
help_url="https://todo",
47+
) from import_error
48+
49+
try:
50+
import pycuda.driver as cuda
51+
except ImportError as import_error:
52+
raise MissingDependencyError(
53+
message="TODO", help_url="https://todo"
54+
) from import_error
55+
56+
57+
class DeepLabV3PlusForSemanticSegmentationTRT(
58+
SemanticSegmentationModel[torch.Tensor, PreProcessingMetadata, torch.Tensor]
59+
):
60+
61+
@classmethod
62+
def from_pretrained(
63+
cls,
64+
model_name_or_path: str,
65+
device: torch.device = DEFAULT_DEVICE,
66+
engine_host_code_allowed: bool = False,
67+
**kwargs,
68+
) -> "DeepLabV3PlusForSemanticSegmentationTRT":
69+
if device.type != "cuda":
70+
raise ModelRuntimeError(
71+
message=f"TRT engine only runs on CUDA device - {device} device detected.",
72+
help_url="https://todo",
73+
)
74+
model_package_content = get_model_package_contents(
75+
model_package_dir=model_name_or_path,
76+
elements=[
77+
"class_names.txt",
78+
"inference_config.json",
79+
"trt_config.json",
80+
"engine.plan",
81+
],
82+
)
83+
class_names = parse_class_names_file(
84+
class_names_path=model_package_content["class_names.txt"]
85+
)
86+
try:
87+
background_class_id = [c.lower() for c in class_names].index("background")
88+
except ValueError:
89+
background_class_id = -1
90+
inference_config = parse_inference_config(
91+
config_path=model_package_content["inference_config.json"],
92+
allowed_resize_modes={
93+
ResizeMode.STRETCH_TO,
94+
ResizeMode.LETTERBOX,
95+
ResizeMode.CENTER_CROP,
96+
ResizeMode.LETTERBOX_REFLECT_EDGES,
97+
},
98+
)
99+
trt_config = parse_trt_config(
100+
config_path=model_package_content["trt_config.json"]
101+
)
102+
cuda.init()
103+
cuda_device = cuda.Device(device.index or 0)
104+
with use_primary_cuda_context(cuda_device=cuda_device) as cuda_context:
105+
engine = load_model(
106+
model_path=model_package_content["engine.plan"],
107+
engine_host_code_allowed=engine_host_code_allowed,
108+
)
109+
execution_context = engine.create_execution_context()
110+
inputs, outputs = get_engine_inputs_and_outputs(engine=engine)
111+
if len(inputs) != 1:
112+
raise CorruptedModelPackageError(
113+
message=f"Implementation assume single model input, found: {len(inputs)}.",
114+
help_url="https://todo",
115+
)
116+
if len(outputs) != 1:
117+
raise CorruptedModelPackageError(
118+
message=f"Implementation assume single model output, found: {len(outputs)}.",
119+
help_url="https://todo",
120+
)
121+
return cls(
122+
engine=engine,
123+
input_name=inputs[0],
124+
output_name=outputs[0],
125+
class_names=class_names,
126+
background_class_id=background_class_id,
127+
inference_config=inference_config,
128+
trt_config=trt_config,
129+
device=device,
130+
cuda_context=cuda_context,
131+
execution_context=execution_context,
132+
)
133+
134+
def __init__(
135+
self,
136+
engine: trt.ICudaEngine,
137+
input_name: str,
138+
output_name: str,
139+
class_names: List[str],
140+
background_class_id: int,
141+
inference_config: InferenceConfig,
142+
trt_config: TRTConfig,
143+
device: torch.device,
144+
cuda_context: cuda.Context,
145+
execution_context: trt.IExecutionContext,
146+
):
147+
self._engine = engine
148+
self._input_name = input_name
149+
self._output_names = [output_name]
150+
self._class_names = class_names
151+
self._background_class_id = background_class_id
152+
self._inference_config = inference_config
153+
self._trt_config = trt_config
154+
self._device = device
155+
self._cuda_context = cuda_context
156+
self._execution_context = execution_context
157+
self._lock = Lock()
158+
159+
@property
160+
def class_names(self) -> List[str]:
161+
return self._class_names
162+
163+
def pre_process(
164+
self,
165+
images: Union[torch.Tensor, List[torch.Tensor]],
166+
input_color_format: Optional[ColorFormat] = None,
167+
**kwargs,
168+
) -> Tuple[PreprocessedInputs, PreprocessingMetadata]:
169+
return pre_process_network_input(
170+
images=images,
171+
image_pre_processing=self._inference_config.image_pre_processing,
172+
network_input=self._inference_config.network_input,
173+
target_device=self._device,
174+
input_color_format=input_color_format,
175+
)
176+
177+
def forward(
178+
self, pre_processed_images: PreprocessedInputs, **kwargs
179+
) -> torch.Tensor:
180+
with self._lock:
181+
with use_cuda_context(context=self._cuda_context):
182+
return infer_from_trt_engine(
183+
pre_processed_images=pre_processed_images,
184+
trt_config=self._trt_config,
185+
engine=self._engine,
186+
context=self._execution_context,
187+
device=self._device,
188+
input_name=self._input_name,
189+
outputs=self._output_names,
190+
)[0]
191+
192+
def post_process(
193+
self,
194+
model_results: torch.Tensor,
195+
pre_processing_meta: PreprocessedInputs,
196+
confidence_threshold: float = 0.5,
197+
**kwargs,
198+
) -> List[SemanticSegmentationResult]:
199+
results = []
200+
for image_results, image_metadata in zip(model_results, pre_processing_meta):
201+
inference_size = image_metadata.inference_size
202+
mask_h_scale = model_results.shape[2] / inference_size.height
203+
mask_w_scale = model_results.shape[3] / inference_size.width
204+
mask_pad_top, mask_pad_bottom, mask_pad_left, mask_pad_right = (
205+
round(mask_h_scale * image_metadata.pad_top),
206+
round(mask_h_scale * image_metadata.pad_bottom),
207+
round(mask_w_scale * image_metadata.pad_left),
208+
round(mask_w_scale * image_metadata.pad_right),
209+
)
210+
_, mh, mw = image_results.shape
211+
if (
212+
mask_pad_top < 0
213+
or mask_pad_bottom < 0
214+
or mask_pad_left < 0
215+
or mask_pad_right < 0
216+
):
217+
image_results = torch.nn.functional.pad(
218+
image_results,
219+
(
220+
abs(min(mask_pad_left, 0)),
221+
abs(min(mask_pad_right, 0)),
222+
abs(min(mask_pad_top, 0)),
223+
abs(min(mask_pad_bottom, 0)),
224+
),
225+
"constant",
226+
self._background_class_id,
227+
)
228+
padded_mask_offset_top = max(mask_pad_top, 0)
229+
padded_mask_offset_bottom = max(mask_pad_bottom, 0)
230+
padded_mask_offset_left = max(mask_pad_left, 0)
231+
padded_mask_offset_right = max(mask_pad_right, 0)
232+
image_results = image_results[
233+
:,
234+
padded_mask_offset_top : image_results.shape[1]
235+
- padded_mask_offset_bottom,
236+
padded_mask_offset_left : image_results.shape[1]
237+
- padded_mask_offset_right,
238+
]
239+
else:
240+
image_results = image_results[
241+
:,
242+
mask_pad_top : mh - mask_pad_bottom,
243+
mask_pad_left : mw - mask_pad_right,
244+
]
245+
if (
246+
image_results.shape[1]
247+
!= image_metadata.size_after_pre_processing.height
248+
or image_results.shape[2]
249+
!= image_metadata.size_after_pre_processing.width
250+
):
251+
image_results = functional.resize(
252+
image_results,
253+
[
254+
image_metadata.size_after_pre_processing.height,
255+
image_metadata.size_after_pre_processing.width,
256+
],
257+
interpolation=functional.InterpolationMode.BILINEAR,
258+
)
259+
image_results = torch.nn.functional.softmax(image_results, dim=0)
260+
image_confidence, image_class_ids = torch.max(image_results, dim=0)
261+
below_threshold = image_confidence < confidence_threshold
262+
image_confidence[below_threshold] = 0.0
263+
image_class_ids[below_threshold] = self._background_class_id
264+
if (
265+
image_metadata.static_crop_offset.offset_x > 0
266+
or image_metadata.static_crop_offset.offset_y > 0
267+
):
268+
original_size_confidence_canvas = torch.zeros(
269+
(
270+
image_metadata.original_size.height,
271+
image_metadata.original_size.width,
272+
),
273+
device=self._device,
274+
dtype=image_confidence.dtype,
275+
)
276+
original_size_confidence_canvas[
277+
image_metadata.static_crop_offset.offset_y : image_metadata.static_crop_offset.offset_y
278+
+ image_confidence.shape[0],
279+
image_metadata.static_crop_offset.offset_x : image_metadata.static_crop_offset.offset_x
280+
+ image_confidence.shape[1],
281+
] = image_confidence
282+
original_size_confidence_class_id_canvas = (
283+
torch.ones(
284+
(
285+
image_metadata.original_size.height,
286+
image_metadata.original_size.width,
287+
),
288+
device=self._device,
289+
dtype=image_class_ids.dtype,
290+
)
291+
* self._background_class_id
292+
)
293+
original_size_confidence_class_id_canvas[
294+
image_metadata.static_crop_offset.offset_y : image_metadata.static_crop_offset.offset_y
295+
+ image_class_ids.shape[0],
296+
image_metadata.static_crop_offset.offset_x : image_metadata.static_crop_offset.offset_x
297+
+ image_class_ids.shape[1],
298+
] = image_class_ids
299+
image_class_ids = original_size_confidence_class_id_canvas
300+
image_confidence = original_size_confidence_canvas
301+
results.append(
302+
SemanticSegmentationResult(
303+
segmentation_map=image_class_ids,
304+
confidence=image_confidence,
305+
)
306+
)
307+
return results

0 commit comments

Comments
 (0)