Skip to content

Commit a5dc995

Browse files
committed
Restore nats directory
1 parent b4b1b3c commit a5dc995

File tree

6 files changed

+95
-243
lines changed

6 files changed

+95
-243
lines changed

nats/src/nats/aio/client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -754,6 +754,8 @@ async def _close(self, status: int, do_cbs: bool = True) -> None:
754754
# Async subs use join when draining already so just cancel here.
755755
if sub._wait_for_msgs_task and not sub._wait_for_msgs_task.done():
756756
sub._wait_for_msgs_task.cancel()
757+
if sub._message_iterator:
758+
sub._message_iterator._cancel()
757759
# Sync subs may have some inflight next_msg calls that could be blocking
758760
# so cancel them here to unblock them.
759761
if sub._pending_next_msgs_calls:

nats/src/nats/aio/subscription.py

Lines changed: 43 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@
2020
AsyncIterator,
2121
Awaitable,
2222
Callable,
23-
Dict,
23+
List,
2424
Optional,
2525
)
26+
from uuid import uuid4
2627

2728
from nats import errors
2829

@@ -80,7 +81,6 @@ def __init__(
8081
self._cb = cb
8182
self._future = future
8283
self._closed = False
83-
self._active_generators = 0 # Track active async generators
8484

8585
# Per subscription message processor.
8686
self._pending_msgs_limit = pending_msgs_limit
@@ -89,12 +89,11 @@ def __init__(
8989
# If no callback, then this is a sync subscription which will
9090
# require tracking the next_msg calls inflight for cancelling.
9191
if cb is None:
92-
self._pending_next_msgs_calls: Optional[Dict[str, asyncio.Task]] = {}
92+
self._pending_next_msgs_calls = {}
9393
else:
9494
self._pending_next_msgs_calls = None
9595
self._pending_size = 0
9696
self._wait_for_msgs_task = None
97-
# For compatibility with tests that expect _message_iterator
9897
self._message_iterator = None
9998

10099
# For JetStream enabled subscriptions.
@@ -130,61 +129,10 @@ def messages(self) -> AsyncIterator[Msg]:
130129
async for msg in sub.messages:
131130
print('Received', msg)
132131
"""
133-
if self._cb:
132+
if not self._message_iterator:
134133
raise errors.Error("cannot iterate over messages with a non iteration subscription type")
135134

136-
return self._message_generator()
137-
138-
async def _message_generator(self) -> AsyncIterator[Msg]:
139-
"""
140-
Async generator that yields messages directly from the subscription queue.
141-
"""
142-
yielded_count = 0
143-
self._active_generators += 1
144-
try:
145-
while True:
146-
# Check if subscription was cancelled/closed.
147-
if self._closed:
148-
break
149-
150-
# Check if wrapper was cancelled (for compatibility with tests).
151-
if (
152-
hasattr(self, "_message_iterator")
153-
and self._message_iterator
154-
and self._message_iterator._unsubscribed_future.done()
155-
):
156-
break
157-
158-
# Check max message limit based on how many we've yielded so far.
159-
if self._max_msgs > 0 and yielded_count >= self._max_msgs:
160-
break
161-
162-
try:
163-
msg = await self._pending_queue.get()
164-
except asyncio.CancelledError:
165-
break
166-
167-
# Check for sentinel value which signals generator to stop.
168-
if msg is None:
169-
self._pending_queue.task_done()
170-
break
171-
172-
self._pending_queue.task_done()
173-
self._pending_size -= len(msg.data)
174-
175-
yield msg
176-
yielded_count += 1
177-
178-
# Check if we should auto-unsubscribe after yielding this message.
179-
if self._max_msgs > 0 and yielded_count >= self._max_msgs:
180-
# Cancel the wrapper too for consistency.
181-
if hasattr(self, "_message_iterator") and self._message_iterator:
182-
self._message_iterator._cancel()
183-
break
184-
except asyncio.CancelledError:
185-
pass
186-
finally:
187-
self._active_generators -= 1
135+
return self._message_iterator
188136

189137
@property
190138
def pending_msgs(self) -> int:
@@ -212,7 +160,6 @@ def delivered(self) -> int:
212160
async def next_msg(self, timeout: Optional[float] = 1.0) -> Msg:
213161
"""
214162
:params timeout: Time in seconds to wait for next message before timing out.
215-
Use 0 or None to wait forever (no timeout).
216163
:raises nats.errors.TimeoutError:
217164
218165
next_msg can be used to retrieve the next message from a stream of messages using
@@ -221,23 +168,22 @@ async def next_msg(self, timeout: Optional[float] = 1.0) -> Msg:
221168
sub = await nc.subscribe('hello')
222169
msg = await sub.next_msg(timeout=1)
223170
224-
# Wait forever for a message
225-
msg = await sub.next_msg(timeout=0)
226-
227171
"""
172+
173+
async def timed_get() -> Msg:
174+
return await asyncio.wait_for(self._pending_queue.get(), timeout)
175+
228176
if self._conn.is_closed:
229177
raise errors.ConnectionClosedError
230178

231179
if self._cb:
232180
raise errors.Error("nats: next_msg cannot be used in async subscriptions")
233181

182+
task_name = str(uuid4())
234183
try:
235-
if timeout == 0 or timeout is None:
236-
# Wait forever for a message
237-
msg = await self._pending_queue.get()
238-
else:
239-
# Wait with timeout
240-
msg = await asyncio.wait_for(self._pending_queue.get(), timeout)
184+
future = asyncio.create_task(timed_get())
185+
self._pending_next_msgs_calls[task_name] = future
186+
msg = await future
241187
except asyncio.TimeoutError:
242188
if self._conn.is_closed:
243189
raise errors.ConnectionClosedError
@@ -253,6 +199,8 @@ async def next_msg(self, timeout: Optional[float] = 1.0) -> Msg:
253199
# regardless of whether it has been processed.
254200
self._pending_queue.task_done()
255201
return msg
202+
finally:
203+
self._pending_next_msgs_calls.pop(task_name, None)
256204

257205
def _start(self, error_cb):
258206
"""
@@ -270,9 +218,7 @@ def _start(self, error_cb):
270218
# Used to handle the single response from a request.
271219
pass
272220
else:
273-
# For async iteration, we now use a generator directly via the messages property
274-
# But we create a compatibility wrapper for tests
275-
self._message_iterator = _CompatibilityIteratorWrapper(self)
221+
self._message_iterator = _SubscriptionMessageIterator(self)
276222

277223
async def drain(self):
278224
"""
@@ -343,18 +289,9 @@ def _stop_processing(self) -> None:
343289
"""
344290
if self._wait_for_msgs_task and not self._wait_for_msgs_task.done():
345291
self._wait_for_msgs_task.cancel()
346-
if hasattr(self, "_message_iterator") and self._message_iterator:
292+
if self._message_iterator:
347293
self._message_iterator._cancel()
348294

349-
# Only put sentinel if there are active async generators
350-
try:
351-
if self._pending_queue and self._active_generators > 0:
352-
# Put a None sentinel to wake up any async generators
353-
self._pending_queue.put_nowait(None)
354-
except Exception:
355-
# Queue might be closed or full, that's ok
356-
pass
357-
358295
async def _wait_for_msgs(self, error_cb) -> None:
359296
"""
360297
A coroutine to read and process messages if a callback is provided.
@@ -365,12 +302,6 @@ async def _wait_for_msgs(self, error_cb) -> None:
365302
while True:
366303
try:
367304
msg = await self._pending_queue.get()
368-
369-
# Check for sentinel value (None) which signals task to stop
370-
if msg is None:
371-
self._pending_queue.task_done()
372-
break
373-
374305
self._pending_size -= len(msg.data)
375306

376307
try:
@@ -396,16 +327,35 @@ async def _wait_for_msgs(self, error_cb) -> None:
396327
break
397328

398329

399-
class _CompatibilityIteratorWrapper:
400-
"""
401-
Compatibility wrapper that provides the same interface as the old _SubscriptionMessageIterator
402-
but uses the more efficient generator internally.
403-
"""
404-
330+
class _SubscriptionMessageIterator:
405331
def __init__(self, sub: Subscription) -> None:
406-
self._sub = sub
332+
self._sub: Subscription = sub
333+
self._queue: asyncio.Queue[Msg] = sub._pending_queue
407334
self._unsubscribed_future: asyncio.Future[bool] = asyncio.Future()
408335

409336
def _cancel(self) -> None:
410337
if not self._unsubscribed_future.done():
411338
self._unsubscribed_future.set_result(True)
339+
340+
def __aiter__(self) -> _SubscriptionMessageIterator:
341+
return self
342+
343+
async def __anext__(self) -> Msg:
344+
get_task = asyncio.get_running_loop().create_task(self._queue.get())
345+
tasks: List[asyncio.Future] = [get_task, self._unsubscribed_future]
346+
finished, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
347+
sub = self._sub
348+
349+
if get_task in finished:
350+
self._queue.task_done()
351+
msg = get_task.result()
352+
self._sub._pending_size -= len(msg.data)
353+
354+
# Unblock the iterator in case it has already received enough messages.
355+
if sub._max_msgs > 0 and sub._received >= sub._max_msgs:
356+
self._cancel()
357+
return msg
358+
elif self._unsubscribed_future.done():
359+
get_task.cancel()
360+
361+
raise StopAsyncIteration

nats/src/nats/js/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -883,14 +883,14 @@ def __init__(
883883
self._cb = sub._cb
884884
self._future = sub._future
885885
self._closed = sub._closed
886-
self._active_generators = sub._active_generators
887886

888887
# Per subscription message processor.
889888
self._pending_msgs_limit = sub._pending_msgs_limit
890889
self._pending_bytes_limit = sub._pending_bytes_limit
891890
self._pending_queue = sub._pending_queue
892891
self._pending_size = sub._pending_size
893892
self._wait_for_msgs_task = sub._wait_for_msgs_task
893+
self._message_iterator = sub._message_iterator
894894
self._pending_next_msgs_calls = sub._pending_next_msgs_calls
895895

896896
async def consumer_info(self) -> api.ConsumerInfo:

nats/src/nats/js/object_store.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ async def get(
213213
else:
214214
executor_fn = writeinto.write
215215

216-
async for msg in sub.messages:
216+
async for msg in sub._message_iterator:
217217
tokens = msg._get_metadata_fields(msg.reply)
218218

219219
if executor:

0 commit comments

Comments
 (0)