Skip to content

Commit 42c6829

Browse files
authored
feat: add bi-directional streaming dealer and router zmq clients (#494)
Signed-off-by: Anthony Casagrande <[email protected]>
1 parent 563d9d8 commit 42c6829

File tree

11 files changed

+1605
-7
lines changed

11 files changed

+1605
-7
lines changed

src/aiperf/common/base_comms.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33
from abc import ABC, abstractmethod
4-
from typing import cast
4+
from typing import Any, cast
55

66
from aiperf.common.decorators import implements_protocol
77
from aiperf.common.enums import CommClientType
@@ -14,6 +14,8 @@
1414
PushClientProtocol,
1515
ReplyClientProtocol,
1616
RequestClientProtocol,
17+
StreamingDealerClientProtocol,
18+
StreamingRouterClientProtocol,
1719
SubClientProtocol,
1820
)
1921
from aiperf.common.types import CommAddressType
@@ -42,6 +44,7 @@ def create_client(
4244
bind: bool = False,
4345
socket_ops: dict | None = None,
4446
max_pull_concurrency: int | None = None,
47+
**kwargs: Any,
4548
) -> CommunicationClientProtocol:
4649
"""Create a communication client for a given client type and address.
4750
@@ -51,6 +54,7 @@ def create_client(
5154
bind: Whether to bind or connect the socket.
5255
socket_ops: Additional socket options to set.
5356
max_pull_concurrency: The maximum number of concurrent pull requests to allow. (Only used for pull clients)
57+
**kwargs: Additional keyword arguments passed to specific client types (e.g., identity for DEALER).
5458
"""
5559

