Skip to content

Commit 98d25ed

Browse files
committed
format
1 parent 92570bc commit 98d25ed

File tree

2 files changed

+56
-43
lines changed

2 files changed

+56
-43
lines changed

inference/core/interfaces/webrtc_worker/entities.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@ class WebRTCWorkerRequest(BaseModel):
3939
data_output: Optional[List[str]] = Field(default=None)
4040
declared_fps: Optional[float] = None
4141
rtsp_url: Optional[str] = None
42-
use_data_channel_frames: bool = False # When True, expect frames via data channel instead of media track
42+
use_data_channel_frames: bool = (
43+
False # When True, expect frames via data channel instead of media track
44+
)
4345
processing_timeout: Optional[int] = WEBRTC_MODAL_FUNCTION_TIME_LIMIT
4446
processing_session_started: Optional[datetime.datetime] = None
4547
requested_plan: Optional[str] = "webrtc-gpu-small"

inference/core/interfaces/webrtc_worker/webrtc.py

Lines changed: 53 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -70,64 +70,65 @@ def create_chunked_binary_message(
7070
frame_id: int, chunk_index: int, total_chunks: int, payload: bytes
7171
) -> bytes:
7272
"""Create a binary message with standard 12-byte header.
73-
73+
7474
Format: [frame_id: 4][chunk_index: 4][total_chunks: 4][payload: N]
7575
All integers are uint32 little-endian.
7676
"""
77-
header = struct.pack('<III', frame_id, chunk_index, total_chunks)
77+
header = struct.pack("<III", frame_id, chunk_index, total_chunks)
7878
return header + payload
7979

8080

8181
def parse_chunked_binary_message(message: bytes) -> Tuple[int, int, int, bytes]:
8282
"""Parse a binary message with standard 12-byte header.
83-
83+
8484
Returns: (frame_id, chunk_index, total_chunks, payload)
8585
"""
8686
if len(message) < 12:
8787
raise ValueError(f"Message too short: {len(message)} bytes (expected >= 12)")
88-
89-
frame_id, chunk_index, total_chunks = struct.unpack('<III', message[0:12])
88+
89+
frame_id, chunk_index, total_chunks = struct.unpack("<III", message[0:12])
9090
payload = message[12:]
9191
return frame_id, chunk_index, total_chunks, payload
9292

9393

