@@ -600,6 +600,179 @@ async def test_subscribe_async_generator(self):
600600
601601 await nc .close ()
602602
603+ @async_test
604+ async def test_subscribe_concurrent_async_generators (self ):
605+ """Test multiple concurrent async generators on the same subscription"""
606+ nc = NATS ()
607+ await nc .connect ()
608+
609+ sub = await nc .subscribe ("test.concurrent" )
610+
611+ # Publish messages
612+ num_msgs = 12
613+ for i in range (num_msgs ):
614+ await nc .publish ("test.concurrent" , f"msg-{ i } " .encode ())
615+ await nc .flush ()
616+
617+ # Track results from each consumer
618+ consumer_results = {}
619+
620+ async def consumer_task (consumer_id : str , max_messages : int = None ):
621+ """Consumer task that processes messages"""
622+ import random
623+ received = []
624+ try :
625+ async for msg in sub .messages :
626+ received .append (msg .data .decode ())
627+ # Add random processing delay to simulate real work.
628+ await asyncio .sleep (random .uniform (0.01 , 0.05 ))
629+ if max_messages and len (received ) >= max_messages :
630+ break
631+ except Exception as e :
632+ # Store the exception for later inspection
633+ consumer_results [consumer_id ] = f"Error: { e } "
634+ return
635+ consumer_results [consumer_id ] = received
636+
637+ # Start multiple concurrent consumers.
638+ tasks = [
639+ asyncio .create_task (consumer_task ("consumer_A" , 3 )),
640+ asyncio .create_task (consumer_task ("consumer_B" , 5 )),
641+ asyncio .create_task (consumer_task ("consumer_C" , 4 )),
642+ ]
643+
644+ # Wait for all consumers to finish.
645+ await asyncio .gather (* tasks )
646+
647+ # Verify results
648+ consumer_A_msgs = consumer_results .get ("consumer_A" , [])
649+ consumer_B_msgs = consumer_results .get ("consumer_B" , [])
650+ consumer_C_msgs = consumer_results .get ("consumer_C" , [])
651+
652+ # Each consumer should get the expected number of messages
653+ self .assertEqual (len (consumer_A_msgs ), 3 )
654+ self .assertEqual (len (consumer_B_msgs ), 5 )
655+ self .assertEqual (len (consumer_C_msgs ), 4 )
656+
657+ # All messages should be unique (no duplicates across consumers)
658+ all_received = consumer_A_msgs + consumer_B_msgs + consumer_C_msgs
659+ self .assertEqual (len (all_received ), len (set (all_received )))
660+
661+ # All received messages should be from our published set
662+ expected_msgs = {f"msg-{ i } " for i in range (num_msgs )}
663+ received_msgs = set (all_received )
664+ self .assertTrue (received_msgs .issubset (expected_msgs ))
665+
666+ # Verify we got exactly 12 unique messages total
667+ self .assertEqual (len (received_msgs ), 12 )
668+
669+ await nc .close ()
670+
671+ @async_test
672+ async def test_subscribe_async_generator_with_unsubscribe_limit (self ):
673+ """Test async generator respects unsubscribe max_msgs limit automatically"""
674+ nc = NATS ()
675+ await nc .connect ()
676+
677+ sub = await nc .subscribe ("test.unsub.limit" )
678+ await sub .unsubscribe (limit = 5 )
679+
680+ # Publish more messages than the limit
681+ num_msgs = 10
682+ for i in range (num_msgs ):
683+ await nc .publish ("test.unsub.limit" , f"msg-{ i } " .encode ())
684+ await nc .flush ()
685+
686+ received_msgs = []
687+ async for msg in sub .messages :
688+ received_msgs .append (msg .data .decode ())
689+ # Add small delay to ensure we don't race with the unsubscribe.
690+ await asyncio .sleep (0.01 )
691+
692+ # Should have received exactly 5 messages due to unsubscribe limit.
693+ self .assertEqual (len (received_msgs ), 5 , f"Expected 5 messages, got { len (received_msgs )} : { received_msgs } " )
694+
695+ # Messages should be the first 5 published.
696+ for i in range (5 ):
697+ self .assertIn (f"msg-{ i } " , received_msgs )
698+
699+ # Verify the subscription received the expected number.
700+ self .assertEqual (sub ._received , 5 )
701+
702+ # The generator should have stopped due to max_msgs limit being reached.
703+ self .assertEqual (sub ._max_msgs , 5 )
704+
705+ await nc .close ()
706+
707+ @async_test
708+ async def test_subscribe_concurrent_async_generators_auto_unsubscribe (self ):
709+ """Test multiple concurrent async generators on the same subscription"""
710+ nc = NATS ()
711+ await nc .connect ()
712+
713+ sub = await nc .subscribe ("test.concurrent" )
714+ await sub .unsubscribe (5 )
715+
716+ # Publish messages over the max msgs limit.
717+ num_msgs = 12
718+ for i in range (num_msgs ):
719+ await nc .publish ("test.concurrent" , f"msg-{ i } " .encode ())
720+ await nc .flush ()
721+
722+ # Track results from each consumer
723+ consumer_results = {}
724+
725+ async def consumer_task (consumer_id : str , max_messages : int = None ):
726+ """Consumer task that processes messages"""
727+ import random
728+ received = []
729+ try :
730+ async for msg in sub .messages :
731+ received .append (msg .data .decode ())
732+ # Add random processing delay to simulate real work
733+ await asyncio .sleep (random .uniform (0.01 , 0.05 ))
734+ if max_messages and len (received ) >= max_messages :
735+ break
736+
737+ # Once subscription reached max number of messages, it should unblock.
738+ except Exception as e :
739+ # Store the exception for later inspection
740+ consumer_results [consumer_id ] = f"Error: { e } "
741+ return
742+ consumer_results [consumer_id ] = received
743+
744+ # Start multiple concurrent consumers.
745+ tasks = [
746+ asyncio .create_task (consumer_task ("consumer_A" , 3 )),
747+ asyncio .create_task (consumer_task ("consumer_B" , 5 )),
748+ asyncio .create_task (consumer_task ("consumer_C" , 4 )),
749+ ]
750+
751+ # Wait for all consumers to finish.
752+ await asyncio .gather (* tasks )
753+
754+ # Verify results
755+ consumer_A_msgs = consumer_results .get ("consumer_A" , [])
756+ consumer_B_msgs = consumer_results .get ("consumer_B" , [])
757+ consumer_C_msgs = consumer_results .get ("consumer_C" , [])
758+
759+ # Each consumer should get the expected number of messages.
760+ total = len (consumer_A_msgs ) + len (consumer_B_msgs ) + len (consumer_C_msgs )
761+ self .assertEqual (total , 5 )
762+
763+ # All messages should be unique (no duplicates across consumers)
764+ all_received = consumer_A_msgs + consumer_B_msgs + consumer_C_msgs
765+ self .assertEqual (len (all_received ), len (set (all_received )))
766+
767+ # All received messages should be from our published set.
768+ expected_msgs = {f"msg-{ i } " for i in range (num_msgs )}
769+ received_msgs = set (all_received )
770+ self .assertTrue (received_msgs .issubset (expected_msgs ))
771+ self .assertEqual (len (received_msgs ), 5 )
772+
773+ await nc .close ()
774+
775+
603776 @async_test
604777 async def test_subscribe_async_generator_with_drain (self ):
605778 """Test async generator with drain functionality"""
0 commit comments