5660
def create_pub_client(
@@ -125,3 +129,35 @@ def create_reply_client(
125129
ReplyClientProtocol,
126130
self.create_client(CommClientType.REPLY, address, bind, socket_ops),
127131
)
132+
133+
def create_streaming_router_client(
134+
self,
135+
address: CommAddressType,
136+
bind: bool = True,
137+
socket_ops: dict | None = None,
138+
) -> StreamingRouterClientProtocol:
139+
return cast(
140+
StreamingRouterClientProtocol,
141+
self.create_client(
142+
CommClientType.STREAMING_ROUTER, address, bind, socket_ops
143+
),
144+
)
145+
146+
def create_streaming_dealer_client(
147+
self,
148+
address: CommAddressType,
149+
identity: str,
150+
bind: bool = False,
151+
socket_ops: dict | None = None,
152+
) -> StreamingDealerClientProtocol:
153+
# Identity must be passed through client_kwargs since it's specific to DEALER
154+
return cast(
155+
StreamingDealerClientProtocol,
156+
self.create_client(
157+
CommClientType.STREAMING_DEALER,
158+
address,
159+
bind,
160+
socket_ops,
161+
identity=identity,
162+
),
163+
)

src/aiperf/common/enums/communication_enums.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ class CommClientType(CaseInsensitiveStrEnum):
1616
PULL = "pull"
1717
REQUEST = "request"
1818
REPLY = "reply"
19+
STREAMING_ROUTER = "streaming_router"
20+
STREAMING_DEALER = "streaming_dealer"
1921

2022

2123
class CommAddress(CaseInsensitiveStrEnum):

src/aiperf/common/protocols.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,59 @@ async def request_async(
181181
) -> None: ...
182182

183183

184+
@runtime_checkable
185+
class StreamingRouterClientProtocol(CommunicationClientProtocol, Protocol):
186+
"""Protocol for ROUTER socket client with bidirectional streaming."""
187+
188+
def register_receiver(
189+
self,
190+
handler: Callable[[str, MessageT], Coroutine[Any, Any, None]],
191+
) -> None:
192+
"""
193+
Register handler for incoming messages from DEALER clients.
194+
195+
Args:
196+
handler: Async function that takes (identity: str, message: Message)
197+
"""
198+
...
199+
200+
async def send_to(self, identity: str, message: MessageT) -> None:
201+
"""
202+
Send message to specific DEALER client by identity.
203+
204+
Args:
205+
identity: The DEALER client's identity (routing key)
206+
message: The message to send
207+
"""
208+
...
209+
210+
211+
@runtime_checkable
212+
class StreamingDealerClientProtocol(CommunicationClientProtocol, Protocol):
213+
"""Protocol for DEALER socket client with bidirectional streaming."""
214+
215+
def register_receiver(
216+
self,
217+
handler: Callable[[MessageT], Coroutine[Any, Any, None]],
218+
) -> None:
219+
"""
220+
Register handler for incoming messages from ROUTER.
221+
222+
Args:
223+
handler: Async function that takes (message: Message)
224+
"""
225+
...
226+
227+
async def send(self, message: MessageT) -> None:
228+
"""
229+
Send message to ROUTER.
230+
231+
Args:
232+
message: The message to send
233+
"""
234+
...
235+
236+
184237
@runtime_checkable
185238
class SubClientProtocol(CommunicationClientProtocol, Protocol):
186239
async def subscribe(
@@ -217,6 +270,7 @@ def create_client(
217270
bind: bool = False,
218271
socket_ops: dict | None = None,
219272
max_pull_concurrency: int | None = None,
273+
**kwargs: Any,
220274
) -> CommunicationClientProtocol:
221275
"""Create a client for the given client type and address, which will be automatically
222276
started and stopped with the CommunicationProtocol instance."""
@@ -283,6 +337,27 @@ def create_reply_client(
283337
started and stopped with the CommunicationProtocol instance."""
284338
...
285339

340+
def create_streaming_router_client(
341+
self,
342+
address: CommAddressType,
343+
bind: bool = True,
344+
socket_ops: dict | None = None,
345+
) -> StreamingRouterClientProtocol:
346+
"""Create a STREAMING_ROUTER client for the given address, which will be automatically
347+
started and stopped with the CommunicationProtocol instance."""
348+
...
349+
350+
def create_streaming_dealer_client(
351+
self,
352+
address: CommAddressType,
353+
identity: str,
354+
bind: bool = False,
355+
socket_ops: dict | None = None,
356+
) -> StreamingDealerClientProtocol:
357+
"""Create a STREAMING_DEALER client for the given address and identity, which will be automatically
358+
started and stopped with the CommunicationProtocol instance."""
359+
...
360+
286361

287362
@runtime_checkable
288363
class MessageBusClientProtocol(PubClientProtocol, SubClientProtocol, Protocol):

src/aiperf/zmq/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@
2323
from aiperf.zmq.router_reply_client import (
2424
ZMQRouterReplyClient,
2525
)
26+
from aiperf.zmq.streaming_dealer_client import (
27+
ZMQStreamingDealerClient,
28+
)
29+
from aiperf.zmq.streaming_router_client import (
30+
ZMQStreamingRouterClient,
31+
)
2632
from aiperf.zmq.sub_client import (
2733
ZMQSubClient,
2834
)
@@ -71,6 +77,8 @@
7177
"ZMQPushPullProxy",
7278
"ZMQRouterReplyClient",
7379
"ZMQSocketDefaults",
80+
"ZMQStreamingDealerClient",
81+
"ZMQStreamingRouterClient",
7482
"ZMQSubClient",
7583
"ZMQTCPCommunication",
7684
"ZMQXPubXSubProxy",
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Streaming DEALER client for bidirectional communication with ROUTER."""
5+
6+
import asyncio
7+
from collections.abc import Awaitable, Callable
8+
9+
import zmq
10+
11+
from aiperf.common.decorators import implements_protocol
12+
from aiperf.common.enums import CommClientType
13+
from aiperf.common.factories import CommunicationClientFactory
14+
from aiperf.common.hooks import background_task, on_stop
15+
from aiperf.common.messages import Message
16+
from aiperf.common.protocols import StreamingDealerClientProtocol
17+
from aiperf.common.utils import yield_to_event_loop
18+
from aiperf.zmq.zmq_base_client import BaseZMQClient
19+
20+
21+
@implements_protocol(StreamingDealerClientProtocol)
22+
@CommunicationClientFactory.register(CommClientType.STREAMING_DEALER)
23+
class ZMQStreamingDealerClient(BaseZMQClient):
24+
"""
25+
ZMQ DEALER socket client for bidirectional streaming with ROUTER.
26+
27+
Unlike ZMQDealerRequestClient (request-response pattern), this client is
28+
designed for streaming scenarios where messages flow bidirectionally without
29+
request-response pairing.
30+
31+
The DEALER socket sets an identity which allows the ROUTER to send messages back
32+
to this specific DEALER instance.
33+
34+
ASCII Diagram:
35+
┌──────────────┐ ┌──────────────┐
36+
│ DEALER │◄──── Stream ──────►│ ROUTER │
37+
│ (Worker) │ │ (Manager) │
38+
│ │ │ │
39+
└──────────────┘ └──────────────┘
40+
41+
Usage Pattern:
42+
- DEALER connects to ROUTER with a unique identity
43+
- DEALER sends messages to ROUTER
44+
- DEALER receives messages from ROUTER (routed by identity)
45+
- No request-response pairing - pure streaming
46+
- Supports concurrent message processing
47+
48+
Example:
49+
```python
50+
# Create via comms (recommended - handles lifecycle management)
51+
dealer = comms.create_streaming_dealer_client(
52+
address=CommAddress.CREDIT_ROUTER, # or "tcp://localhost:5555"
53+
identity="worker-1",
54+
)
55+
56+
async def handle_message(message: Message) -> None:
57+
if message.message_type == MessageType.CREDIT_DROP:
58+
do_some_work(message.credit)
59+
await dealer.send(CreditReturnMessage(...))
60+
61+
dealer.register_receiver(handle_message)
62+
63+
# Lifecycle managed by comms - initialize/start/stop comms instead
64+
await comms.initialize()
65+
await comms.start()
66+
await dealer.send(WorkerReadyMessage(...))
67+
...
68+
await dealer.send(WorkerShutdownMessage(...))
69+
await comms.stop()
70+
```
71+
"""
72+
73+
def __init__(
74+
self,
75+
address: str,
76+
identity: str,
77+
bind: bool = False,
78+
socket_ops: dict | None = None,
79+
**kwargs,
80+
) -> None:
81+
"""
82+
Initialize the streaming DEALER client.
83+
84+
Args:
85+
address: The address to connect to (e.g., "tcp://localhost:5555")
86+
identity: Unique identity for this DEALER (used by ROUTER for routing)
87+
bind: Whether to bind (True) or connect (False) the socket.
88+
Usually False for DEALER.
89+
socket_ops: Additional socket options to set
90+
**kwargs: Additional arguments passed to BaseZMQClient
91+
"""
92+
super().__init__(
93+
zmq.SocketType.DEALER,
94+
address,
95+
bind,
96+
socket_ops={**(socket_ops or {}), zmq.IDENTITY: identity.encode()},
97+
client_id=identity,
98+
**kwargs,
99+
)
100+
self.identity = identity
101+
self._receiver_handler: Callable[[Message], Awaitable[None]] | None = None
102+
103+
def register_receiver(self, handler: Callable[[Message], Awaitable[None]]) -> None:
104+
"""
105+
Register handler for incoming messages from ROUTER.
106+
107+
The handler will be called for each message received.
108+
109+
Args:
110+
handler: Async function that takes (message: Message)
111+
"""
112+
if self._receiver_handler is not None:
113+
raise ValueError("Receiver handler already registered")
114+
self._receiver_handler = handler
115+
self.debug(
116+
lambda: f"Registered streaming DEALER receiver handler for {self.identity}"
117+
)
118+
119+
@on_stop
120+
async def _clear_receiver(self) -> None:
121+
"""Clear receiver handler on stop."""
122+
self._receiver_handler = None
123+
124+
async def send(self, message: Message) -> None:
125+
"""
126+
Send message to ROUTER.
127+
128+
Args:
129+
message: The message to send
130+
131+
Raises:
132+
NotInitializedError: If socket not initialized
133+
CommunicationError: If send fails
134+
"""
135+
await self._check_initialized()
136+
137+
if not isinstance(message, Message):
138+
raise TypeError(
139+
f"message must be an instance of Message, got {type(message).__name__}"
140+
)
141+
142+
try:
143+
# DEALER automatically handles framing - use single-frame send
144+
await self.socket.send(message.to_json_bytes())
145+
if self.is_trace_enabled:
146+
self.trace(f"Sent message: {message}")
147+
except Exception as e:
148+
self.exception(f"Failed to send message: {e}")
149+
raise
150+
151+
@background_task(immediate=True, interval=None)
152+
async def _streaming_dealer_receiver(self) -> None:
153+
"""
154+
Background task for receiving messages from ROUTER.
155+
156+
Runs continuously until stop is requested. Receives messages with DEALER
157+
envelope format: [empty_delimiter, message_bytes] or just [message_bytes]
158+
"""
159+
self.debug(
160+
lambda: f"Streaming DEALER receiver task started for {self.identity}"
161+
)
162+
163+
while not self.stop_requested:
164+
try:
165+
message_bytes = await self.socket.recv()
166+
if self.is_trace_enabled:
167+
self.trace(f"Received message: {message_bytes}")
168+
message = Message.from_json(message_bytes)
169+
170+
if self._receiver_handler:
171+
self.execute_async(self._receiver_handler(message))
172+
else:
173+
self.warning(
174+
f"Received {message.message_type} message but no handler registered"
175+
)
176+
177+
except zmq.Again:
178+
self.debug("No data on dealer socket received, yielding to event loop")
179+
await yield_to_event_loop()
180+
except Exception as e:
181+
self.exception(f"Exception receiving messages: {e}")
182+
await yield_to_event_loop()
183+
except asyncio.CancelledError:
184+
self.debug("Streaming DEALER receiver task cancelled")
185+
raise # re-raise the cancelled error
186+
187+
self.debug(
188+
lambda: f"Streaming DEALER receiver task stopped for {self.identity}"
189+
)

0 commit comments

Comments
 (0)