9494
class ChunkReassembler:
9595
"""Helper to reassemble chunked binary messages."""
96-
96+
9797
def __init__(self):
98-
self._chunks: Dict[int, Dict[int, bytes]] = {} # {frame_id: {chunk_index: data}}
98+
self._chunks: Dict[int, Dict[int, bytes]] = (
99+
{}
100+
) # {frame_id: {chunk_index: data}}
99101
self._total: Dict[int, int] = {} # {frame_id: total_chunks}
100-
102+
101103
def add_chunk(
102104
self, frame_id: int, chunk_index: int, total_chunks: int, chunk_data: bytes
103105
) -> Optional[bytes]:
104106
"""Add a chunk and return complete payload if all chunks received.
105-
107+
106108
Returns:
107109
Complete reassembled payload bytes if all chunks received, None otherwise.
108110
"""
109111
# Initialize buffers for new frame
110112
if frame_id not in self._chunks:
111113
self._chunks[frame_id] = {}
112114
self._total[frame_id] = total_chunks
113-
115+
114116
# Store chunk
115117
self._chunks[frame_id][chunk_index] = chunk_data
116-
118+
117119
# Check if all chunks received
118120
if len(self._chunks[frame_id]) >= total_chunks:
119121
# Reassemble in order
120-
complete_payload = b''.join(
121-
self._chunks[frame_id][i]
122-
for i in range(total_chunks)
122+
complete_payload = b"".join(
123+
self._chunks[frame_id][i] for i in range(total_chunks)
123124
)
124-
125+
125126
# Clean up
126127
del self._chunks[frame_id]
127128
del self._total[frame_id]
128-
129+
129130
return complete_payload
130-
131+
131132
return None
132133

133134

@@ -138,7 +139,7 @@ def send_chunked_data(
138139
chunk_size: int = CHUNK_SIZE,
139140
) -> None:
140141
"""Send payload via data channel, automatically chunking if needed.
141-
142+
142143
Args:
143144
data_channel: RTCDataChannel to send on
144145
frame_id: Frame identifier
@@ -148,19 +149,21 @@ def send_chunked_data(
148149
if data_channel.readyState != "open":
149150
logger.warning(f"Cannot send response for frame {frame_id}, channel not open")
150151
return
151-
152-
total_chunks = (len(payload_bytes) + chunk_size - 1) // chunk_size # Ceiling division
153-
152+
153+
total_chunks = (
154+
len(payload_bytes) + chunk_size - 1
155+
) // chunk_size # Ceiling division
156+
154157
if frame_id % 100 == 1:
155158
logger.info(
156159
f"Sending response for frame {frame_id}: {total_chunks} chunk(s), {len(payload_bytes)} bytes"
157160
)
158-
161+
159162
for chunk_index in range(total_chunks):
160163
start = chunk_index * chunk_size
161164
end = min(start + chunk_size, len(payload_bytes))
162165
chunk_data = payload_bytes[start:end]
163-
166+
164167
message = create_chunked_binary_message(
165168
frame_id, chunk_index, total_chunks, chunk_data
166169
)
@@ -209,7 +212,9 @@ def __init__(
209212
self._stop_processing = False
210213
self.use_data_channel_frames = use_data_channel_frames
211214
self._data_frame_queue: "asyncio.Queue[Optional[VideoFrame]]" = asyncio.Queue()
212-
self._chunk_reassembler = ChunkReassembler() # For reassembling inbound frame chunks
215+
self._chunk_reassembler = (
216+
ChunkReassembler()
217+
) # For reassembling inbound frame chunks
213218

214219
self.has_video_track = has_video_track
215220
self.stream_output = stream_output
@@ -294,7 +299,7 @@ async def _send_data_output(
294299

295300
if self._data_mode == DataOutputMode.NONE:
296301
# Even empty responses use binary protocol
297-
json_bytes = json.dumps(webrtc_output.model_dump()).encode('utf-8')
302+
json_bytes = json.dumps(webrtc_output.model_dump()).encode("utf-8")
298303
send_chunked_data(self.data_channel, self._received_frames, json_bytes)
299304
return
300305

@@ -331,46 +336,48 @@ async def _send_data_output(
331336
webrtc_output.serialized_output_data = serialized_outputs
332337

333338
# Send using binary chunked protocol
334-
json_bytes = json.dumps(webrtc_output.model_dump(mode="json")).encode('utf-8')
339+
json_bytes = json.dumps(webrtc_output.model_dump(mode="json")).encode("utf-8")
335340
send_chunked_data(self.data_channel, self._received_frames, json_bytes)
336341

337342
async def _handle_data_channel_frame(self, message: bytes) -> None:
338343
"""Handle incoming binary frame chunk from upstream_frames data channel.
339-
344+
340345
Uses standard binary protocol with 12-byte header + JPEG chunk payload.
341346
"""
342347
try:
343348
# Parse message
344-
frame_id, chunk_index, total_chunks, jpeg_chunk = parse_chunked_binary_message(message)
345-
349+
frame_id, chunk_index, total_chunks, jpeg_chunk = (
350+
parse_chunked_binary_message(message)
351+
)
352+
346353
# Add chunk and check if complete
347354
jpeg_bytes = self._chunk_reassembler.add_chunk(
348355
frame_id, chunk_index, total_chunks, jpeg_chunk
349356
)
350-
357+
351358
if jpeg_bytes is None:
352359
# Still waiting for more chunks
353360
return
354-
361+
355362
# All chunks received - decode and queue frame
356363
if frame_id % 100 == 1:
357364
logger.info(
358365
f"Received frame {frame_id}: {total_chunks} chunk(s), {len(jpeg_bytes)} bytes JPEG"
359366
)
360-
367+
361368
nparr = np.frombuffer(jpeg_bytes, np.uint8)
362369
np_image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
363-
370+
364371
if np_image is None:
365372
logger.error(f"Failed to decode JPEG for frame {frame_id}")
366373
return
367-
374+
368375
video_frame = VideoFrame.from_ndarray(np_image, format="bgr24")
369376
await self._data_frame_queue.put((frame_id, video_frame))
370-
377+
371378
if frame_id % 100 == 1:
372379
logger.info(f"Queued frame {frame_id}")
373-
380+
374381
except Exception as e:
375382
logger.error(f"Error handling frame chunk: {e}", exc_info=True)
376383

@@ -384,7 +391,9 @@ async def process_frames_data_only(self):
384391
av_logging.set_libav_level(av_logging.ERROR)
385392
self._av_logging_set = True
386393

387-
logger.info(f"Starting data-only frame processing (use_data_channel_frames={self.use_data_channel_frames})")
394+
logger.info(
395+
f"Starting data-only frame processing (use_data_channel_frames={self.use_data_channel_frames})"
396+
)
388397

389398
try:
390399
while not self._stop_processing:
@@ -404,7 +413,7 @@ async def process_frames_data_only(self):
404413
# Get frame from media track (existing behavior)
405414
if not self.track or self.track.readyState == "ended":
406415
break
407-
416+
408417
# Drain queue if using PlayerStreamTrack (RTSP)
409418
if isinstance(self.track, PlayerStreamTrack):
410419
while self.track._queue.qsize() > 30:
@@ -843,16 +852,18 @@ def on_datachannel(channel: RTCDataChannel):
843852

844853
# Handle upstream frames channel (client sending frames to server)
845854
if channel.label == "upstream_frames":
846-
logger.info("Upstream frames channel established, starting data-only processing")
847-
855+
logger.info(
856+
"Upstream frames channel established, starting data-only processing"
857+
)
858+
848859
@channel.on("message")
849860
def on_frame_message(message):
850861
asyncio.create_task(video_processor._handle_data_channel_frame(message))
851-
862+
852863
# Start processing immediately since we won't get a media track
853864
if webrtc_request.use_data_channel_frames and not should_send_video:
854865
asyncio.create_task(video_processor.process_frames_data_only())
855-
866+
856867
return
857868

858869
# Handle inference control channel (bidirectional communication)

0 commit comments

Comments
 (0)