88import sys
99import time
1010from dataclasses import dataclass
11-
12- from nats .client import Headers , connect
11+ from typing import Any
1312
1413
1514@dataclass
@@ -43,16 +42,24 @@ def __str__(self) -> str:
4342
4443async def run_pub_benchmark (
4544 * ,
45+ client_type : str = "client" ,
4646 url : str = "nats://localhost:4222" ,
4747 msg_count : int = 100_000 ,
4848 msg_size : int = 128 ,
4949 pub_subject : str = "test" ,
50- headers : Headers | None = None ,
50+ headers : dict [ str , str ] | Any | None = None ,
5151) -> BenchmarkResults :
5252 """Run publisher benchmark."""
5353
54- # Connect to server
55- nc = await connect (url )
54+ # Connect to server based on client type
55+ if client_type == "aio" :
56+ import nats
57+
58+ nc = await nats .connect (url )
59+ else :
60+ from nats .client import connect
61+
62+ nc = await connect (url )
5663
5764 try :
5865 # Prepare payload
@@ -103,12 +110,23 @@ async def run_pub_benchmark(
103110
104111
105112async def run_sub_benchmark (
106- * , url : str = "nats://localhost:4222" , msg_count : int = 100_000 , sub_subject : str = "test"
113+ * ,
114+ client_type : str = "client" ,
115+ url : str = "nats://localhost:4222" ,
116+ msg_count : int = 100_000 ,
117+ sub_subject : str = "test" ,
107118) -> BenchmarkResults :
108119 """Run subscriber benchmark."""
109120
110- # Connect to server
111- nc = await connect (url )
121+ # Connect to server based on client type
122+ if client_type == "aio" :
123+ import nats
124+
125+ nc = await nats .connect (url )
126+ else :
127+ from nats .client import connect
128+
129+ nc = await connect (url )
112130 received = 0
113131 first_msg_time = 0.0
114132 last_msg_time = 0.0
@@ -120,8 +138,13 @@ async def run_sub_benchmark(
120138 sub = await nc .subscribe (sub_subject )
121139 start_time = time .perf_counter ()
122140
123- # Receive messages
124- async for msg in sub :
141+ # Receive messages - handle different iterator styles
142+ if client_type == "aio" :
143+ iterator = sub .messages
144+ else :
145+ iterator = sub
146+
147+ async for msg in iterator :
125148 msg_time = time .perf_counter ()
126149 if received == 0 :
127150 first_msg_time = msg_time
@@ -136,6 +159,9 @@ async def run_sub_benchmark(
136159
137160 duration = last_msg_time - first_msg_time
138161
162+ # Assert we received all expected messages
163+ assert received == msg_count , f"Message loss detected! Received { received } /{ msg_count } messages"
164+
139165 # Calculate stats
140166 throughput = received / duration
141167 bytes_per_sec = total_bytes / duration
@@ -167,23 +193,26 @@ async def run_sub_benchmark(
167193
168194async def run_pubsub_benchmark (
169195 * ,
196+ client_type : str = "client" ,
170197 url : str = "nats://localhost:4222" ,
171198 msg_count : int = 100_000 ,
172199 msg_size : int = 128 ,
173200 subject : str = "test" ,
174- headers : Headers | None = None ,
201+ headers : dict [ str , str ] | Any | None = None ,
175202) -> tuple [BenchmarkResults , BenchmarkResults ]:
176203 """Run combined publisher/subscriber benchmark."""
177204
178205 # Start subscriber first
179- sub_task = asyncio .create_task (run_sub_benchmark (url = url , msg_count = msg_count , sub_subject = subject ))
206+ sub_task = asyncio .create_task (
207+ run_sub_benchmark (client_type = client_type , url = url , msg_count = msg_count , sub_subject = subject )
208+ )
180209
181210 # Small delay to ensure subscriber is ready
182211 await asyncio .sleep (0.1 )
183212
184213 # Run publisher
185214 pub_results = await run_pub_benchmark (
186- url = url , msg_count = msg_count , msg_size = msg_size , pub_subject = subject , headers = headers
215+ client_type = client_type , url = url , msg_count = msg_count , msg_size = msg_size , pub_subject = subject , headers = headers
187216 )
188217
189218 # Wait for subscriber to finish
@@ -195,6 +224,12 @@ async def run_pubsub_benchmark(
195224def main ():
196225 """Main entry point."""
197226 parser = argparse .ArgumentParser (description = "NATS benchmarking tool" )
227+ parser .add_argument (
228+ "--client" ,
229+ choices = ["client" , "aio" ],
230+ default = "client" ,
231+ help = "Client type to use: 'client' (nats-client) or 'aio' (nats.aio)" ,
232+ )
198233 parser .add_argument ("--url" , default = "nats://localhost:4222" , help = "NATS server URL" )
199234 parser .add_argument ("--msgs" , type = int , default = 100_000 , help = "Number of messages to publish" )
200235 parser .add_argument ("--size" , type = int , default = 128 , help = "Size of the message payload" )
@@ -213,27 +248,49 @@ def main():
213248 # Create headers if requested
214249 headers = None
215250 if args .headers :
216- headers = Headers ({f"key{ i } " : f"value{ i } " for i in range (args .headers )})
251+ if args .client == "client" :
252+ from nats .client import Headers
253+
254+ headers = Headers ({f"key{ i } " : f"value{ i } " for i in range (args .headers )})
255+ else :
256+ headers = {f"key{ i } " : f"value{ i } " for i in range (args .headers )}
217257
218258 async def run ():
259+ client_name = "nats-client" if args .client == "client" else "nats.aio"
219260 if args .pub and args .sub :
220- sys .stdout .write (f"\n Starting pub/sub benchmark [msgs={ args .msgs :,} , size={ args .size :,} B]\n " )
261+ sys .stdout .write (
262+ f"\n Starting pub/sub benchmark with { client_name } [msgs={ args .msgs :,} , size={ args .size :,} B]\n "
263+ )
221264 pub_results , sub_results = await run_pubsub_benchmark (
222- url = args .url , msg_count = args .msgs , msg_size = args .size , subject = args .subject , headers = headers
265+ client_type = args .client ,
266+ url = args .url ,
267+ msg_count = args .msgs ,
268+ msg_size = args .size ,
269+ subject = args .subject ,
270+ headers = headers ,
223271 )
224272 sys .stdout .write (f"\n Publisher results: { pub_results } \n " )
225273 sys .stdout .write (f"\n Subscriber results: { sub_results } \n " )
226274
227275 elif args .pub :
228- sys .stdout .write (f"\n Starting publisher benchmark [msgs={ args .msgs :,} , size={ args .size :,} B]\n " )
276+ sys .stdout .write (
277+ f"\n Starting publisher benchmark with { client_name } [msgs={ args .msgs :,} , size={ args .size :,} B]\n "
278+ )
229279 results = await run_pub_benchmark (
230- url = args .url , msg_count = args .msgs , msg_size = args .size , pub_subject = args .subject , headers = headers
280+ client_type = args .client ,
281+ url = args .url ,
282+ msg_count = args .msgs ,
283+ msg_size = args .size ,
284+ pub_subject = args .subject ,
285+ headers = headers ,
231286 )
232287 sys .stdout .write (f"\n Results: { results } \n " )
233288
234289 elif args .sub :
235- sys .stdout .write (f"\n Starting subscriber benchmark [msgs={ args .msgs :,} ]\n " )
236- results = await run_sub_benchmark (url = args .url , msg_count = args .msgs , sub_subject = args .subject )
290+ sys .stdout .write (f"\n Starting subscriber benchmark with { client_name } [msgs={ args .msgs :,} ]\n " )
291+ results = await run_sub_benchmark (
292+ client_type = args .client , url = args .url , msg_count = args .msgs , sub_subject = args .subject
293+ )
237294 sys .stdout .write (f"\n Results: { results } \n " )
238295
239296 asyncio .run (run ())
0 commit comments