Skip to content

Commit cecb078

Browse files
committed
Handle cancellation of concurrent async generators
Signed-off-by: Waldemar Quevedo <[email protected]>
1 parent e3852c6 commit cecb078

File tree

3 files changed

+192
-1
lines changed

3 files changed

+192
-1
lines changed

nats/src/nats/aio/client.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1697,7 +1697,8 @@ async def _process_msg(
16971697
return
16981698

16991699
sub._received += 1
1700-
if sub._max_msgs > 0 and sub._received >= sub._max_msgs:
1700+
max_msgs_reached = sub._max_msgs > 0 and sub._received >= sub._max_msgs
1701+
if max_msgs_reached:
17011702
# Enough messages so can throwaway subscription now, the
17021703
# pending messages will still be in the subscription
17031704
# internal queue and the task will finish once the last
@@ -1800,6 +1801,16 @@ async def _process_msg(
18001801
if sub._jsi:
18011802
await sub._jsi.check_for_sequence_mismatch(msg)
18021803

1804+
# Send sentinel after reaching max messages for non-callback subscriptions.
1805+
if max_msgs_reached and not sub._cb and sub._active_generators > 0:
1806+
# Send one sentinel per active generator to unblock them all.
1807+
for _ in range(sub._active_generators):
1808+
try:
1809+
sub._pending_queue.put_nowait(None)
1810+
except Exception:
1811+
# Queue might be full or closed, that's ok
1812+
break
1813+
18031814
def _build_message(
18041815
self,
18051816
sid: int,

nats/src/nats/aio/subscription.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,13 @@ def delivered(self) -> int:
196196
"""
197197
return self._received
198198

199+
@property
200+
def is_closed(self) -> bool:
201+
"""
202+
Returns True if the subscription is closed, False otherwise.
203+
"""
204+
return self._closed
205+
199206
async def next_msg(self, timeout: Optional[float] = 1.0) -> Msg:
200207
"""
201208
:params timeout: Time in seconds to wait for next message before timing out.

nats/tests/test_client.py

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,179 @@ async def test_subscribe_async_generator(self):
600600

601601
await nc.close()
602602

603+
@async_test
604+
async def test_subscribe_concurrent_async_generators(self):
605+
"""Test multiple concurrent async generators on the same subscription"""
606+
nc = NATS()
607+
await nc.connect()
608+
609+
sub = await nc.subscribe("test.concurrent")
610+
611+
# Publish messages
612+
num_msgs = 12
613+
for i in range(num_msgs):
614+
await nc.publish("test.concurrent", f"msg-{i}".encode())
615+
await nc.flush()
616+
617+
# Track results from each consumer
618+
consumer_results = {}
619+
620+
async def consumer_task(consumer_id: str, max_messages: int = None):
621+
"""Consumer task that processes messages"""
622+
import random
623+
received = []
624+
try:
625+
async for msg in sub.messages:
626+
received.append(msg.data.decode())
627+
# Add random processing delay to simulate real work.
628+
await asyncio.sleep(random.uniform(0.01, 0.05))
629+
if max_messages and len(received) >= max_messages:
630+
break
631+
except Exception as e:
632+
# Store the exception for later inspection
633+
consumer_results[consumer_id] = f"Error: {e}"
634+
return
635+
consumer_results[consumer_id] = received
636+
637+
# Start multiple concurrent consumers.
638+
tasks = [
639+
asyncio.create_task(consumer_task("consumer_A", 3)),
640+
asyncio.create_task(consumer_task("consumer_B", 5)),
641+
asyncio.create_task(consumer_task("consumer_C", 4)),
642+
]
643+
644+
# Wait for all consumers to finish.
645+
await asyncio.gather(*tasks)
646+
647+
# Verify results
648+
consumer_A_msgs = consumer_results.get("consumer_A", [])
649+
consumer_B_msgs = consumer_results.get("consumer_B", [])
650+
consumer_C_msgs = consumer_results.get("consumer_C", [])
651+
652+
# Each consumer should get the expected number of messages
653+
self.assertEqual(len(consumer_A_msgs), 3)
654+
self.assertEqual(len(consumer_B_msgs), 5)
655+
self.assertEqual(len(consumer_C_msgs), 4)
656+
657+
# All messages should be unique (no duplicates across consumers)
658+
all_received = consumer_A_msgs + consumer_B_msgs + consumer_C_msgs
659+
self.assertEqual(len(all_received), len(set(all_received)))
660+
661+
# All received messages should be from our published set
662+
expected_msgs = {f"msg-{i}" for i in range(num_msgs)}
663+
received_msgs = set(all_received)
664+
self.assertTrue(received_msgs.issubset(expected_msgs))
665+
666+
# Verify we got exactly 12 unique messages total
667+
self.assertEqual(len(received_msgs), 12)
668+
669+
await nc.close()
670+
671+
@async_test
672+
async def test_subscribe_async_generator_with_unsubscribe_limit(self):
673+
"""Test async generator respects unsubscribe max_msgs limit automatically"""
674+
nc = NATS()
675+
await nc.connect()
676+
677+
sub = await nc.subscribe("test.unsub.limit")
678+
await sub.unsubscribe(limit=5)
679+
680+
# Publish more messages than the limit
681+
num_msgs = 10
682+
for i in range(num_msgs):
683+
await nc.publish("test.unsub.limit", f"msg-{i}".encode())
684+
await nc.flush()
685+
686+
received_msgs = []
687+
async for msg in sub.messages:
688+
received_msgs.append(msg.data.decode())
689+
# Add small delay to ensure we don't race with the unsubscribe.
690+
await asyncio.sleep(0.01)
691+
692+
# Should have received exactly 5 messages due to unsubscribe limit.
693+
self.assertEqual(len(received_msgs), 5, f"Expected 5 messages, got {len(received_msgs)}: {received_msgs}")
694+
695+
# Messages should be the first 5 published.
696+
for i in range(5):
697+
self.assertIn(f"msg-{i}", received_msgs)
698+
699+
# Verify the subscription received the expected number.
700+
self.assertEqual(sub._received, 5)
701+
702+
# The generator should have stopped due to max_msgs limit being reached.
703+
self.assertEqual(sub._max_msgs, 5)
704+
705+
await nc.close()
706+
707+
@async_test
708+
async def test_subscribe_concurrent_async_generators_auto_unsubscribe(self):
709+
"""Test multiple concurrent async generators on the same subscription"""
710+
nc = NATS()
711+
await nc.connect()
712+
713+
sub = await nc.subscribe("test.concurrent")
714+
await sub.unsubscribe(5)
715+
716+
# Publish messages over the max msgs limit.
717+
num_msgs = 12
718+
for i in range(num_msgs):
719+
await nc.publish("test.concurrent", f"msg-{i}".encode())
720+
await nc.flush()
721+
722+
# Track results from each consumer
723+
consumer_results = {}
724+
725+
async def consumer_task(consumer_id: str, max_messages: int = None):
726+
"""Consumer task that processes messages"""
727+
import random
728+
received = []
729+
try:
730+
async for msg in sub.messages:
731+
received.append(msg.data.decode())
732+
# Add random processing delay to simulate real work
733+
await asyncio.sleep(random.uniform(0.01, 0.05))
734+
if max_messages and len(received) >= max_messages:
735+
break
736+
737+
# Once subscription reached max number of messages, it should unblock.
738+
except Exception as e:
739+
# Store the exception for later inspection
740+
consumer_results[consumer_id] = f"Error: {e}"
741+
return
742+
consumer_results[consumer_id] = received
743+
744+
# Start multiple concurrent consumers.
745+
tasks = [
746+
asyncio.create_task(consumer_task("consumer_A", 3)),
747+
asyncio.create_task(consumer_task("consumer_B", 5)),
748+
asyncio.create_task(consumer_task("consumer_C", 4)),
749+
]
750+
751+
# Wait for all consumers to finish.
752+
await asyncio.gather(*tasks)
753+
754+
# Verify results
755+
consumer_A_msgs = consumer_results.get("consumer_A", [])
756+
consumer_B_msgs = consumer_results.get("consumer_B", [])
757+
consumer_C_msgs = consumer_results.get("consumer_C", [])
758+
759+
# Each consumer should get the expected number of messages.
760+
total = len(consumer_A_msgs) + len(consumer_B_msgs) + len(consumer_C_msgs)
761+
self.assertEqual(total, 5)
762+
763+
# All messages should be unique (no duplicates across consumers)
764+
all_received = consumer_A_msgs + consumer_B_msgs + consumer_C_msgs
765+
self.assertEqual(len(all_received), len(set(all_received)))
766+
767+
# All received messages should be from our published set.
768+
expected_msgs = {f"msg-{i}" for i in range(num_msgs)}
769+
received_msgs = set(all_received)
770+
self.assertTrue(received_msgs.issubset(expected_msgs))
771+
self.assertEqual(len(received_msgs), 5)
772+
773+
await nc.close()
774+
775+
603776
@async_test
604777
async def test_subscribe_async_generator_with_drain(self):
605778
"""Test async generator with drain functionality"""

0 commit comments

Comments
 (0)