2020 AsyncIterator ,
2121 Awaitable ,
2222 Callable ,
23- Dict ,
23+ List ,
2424 Optional ,
2525)
26+ from uuid import uuid4
2627
2728from nats import errors
2829
@@ -80,7 +81,6 @@ def __init__(
8081 self ._cb = cb
8182 self ._future = future
8283 self ._closed = False
83- self ._active_generators = 0 # Track active async generators
8484
8585 # Per subscription message processor.
8686 self ._pending_msgs_limit = pending_msgs_limit
@@ -89,12 +89,11 @@ def __init__(
8989 # If no callback, then this is a sync subscription which will
9090 # require tracking the next_msg calls inflight for cancelling.
9191 if cb is None :
92- self ._pending_next_msgs_calls : Optional [ Dict [ str , asyncio . Task ]] = {}
92+ self ._pending_next_msgs_calls = {}
9393 else :
9494 self ._pending_next_msgs_calls = None
9595 self ._pending_size = 0
9696 self ._wait_for_msgs_task = None
97- # For compatibility with tests that expect _message_iterator
9897 self ._message_iterator = None
9998
10099 # For JetStream enabled subscriptions.
@@ -130,61 +129,10 @@ def messages(self) -> AsyncIterator[Msg]:
130129 async for msg in sub.messages:
131130 print('Received', msg)
132131 """
133- if self ._cb :
132+ if not self ._message_iterator :
134133 raise errors .Error ("cannot iterate over messages with a non iteration subscription type" )
135134
136- return self ._message_generator ()
137-
138- async def _message_generator (self ) -> AsyncIterator [Msg ]:
139- """
140- Async generator that yields messages directly from the subscription queue.
141- """
142- yielded_count = 0
143- self ._active_generators += 1
144- try :
145- while True :
146- # Check if subscription was cancelled/closed.
147- if self ._closed :
148- break
149-
150- # Check if wrapper was cancelled (for compatibility with tests).
151- if (
152- hasattr (self , "_message_iterator" )
153- and self ._message_iterator
154- and self ._message_iterator ._unsubscribed_future .done ()
155- ):
156- break
157-
158- # Check max message limit based on how many we've yielded so far.
159- if self ._max_msgs > 0 and yielded_count >= self ._max_msgs :
160- break
161-
162- try :
163- msg = await self ._pending_queue .get ()
164- except asyncio .CancelledError :
165- break
166-
167- # Check for sentinel value which signals generator to stop.
168- if msg is None :
169- self ._pending_queue .task_done ()
170- break
171-
172- self ._pending_queue .task_done ()
173- self ._pending_size -= len (msg .data )
174-
175- yield msg
176- yielded_count += 1
177-
178- # Check if we should auto-unsubscribe after yielding this message.
179- if self ._max_msgs > 0 and yielded_count >= self ._max_msgs :
180- # Cancel the wrapper too for consistency.
181- if hasattr (self , "_message_iterator" ) and self ._message_iterator :
182- self ._message_iterator ._cancel ()
183- break
184- except asyncio .CancelledError :
185- pass
186- finally :
187- self ._active_generators -= 1
135+ return self ._message_iterator
188136
189137 @property
190138 def pending_msgs (self ) -> int :
@@ -212,7 +160,6 @@ def delivered(self) -> int:
212160 async def next_msg (self , timeout : Optional [float ] = 1.0 ) -> Msg :
213161 """
214162 :params timeout: Time in seconds to wait for next message before timing out.
215- Use 0 or None to wait forever (no timeout).
216163 :raises nats.errors.TimeoutError:
217164
218165 next_msg can be used to retrieve the next message from a stream of messages using
@@ -221,23 +168,22 @@ async def next_msg(self, timeout: Optional[float] = 1.0) -> Msg:
221168 sub = await nc.subscribe('hello')
222169 msg = await sub.next_msg(timeout=1)
223170
224- # Wait forever for a message
225- msg = await sub.next_msg(timeout=0)
226-
227171 """
172+
173+ async def timed_get () -> Msg :
174+ return await asyncio .wait_for (self ._pending_queue .get (), timeout )
175+
228176 if self ._conn .is_closed :
229177 raise errors .ConnectionClosedError
230178
231179 if self ._cb :
232180 raise errors .Error ("nats: next_msg cannot be used in async subscriptions" )
233181
182+ task_name = str (uuid4 ())
234183 try :
235- if timeout == 0 or timeout is None :
236- # Wait forever for a message
237- msg = await self ._pending_queue .get ()
238- else :
239- # Wait with timeout
240- msg = await asyncio .wait_for (self ._pending_queue .get (), timeout )
184+ future = asyncio .create_task (timed_get ())
185+ self ._pending_next_msgs_calls [task_name ] = future
186+ msg = await future
241187 except asyncio .TimeoutError :
242188 if self ._conn .is_closed :
243189 raise errors .ConnectionClosedError
@@ -253,6 +199,8 @@ async def next_msg(self, timeout: Optional[float] = 1.0) -> Msg:
253199 # regardless of whether it has been processed.
254200 self ._pending_queue .task_done ()
255201 return msg
202+ finally :
203+ self ._pending_next_msgs_calls .pop (task_name , None )
256204
257205 def _start (self , error_cb ):
258206 """
@@ -270,9 +218,7 @@ def _start(self, error_cb):
270218 # Used to handle the single response from a request.
271219 pass
272220 else :
273- # For async iteration, we now use a generator directly via the messages property
274- # But we create a compatibility wrapper for tests
275- self ._message_iterator = _CompatibilityIteratorWrapper (self )
221+ self ._message_iterator = _SubscriptionMessageIterator (self )
276222
277223 async def drain (self ):
278224 """
@@ -343,18 +289,9 @@ def _stop_processing(self) -> None:
343289 """
344290 if self ._wait_for_msgs_task and not self ._wait_for_msgs_task .done ():
345291 self ._wait_for_msgs_task .cancel ()
346- if hasattr ( self , "_message_iterator" ) and self ._message_iterator :
292+ if self ._message_iterator :
347293 self ._message_iterator ._cancel ()
348294
349- # Only put sentinel if there are active async generators
350- try :
351- if self ._pending_queue and self ._active_generators > 0 :
352- # Put a None sentinel to wake up any async generators
353- self ._pending_queue .put_nowait (None )
354- except Exception :
355- # Queue might be closed or full, that's ok
356- pass
357-
358295 async def _wait_for_msgs (self , error_cb ) -> None :
359296 """
360297 A coroutine to read and process messages if a callback is provided.
@@ -365,12 +302,6 @@ async def _wait_for_msgs(self, error_cb) -> None:
365302 while True :
366303 try :
367304 msg = await self ._pending_queue .get ()
368-
369- # Check for sentinel value (None) which signals task to stop
370- if msg is None :
371- self ._pending_queue .task_done ()
372- break
373-
374305 self ._pending_size -= len (msg .data )
375306
376307 try :
@@ -396,16 +327,35 @@ async def _wait_for_msgs(self, error_cb) -> None:
396327 break
397328
398329
399- class _CompatibilityIteratorWrapper :
400- """
401- Compatibility wrapper that provides the same interface as the old _SubscriptionMessageIterator
402- but uses the more efficient generator internally.
403- """
404-
330+ class _SubscriptionMessageIterator :
405331 def __init__ (self , sub : Subscription ) -> None :
406- self ._sub = sub
332+ self ._sub : Subscription = sub
333+ self ._queue : asyncio .Queue [Msg ] = sub ._pending_queue
407334 self ._unsubscribed_future : asyncio .Future [bool ] = asyncio .Future ()
408335
409336 def _cancel (self ) -> None :
410337 if not self ._unsubscribed_future .done ():
411338 self ._unsubscribed_future .set_result (True )
339+
340+ def __aiter__ (self ) -> _SubscriptionMessageIterator :
341+ return self
342+
343+ async def __anext__ (self ) -> Msg :
344+ get_task = asyncio .get_running_loop ().create_task (self ._queue .get ())
345+ tasks : List [asyncio .Future ] = [get_task , self ._unsubscribed_future ]
346+ finished , _ = await asyncio .wait (tasks , return_when = asyncio .FIRST_COMPLETED )
347+ sub = self ._sub
348+
349+ if get_task in finished :
350+ self ._queue .task_done ()
351+ msg = get_task .result ()
352+ self ._sub ._pending_size -= len (msg .data )
353+
354+ # Unblock the iterator in case it has already received enough messages.
355+ if sub ._max_msgs > 0 and sub ._received >= sub ._max_msgs :
356+ self ._cancel ()
357+ return msg
358+ elif self ._unsubscribed_future .done ():
359+ get_task .cancel ()
360+
361+ raise StopAsyncIteration
0 commit comments