Skip to content

Commit c06a52a

Browse files
Merge branch 'main' into webrtc-transport-datachannel
2 parents d7e3295 + d67dac6 commit c06a52a

File tree

4 files changed

+83
-27
lines changed

4 files changed

+83
-27
lines changed

inference/core/env.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@
122122
# and also ENABLE_STREAM_API environmental variable is set to False
123123
PRELOAD_HF_IDS = os.getenv("PRELOAD_HF_IDS")
124124
if PRELOAD_HF_IDS:
125-
PRELOAD_HF_IDS = [id.strip() for id in PRELOAD_HF_IDS.split(",")]
125+
PRELOAD_HF_IDS = [m.strip() for m in PRELOAD_HF_IDS.split(",")]
126126

127127
# Maximum batch size for GAZE, default is 8
128128
GAZE_MAX_BATCH_SIZE = int(os.getenv("GAZE_MAX_BATCH_SIZE", 8))

inference/core/interfaces/stream/inference_pipeline.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
PipelineWatchDog,
5656
)
5757
from inference.core.managers.active_learning import BackgroundTaskActiveLearningManager
58+
from inference.core.managers.base import ModelManager
5859
from inference.core.managers.decorators.fixed_size_cache import WithFixedSizeCache
5960
from inference.core.registries.roboflow import RoboflowModelRegistry
6061
from inference.core.utils.function import experimental
@@ -486,6 +487,7 @@ def init_with_workflow(
486487
serialize_results: bool = False,
487488
predictions_queue_size: int = PREDICTIONS_QUEUE_SIZE,
488489
decoding_buffer_size: int = DEFAULT_BUFFER_SIZE,
490+
model_manager: Optional[ModelManager] = None,
489491
) -> "InferencePipeline":
490492
"""
491493
This class creates the abstraction for making inferences from given workflow against video stream.
@@ -566,6 +568,8 @@ def init_with_workflow(
566568
default value is taken from INFERENCE_PIPELINE_PREDICTIONS_QUEUE_SIZE env variable
567569
decoding_buffer_size (int): size of video source decoding buffer
568570
default value is taken from VIDEO_SOURCE_BUFFER_SIZE env variable
571+
model_manager (Optional[ModelManager]): Model manager to be used by InferencePipeline, defaults to
572+
BackgroundTaskActiveLearningManager with WithFixedSizeCache
569573
570574
Other ENV variables involved in low-level configuration:
571575
* INFERENCE_PIPELINE_PREDICTIONS_QUEUE_SIZE - size of buffer for predictions that are ready for dispatching
@@ -623,13 +627,14 @@ def init_with_workflow(
623627
use_cache=use_workflow_definition_cache,
624628
)
625629
model_registry = RoboflowModelRegistry(ROBOFLOW_MODEL_TYPES)
626-
model_manager = BackgroundTaskActiveLearningManager(
627-
model_registry=model_registry, cache=cache
628-
)
629-
model_manager = WithFixedSizeCache(
630-
model_manager,
631-
max_size=MAX_ACTIVE_MODELS,
632-
)
630+
if model_manager is None:
631+
model_manager = BackgroundTaskActiveLearningManager(
632+
model_registry=model_registry, cache=cache
633+
)
634+
model_manager = WithFixedSizeCache(
635+
model_manager,
636+
max_size=MAX_ACTIVE_MODELS,
637+
)
633638
if workflow_init_parameters is None:
634639
workflow_init_parameters = {}
635640
thread_pool_executor = ThreadPoolExecutor(

inference/core/interfaces/webrtc_worker/modal.py

Lines changed: 62 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
MODELS_CACHE_AUTH_CACHE_TTL,
1717
MODELS_CACHE_AUTH_ENABLED,
1818
PRELOAD_HF_IDS,
19+
PRELOAD_MODELS,
1920
PROJECT,
2021
ROBOFLOW_INTERNAL_SERVICE_SECRET,
2122
WEBRTC_MODAL_APP_NAME,
@@ -46,8 +47,12 @@
4647
from inference.core.interfaces.webrtc_worker.webrtc import (
4748
init_rtc_peer_connection_with_loop,
4849
)
50+
from inference.core.managers.base import ModelManager
51+
from inference.core.registries.roboflow import RoboflowModelRegistry
4952
from inference.core.roboflow_api import get_roboflow_workspace
5053
from inference.core.version import __version__
54+
from inference.models.aliases import resolve_roboflow_model_alias
55+
from inference.models.utils import ROBOFLOW_MODEL_TYPES
5156
from inference.usage_tracking.collector import usage_collector
5257
from inference.usage_tracking.plan_details import WebRTCPlan
5358

@@ -109,7 +114,8 @@
109114
"MODELS_CACHE_AUTH_ENABLED": str(MODELS_CACHE_AUTH_ENABLED),
110115
"LOG_LEVEL": LOG_LEVEL,
111116
"ONNXRUNTIME_EXECUTION_PROVIDERS": "[CUDAExecutionProvider,CPUExecutionProvider]",
112-
"PRELOAD_HF_IDS": PRELOAD_HF_IDS,
117+
"PRELOAD_HF_IDS": ",".join(PRELOAD_HF_IDS) if PRELOAD_HF_IDS else "",
118+
"PRELOAD_MODELS": ",".join(PRELOAD_MODELS) if PRELOAD_MODELS else "",
113119
"PROJECT": PROJECT,
114120
"ROBOFLOW_INTERNAL_SERVICE_NAME": WEBRTC_MODAL_ROBOFLOW_INTERNAL_SERVICE_NAME,
115121
"ROBOFLOW_INTERNAL_SERVICE_SECRET": ROBOFLOW_INTERNAL_SERVICE_SECRET,
@@ -135,7 +141,7 @@
135141
}
136142

137143
class RTCPeerConnectionModal:
138-
_webrtc_request: Optional[WebRTCWorkerRequest] = modal.parameter(default=None)
144+
_model_manager: Optional[ModelManager] = modal.parameter(default=None)
139145

140146
@modal.method()
141147
def rtc_peer_connection_modal(
@@ -145,6 +151,14 @@ def rtc_peer_connection_modal(
145151
):
146152
logger.info("*** Spawning %s:", self.__class__.__name__)
147153
logger.info("Inference tag: %s", docker_tag)
154+
logger.info(
155+
"Preloaded models: %s",
156+
(
157+
", ".join(self._model_manager.models().keys())
158+
if self._model_manager
159+
else ""
160+
),
161+
)
148162
_exec_session_started = datetime.datetime.now()
149163
webrtc_request.processing_session_started = _exec_session_started
150164
logger.info(
@@ -170,7 +184,6 @@ def rtc_peer_connection_modal(
170184
else []
171185
),
172186
)
173-
self._webrtc_request = webrtc_request
174187

175188
def send_answer(obj: WebRTCWorkerResult):
176189
logger.info("Sending webrtc answer")
@@ -180,35 +193,36 @@ def send_answer(obj: WebRTCWorkerResult):
180193
init_rtc_peer_connection_with_loop(
181194
webrtc_request=webrtc_request,
182195
send_answer=send_answer,
196+
model_manager=self._model_manager,
183197
)
184198
)
185199
_exec_session_stopped = datetime.datetime.now()
186200
logger.info(
187201
"WebRTC session stopped at %s",
188202
_exec_session_stopped.isoformat(),
189203
)
190-
workflow_id = self._webrtc_request.workflow_configuration.workflow_id
204+
workflow_id = webrtc_request.workflow_configuration.workflow_id
191205
if not workflow_id:
192-
if self._webrtc_request.workflow_configuration.workflow_specification:
206+
if webrtc_request.workflow_configuration.workflow_specification:
193207
workflow_id = usage_collector._calculate_resource_hash(
194-
resource_details=self._webrtc_request.workflow_configuration.workflow_specification
208+
resource_details=webrtc_request.workflow_configuration.workflow_specification
195209
)
196210
else:
197211
workflow_id = "unknown"
198212

199213
# requested plan is guaranteed to be set due to validation in spawn_rtc_peer_connection_modal
200-
webrtc_plan = self._webrtc_request.requested_plan
214+
webrtc_plan = webrtc_request.requested_plan
201215

202216
video_source = "realtime browser stream"
203-
if self._webrtc_request.rtsp_url:
217+
if webrtc_request.rtsp_url:
204218
video_source = "rtsp"
205-
elif not self._webrtc_request.webrtc_realtime_processing:
219+
elif not webrtc_request.webrtc_realtime_processing:
206220
video_source = "buffered browser stream"
207221

208222
usage_collector.record_usage(
209223
source=workflow_id,
210224
category="modal",
211-
api_key=self._webrtc_request.api_key,
225+
api_key=webrtc_request.api_key,
212226
resource_details={
213227
"plan": webrtc_plan,
214228
"billable": True,
@@ -221,13 +235,6 @@ def send_answer(obj: WebRTCWorkerResult):
221235
usage_collector.push_usage_payloads()
222236
logger.info("Function completed")
223237

224-
# https://modal.com/docs/reference/modal.enter
225-
# https://modal.com/docs/guide/memory-snapshot#gpu-memory-snapshot
226-
@modal.enter(snap=True)
227-
def start(self):
228-
# TODO: pre-load models
229-
logger.info("Starting container")
230-
231238
@modal.exit()
232239
def stop(self):
233240
logger.info("Stopping container")
@@ -238,7 +245,11 @@ def stop(self):
238245
**decorator_kwargs,
239246
)
240247
class RTCPeerConnectionModalCPU(RTCPeerConnectionModal):
241-
pass
248+
# https://modal.com/docs/reference/modal.enter
249+
@modal.enter(snap=True)
250+
def start(self):
251+
# TODO: pre-load models on CPU
252+
logger.info("Starting CPU container")
242253

243254
@app.cls(
244255
**{
@@ -250,7 +261,39 @@ class RTCPeerConnectionModalCPU(RTCPeerConnectionModal):
250261
}
251262
)
252263
class RTCPeerConnectionModalGPU(RTCPeerConnectionModal):
253-
pass
264+
# https://modal.com/docs/reference/modal.enter
265+
# https://modal.com/docs/guide/memory-snapshot#gpu-memory-snapshot
266+
@modal.enter(snap=True)
267+
def start(self):
268+
logger.info("Starting GPU container")
269+
logger.info("Preload hf ids: %s", PRELOAD_HF_IDS)
270+
logger.info("Preload models: %s", PRELOAD_MODELS)
271+
if PRELOAD_HF_IDS:
272+
# Kick off pre-loading of models (owlv2 preloading is based on module-level singleton)
273+
logger.info("Preloading owlv2 base model")
274+
import inference.models.owlv2.owlv2
275+
if PRELOAD_MODELS:
276+
model_registry = RoboflowModelRegistry(ROBOFLOW_MODEL_TYPES)
277+
model_manager = ModelManager(model_registry=model_registry)
278+
for model_id in PRELOAD_MODELS:
279+
try:
280+
de_aliased_model_id = resolve_roboflow_model_alias(
281+
model_id=model_id
282+
)
283+
logger.info(f"Preloading model: {de_aliased_model_id}")
284+
model_manager.add_model(
285+
model_id=de_aliased_model_id,
286+
api_key=None,
287+
countinference=False,
288+
service_secret=ROBOFLOW_INTERNAL_SERVICE_SECRET,
289+
)
290+
except Exception as exc:
291+
logger.error(
292+
"Failed to preload model %s: %s",
293+
model_id,
294+
exc,
295+
)
296+
self._model_manager = model_manager
254297

255298
def spawn_rtc_peer_connection_modal(
256299
webrtc_request: WebRTCWorkerRequest,

inference/core/interfaces/webrtc_worker/webrtc.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
detect_image_output,
5353
process_frame,
5454
)
55+
from inference.core.managers.base import ModelManager
5556
from inference.core.roboflow_api import get_workflow_specification
5657
from inference.core.workflows.core_steps.common.serializers import (
5758
serialize_wildcard_kind,
@@ -193,6 +194,7 @@ def __init__(
193194
asyncio_loop: asyncio.AbstractEventLoop,
194195
workflow_configuration: WorkflowConfiguration,
195196
api_key: str,
197+
model_manager: Optional[ModelManager] = None,
196198
data_output: Optional[List[str]] = None,
197199
stream_output: Optional[str] = None,
198200
has_video_track: bool = True,
@@ -250,6 +252,7 @@ def __init__(
250252
workflows_thread_pool_workers=workflow_configuration.workflows_thread_pool_workers,
251253
cancel_thread_pool_tasks_on_exit=workflow_configuration.cancel_thread_pool_tasks_on_exit,
252254
video_metadata_input_name=workflow_configuration.video_metadata_input_name,
255+
model_manager=model_manager,
253256
)
254257

255258
def set_track(self, track: RemoteStreamTrack):
@@ -537,6 +540,7 @@ def __init__(
537540
asyncio_loop: asyncio.AbstractEventLoop,
538541
workflow_configuration: WorkflowConfiguration,
539542
api_key: str,
543+
model_manager: Optional[ModelManager] = None,
540544
data_output: Optional[List[str]] = None,
541545
stream_output: Optional[str] = None,
542546
has_video_track: bool = True,
@@ -560,6 +564,7 @@ def __init__(
560564
termination_date=termination_date,
561565
terminate_event=terminate_event,
562566
use_data_channel_frames=use_data_channel_frames,
567+
model_manager=model_manager,
563568
)
564569

565570
async def _auto_detect_stream_output(
@@ -643,6 +648,7 @@ async def init_rtc_peer_connection_with_loop(
643648
webrtc_request: WebRTCWorkerRequest,
644649
send_answer: Callable[[WebRTCWorkerResult], None],
645650
asyncio_loop: Optional[asyncio.AbstractEventLoop] = None,
651+
model_manager: Optional[ModelManager] = None,
646652
shutdown_reserve: int = WEBRTC_MODAL_SHUTDOWN_RESERVE,
647653
) -> RTCPeerConnectionWithLoop:
648654
termination_date = None
@@ -694,6 +700,7 @@ async def init_rtc_peer_connection_with_loop(
694700
video_processor = VideoTransformTrackWithLoop(
695701
asyncio_loop=asyncio_loop,
696702
workflow_configuration=webrtc_request.workflow_configuration,
703+
model_manager=model_manager,
697704
api_key=webrtc_request.api_key,
698705
data_output=data_fields,
699706
stream_output=stream_field,
@@ -708,6 +715,7 @@ async def init_rtc_peer_connection_with_loop(
708715
video_processor = VideoFrameProcessor(
709716
asyncio_loop=asyncio_loop,
710717
workflow_configuration=webrtc_request.workflow_configuration,
718+
model_manager=model_manager,
711719
api_key=webrtc_request.api_key,
712720
data_output=data_fields,
713721
stream_output=None,

0 commit comments

Comments
 (0)