88import sys
99import time
1010from dataclasses import dataclass
11-
12- from nats .client import Headers , connect
11+ from typing import Any
1312
1413
1514@dataclass
@@ -43,16 +42,22 @@ def __str__(self) -> str:
4342
4443async def run_pub_benchmark (
4544 * ,
45+ client_type : str = "client" ,
4646 url : str = "nats://localhost:4222" ,
4747 msg_count : int = 100_000 ,
4848 msg_size : int = 128 ,
4949 pub_subject : str = "test" ,
50- headers : Headers | None = None ,
50+ headers : dict [ str , str ] | Any | None = None ,
5151) -> BenchmarkResults :
5252 """Run publisher benchmark."""
5353
54- # Connect to server
55- nc = await connect (url )
54+ # Connect to server based on client type
55+ if client_type == "aio" :
56+ import nats
57+ nc = await nats .connect (url )
58+ else :
59+ from nats .client import connect
60+ nc = await connect (url )
5661
5762 try :
5863 # Prepare payload
@@ -103,12 +108,21 @@ async def run_pub_benchmark(
103108
104109
105110async def run_sub_benchmark (
106- * , url : str = "nats://localhost:4222" , msg_count : int = 100_000 , sub_subject : str = "test"
111+ * ,
112+ client_type : str = "client" ,
113+ url : str = "nats://localhost:4222" ,
114+ msg_count : int = 100_000 ,
115+ sub_subject : str = "test"
107116) -> BenchmarkResults :
108117 """Run subscriber benchmark."""
109118
110- # Connect to server
111- nc = await connect (url )
119+ # Connect to server based on client type
120+ if client_type == "aio" :
121+ import nats
122+ nc = await nats .connect (url )
123+ else :
124+ from nats .client import connect
125+ nc = await connect (url )
112126 received = 0
113127 first_msg_time = 0.0
114128 last_msg_time = 0.0
@@ -120,8 +134,13 @@ async def run_sub_benchmark(
120134 sub = await nc .subscribe (sub_subject )
121135 start_time = time .perf_counter ()
122136
123- # Receive messages
124- async for msg in sub :
137+ # Receive messages - handle different iterator styles
138+ if client_type == "aio" :
139+ iterator = sub .messages
140+ else :
141+ iterator = sub
142+
143+ async for msg in iterator :
125144 msg_time = time .perf_counter ()
126145 if received == 0 :
127146 first_msg_time = msg_time
@@ -136,6 +155,9 @@ async def run_sub_benchmark(
136155
137156 duration = last_msg_time - first_msg_time
138157
158+ # Assert we received all expected messages
159+ assert received == msg_count , f"Message loss detected! Received { received } /{ msg_count } messages"
160+
139161 # Calculate stats
140162 throughput = received / duration
141163 bytes_per_sec = total_bytes / duration
@@ -167,23 +189,26 @@ async def run_sub_benchmark(
167189
168190async def run_pubsub_benchmark (
169191 * ,
192+ client_type : str = "client" ,
170193 url : str = "nats://localhost:4222" ,
171194 msg_count : int = 100_000 ,
172195 msg_size : int = 128 ,
173196 subject : str = "test" ,
174- headers : Headers | None = None ,
197+ headers : dict [ str , str ] | Any | None = None ,
175198) -> tuple [BenchmarkResults , BenchmarkResults ]:
176199 """Run combined publisher/subscriber benchmark."""
177200
178201 # Start subscriber first
179- sub_task = asyncio .create_task (run_sub_benchmark (url = url , msg_count = msg_count , sub_subject = subject ))
202+ sub_task = asyncio .create_task (
203+ run_sub_benchmark (client_type = client_type , url = url , msg_count = msg_count , sub_subject = subject )
204+ )
180205
181206 # Small delay to ensure subscriber is ready
182207 await asyncio .sleep (0.1 )
183208
184209 # Run publisher
185210 pub_results = await run_pub_benchmark (
186- url = url , msg_count = msg_count , msg_size = msg_size , pub_subject = subject , headers = headers
211+ client_type = client_type , url = url , msg_count = msg_count , msg_size = msg_size , pub_subject = subject , headers = headers
187212 )
188213
189214 # Wait for subscriber to finish
@@ -195,6 +220,8 @@ async def run_pubsub_benchmark(
195220def main ():
196221 """Main entry point."""
197222 parser = argparse .ArgumentParser (description = "NATS benchmarking tool" )
223+ parser .add_argument ("--client" , choices = ["client" , "aio" ], default = "client" ,
224+ help = "Client type to use: 'client' (nats-client) or 'aio' (nats.aio)" )
198225 parser .add_argument ("--url" , default = "nats://localhost:4222" , help = "NATS server URL" )
199226 parser .add_argument ("--msgs" , type = int , default = 100_000 , help = "Number of messages to publish" )
200227 parser .add_argument ("--size" , type = int , default = 128 , help = "Size of the message payload" )
@@ -213,27 +240,32 @@ def main():
213240 # Create headers if requested
214241 headers = None
215242 if args .headers :
216- headers = Headers ({f"key{ i } " : f"value{ i } " for i in range (args .headers )})
243+ if args .client == "client" :
244+ from nats .client import Headers
245+ headers = Headers ({f"key{ i } " : f"value{ i } " for i in range (args .headers )})
246+ else :
247+ headers = {f"key{ i } " : f"value{ i } " for i in range (args .headers )}
217248
218249 async def run ():
250+ client_name = "nats-client" if args .client == "client" else "nats.aio"
219251 if args .pub and args .sub :
220- sys .stdout .write (f"\n Starting pub/sub benchmark [msgs={ args .msgs :,} , size={ args .size :,} B]\n " )
252+ sys .stdout .write (f"\n Starting pub/sub benchmark with { client_name } [msgs={ args .msgs :,} , size={ args .size :,} B]\n " )
221253 pub_results , sub_results = await run_pubsub_benchmark (
222- url = args .url , msg_count = args .msgs , msg_size = args .size , subject = args .subject , headers = headers
254+ client_type = args . client , url = args .url , msg_count = args .msgs , msg_size = args .size , subject = args .subject , headers = headers
223255 )
224256 sys .stdout .write (f"\n Publisher results: { pub_results } \n " )
225257 sys .stdout .write (f"\n Subscriber results: { sub_results } \n " )
226258
227259 elif args .pub :
228- sys .stdout .write (f"\n Starting publisher benchmark [msgs={ args .msgs :,} , size={ args .size :,} B]\n " )
260+ sys .stdout .write (f"\n Starting publisher benchmark with { client_name } [msgs={ args .msgs :,} , size={ args .size :,} B]\n " )
229261 results = await run_pub_benchmark (
230- url = args .url , msg_count = args .msgs , msg_size = args .size , pub_subject = args .subject , headers = headers
262+ client_type = args . client , url = args .url , msg_count = args .msgs , msg_size = args .size , pub_subject = args .subject , headers = headers
231263 )
232264 sys .stdout .write (f"\n Results: { results } \n " )
233265
234266 elif args .sub :
235- sys .stdout .write (f"\n Starting subscriber benchmark [msgs={ args .msgs :,} ]\n " )
236- results = await run_sub_benchmark (url = args .url , msg_count = args .msgs , sub_subject = args .subject )
267+ sys .stdout .write (f"\n Starting subscriber benchmark with { client_name } [msgs={ args .msgs :,} ]\n " )
268+ results = await run_sub_benchmark (client_type = args . client , url = args .url , msg_count = args .msgs , sub_subject = args .subject )
237269 sys .stdout .write (f"\n Results: { results } \n " )
238270
239271 asyncio .run (run ())
0 commit comments