Skip to content

Commit 46d9f1d

Browse files
Merge pull request #1684 from roboflow/feat/webrtc_worker-timeout
Make webrtc_worker time out gracefully if WEBRTC_MODAL_FUNCTION_TIME_LIMIT is set
2 parents a5f2419 + 82fa9be commit 46d9f1d

File tree

2 files changed

+35
-4
lines changed

2 files changed

+35
-4
lines changed

inference/core/env.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -695,7 +695,9 @@
695695
WEBRTC_MODAL_APP_NAME = os.getenv(
696696
"WEBRTC_MODAL_APP_NAME", f"inference-webrtc-{PROJECT}"
697697
)
698+
# seconds
698699
WEBRTC_MODAL_RESPONSE_TIMEOUT = int(os.getenv("WEBRTC_MODAL_RESPONSE_TIMEOUT", "60"))
700+
# seconds
699701
WEBRTC_MODAL_FUNCTION_TIME_LIMIT = int(
700702
os.getenv("WEBRTC_MODAL_FUNCTION_TIME_LIMIT", "60")
701703
)
@@ -716,6 +718,7 @@
716718
WEBRTC_MODAL_FUNCTION_BUFFER_CONTAINERS = int(
717719
os.getenv("WEBRTC_MODAL_FUNCTION_BUFFER_CONTAINERS", "0")
718720
)
721+
# seconds
719722
WEBRTC_MODAL_FUNCTION_SCALEDOWN_WINDOW = int(
720723
os.getenv("WEBRTC_MODAL_FUNCTION_SCALEDOWN_WINDOW", "15")
721724
)

inference/core/interfaces/webrtc_worker/webrtc.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from inference.core import logger
2323
from inference.core.env import (
24+
WEBRTC_MODAL_FUNCTION_TIME_LIMIT,
2425
WEBRTC_MODAL_RTSP_PLACEHOLDER,
2526
WEBRTC_MODAL_RTSP_PLACEHOLDER_URL,
2627
)
@@ -47,6 +48,7 @@
4748
)
4849
from inference.core.workflows.errors import WorkflowSyntaxError
4950
from inference.core.workflows.execution_engine.entities.base import WorkflowImageData
51+
from inference.usage_tracking.collector import usage_collector
5052

5153
logging.getLogger("aiortc").setLevel(logging.WARNING)
5254

@@ -71,12 +73,15 @@ def __init__(
7173
data_output: Optional[str] = None,
7274
stream_output: Optional[str] = None,
7375
declared_fps: float = 30,
76+
termination_date: Optional[datetime.datetime] = None,
77+
terminate_event: Optional[asyncio.Event] = None,
7478
*args,
7579
**kwargs,
7680
):
7781
super().__init__(*args, **kwargs)
7882
self._loop = asyncio_loop
79-
83+
self._termination_date = termination_date
84+
self._terminate_event = terminate_event
8085
self.track: Optional[RemoteStreamTrack] = None
8186
self._track_active: bool = False
8287

@@ -116,6 +121,15 @@ async def recv(self):
116121
av_logging.set_libav_level(av_logging.ERROR)
117122
self._av_logging_set = True
118123

124+
if (
125+
self._termination_date
126+
and self._termination_date < datetime.datetime.now()
127+
and self._terminate_event
128+
and not self._terminate_event.is_set()
129+
):
130+
logger.info("Timeout reached, terminating inference pipeline")
131+
self._terminate_event.set()
132+
119133
frame: VideoFrame = await self.track.recv()
120134

121135
self._received_frames += 1
@@ -200,6 +214,18 @@ async def init_rtc_peer_connection_with_loop(
200214
send_answer: Callable[[WebRTCWorkerResult], None],
201215
asyncio_loop: Optional[asyncio.AbstractEventLoop] = None,
202216
) -> RTCPeerConnectionWithLoop:
217+
termination_date = None
218+
terminate_event = asyncio.Event()
219+
220+
if WEBRTC_MODAL_FUNCTION_TIME_LIMIT is not None:
221+
try:
222+
time_limit_seconds = int(WEBRTC_MODAL_FUNCTION_TIME_LIMIT)
223+
termination_date = datetime.datetime.now() + datetime.timedelta(
224+
seconds=time_limit_seconds - 1
225+
)
226+
logger.info("Setting termination date to %s", termination_date)
227+
except (TypeError, ValueError):
228+
pass
203229
stream_output = None
204230
if webrtc_request.stream_output:
205231
# TODO: UI sends None as stream_output for wildcard outputs
@@ -216,6 +242,8 @@ async def init_rtc_peer_connection_with_loop(
216242
data_output=data_output,
217243
stream_output=stream_output,
218244
declared_fps=webrtc_request.declared_fps,
245+
termination_date=termination_date,
246+
terminate_event=terminate_event,
219247
)
220248
except (
221249
ValidationError,
@@ -280,7 +308,6 @@ async def init_rtc_peer_connection_with_loop(
280308
asyncio_loop=asyncio_loop,
281309
)
282310

283-
closed = asyncio.Event()
284311
relay = MediaRelay()
285312

286313
player: Optional[MediaPlayer] = None
@@ -324,7 +351,7 @@ async def on_connectionstatechange():
324351
video_transform_track.track.stop()
325352
logger.info("Stopping WebRTC peer")
326353
await peer_connection.close()
327-
closed.set()
354+
terminate_event.set()
328355
logger.info("'connectionstatechange' event handler finished")
329356

330357
@peer_connection.on("datachannel")
@@ -367,7 +394,7 @@ def on_message(message):
367394
)
368395
)
369396

370-
await closed.wait()
397+
await terminate_event.wait()
371398
if player:
372399
logger.info("Stopping player")
373400
player.video.stop()
@@ -377,4 +404,5 @@ def on_message(message):
377404
if video_transform_track.track:
378405
logger.info("Stopping video transform track")
379406
video_transform_track.track.stop()
407+
await usage_collector.async_push_usage_payloads()
380408
logger.info("WebRTC peer connection closed")

0 commit comments

Comments
 (0)