Skip to content

Commit c6ab823

Browse files
Merge pull request #1704 from roboflow/webrtc-transport-datachannel
Workflows Streaming - Add transport via datachannel for guaranteed frame processing
2 parents d67dac6 + c06a52a commit c6ab823

File tree

2 files changed

+214
-15
lines changed

2 files changed

+214
-15
lines changed

inference/core/interfaces/webrtc_worker/entities.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ class WebRTCWorkerRequest(BaseModel):
3939
data_output: Optional[List[str]] = Field(default=None)
4040
declared_fps: Optional[float] = None
4141
rtsp_url: Optional[str] = None
42+
use_data_channel_frames: bool = (
43+
False # When True, expect frames via data channel instead of media track
44+
)
4245
processing_timeout: Optional[int] = WEBRTC_MODAL_FUNCTION_TIME_LIMIT
4346
processing_session_started: Optional[datetime.datetime] = None
4447
requested_plan: Optional[str] = "webrtc-gpu-small"

inference/core/interfaces/webrtc_worker/webrtc.py

Lines changed: 211 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22
import datetime
33
import json
44
import logging
5+
import struct
56
from typing import Any, Callable, Dict, List, Optional, Tuple
67

8+
import cv2
9+
import numpy as np
710
from aiortc import (
811
RTCConfiguration,
912
RTCDataChannel,
@@ -60,6 +63,113 @@
6063

6164
logging.getLogger("aiortc").setLevel(logging.WARNING)
6265

66+
# WebRTC data channel chunking configuration
67+
CHUNK_SIZE = 48 * 1024 # 48KB - safe for all WebRTC implementations
68+
69+
70+
def create_chunked_binary_message(
71+
frame_id: int, chunk_index: int, total_chunks: int, payload: bytes
72+
) -> bytes:
73+
"""Create a binary message with standard 12-byte header.
74+
75+
Format: [frame_id: 4][chunk_index: 4][total_chunks: 4][payload: N]
76+
All integers are uint32 little-endian.
77+
"""
78+
header = struct.pack("<III", frame_id, chunk_index, total_chunks)
79+
return header + payload
80+
81+
82+
def parse_chunked_binary_message(message: bytes) -> Tuple[int, int, int, bytes]:
83+
"""Parse a binary message with standard 12-byte header.
84+
85+
Returns: (frame_id, chunk_index, total_chunks, payload)
86+
"""
87+
if len(message) < 12:
88+
raise ValueError(f"Message too short: {len(message)} bytes (expected >= 12)")
89+
90+
frame_id, chunk_index, total_chunks = struct.unpack("<III", message[0:12])
91+
payload = message[12:]
92+
return frame_id, chunk_index, total_chunks, payload
93+
94+
95+
class ChunkReassembler:
96+
"""Helper to reassemble chunked binary messages."""
97+
98+
def __init__(self):
99+
self._chunks: Dict[int, Dict[int, bytes]] = (
100+
{}
101+
) # {frame_id: {chunk_index: data}}
102+
self._total: Dict[int, int] = {} # {frame_id: total_chunks}
103+
104+
def add_chunk(
105+
self, frame_id: int, chunk_index: int, total_chunks: int, chunk_data: bytes
106+
) -> Optional[bytes]:
107+
"""Add a chunk and return complete payload if all chunks received.
108+
109+
Returns:
110+
Complete reassembled payload bytes if all chunks received, None otherwise.
111+
"""
112+
# Initialize buffers for new frame
113+
if frame_id not in self._chunks:
114+
self._chunks[frame_id] = {}
115+
self._total[frame_id] = total_chunks
116+
117+
# Store chunk
118+
self._chunks[frame_id][chunk_index] = chunk_data
119+
120+
# Check if all chunks received
121+
if len(self._chunks[frame_id]) >= total_chunks:
122+
# Reassemble in order
123+
complete_payload = b"".join(
124+
self._chunks[frame_id][i] for i in range(total_chunks)
125+
)
126+
127+
# Clean up
128+
del self._chunks[frame_id]
129+
del self._total[frame_id]
130+
131+
return complete_payload
132+
133+
return None
134+
135+
136+
def send_chunked_data(
137+
data_channel: RTCDataChannel,
138+
frame_id: int,
139+
payload_bytes: bytes,
140+
chunk_size: int = CHUNK_SIZE,
141+
) -> None:
142+
"""Send payload via data channel, automatically chunking if needed.
143+
144+
Args:
145+
data_channel: RTCDataChannel to send on
146+
frame_id: Frame identifier
147+
payload_bytes: Data to send (JPEG, JSON UTF-8, etc.)
148+
chunk_size: Maximum chunk size (default 48KB)
149+
"""
150+
if data_channel.readyState != "open":
151+
logger.warning(f"Cannot send response for frame {frame_id}, channel not open")
152+
return
153+
154+
total_chunks = (
155+
len(payload_bytes) + chunk_size - 1
156+
) // chunk_size # Ceiling division
157+
158+
if frame_id % 100 == 1:
159+
logger.info(
160+
f"Sending response for frame {frame_id}: {total_chunks} chunk(s), {len(payload_bytes)} bytes"
161+
)
162+
163+
for chunk_index in range(total_chunks):
164+
start = chunk_index * chunk_size
165+
end = min(start + chunk_size, len(payload_bytes))
166+
chunk_data = payload_bytes[start:end]
167+
168+
message = create_chunked_binary_message(
169+
frame_id, chunk_index, total_chunks, chunk_data
170+
)
171+
data_channel.send(message)
172+
63173

64174
class RTCPeerConnectionWithLoop(RTCPeerConnection):
65175
def __init__(
@@ -91,6 +201,7 @@ def __init__(
91201
declared_fps: float = 30,
92202
termination_date: Optional[datetime.datetime] = None,
93203
terminate_event: Optional[asyncio.Event] = None,
204+
use_data_channel_frames: bool = False,
94205
):
95206
self._loop = asyncio_loop
96207
self._termination_date = termination_date
@@ -101,6 +212,11 @@ def __init__(
101212
self._received_frames = 0
102213
self._declared_fps = declared_fps
103214
self._stop_processing = False
215+
self.use_data_channel_frames = use_data_channel_frames
216+
self._data_frame_queue: "asyncio.Queue[Optional[VideoFrame]]" = asyncio.Queue()
217+
self._chunk_reassembler = (
218+
ChunkReassembler()
219+
) # For reassembling inbound frame chunks
104220

105221
self.has_video_track = has_video_track
106222
self.stream_output = stream_output
@@ -185,7 +301,9 @@ async def _send_data_output(
185301
)
186302

187303
if self._data_mode == DataOutputMode.NONE:
188-
self.data_channel.send(json.dumps(webrtc_output.model_dump()))
304+
# Even empty responses use binary protocol
305+
json_bytes = json.dumps(webrtc_output.model_dump()).encode("utf-8")
306+
send_chunked_data(self.data_channel, self._received_frames, json_bytes)
189307
return
190308

191309
if self._data_mode == DataOutputMode.ALL:
@@ -216,11 +334,55 @@ async def _send_data_output(
216334
webrtc_output.errors.append(f"{field_name}: {e}")
217335
serialized_outputs[field_name] = {"__serialization_error__": str(e)}
218336

219-
# Only set serialized_output_data if we have data to send
337+
# Set serialized outputs
220338
if serialized_outputs:
221339
webrtc_output.serialized_output_data = serialized_outputs
222340

223-
self.data_channel.send(json.dumps(webrtc_output.model_dump(mode="json")))
341+
# Send using binary chunked protocol
342+
json_bytes = json.dumps(webrtc_output.model_dump(mode="json")).encode("utf-8")
343+
send_chunked_data(self.data_channel, self._received_frames, json_bytes)
344+
345+
async def _handle_data_channel_frame(self, message: bytes) -> None:
346+
"""Handle incoming binary frame chunk from upstream_frames data channel.
347+
348+
Uses standard binary protocol with 12-byte header + JPEG chunk payload.
349+
"""
350+
try:
351+
# Parse message
352+
frame_id, chunk_index, total_chunks, jpeg_chunk = (
353+
parse_chunked_binary_message(message)
354+
)
355+
356+
# Add chunk and check if complete
357+
jpeg_bytes = self._chunk_reassembler.add_chunk(
358+
frame_id, chunk_index, total_chunks, jpeg_chunk
359+
)
360+
361+
if jpeg_bytes is None:
362+
# Still waiting for more chunks
363+
return
364+
365+
# All chunks received - decode and queue frame
366+
if frame_id % 100 == 1:
367+
logger.info(
368+
f"Received frame {frame_id}: {total_chunks} chunk(s), {len(jpeg_bytes)} bytes JPEG"
369+
)
370+
371+
nparr = np.frombuffer(jpeg_bytes, np.uint8)
372+
np_image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
373+
374+
if np_image is None:
375+
logger.error(f"Failed to decode JPEG for frame {frame_id}")
376+
return
377+
378+
video_frame = VideoFrame.from_ndarray(np_image, format="bgr24")
379+
await self._data_frame_queue.put((frame_id, video_frame))
380+
381+
if frame_id % 100 == 1:
382+
logger.info(f"Queued frame {frame_id}")
383+
384+
except Exception as e:
385+
logger.error(f"Error handling frame chunk: {e}", exc_info=True)
224386

225387
async def process_frames_data_only(self):
226388
"""Process frames for data extraction only, without video track output.
@@ -232,24 +394,37 @@ async def process_frames_data_only(self):
232394
av_logging.set_libav_level(av_logging.ERROR)
233395
self._av_logging_set = True
234396

235-
logger.info("Starting data-only frame processing")
397+
logger.info(
398+
f"Starting data-only frame processing (use_data_channel_frames={self.use_data_channel_frames})"
399+
)
236400

237401
try:
238-
while (
239-
self.track
240-
and self.track.readyState != "ended"
241-
and not self._stop_processing
242-
):
402+
while not self._stop_processing:
243403
if self._check_termination():
244404
break
245405

246-
# Drain queue if using PlayerStreamTrack (RTSP)
247-
if isinstance(self.track, PlayerStreamTrack):
248-
while self.track._queue.qsize() > 30:
249-
self.track._queue.get_nowait()
406+
# Get frame from appropriate source
407+
if self.use_data_channel_frames:
408+
# Wait for frame from data channel queue
409+
item = await self._data_frame_queue.get()
410+
if item is None:
411+
logger.info("Received stop signal from data channel")
412+
break
413+
frame_id, frame = item
414+
self._received_frames = frame_id
415+
else:
416+
# Get frame from media track (existing behavior)
417+
if not self.track or self.track.readyState == "ended":
418+
break
419+
420+
# Drain queue if using PlayerStreamTrack (RTSP)
421+
if isinstance(self.track, PlayerStreamTrack):
422+
while self.track._queue.qsize() > 30:
423+
self.track._queue.get_nowait()
424+
425+
frame = await self.track.recv()
426+
self._received_frames += 1
250427

251-
frame: VideoFrame = await self.track.recv()
252-
self._received_frames += 1
253428
frame_timestamp = datetime.datetime.now()
254429

255430
workflow_output, _, errors = await self._process_frame_async(
@@ -372,6 +547,7 @@ def __init__(
372547
declared_fps: float = 30,
373548
termination_date: Optional[datetime.datetime] = None,
374549
terminate_event: Optional[asyncio.Event] = None,
550+
use_data_channel_frames: bool = False,
375551
*args,
376552
**kwargs,
377553
):
@@ -387,6 +563,7 @@ def __init__(
387563
declared_fps=declared_fps,
388564
termination_date=termination_date,
389565
terminate_event=terminate_event,
566+
use_data_channel_frames=use_data_channel_frames,
390567
model_manager=model_manager,
391568
)
392569

@@ -531,6 +708,7 @@ async def init_rtc_peer_connection_with_loop(
531708
declared_fps=webrtc_request.declared_fps,
532709
termination_date=termination_date,
533710
terminate_event=terminate_event,
711+
use_data_channel_frames=webrtc_request.use_data_channel_frames,
534712
)
535713
else:
536714
# No video track - use base VideoFrameProcessor
@@ -545,6 +723,7 @@ async def init_rtc_peer_connection_with_loop(
545723
declared_fps=webrtc_request.declared_fps,
546724
termination_date=termination_date,
547725
terminate_event=terminate_event,
726+
use_data_channel_frames=webrtc_request.use_data_channel_frames,
548727
)
549728
except (
550729
ValidationError,
@@ -679,6 +858,23 @@ async def on_connectionstatechange():
679858
def on_datachannel(channel: RTCDataChannel):
680859
logger.info("Data channel '%s' received", channel.label)
681860

861+
# Handle upstream frames channel (client sending frames to server)
862+
if channel.label == "upstream_frames":
863+
logger.info(
864+
"Upstream frames channel established, starting data-only processing"
865+
)
866+
867+
@channel.on("message")
868+
def on_frame_message(message):
869+
asyncio.create_task(video_processor._handle_data_channel_frame(message))
870+
871+
# Start processing immediately since we won't get a media track
872+
if webrtc_request.use_data_channel_frames and not should_send_video:
873+
asyncio.create_task(video_processor.process_frames_data_only())
874+
875+
return
876+
877+
# Handle inference control channel (bidirectional communication)
682878
@channel.on("message")
683879
def on_message(message):
684880
logger.info("Data channel message received: %s", message)

0 commit comments

Comments
 (0)