Skip to content

Commit 2e4432e

Browse files
committed
Improve object store get usage
Signed-off-by: Waldemar Quevedo <[email protected]>
1 parent b735cd0 commit 2e4432e

File tree

5 files changed

+32
-26
lines changed

5 files changed

+32
-26
lines changed

nats/benchmark/sub_perf_messages.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,4 +80,4 @@ async def main():
8080

8181

8282
if __name__ == "__main__":
83-
asyncio.run(main())
83+
asyncio.run(main())

nats/src/nats/aio/subscription.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def __init__(
8080
self._cb = cb
8181
self._future = future
8282
self._closed = False
83+
self._active_generators = 0 # Track active async generators
8384

8485
# Per subscription message processor.
8586
self._pending_msgs_limit = pending_msgs_limit
@@ -138,14 +139,19 @@ async def _message_generator(self) -> AsyncIterator[Msg]:
138139
Async generator that yields messages directly from the subscription queue.
139140
"""
140141
yielded_count = 0
142+
self._active_generators += 1
141143
try:
142144
while True:
143145
# Check if subscription was cancelled/closed.
144146
if self._closed:
145147
break
146148

147149
# Check if wrapper was cancelled (for compatibility with tests).
148-
if hasattr(self, '_message_iterator') and self._message_iterator and self._message_iterator._unsubscribed_future.done():
150+
if (
151+
hasattr(self, "_message_iterator")
152+
and self._message_iterator
153+
and self._message_iterator._unsubscribed_future.done()
154+
):
149155
break
150156

151157
# Check max message limit based on how many we've yielded so far.
@@ -171,11 +177,13 @@ async def _message_generator(self) -> AsyncIterator[Msg]:
171177
# Check if we should auto-unsubscribe after yielding this message.
172178
if self._max_msgs > 0 and yielded_count >= self._max_msgs:
173179
# Cancel the wrapper too for consistency.
174-
if hasattr(self, '_message_iterator') and self._message_iterator:
180+
if hasattr(self, "_message_iterator") and self._message_iterator:
175181
self._message_iterator._cancel()
176182
break
177183
except asyncio.CancelledError:
178184
pass
185+
finally:
186+
self._active_generators -= 1
179187

180188
@property
181189
def pending_msgs(self) -> int:
@@ -334,15 +342,13 @@ def _stop_processing(self) -> None:
334342
"""
335343
if self._wait_for_msgs_task and not self._wait_for_msgs_task.done():
336344
self._wait_for_msgs_task.cancel()
337-
if hasattr(self, '_message_iterator') and self._message_iterator:
345+
if hasattr(self, "_message_iterator") and self._message_iterator:
338346
self._message_iterator._cancel()
339-
340-
# Put a sentinel in the queue to wake up any generators waiting on get()
341-
# This ensures they see the _closed flag and exit cleanly
347+
348+
# Only put sentinel if there are active async generators
342349
try:
343-
if self._pending_queue:
344-
# Put a None sentinel to wake up the generator
345-
# The generator will check _closed and exit
350+
if self._pending_queue and self._active_generators > 0:
351+
# Put a None sentinel to wake up any async generators
346352
self._pending_queue.put_nowait(None)
347353
except Exception:
348354
# Queue might be closed or full, that's ok
@@ -358,12 +364,12 @@ async def _wait_for_msgs(self, error_cb) -> None:
358364
while True:
359365
try:
360366
msg = await self._pending_queue.get()
361-
367+
362368
# Check for sentinel value (None) which signals task to stop
363369
if msg is None:
364370
self._pending_queue.task_done()
365371
break
366-
372+
367373
self._pending_size -= len(msg.data)
368374

369375
try:
@@ -394,12 +400,11 @@ class _CompatibilityIteratorWrapper:
394400
Compatibility wrapper that provides the same interface as the old _SubscriptionMessageIterator
395401
but uses the more efficient generator internally.
396402
"""
403+
397404
def __init__(self, sub: Subscription) -> None:
398405
self._sub = sub
399406
self._unsubscribed_future: asyncio.Future[bool] = asyncio.Future()
400407

401408
def _cancel(self) -> None:
402409
if not self._unsubscribed_future.done():
403410
self._unsubscribed_future.set_result(True)
404-
405-

nats/src/nats/js/client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -883,6 +883,7 @@ def __init__(
883883
self._cb = sub._cb
884884
self._future = sub._future
885885
self._closed = sub._closed
886+
self._active_generators = sub._active_generators
886887

887888
# Per subscription message processor.
888889
self._pending_msgs_limit = sub._pending_msgs_limit

nats/src/nats/js/object_store.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ async def get(
213213
else:
214214
executor_fn = writeinto.write
215215

216-
async for msg in sub._message_iterator:
216+
async for msg in sub.messages:
217217
tokens = msg._get_metadata_fields(msg.reply)
218218

219219
if executor:

nats/tests/test_client.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -560,7 +560,7 @@ async def iterator_func(sub):
560560
# Wait for iterator to complete after drain
561561
await asyncio.wait_for(fut, 1)
562562
await iterator_task # Ensure task cleanup
563-
563+
564564
self.assertEqual(5, len(msgs))
565565
self.assertEqual("tests.1", msgs[1].subject)
566566
self.assertEqual("tests.3", msgs[3].subject)
@@ -575,55 +575,55 @@ async def test_subscribe_async_generator(self):
575575
"""Test the optimized async generator implementation for sub.messages"""
576576
nc = NATS()
577577
await nc.connect()
578-
578+
579579
# Test basic async generator functionality
580580
sub = await nc.subscribe("test.generator")
581-
581+
582582
# Publish messages
583583
num_msgs = 10
584584
for i in range(num_msgs):
585585
await nc.publish("test.generator", f"msg-{i}".encode())
586586
await nc.flush()
587-
587+
588588
# Consume messages using async generator
589589
received_msgs = []
590590
async for msg in sub.messages:
591591
received_msgs.append(msg)
592592
if len(received_msgs) >= num_msgs:
593593
break
594-
594+
595595
# Verify all messages received correctly
596596
self.assertEqual(len(received_msgs), num_msgs)
597597
for i, msg in enumerate(received_msgs):
598598
self.assertEqual(msg.data, f"msg-{i}".encode())
599599
self.assertEqual(msg.subject, "test.generator")
600-
600+
601601
await nc.close()
602602

603603
@async_test
604604
async def test_subscribe_async_generator_with_drain(self):
605605
"""Test async generator with drain functionality"""
606606
nc = NATS()
607607
await nc.connect()
608-
608+
609609
sub = await nc.subscribe("test.drain")
610-
610+
611611
# Publish messages
612612
for i in range(5):
613613
await nc.publish("test.drain", f"drain-msg-{i}".encode())
614-
614+
615615
# Start consuming messages
616616
received_msgs = []
617617
async for msg in sub.messages:
618618
received_msgs.append(msg)
619619
# Drain after receiving all messages
620620
if len(received_msgs) == 5:
621621
await sub.drain()
622-
622+
623623
# Verify correct number of messages and drain worked
624624
self.assertEqual(len(received_msgs), 5)
625625
self.assertEqual(sub.pending_bytes, 0)
626-
626+
627627
await nc.close()
628628

629629
@async_test

0 commit comments

Comments
 (0)