1616 MODELS_CACHE_AUTH_CACHE_TTL ,
1717 MODELS_CACHE_AUTH_ENABLED ,
1818 PRELOAD_HF_IDS ,
19+ PRELOAD_MODELS ,
1920 PROJECT ,
2021 ROBOFLOW_INTERNAL_SERVICE_SECRET ,
2122 WEBRTC_MODAL_APP_NAME ,
4647from inference .core .interfaces .webrtc_worker .webrtc import (
4748 init_rtc_peer_connection_with_loop ,
4849)
50+ from inference .core .managers .base import ModelManager
51+ from inference .core .registries .roboflow import RoboflowModelRegistry
4952from inference .core .roboflow_api import get_roboflow_workspace
5053from inference .core .version import __version__
54+ from inference .models .aliases import resolve_roboflow_model_alias
55+ from inference .models .utils import ROBOFLOW_MODEL_TYPES
5156from inference .usage_tracking .collector import usage_collector
5257from inference .usage_tracking .plan_details import WebRTCPlan
5358
109114 "MODELS_CACHE_AUTH_ENABLED" : str (MODELS_CACHE_AUTH_ENABLED ),
110115 "LOG_LEVEL" : LOG_LEVEL ,
111116 "ONNXRUNTIME_EXECUTION_PROVIDERS" : "[CUDAExecutionProvider,CPUExecutionProvider]" ,
112- "PRELOAD_HF_IDS" : PRELOAD_HF_IDS ,
117+ "PRELOAD_HF_IDS" : "," .join (PRELOAD_HF_IDS ) if PRELOAD_HF_IDS else "" ,
118+ "PRELOAD_MODELS" : "," .join (PRELOAD_MODELS ) if PRELOAD_MODELS else "" ,
113119 "PROJECT" : PROJECT ,
114120 "ROBOFLOW_INTERNAL_SERVICE_NAME" : WEBRTC_MODAL_ROBOFLOW_INTERNAL_SERVICE_NAME ,
115121 "ROBOFLOW_INTERNAL_SERVICE_SECRET" : ROBOFLOW_INTERNAL_SERVICE_SECRET ,
135141 }
136142
137143 class RTCPeerConnectionModal :
138- _webrtc_request : Optional [WebRTCWorkerRequest ] = modal .parameter (default = None )
144+ _model_manager : Optional [ModelManager ] = modal .parameter (default = None )
139145
140146 @modal .method ()
141147 def rtc_peer_connection_modal (
@@ -145,6 +151,14 @@ def rtc_peer_connection_modal(
145151 ):
146152 logger .info ("*** Spawning %s:" , self .__class__ .__name__ )
147153 logger .info ("Inference tag: %s" , docker_tag )
154+ logger .info (
155+ "Preloaded models: %s" ,
156+ (
157+ ", " .join (self ._model_manager .models ().keys ())
158+ if self ._model_manager
159+ else ""
160+ ),
161+ )
148162 _exec_session_started = datetime .datetime .now ()
149163 webrtc_request .processing_session_started = _exec_session_started
150164 logger .info (
@@ -170,7 +184,6 @@ def rtc_peer_connection_modal(
170184 else []
171185 ),
172186 )
173- self ._webrtc_request = webrtc_request
174187
175188 def send_answer (obj : WebRTCWorkerResult ):
176189 logger .info ("Sending webrtc answer" )
@@ -180,35 +193,36 @@ def send_answer(obj: WebRTCWorkerResult):
180193 init_rtc_peer_connection_with_loop (
181194 webrtc_request = webrtc_request ,
182195 send_answer = send_answer ,
196+ model_manager = self ._model_manager ,
183197 )
184198 )
185199 _exec_session_stopped = datetime .datetime .now ()
186200 logger .info (
187201 "WebRTC session stopped at %s" ,
188202 _exec_session_stopped .isoformat (),
189203 )
190- workflow_id = self . _webrtc_request .workflow_configuration .workflow_id
204+ workflow_id = webrtc_request .workflow_configuration .workflow_id
191205 if not workflow_id :
192- if self . _webrtc_request .workflow_configuration .workflow_specification :
206+ if webrtc_request .workflow_configuration .workflow_specification :
193207 workflow_id = usage_collector ._calculate_resource_hash (
194- resource_details = self . _webrtc_request .workflow_configuration .workflow_specification
208+ resource_details = webrtc_request .workflow_configuration .workflow_specification
195209 )
196210 else :
197211 workflow_id = "unknown"
198212
199213 # requested plan is guaranteed to be set due to validation in spawn_rtc_peer_connection_modal
200- webrtc_plan = self . _webrtc_request .requested_plan
214+ webrtc_plan = webrtc_request .requested_plan
201215
202216 video_source = "realtime browser stream"
203- if self . _webrtc_request .rtsp_url :
217+ if webrtc_request .rtsp_url :
204218 video_source = "rtsp"
205- elif not self . _webrtc_request .webrtc_realtime_processing :
219+ elif not webrtc_request .webrtc_realtime_processing :
206220 video_source = "buffered browser stream"
207221
208222 usage_collector .record_usage (
209223 source = workflow_id ,
210224 category = "modal" ,
211- api_key = self . _webrtc_request .api_key ,
225+ api_key = webrtc_request .api_key ,
212226 resource_details = {
213227 "plan" : webrtc_plan ,
214228 "billable" : True ,
@@ -221,13 +235,6 @@ def send_answer(obj: WebRTCWorkerResult):
221235 usage_collector .push_usage_payloads ()
222236 logger .info ("Function completed" )
223237
224- # https://modal.com/docs/reference/modal.enter
225- # https://modal.com/docs/guide/memory-snapshot#gpu-memory-snapshot
226- @modal .enter (snap = True )
227- def start (self ):
228- # TODO: pre-load models
229- logger .info ("Starting container" )
230-
231238 @modal .exit ()
232239 def stop (self ):
233240 logger .info ("Stopping container" )
@@ -238,7 +245,11 @@ def stop(self):
238245 ** decorator_kwargs ,
239246 )
240247 class RTCPeerConnectionModalCPU (RTCPeerConnectionModal ):
241- pass
248+ # https://modal.com/docs/reference/modal.enter
249+ @modal .enter (snap = True )
250+ def start (self ):
251+ # TODO: pre-load models on CPU
252+ logger .info ("Starting CPU container" )
242253
243254 @app .cls (
244255 ** {
@@ -250,7 +261,39 @@ class RTCPeerConnectionModalCPU(RTCPeerConnectionModal):
250261 }
251262 )
252263 class RTCPeerConnectionModalGPU (RTCPeerConnectionModal ):
253- pass
264+ # https://modal.com/docs/reference/modal.enter
265+ # https://modal.com/docs/guide/memory-snapshot#gpu-memory-snapshot
266+ @modal .enter (snap = True )
267+ def start (self ):
268+ logger .info ("Starting GPU container" )
269+ logger .info ("Preload hf ids: %s" , PRELOAD_HF_IDS )
270+ logger .info ("Preload models: %s" , PRELOAD_MODELS )
271+ if PRELOAD_HF_IDS :
272+ # Kick off pre-loading of models (owlv2 preloading is based on module-level singleton)
273+ logger .info ("Preloading owlv2 base model" )
274+ import inference .models .owlv2 .owlv2
275+ if PRELOAD_MODELS :
276+ model_registry = RoboflowModelRegistry (ROBOFLOW_MODEL_TYPES )
277+ model_manager = ModelManager (model_registry = model_registry )
278+ for model_id in PRELOAD_MODELS :
279+ try :
280+ de_aliased_model_id = resolve_roboflow_model_alias (
281+ model_id = model_id
282+ )
283+ logger .info (f"Preloading model: { de_aliased_model_id } " )
284+ model_manager .add_model (
285+ model_id = de_aliased_model_id ,
286+ api_key = None ,
287+ countinference = False ,
288+ service_secret = ROBOFLOW_INTERNAL_SERVICE_SECRET ,
289+ )
290+ except Exception as exc :
291+ logger .error (
292+ "Failed to preload model %s: %s" ,
293+ model_id ,
294+ exc ,
295+ )
296+ self ._model_manager = model_manager
254297
255298 def spawn_rtc_peer_connection_modal (
256299 webrtc_request : WebRTCWorkerRequest ,
0 commit comments