From 38ab01ba60f8aded0148091be7ec2a16ace236f1 Mon Sep 17 00:00:00 2001 From: Anthony Casagrande Date: Tue, 25 Nov 2025 06:17:54 -0800 Subject: [PATCH 1/2] feat: add bi-directional streaming dealer and router clients --- src/aiperf/common/base_comms.py | 38 +- .../common/enums/communication_enums.py | 2 + src/aiperf/common/protocols.py | 75 +++ src/aiperf/zmq/__init__.py | 8 + src/aiperf/zmq/streaming_dealer_client.py | 189 +++++++ src/aiperf/zmq/streaming_router_client.py | 206 ++++++++ src/aiperf/zmq/zmq_base_client.py | 12 +- tests/unit/zmq/conftest.py | 92 ++++ .../unit/zmq/test_streaming_dealer_client.py | 467 ++++++++++++++++++ .../unit/zmq/test_streaming_router_client.py | 392 +++++++++++++++ tests/unit/zmq/test_zmq_base_client.py | 131 ++++- 11 files changed, 1605 insertions(+), 7 deletions(-) create mode 100644 src/aiperf/zmq/streaming_dealer_client.py create mode 100644 src/aiperf/zmq/streaming_router_client.py create mode 100644 tests/unit/zmq/test_streaming_dealer_client.py create mode 100644 tests/unit/zmq/test_streaming_router_client.py diff --git a/src/aiperf/common/base_comms.py b/src/aiperf/common/base_comms.py index 45ce64686..6956f2390 100644 --- a/src/aiperf/common/base_comms.py +++ b/src/aiperf/common/base_comms.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 from abc import ABC, abstractmethod -from typing import cast +from typing import Any, cast from aiperf.common.decorators import implements_protocol from aiperf.common.enums import CommClientType @@ -14,6 +14,8 @@ PushClientProtocol, ReplyClientProtocol, RequestClientProtocol, + StreamingDealerClientProtocol, + StreamingRouterClientProtocol, SubClientProtocol, ) from aiperf.common.types import CommAddressType @@ -42,6 +44,7 @@ def create_client( bind: bool = False, socket_ops: dict | None = None, max_pull_concurrency: int | None = None, + **kwargs: Any, ) -> CommunicationClientProtocol: """Create a communication client for a given client type and address. @@ -51,6 +54,7 @@ def create_client( bind: Whether to bind or connect the socket. socket_ops: Additional socket options to set. max_pull_concurrency: The maximum number of concurrent pull requests to allow. (Only used for pull clients) + **kwargs: Additional keyword arguments passed to specific client types (e.g., identity for DEALER). """ def create_pub_client( @@ -125,3 +129,35 @@ def create_reply_client( ReplyClientProtocol, self.create_client(CommClientType.REPLY, address, bind, socket_ops), ) + + def create_streaming_router_client( + self, + address: CommAddressType, + bind: bool = True, + socket_ops: dict | None = None, + ) -> StreamingRouterClientProtocol: + return cast( + StreamingRouterClientProtocol, + self.create_client( + CommClientType.STREAMING_ROUTER, address, bind, socket_ops + ), + ) + + def create_streaming_dealer_client( + self, + address: CommAddressType, + identity: str, + bind: bool = False, + socket_ops: dict | None = None, + ) -> StreamingDealerClientProtocol: + # Identity must be passed through client_kwargs since it's specific to DEALER + return cast( + StreamingDealerClientProtocol, + self.create_client( + CommClientType.STREAMING_DEALER, + address, + bind, + socket_ops, + identity=identity, + ), + ) diff --git a/src/aiperf/common/enums/communication_enums.py b/src/aiperf/common/enums/communication_enums.py index dc931e9a5..f59956048 100644 --- a/src/aiperf/common/enums/communication_enums.py +++ b/src/aiperf/common/enums/communication_enums.py @@ -16,6 +16,8 @@ class CommClientType(CaseInsensitiveStrEnum): PULL = "pull" REQUEST = "request" REPLY = "reply" + STREAMING_ROUTER = "streaming_router" + STREAMING_DEALER = "streaming_dealer" class CommAddress(CaseInsensitiveStrEnum): diff --git a/src/aiperf/common/protocols.py b/src/aiperf/common/protocols.py index d46b28fa5..ff6584ac3 100644 --- a/src/aiperf/common/protocols.py +++ b/src/aiperf/common/protocols.py @@ -181,6 +181,59 @@ async def request_async( ) -> None: ... +@runtime_checkable +class StreamingRouterClientProtocol(CommunicationClientProtocol, Protocol): + """Protocol for ROUTER socket client with bidirectional streaming.""" + + def register_receiver( + self, + handler: Callable[[str, MessageT], Coroutine[Any, Any, None]], + ) -> None: + """ + Register handler for incoming messages from DEALER clients. + + Args: + handler: Async function that takes (identity: str, message: Message) + """ + ... + + async def send_to(self, identity: str, message: MessageT) -> None: + """ + Send message to specific DEALER client by identity. + + Args: + identity: The DEALER client's identity (routing key) + message: The message to send + """ + ... + + +@runtime_checkable +class StreamingDealerClientProtocol(CommunicationClientProtocol, Protocol): + """Protocol for DEALER socket client with bidirectional streaming.""" + + def register_receiver( + self, + handler: Callable[[MessageT], Coroutine[Any, Any, None]], + ) -> None: + """ + Register handler for incoming messages from ROUTER. + + Args: + handler: Async function that takes (message: Message) + """ + ... + + async def send(self, message: MessageT) -> None: + """ + Send message to ROUTER. + + Args: + message: The message to send + """ + ... + + @runtime_checkable class SubClientProtocol(CommunicationClientProtocol, Protocol): async def subscribe( @@ -217,6 +270,7 @@ def create_client( bind: bool = False, socket_ops: dict | None = None, max_pull_concurrency: int | None = None, + **kwargs: Any, ) -> CommunicationClientProtocol: """Create a client for the given client type and address, which will be automatically started and stopped with the CommunicationProtocol instance.""" @@ -283,6 +337,27 @@ def create_reply_client( started and stopped with the CommunicationProtocol instance.""" ... + def create_streaming_router_client( + self, + address: CommAddressType, + bind: bool = True, + socket_ops: dict | None = None, + ) -> StreamingRouterClientProtocol: + """Create a STREAMING_ROUTER client for the given address, which will be automatically + started and stopped with the CommunicationProtocol instance.""" + ... + + def create_streaming_dealer_client( + self, + address: CommAddressType, + identity: str, + bind: bool = False, + socket_ops: dict | None = None, + ) -> StreamingDealerClientProtocol: + """Create a STREAMING_DEALER client for the given address and identity, which will be automatically + started and stopped with the CommunicationProtocol instance.""" + ... + @runtime_checkable class MessageBusClientProtocol(PubClientProtocol, SubClientProtocol, Protocol): diff --git a/src/aiperf/zmq/__init__.py b/src/aiperf/zmq/__init__.py index 4d03483a5..66e8144d2 100644 --- a/src/aiperf/zmq/__init__.py +++ b/src/aiperf/zmq/__init__.py @@ -23,6 +23,12 @@ from aiperf.zmq.router_reply_client import ( ZMQRouterReplyClient, ) +from aiperf.zmq.streaming_dealer_client import ( + ZMQStreamingDealerClient, +) +from aiperf.zmq.streaming_router_client import ( + ZMQStreamingRouterClient, +) from aiperf.zmq.sub_client import ( ZMQSubClient, ) @@ -71,6 +77,8 @@ "ZMQPushPullProxy", "ZMQRouterReplyClient", "ZMQSocketDefaults", + "ZMQStreamingDealerClient", + "ZMQStreamingRouterClient", "ZMQSubClient", "ZMQTCPCommunication", "ZMQXPubXSubProxy", diff --git a/src/aiperf/zmq/streaming_dealer_client.py b/src/aiperf/zmq/streaming_dealer_client.py new file mode 100644 index 000000000..6fd2327e8 --- /dev/null +++ b/src/aiperf/zmq/streaming_dealer_client.py @@ -0,0 +1,189 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Streaming DEALER client for bidirectional communication with ROUTER.""" + +import asyncio +from collections.abc import Awaitable, Callable + +import zmq + +from aiperf.common.decorators import implements_protocol +from aiperf.common.enums import CommClientType +from aiperf.common.factories import CommunicationClientFactory +from aiperf.common.hooks import background_task, on_stop +from aiperf.common.messages import Message +from aiperf.common.protocols import StreamingDealerClientProtocol +from aiperf.common.utils import yield_to_event_loop +from aiperf.zmq.zmq_base_client import BaseZMQClient + + +@implements_protocol(StreamingDealerClientProtocol) +@CommunicationClientFactory.register(CommClientType.STREAMING_DEALER) +class ZMQStreamingDealerClient(BaseZMQClient): + """ + ZMQ DEALER socket client for bidirectional streaming with ROUTER. + + Unlike ZMQDealerRequestClient (request-response pattern), this client is + designed for streaming scenarios where messages flow bidirectionally without + request-response pairing. + + The DEALER socket sets an identity which allows the ROUTER to send messages back + to this specific DEALER instance. + + ASCII Diagram: + ┌──────────────┐ ┌──────────────┐ + │ DEALER │◄──── Stream ──────►│ ROUTER │ + │ (Worker) │ │ (Manager) │ + │ │ │ │ + └──────────────┘ └──────────────┘ + + Usage Pattern: + - DEALER connects to ROUTER with a unique identity + - DEALER sends messages to ROUTER + - DEALER receives messages from ROUTER (routed by identity) + - No request-response pairing - pure streaming + - Supports concurrent message processing + + Example: + ```python + # Create via comms (recommended - handles lifecycle management) + dealer = comms.create_streaming_dealer_client( + address=CommAddress.CREDIT_ROUTER, # or "tcp://localhost:5555" + identity="worker-1", + ) + + async def handle_message(message: Message) -> None: + if message.message_type == MessageType.CREDIT_DROP: + do_some_work(message.credit) + await dealer.send(CreditReturnMessage(...)) + + dealer.register_receiver(handle_message) + + # Lifecycle managed by comms - initialize/start/stop comms instead + await comms.initialize() + await comms.start() + await dealer.send(WorkerReadyMessage(...)) + ... + await dealer.send(WorkerShutdownMessage(...)) + await comms.stop() + ``` + """ + + def __init__( + self, + address: str, + identity: str, + bind: bool = False, + socket_ops: dict | None = None, + **kwargs, + ) -> None: + """ + Initialize the streaming DEALER client. + + Args: + address: The address to connect to (e.g., "tcp://localhost:5555") + identity: Unique identity for this DEALER (used by ROUTER for routing) + bind: Whether to bind (True) or connect (False) the socket. + Usually False for DEALER. + socket_ops: Additional socket options to set + **kwargs: Additional arguments passed to BaseZMQClient + """ + super().__init__( + zmq.SocketType.DEALER, + address, + bind, + socket_ops={**(socket_ops or {}), zmq.IDENTITY: identity.encode()}, + client_id=identity, + **kwargs, + ) + self.identity = identity + self._receiver_handler: Callable[[Message], Awaitable[None]] | None = None + + def register_receiver(self, handler: Callable[[Message], Awaitable[None]]) -> None: + """ + Register handler for incoming messages from ROUTER. + + The handler will be called for each message received. + + Args: + handler: Async function that takes (message: Message) + """ + if self._receiver_handler is not None: + raise ValueError("Receiver handler already registered") + self._receiver_handler = handler + self.debug( + lambda: f"Registered streaming DEALER receiver handler for {self.identity}" + ) + + @on_stop + async def _clear_receiver(self) -> None: + """Clear receiver handler on stop.""" + self._receiver_handler = None + + async def send(self, message: Message) -> None: + """ + Send message to ROUTER. + + Args: + message: The message to send + + Raises: + NotInitializedError: If socket not initialized + CommunicationError: If send fails + """ + await self._check_initialized() + + if not isinstance(message, Message): + raise TypeError( + f"message must be an instance of Message, got {type(message).__name__}" + ) + + try: + # DEALER automatically handles framing - use single-frame send + await self.socket.send(message.to_json_bytes()) + if self.is_trace_enabled: + self.trace(f"Sent message: {message}") + except Exception as e: + self.exception(f"Failed to send message: {e}") + raise + + @background_task(immediate=True, interval=None) + async def _streaming_dealer_receiver(self) -> None: + """ + Background task for receiving messages from ROUTER.xz + + Runs continuously until stop is requested. Receives messages with DEALER + envelope format: [empty_delimiter, message_bytes] or just [message_bytes] + """ + self.debug( + lambda: f"Streaming DEALER receiver task started for {self.identity}" + ) + + while not self.stop_requested: + try: + message_bytes = await self.socket.recv() + if self.is_trace_enabled: + self.trace(f"Received message: {message_bytes}") + message = Message.from_json(message_bytes) + + if self._receiver_handler: + self.execute_async(self._receiver_handler(message)) + else: + self.warning( + f"Received {message.message_type} message but no handler registered" + ) + + except zmq.Again: + self.debug("No data on dealer socket received, yielding to event loop") + await yield_to_event_loop() + except Exception as e: + self.exception(f"Exception receiving messages: {e}") + await yield_to_event_loop() + except asyncio.CancelledError: + self.debug("Streaming DEALER receiver task cancelled") + raise # re-raise the cancelled error + + self.debug( + lambda: f"Streaming DEALER receiver task stopped for {self.identity}" + ) diff --git a/src/aiperf/zmq/streaming_router_client.py b/src/aiperf/zmq/streaming_router_client.py new file mode 100644 index 000000000..d30262108 --- /dev/null +++ b/src/aiperf/zmq/streaming_router_client.py @@ -0,0 +1,206 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Streaming ROUTER client for bidirectional communication with DEALER clients.""" + +import asyncio +from collections.abc import Awaitable, Callable + +import zmq + +from aiperf.common.decorators import implements_protocol +from aiperf.common.enums import CommClientType +from aiperf.common.factories import CommunicationClientFactory +from aiperf.common.hooks import background_task, on_stop +from aiperf.common.messages import Message +from aiperf.common.protocols import StreamingRouterClientProtocol +from aiperf.common.utils import yield_to_event_loop +from aiperf.zmq.zmq_base_client import BaseZMQClient + + +@implements_protocol(StreamingRouterClientProtocol) +@CommunicationClientFactory.register(CommClientType.STREAMING_ROUTER) +class ZMQStreamingRouterClient(BaseZMQClient): + """ + ZMQ ROUTER socket client for bidirectional streaming with DEALER clients. + + Unlike ZMQRouterReplyClient (request-response pattern), this client is + designed for streaming scenarios where messages flow bidirectionally without + request-response pairing. + + Features: + - Bidirectional streaming with automatic routing by peer identity + - Message-based peer lifecycle tracking (ready/shutdown messages) + - Works with both TCP and IPC transports + + ASCII Diagram: + ┌──────────────┐ ┌──────────────┐ + │ DEALER │◄──── Stream ──────►│ │ + │ (Worker) │ │ │ + └──────────────┘ │ │ + ┌──────────────┐ │ ROUTER │ + │ DEALER │◄──── Stream ──────►│ (Manager) │ + │ (Worker) │ │ │ + └──────────────┘ │ │ + ┌──────────────┐ │ │ + │ DEALER │◄──── Stream ──────►│ │ + │ (Worker) │ │ │ + └──────────────┘ └──────────────┘ + + Usage Pattern: + - ROUTER sends messages to specific DEALER clients by identity + - ROUTER receives messages from DEALER clients (identity included in envelope) + - No request-response pairing - pure streaming + - Supports concurrent message processing + - Automatic peer tracking via worker ready and shutdown messages + + Example: + ```python + # Create via comms (recommended - handles lifecycle management) + router = comms.create_streaming_router_client( + address=CommAddress.CREDIT_ROUTER, # or "tcp://*:5555" + bind=True, + ) + + async def handle_message(identity: str, message: Message) -> None: + if message.message_type == MessageType.WORKER_READY: + await register_worker(identity) + elif message.message_type == MessageType.WORKER_SHUTDOWN: + await unregister_worker(identity) + + router.register_receiver(handle_message) + + # Lifecycle managed by comms - initialize/start/stop comms instead + await comms.initialize() + await comms.start() + + # Send message to specific DEALER + await router.send_to("worker-1", CreditDropMessage(...)) + ... + await comms.stop() + ``` + """ + + def __init__( + self, + address: str, + bind: bool = True, + socket_ops: dict | None = None, + **kwargs, + ) -> None: + """ + Initialize the streaming ROUTER client. + + Args: + address: The address to bind or connect to (e.g., "tcp://*:5555" or "ipc:///tmp/socket") + bind: Whether to bind (True) or connect (False) the socket + socket_ops: Additional socket options to set + **kwargs: Additional arguments passed to BaseZMQClient + """ + super().__init__(zmq.SocketType.ROUTER, address, bind, socket_ops, **kwargs) + self._receiver_handler: Callable[[str, Message], Awaitable[None]] | None = None + + def register_receiver( + self, handler: Callable[[str, Message], Awaitable[None]] + ) -> None: + """ + Register handler for incoming messages from DEALER clients. + + The handler will be called for each message received, with the DEALER's + identity (routing key) and the message. + + Args: + handler: Async function that takes (identity: str, message: Message) + """ + if self._receiver_handler is not None: + raise ValueError("Receiver handler already registered") + self._receiver_handler = handler + self.debug("Registered streaming ROUTER receiver handler") + + @on_stop + async def _clear_receiver(self) -> None: + """Clear receiver handler and callbacks on stop.""" + self._receiver_handler = None + + async def send_to(self, identity: str, message: Message) -> None: + """ + Send message to specific DEALER client by identity. + + Args: + identity: The DEALER client's identity (routing key) + message: The message to send + + Raises: + NotInitializedError: If socket not initialized + CommunicationError: If send fails + """ + await self._check_initialized() + + if not isinstance(message, Message): + raise TypeError( + f"message must be an instance of Message, got {type(message).__name__}" + ) + + try: + # Send using routing envelope pattern (identity string → bytes) + routing_envelope = (identity.encode(),) + await self.socket.send_multipart( + [*routing_envelope, message.to_json_bytes()] + ) + if self.is_trace_enabled: + self.trace(f"Sent message to {identity}: {message}") + except Exception as e: + self.exception(f"Failed to send message to {identity}: {e}") + raise + + @background_task(immediate=True, interval=None) + async def _streaming_router_receiver(self) -> None: + """ + Background task for receiving messages from DEALER clients. + + Runs continuously until stop is requested. Receives messages with ROUTER + envelope format: [identity, empty_delimiter, message_bytes] + """ + self.debug("Streaming ROUTER receiver task started") + + while not self.stop_requested: + try: + data = await self.socket.recv_multipart() + if self.is_trace_enabled: + self.trace(f"Received message: {data}") + + message = Message.from_json(data[-1]) + + routing_envelope: tuple[bytes, ...] = ( + tuple(data[:-1]) if len(data) > 1 else (b"",) + ) + + # Decode identity for tracking (first frame of routing envelope) + identity_bytes = routing_envelope[0] if routing_envelope else b"" + identity = identity_bytes.decode("utf-8") + + if self.is_trace_enabled: + self.trace( + f"Received {message.message_type} message from {identity}: {message}" + ) + + if self._receiver_handler: + self.execute_async(self._receiver_handler(identity, message)) + else: + self.warning( + f"Received {message.message_type} message but no handler registered" + ) + + except zmq.Again: + self.debug("Router receiver task timed out") + await yield_to_event_loop() + continue + except Exception as e: + if not self.stop_requested: + self.exception(f"Error in streaming ROUTER receiver: {e}") + await yield_to_event_loop() + except asyncio.CancelledError: + self.debug("Streaming ROUTER receiver task cancelled") + break + + self.debug("Streaming ROUTER receiver task stopped") diff --git a/src/aiperf/zmq/zmq_base_client.py b/src/aiperf/zmq/zmq_base_client.py index 49c6fed18..c3e49310f 100644 --- a/src/aiperf/zmq/zmq_base_client.py +++ b/src/aiperf/zmq/zmq_base_client.py @@ -43,7 +43,7 @@ def __init__( """ self.context: zmq.asyncio.Context = zmq.asyncio.Context.instance() self.socket_type: zmq.SocketType = socket_type - self.socket: zmq.asyncio.Socket = self.context.socket(self.socket_type) + self.socket: zmq.asyncio.Socket = None self.address: str = address self.bind: bool = bind self.socket_ops: dict = socket_ops or {} @@ -77,10 +77,20 @@ async def _initialize_socket(self) -> None: - Run the AIPerfHook.ON_INIT hooks """ try: + self.socket = self.context.socket(self.socket_type) self.debug( lambda: f"ZMQ {self.socket_type_name} socket initialized, try {'BIND' if self.bind else 'CONNECT'} to {self.address} ({self.client_id})" ) + if zmq.IDENTITY in self.socket_ops: + # IMPORTANT! Set IDENTITY socket option immediately after socket creation, BEFORE bind/connect + # otherwise it will not be properly set when the socket is bound/connected + self.socket.setsockopt(zmq.IDENTITY, self.socket_ops[zmq.IDENTITY]) + self.debug( + lambda: f"Set IDENTITY socket option: {self.socket_ops[zmq.IDENTITY]}" + ) + del self.socket_ops[zmq.IDENTITY] + if self.bind: self.socket.bind(self.address) else: diff --git a/tests/unit/zmq/conftest.py b/tests/unit/zmq/conftest.py index 240ee5304..2a3fa085f 100644 --- a/tests/unit/zmq/conftest.py +++ b/tests/unit/zmq/conftest.py @@ -437,6 +437,77 @@ async def create_client( return PullHelper(helper) +@pytest.fixture +def streaming_router_test_helper(mock_zmq_context, wait_for_background_task): + """Provide a helper for ZMQStreamingRouterClient tests.""" + from aiperf.zmq.streaming_router_client import ZMQStreamingRouterClient + + helper = BaseClientTestHelper(mock_zmq_context, wait_for_background_task) + + class StreamingRouterHelper: + def __init__(self, base_helper): + self._base = base_helper + + def setup_mock_socket(self, **kwargs): + return self._base.setup_mock_socket(**kwargs) + + @asynccontextmanager + async def create_client( + self, + address="tcp://*:5555", + bind=True, + auto_start=False, + **mock_kwargs, + ): + async with self._base.create_client( + ZMQStreamingRouterClient, + address=address, + bind=bind, + auto_start=auto_start, + **mock_kwargs, + ) as client: + yield client + + return StreamingRouterHelper(helper) + + +@pytest.fixture +def streaming_dealer_test_helper(mock_zmq_context, wait_for_background_task): + """Provide a helper for ZMQStreamingDealerClient tests.""" + from aiperf.zmq.streaming_dealer_client import ZMQStreamingDealerClient + + helper = BaseClientTestHelper(mock_zmq_context, wait_for_background_task) + + class StreamingDealerHelper: + def __init__(self, base_helper): + self._base = base_helper + + def setup_mock_socket(self, **kwargs): + return self._base.setup_mock_socket(**kwargs) + + @asynccontextmanager + async def create_client( + self, + address="tcp://127.0.0.1:5555", + identity="worker-1", + bind=False, + auto_start=False, + **mock_kwargs, + ): + client_kwargs = {"identity": identity} + async with self._base.create_client( + ZMQStreamingDealerClient, + address=address, + bind=bind, + auto_start=auto_start, + client_kwargs=client_kwargs, + **mock_kwargs, + ) as client: + yield client + + return StreamingDealerHelper(helper) + + # Shared test data and error scenarios @pytest.fixture( params=[ @@ -494,3 +565,24 @@ async def callback(msg): return callback, event, received_messages return _create + + +@pytest.fixture +def multiple_identities(): + """Common worker identities for testing.""" + return ["worker-1", "worker-2", "worker-3"] + + +@pytest.fixture( + params=[ + "worker-1", + "worker_2", + "worker.3", + "worker:4", + "worker@host", + ], + ids=["dash", "underscore", "dot", "colon", "at-sign"], +) # fmt: skip +def special_identity(request): + """Various identity formats with special characters.""" + return request.param diff --git a/tests/unit/zmq/test_streaming_dealer_client.py b/tests/unit/zmq/test_streaming_dealer_client.py new file mode 100644 index 000000000..a93da11b9 --- /dev/null +++ b/tests/unit/zmq/test_streaming_dealer_client.py @@ -0,0 +1,467 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +""" +Tests for streaming_dealer_client.py - ZMQStreamingDealerClient class. +""" + +import asyncio + +import pytest +import zmq + +from aiperf.common.enums import LifecycleState, MessageType +from aiperf.common.exceptions import NotInitializedError +from aiperf.common.messages import Message +from aiperf.zmq.streaming_dealer_client import ZMQStreamingDealerClient + + +class TestZMQStreamingDealerClientInitialization: + """Test ZMQStreamingDealerClient initialization.""" + + def test_init_creates_dealer_socket(self, mock_zmq_context): + """Test that initialization creates a DEALER socket.""" + client = ZMQStreamingDealerClient( + address="tcp://127.0.0.1:5555", + identity="worker-1", + bind=False, + ) + + assert client.socket_type == zmq.SocketType.DEALER + assert client.identity == "worker-1" + assert client._receiver_handler is None + + @pytest.mark.parametrize( + "address,identity,bind", + [ + ("tcp://127.0.0.1:5555", "worker-1", False), + ("tcp://127.0.0.1:5556", "worker-2", True), + ("ipc:///tmp/test.ipc", "worker-3", False), + ("ipc:///tmp/test.ipc", "worker-4", True), + ], + ids=["tcp_connect", "tcp_bind", "ipc_connect", "ipc_bind"], + ) # fmt: skip + def test_init_with_various_addresses( + self, address, identity, bind, mock_zmq_context + ): + """Test initialization with various address types.""" + client = ZMQStreamingDealerClient( + address=address, + identity=identity, + bind=bind, + ) + + assert client.address == address + assert client.identity == identity + assert client.bind == bind + + def test_init_sets_identity_socket_option(self, mock_zmq_context): + """Test that initialization sets IDENTITY socket option.""" + identity = "test-worker" + client = ZMQStreamingDealerClient( + address="tcp://127.0.0.1:5555", + identity=identity, + bind=False, + ) + + # Check that identity is in socket_ops + assert zmq.IDENTITY in client.socket_ops + assert client.socket_ops[zmq.IDENTITY] == identity.encode() + + def test_init_with_custom_socket_options(self, mock_zmq_context): + """Test initialization with custom socket options.""" + identity = "test-worker" + custom_ops = {zmq.IMMEDIATE: 1} + client = ZMQStreamingDealerClient( + address="tcp://127.0.0.1:5555", + identity=identity, + bind=False, + socket_ops=custom_ops, + ) + + # Should have both identity and custom options + assert zmq.IDENTITY in client.socket_ops + assert zmq.IMMEDIATE in client.socket_ops + + def test_init_sets_client_id(self, mock_zmq_context): + """Test that initialization sets client_id to identity.""" + identity = "test-worker" + client = ZMQStreamingDealerClient( + address="tcp://127.0.0.1:5555", + identity=identity, + bind=False, + ) + + assert client.client_id == identity + + +class TestZMQStreamingDealerClientRegisterReceiver: + """Test ZMQStreamingDealerClient.register_receiver method.""" + + @pytest.mark.asyncio + async def test_register_receiver_succeeds(self, mock_zmq_context): + """Test that register_receiver successfully registers a handler.""" + client = ZMQStreamingDealerClient( + address="tcp://127.0.0.1:5555", + identity="worker-1", + bind=False, + ) + + async def handler(message: Message) -> None: + pass + + client.register_receiver(handler) + + assert client._receiver_handler == handler + + @pytest.mark.asyncio + async def test_register_receiver_raises_when_already_registered( + self, mock_zmq_context + ): + """Test that register_receiver raises ValueError if already registered.""" + client = ZMQStreamingDealerClient( + address="tcp://127.0.0.1:5555", + identity="worker-1", + bind=False, + ) + + async def handler1(message: Message) -> None: + pass + + async def handler2(message: Message) -> None: + pass + + client.register_receiver(handler1) + + with pytest.raises(ValueError, match="already registered"): + client.register_receiver(handler2) + + +class TestZMQStreamingDealerClientSend: + """Test ZMQStreamingDealerClient.send method.""" + + @pytest.mark.asyncio + async def test_send_sends_message( + self, streaming_dealer_test_helper, sample_message + ): + """Test that send sends message correctly.""" + async with streaming_dealer_test_helper.create_client() as client: + mock_socket = client.socket + + await client.send(sample_message) + + mock_socket.send.assert_called_once() + sent_data = mock_socket.send.call_args[0][0] + assert sample_message.request_id in sent_data.decode() + + @pytest.mark.asyncio + async def test_send_multiple_messages(self, streaming_dealer_test_helper): + """Test sending multiple messages.""" + async with streaming_dealer_test_helper.create_client() as client: + mock_socket = client.socket + messages = [ + Message(message_type=MessageType.HEARTBEAT, request_id=f"req-{i}") + for i in range(3) + ] + + for message in messages: + await client.send(message) + + assert mock_socket.send.call_count == len(messages) + + @pytest.mark.asyncio + async def test_send_raises_when_not_initialized( + self, streaming_dealer_test_helper, sample_message + ): + """Test that send raises NotInitializedError when not initialized.""" + client = ZMQStreamingDealerClient( + address="tcp://127.0.0.1:5555", + identity="worker-1", + bind=False, + ) + client.socket = None + + with pytest.raises(NotInitializedError, match="Socket not initialized"): + await client.send(sample_message) + + @pytest.mark.asyncio + async def test_send_raises_on_non_message_type(self, streaming_dealer_test_helper): + """Test that send raises TypeError for non-Message objects.""" + async with streaming_dealer_test_helper.create_client() as client: + with pytest.raises(TypeError, match="must be an instance of Message"): + await client.send("not a message") + + @pytest.mark.asyncio + async def test_send_handles_send_failure(self, streaming_dealer_test_helper): + """Test that send handles send failures.""" + async with streaming_dealer_test_helper.create_client( + send_side_effect=Exception("Send failed") + ) as client: + message = Message(message_type=MessageType.HEARTBEAT, request_id="test-123") + + with pytest.raises(Exception, match="Send failed"): + await client.send(message) + + +class TestZMQStreamingDealerClientReceiver: + """Test ZMQStreamingDealerClient receiver background task.""" + + @pytest.mark.asyncio + async def test_receiver_task_starts_on_start(self, streaming_dealer_test_helper): + """Test that receiver task starts when client starts.""" + async with streaming_dealer_test_helper.create_client( + auto_start=True + ) as client: + assert client.state == LifecycleState.RUNNING + + @pytest.mark.asyncio + async def test_receiver_calls_handler_on_message( + self, streaming_dealer_test_helper, sample_message, create_callback_tracker + ): + """Test that receiver calls handler when message arrives.""" + callback, event, received = create_callback_tracker() + + async def test_handler(message: Message) -> None: + await callback(message) + + call_count = 0 + + async def mock_recv(): + nonlocal call_count + call_count += 1 + if call_count == 1: + return sample_message.to_json_bytes() + await asyncio.Future() # Block forever after first call + + streaming_dealer_test_helper.setup_mock_socket(recv_side_effect=mock_recv) + + async with streaming_dealer_test_helper.create_client() as client: + # Register handler BEFORE starting to avoid race condition + client.register_receiver(test_handler) + await client.start() + + await asyncio.wait_for(event.wait(), timeout=1.0) + assert len(received) == 1 + recv_message = received[0] + assert recv_message.request_id == sample_message.request_id + + @pytest.mark.asyncio + async def test_receiver_warns_when_no_handler_registered( + self, streaming_dealer_test_helper, sample_message, wait_for_background_task + ): + """Test that receiver logs warning when no handler is registered.""" + call_count = 0 + + async def mock_recv(): + nonlocal call_count + call_count += 1 + if call_count == 1: + return sample_message.to_json_bytes() + await asyncio.Future() # Block forever after first call + + streaming_dealer_test_helper.setup_mock_socket(recv_side_effect=mock_recv) + + async with streaming_dealer_test_helper.create_client(auto_start=True): + # Don't register handler + await wait_for_background_task(iterations=5) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "exception,iterations", + [ + (zmq.Again(), 3), + (RuntimeError("Test error"), 3), + ], + ids=["zmq_again", "generic_error"], + ) # fmt: skip + async def test_receiver_handles_exceptions( + self, + streaming_dealer_test_helper, + wait_for_background_task, + exception, + iterations, + ): + """Test that receiver handles exceptions gracefully.""" + call_count = 0 + + async def mock_recv(): + nonlocal call_count + call_count += 1 + if call_count < iterations: + raise exception + await asyncio.Future() # Block forever after + + streaming_dealer_test_helper.setup_mock_socket(recv_side_effect=mock_recv) + + async with streaming_dealer_test_helper.create_client(auto_start=True): + await wait_for_background_task(iterations=5) + assert call_count >= iterations + + @pytest.mark.asyncio + async def test_receiver_stops_on_cancelled_error( + self, streaming_dealer_test_helper, wait_for_background_task + ): + """Test that receiver stops gracefully on CancelledError.""" + streaming_dealer_test_helper.setup_mock_socket( + recv_side_effect=asyncio.CancelledError() + ) + + async with streaming_dealer_test_helper.create_client( + auto_start=True + ) as client: + await wait_for_background_task() + # The receiver task should exit gracefully without raising an unhandled exception + # Client remains in RUNNING state until explicitly stopped + assert client.state == LifecycleState.RUNNING + + +class TestZMQStreamingDealerClientLifecycle: + """Test ZMQStreamingDealerClient lifecycle management.""" + + @pytest.mark.asyncio + async def test_clear_receiver_on_stop(self, streaming_dealer_test_helper): + """Test that receiver handler is cleared on stop.""" + async with streaming_dealer_test_helper.create_client() as client: + + async def handler(message: Message) -> None: + pass + + client.register_receiver(handler) + assert client._receiver_handler == handler + + # After context exits (which calls stop), handler should be cleared + assert client._receiver_handler is None + + @pytest.mark.asyncio + async def test_full_lifecycle( + self, streaming_dealer_test_helper, wait_for_background_task + ): + """Test full client lifecycle: initialize -> start -> stop.""" + async with streaming_dealer_test_helper.create_client() as client: + + async def handler(message: Message) -> None: + pass + + client.register_receiver(handler) + assert client.state == LifecycleState.INITIALIZED + + await client.start() + await wait_for_background_task() + assert client.state == LifecycleState.RUNNING + + # Context exit calls stop + assert client.state == LifecycleState.STOPPED + assert client._receiver_handler is None + + @pytest.mark.asyncio + async def test_send_after_stop_raises( + self, streaming_dealer_test_helper, sample_message + ): + """Test that send raises after client is stopped.""" + async with streaming_dealer_test_helper.create_client() as client: + pass + + # Client is now stopped after context exit + with pytest.raises(asyncio.CancelledError, match="Socket was stopped"): + await client.send(sample_message) + + +class TestZMQStreamingDealerClientEdgeCases: + """Test edge cases and error handling.""" + + @pytest.mark.asyncio + async def test_multiple_concurrent_sends( + self, streaming_dealer_test_helper, sample_message + ): + """Test multiple concurrent sends.""" + async with streaming_dealer_test_helper.create_client() as client: + mock_socket = client.socket + num_messages = 5 + + await asyncio.gather( + *[client.send(sample_message) for _ in range(num_messages)] + ) + + assert mock_socket.send.call_count == num_messages + + @pytest.mark.asyncio + async def test_different_message_types(self, streaming_dealer_test_helper): + """Test sending different message types.""" + async with streaming_dealer_test_helper.create_client() as client: + mock_socket = client.socket + messages = [ + Message(message_type=MessageType.HEARTBEAT, request_id="req-1"), + Message(message_type=MessageType.ERROR, request_id="req-2"), + ] + + for message in messages: + await client.send(message) + + assert mock_socket.send.call_count == len(messages) + + @pytest.mark.asyncio + async def test_receiver_with_multiple_messages( + self, streaming_dealer_test_helper, sample_message + ): + """Test receiver processing multiple messages.""" + # Use sample_message as template and create variants with different request_ids + messages = [sample_message] * 3 + + message_index = 0 + received = [] + received_event = asyncio.Event() + + async def mock_recv(): + nonlocal message_index + if message_index < len(messages): + result = messages[message_index].to_json_bytes() + message_index += 1 + return result + await asyncio.Future() # Block forever after all messages + + streaming_dealer_test_helper.setup_mock_socket(recv_side_effect=mock_recv) + + async def test_handler(message: Message) -> None: + received.append(message) + if len(received) == len(messages): + received_event.set() + + async with streaming_dealer_test_helper.create_client() as client: + client.register_receiver(test_handler) + await client.start() + + await asyncio.wait_for(received_event.wait(), timeout=2.0) + + assert len(received) == len(messages) + for msg in received: + assert msg.request_id == sample_message.request_id + + @pytest.mark.parametrize( + "identity", + ["worker-1", "worker_2", "worker.3", "worker:4", "worker@host"], + ids=["dash", "underscore", "dot", "colon", "at-sign"], + ) # fmt: skip + def test_identity_with_special_characters(self, mock_zmq_context, identity): + """Test creating client with various identity formats.""" + client = ZMQStreamingDealerClient( + address="tcp://127.0.0.1:5555", + identity=identity, + bind=False, + ) + assert client.identity == identity + assert client.socket_ops[zmq.IDENTITY] == identity.encode() + + @pytest.mark.asyncio + async def test_bind_mode(self, mock_zmq_socket, mock_zmq_context): + """Test DEALER client in bind mode (unusual but supported).""" + client = ZMQStreamingDealerClient( + address="tcp://*:5555", + identity="worker-1", + bind=True, # Bind instead of connect + ) + await client.initialize() + + # Should bind, not connect + mock_zmq_socket.bind.assert_called_once_with("tcp://*:5555") + assert not mock_zmq_socket.connect.called + + await client.stop() diff --git a/tests/unit/zmq/test_streaming_router_client.py b/tests/unit/zmq/test_streaming_router_client.py new file mode 100644 index 000000000..3f4dd45e0 --- /dev/null +++ b/tests/unit/zmq/test_streaming_router_client.py @@ -0,0 +1,392 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +""" +Tests for streaming_router_client.py - ZMQStreamingRouterClient class. +""" + +import asyncio + +import pytest +import zmq + +from aiperf.common.enums import LifecycleState, MessageType +from aiperf.common.exceptions import NotInitializedError +from aiperf.common.messages import Message +from aiperf.zmq.streaming_router_client import ZMQStreamingRouterClient + + +class TestZMQStreamingRouterClientInitialization: + """Test ZMQStreamingRouterClient initialization.""" + + def test_init_creates_router_socket(self, mock_zmq_context): + """Test that initialization creates a ROUTER socket.""" + client = ZMQStreamingRouterClient(address="tcp://*:5555", bind=True) + + assert client.socket_type == zmq.SocketType.ROUTER + assert client._receiver_handler is None + + @pytest.mark.parametrize( + "address,bind", + [ + ("tcp://*:5555", True), + ("tcp://127.0.0.1:5556", False), + ("ipc:///tmp/test.ipc", True), + ("ipc:///tmp/test.ipc", False), + ], + ids=["tcp_bind", "tcp_connect", "ipc_bind", "ipc_connect"], + ) # fmt: skip + def test_init_with_various_addresses(self, address, bind, mock_zmq_context): + """Test initialization with various address types.""" + client = ZMQStreamingRouterClient(address=address, bind=bind) + + assert client.address == address + assert client.bind == bind + + def test_init_with_custom_socket_options(self, mock_zmq_context): + """Test initialization with custom socket options.""" + custom_ops = {zmq.ROUTER_MANDATORY: 1} + client = ZMQStreamingRouterClient( + address="tcp://*:5555", + bind=True, + socket_ops=custom_ops, + ) + + assert client.socket_ops == custom_ops + + +class TestZMQStreamingRouterClientRegisterReceiver: + """Test ZMQStreamingRouterClient.register_receiver method.""" + + @pytest.mark.asyncio + async def test_register_receiver_succeeds(self, mock_zmq_context): + """Test that register_receiver successfully registers a handler.""" + client = ZMQStreamingRouterClient(address="tcp://*:5555", bind=True) + + async def handler(identity: str, message: Message) -> None: + pass + + client.register_receiver(handler) + assert client._receiver_handler == handler + + @pytest.mark.asyncio + async def test_register_receiver_raises_when_already_registered( + self, mock_zmq_context + ): + """Test that register_receiver raises ValueError if already registered.""" + client = ZMQStreamingRouterClient(address="tcp://*:5555", bind=True) + + async def handler1(identity: str, message: Message) -> None: + pass + + async def handler2(identity: str, message: Message) -> None: + pass + + client.register_receiver(handler1) + + with pytest.raises(ValueError, match="already registered"): + client.register_receiver(handler2) + + +class TestZMQStreamingRouterClientSendTo: + """Test ZMQStreamingRouterClient.send_to method.""" + + @pytest.mark.asyncio + async def test_send_to_sends_message_with_routing( + self, streaming_router_test_helper, sample_message + ): + """Test that send_to sends message with routing envelope.""" + async with streaming_router_test_helper.create_client() as client: + identity = "worker-1" + mock_socket = client.socket + + await client.send_to(identity, sample_message) + + mock_socket.send_multipart.assert_called_once() + sent_data = mock_socket.send_multipart.call_args[0][0] + assert sent_data[0] == identity.encode() + assert sample_message.request_id in sent_data[1].decode() + + @pytest.mark.asyncio + async def test_send_to_multiple_identities( + self, streaming_router_test_helper, sample_message, multiple_identities + ): + """Test sending to different worker identities.""" + async with streaming_router_test_helper.create_client() as client: + mock_socket = client.socket + + for identity in multiple_identities: + await client.send_to(identity, sample_message) + + assert mock_socket.send_multipart.call_count == len(multiple_identities) + + @pytest.mark.asyncio + async def test_send_to_raises_when_not_initialized( + self, mock_zmq_context, sample_message + ): + """Test that send_to raises NotInitializedError when not initialized.""" + client = ZMQStreamingRouterClient(address="tcp://*:5555", bind=True) + client.socket = None + + with pytest.raises(NotInitializedError, match="Socket not initialized"): + await client.send_to("worker-1", sample_message) + + @pytest.mark.asyncio + async def test_send_to_raises_on_non_message_type( + self, streaming_router_test_helper + ): + """Test that send_to raises TypeError for non-Message objects.""" + async with streaming_router_test_helper.create_client() as client: + with pytest.raises(TypeError, match="must be an instance of Message"): + await client.send_to("worker-1", "not a message") + + @pytest.mark.asyncio + async def test_send_to_handles_send_failure(self, streaming_router_test_helper): + """Test that send_to handles send failures.""" + async with streaming_router_test_helper.create_client( + send_multipart_side_effect=Exception("Send failed") + ) as client: + message = Message(message_type=MessageType.HEARTBEAT, request_id="test-123") + + with pytest.raises(Exception, match="Send failed"): + await client.send_to("worker-1", message) + + @pytest.mark.asyncio + async def test_send_to_with_special_identity( + self, streaming_router_test_helper, sample_message, special_identity + ): + """Test identity encoding with special characters.""" + async with streaming_router_test_helper.create_client() as client: + mock_socket = client.socket + + await client.send_to(special_identity, sample_message) + + sent_data = mock_socket.send_multipart.call_args[0][0] + assert sent_data[0] == special_identity.encode() + + +class TestZMQStreamingRouterClientReceiver: + """Test ZMQStreamingRouterClient receiver background task.""" + + @pytest.mark.asyncio + async def test_receiver_task_starts_on_start(self, streaming_router_test_helper): + """Test that receiver task starts when client starts.""" + async with streaming_router_test_helper.create_client( + auto_start=True + ) as client: + assert client.state == LifecycleState.RUNNING + + @pytest.mark.asyncio + async def test_receiver_calls_handler_on_message( + self, streaming_router_test_helper, sample_message, create_callback_tracker + ): + """Test that receiver calls handler when message arrives.""" + identity = "worker-1" + callback, event, received = create_callback_tracker() + + async def test_handler(recv_identity: str, message: Message) -> None: + await callback((recv_identity, message)) + + # Setup mock to return message once then block forever + call_count = 0 + + async def mock_recv(): + nonlocal call_count + call_count += 1 + if call_count == 1: + return [identity.encode(), sample_message.to_json_bytes()] + await asyncio.Future() # Block forever after first call + + streaming_router_test_helper.setup_mock_socket( + recv_multipart_side_effect=mock_recv + ) + + async with streaming_router_test_helper.create_client() as client: + # Register handler BEFORE starting to avoid race condition + client.register_receiver(test_handler) + await client.start() + + await asyncio.wait_for(event.wait(), timeout=1.0) + assert len(received) == 1 + recv_identity, recv_message = received[0] + assert recv_identity == identity + assert recv_message.request_id == sample_message.request_id + + @pytest.mark.asyncio + async def test_receiver_warns_when_no_handler_registered( + self, streaming_router_test_helper, sample_message, wait_for_background_task + ): + """Test that receiver logs warning when no handler is registered.""" + call_count = 0 + + async def mock_recv(): + nonlocal call_count + call_count += 1 + if call_count == 1: + return [b"worker-1", sample_message.to_json_bytes()] + await asyncio.Future() # Block forever after first call + + streaming_router_test_helper.setup_mock_socket( + recv_multipart_side_effect=mock_recv + ) + + async with streaming_router_test_helper.create_client(auto_start=True): + # Don't register handler + await wait_for_background_task(iterations=5) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "exception,iterations", + [ + (zmq.Again(), 3), + (RuntimeError("Test error"), 3), + ], + ids=["zmq_again", "generic_error"], + ) # fmt: skip + async def test_receiver_handles_exceptions( + self, + streaming_router_test_helper, + wait_for_background_task, + exception, + iterations, + ): + """Test that receiver handles exceptions gracefully.""" + call_count = 0 + + async def mock_recv(): + nonlocal call_count + call_count += 1 + if call_count < iterations: + raise exception + await asyncio.Future() # Block forever after + + streaming_router_test_helper.setup_mock_socket( + recv_multipart_side_effect=mock_recv + ) + + async with streaming_router_test_helper.create_client(auto_start=True): + await wait_for_background_task(iterations=5) + assert call_count >= iterations + + @pytest.mark.asyncio + async def test_receiver_stops_on_cancelled_error( + self, streaming_router_test_helper, wait_for_background_task + ): + """Test that receiver stops gracefully on CancelledError.""" + streaming_router_test_helper.setup_mock_socket( + recv_multipart_side_effect=asyncio.CancelledError() + ) + + async with streaming_router_test_helper.create_client( + auto_start=True + ) as client: + # Wait for the background task to run and exit due to CancelledError + await wait_for_background_task() + # The receiver task should exit gracefully without raising an unhandled exception + # Client remains in RUNNING state until explicitly stopped + assert client.state == LifecycleState.RUNNING + + @pytest.mark.asyncio + async def test_receiver_with_empty_routing_envelope( + self, streaming_router_test_helper, sample_message, create_callback_tracker + ): + """Test receiver handling of message with empty routing envelope.""" + callback, event, received = create_callback_tracker() + + async def test_handler(identity: str, message: Message) -> None: + await callback((identity, message)) + + call_count = 0 + + async def mock_recv(): + nonlocal call_count + call_count += 1 + if call_count == 1: + return [b"", sample_message.to_json_bytes()] + await asyncio.Future() # Block forever after first call + + streaming_router_test_helper.setup_mock_socket( + recv_multipart_side_effect=mock_recv + ) + + async with streaming_router_test_helper.create_client() as client: + # Register handler BEFORE starting to avoid race condition + client.register_receiver(test_handler) + await client.start() + + await asyncio.wait_for(event.wait(), timeout=1.0) + assert len(received) == 1 + recv_identity, _ = received[0] + assert recv_identity == "" # Empty identity + + +class TestZMQStreamingRouterClientLifecycle: + """Test ZMQStreamingRouterClient lifecycle management.""" + + @pytest.mark.asyncio + async def test_clear_receiver_on_stop(self, streaming_router_test_helper): + """Test that receiver handler is cleared on stop.""" + async with streaming_router_test_helper.create_client() as client: + + async def handler(identity: str, message: Message) -> None: + pass + + client.register_receiver(handler) + assert client._receiver_handler == handler + + # Client stopped after context exit + assert client._receiver_handler is None + + @pytest.mark.asyncio + async def test_full_lifecycle(self, streaming_router_test_helper): + """Test full client lifecycle: initialize -> start -> stop.""" + client = ZMQStreamingRouterClient(address="tcp://*:5555", bind=True) + + async def handler(identity: str, message: Message) -> None: + pass + + client.register_receiver(handler) + + # Initialize + await client.initialize() + assert client.state == LifecycleState.INITIALIZED + + # Start + await client.start() + assert client.state == LifecycleState.RUNNING + + # Stop + await client.stop() + assert client.state == LifecycleState.STOPPED + assert client._receiver_handler is None + + @pytest.mark.asyncio + async def test_send_to_after_stop_raises( + self, streaming_router_test_helper, sample_message + ): + """Test that send_to raises after client is stopped.""" + async with streaming_router_test_helper.create_client() as client: + pass # Client stopped after context exit + + with pytest.raises(asyncio.CancelledError, match="Socket was stopped"): + await client.send_to("worker-1", sample_message) + + +class TestZMQStreamingRouterClientEdgeCases: + """Test edge cases and error handling.""" + + @pytest.mark.asyncio + async def test_multiple_concurrent_sends( + self, streaming_router_test_helper, sample_message, multiple_identities + ): + """Test multiple concurrent sends to different workers.""" + async with streaming_router_test_helper.create_client() as client: + mock_socket = client.socket + + await asyncio.gather( + *[ + client.send_to(identity, sample_message) + for identity in multiple_identities + ] + ) + + assert mock_socket.send_multipart.call_count == len(multiple_identities) diff --git a/tests/unit/zmq/test_zmq_base_client.py b/tests/unit/zmq/test_zmq_base_client.py index 47d038b89..a447675b2 100644 --- a/tests/unit/zmq/test_zmq_base_client.py +++ b/tests/unit/zmq/test_zmq_base_client.py @@ -30,18 +30,19 @@ class TestBaseZMQClientInitialization: (zmq.SocketType.ROUTER, "tcp://127.0.0.1:5558", False), ], ) # fmt: skip - def test_init_creates_socket_with_correct_params( + def test_init_stores_correct_params( self, socket_type, address, bind, mock_zmq_context ): - """Test that __init__ creates socket with correct parameters.""" - # No need to patch - autouse fixture handles this + """Test that __init__ stores correct parameters (socket created on initialize).""" client = BaseZMQClient(socket_type=socket_type, address=address, bind=bind) assert client.socket_type == socket_type assert client.address == address assert client.bind == bind - assert client.socket is not None - mock_zmq_context.socket.assert_called_once_with(socket_type) + # Socket is None until initialize() is called + assert client.socket is None + # Socket not created yet + mock_zmq_context.socket.assert_not_called() def test_init_with_custom_client_id(self, mock_zmq_context): """Test initialization with custom client ID.""" @@ -269,3 +270,123 @@ async def test_check_initialized_succeeds_when_initialized( await client.initialize() # Should not raise await client._check_initialized() + + +class TestBaseZMQClientIdentity: + """Test BaseZMQClient IDENTITY socket option handling.""" + + @staticmethod + def get_identity_calls(mock_socket): + """Extract IDENTITY setsockopt calls from mock socket.""" + return [ + call[0][1] + for call in mock_socket.setsockopt.call_args_list + if call[0][0] == zmq.IDENTITY + ] + + @pytest.mark.asyncio + @pytest.mark.parametrize("bind", [True, False], ids=["bind", "connect"]) + async def test_identity_set_before_bind_or_connect( + self, bind, mock_zmq_socket, mock_zmq_context + ): + """Test that IDENTITY is set before bind/connect.""" + identity = b"test-identity" + client = BaseZMQClient( + socket_type=zmq.SocketType.DEALER, + address="tcp://127.0.0.1:5555", + bind=bind, + socket_ops={zmq.IDENTITY: identity}, + ) + + await client.initialize() + + # Verify IDENTITY was set correctly + assert self.get_identity_calls(mock_zmq_socket) == [identity] + + # Verify IDENTITY was set before bind/connect + all_calls = mock_zmq_socket.method_calls + identity_idx = next( + i + for i, call in enumerate(all_calls) + if call[0] == "setsockopt" and call[1][0] == zmq.IDENTITY + ) + action = "bind" if bind else "connect" + action_idx = next(i for i, call in enumerate(all_calls) if call[0] == action) + assert identity_idx < action_idx, f"IDENTITY must be set before {action}" + + @pytest.mark.asyncio + async def test_identity_removed_from_socket_ops_after_set( + self, mock_zmq_socket, mock_zmq_context + ): + """Test that IDENTITY is removed from socket_ops after being set.""" + client = BaseZMQClient( + socket_type=zmq.SocketType.DEALER, + address="tcp://127.0.0.1:5555", + bind=False, + socket_ops={zmq.IDENTITY: b"test", zmq.IMMEDIATE: 1}, + ) + + assert zmq.IDENTITY in client.socket_ops + await client.initialize() + assert zmq.IDENTITY not in client.socket_ops + assert zmq.IMMEDIATE in client.socket_ops + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "socket_type,bind", + [ + (zmq.SocketType.DEALER, False), + (zmq.SocketType.ROUTER, True), + ], + ids=["dealer", "router"], + ) + async def test_identity_with_socket_type( + self, socket_type, bind, mock_zmq_socket, mock_zmq_context + ): + """Test IDENTITY with DEALER and ROUTER socket types.""" + identity = b"test-identity" + client = BaseZMQClient( + socket_type=socket_type, + address="tcp://127.0.0.1:5555", + bind=bind, + socket_ops={zmq.IDENTITY: identity}, + ) + + await client.initialize() + + mock_zmq_context.socket.assert_called_once_with(socket_type) + assert self.get_identity_calls(mock_zmq_socket) == [identity] + + @pytest.mark.asyncio + async def test_no_identity_option(self, mock_zmq_socket, mock_zmq_context): + """Test that sockets without IDENTITY option work normally.""" + client = BaseZMQClient( + socket_type=zmq.SocketType.PUB, + address="tcp://127.0.0.1:5555", + bind=True, + ) + + await client.initialize() + + assert self.get_identity_calls(mock_zmq_socket) == [] + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "identity", + [b"simple", b"worker-123", b"192.168.1.1:5555", b"a" * 255], + ids=["simple", "with-special-chars", "ip-port", "max-length"], + ) + async def test_various_identity_formats( + self, identity, mock_zmq_socket, mock_zmq_context + ): + """Test various identity string formats.""" + client = BaseZMQClient( + socket_type=zmq.SocketType.DEALER, + address="tcp://127.0.0.1:5555", + bind=False, + socket_ops={zmq.IDENTITY: identity}, + ) + + await client.initialize() + + assert self.get_identity_calls(mock_zmq_socket) == [identity] From 9e00c2694a7ecfdf1bd2d7a4c91d3ead7725f490 Mon Sep 17 00:00:00 2001 From: Anthony Casagrande Date: Tue, 25 Nov 2025 10:30:16 -0800 Subject: [PATCH 2/2] Update streaming_dealer_client.py Signed-off-by: Anthony Casagrande --- src/aiperf/zmq/streaming_dealer_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aiperf/zmq/streaming_dealer_client.py b/src/aiperf/zmq/streaming_dealer_client.py index 6fd2327e8..d0ecaaa79 100644 --- a/src/aiperf/zmq/streaming_dealer_client.py +++ b/src/aiperf/zmq/streaming_dealer_client.py @@ -151,7 +151,7 @@ async def send(self, message: Message) -> None: @background_task(immediate=True, interval=None) async def _streaming_dealer_receiver(self) -> None: """ - Background task for receiving messages from ROUTER.xz + Background task for receiving messages from ROUTER. Runs continuously until stop is requested. Receives messages with DEALER envelope format: [empty_delimiter, message_bytes] or just [message_bytes]