Skip to content

Commit 777a0e0

Browse files
committed
Add --client option to bench script
1 parent 6b3edd9 commit 777a0e0

File tree

1 file changed

+77
-20
lines changed

1 file changed

+77
-20
lines changed

nats-client/tools/bench.py

Lines changed: 77 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
import sys
99
import time
1010
from dataclasses import dataclass
11-
12-
from nats.client import Headers, connect
11+
from typing import Any
1312

1413

1514
@dataclass
@@ -43,16 +42,24 @@ def __str__(self) -> str:
4342

4443
async def run_pub_benchmark(
4544
*,
45+
client_type: str = "client",
4646
url: str = "nats://localhost:4222",
4747
msg_count: int = 100_000,
4848
msg_size: int = 128,
4949
pub_subject: str = "test",
50-
headers: Headers | None = None,
50+
headers: dict[str, str] | Any | None = None,
5151
) -> BenchmarkResults:
5252
"""Run publisher benchmark."""
5353

54-
# Connect to server
55-
nc = await connect(url)
54+
# Connect to server based on client type
55+
if client_type == "aio":
56+
import nats
57+
58+
nc = await nats.connect(url)
59+
else:
60+
from nats.client import connect
61+
62+
nc = await connect(url)
5663

5764
try:
5865
# Prepare payload
@@ -103,12 +110,23 @@ async def run_pub_benchmark(
103110

104111

105112
async def run_sub_benchmark(
106-
*, url: str = "nats://localhost:4222", msg_count: int = 100_000, sub_subject: str = "test"
113+
*,
114+
client_type: str = "client",
115+
url: str = "nats://localhost:4222",
116+
msg_count: int = 100_000,
117+
sub_subject: str = "test",
107118
) -> BenchmarkResults:
108119
"""Run subscriber benchmark."""
109120

110-
# Connect to server
111-
nc = await connect(url)
121+
# Connect to server based on client type
122+
if client_type == "aio":
123+
import nats
124+
125+
nc = await nats.connect(url)
126+
else:
127+
from nats.client import connect
128+
129+
nc = await connect(url)
112130
received = 0
113131
first_msg_time = 0.0
114132
last_msg_time = 0.0
@@ -120,8 +138,13 @@ async def run_sub_benchmark(
120138
sub = await nc.subscribe(sub_subject)
121139
start_time = time.perf_counter()
122140

123-
# Receive messages
124-
async for msg in sub:
141+
# Receive messages - handle different iterator styles
142+
if client_type == "aio":
143+
iterator = sub.messages
144+
else:
145+
iterator = sub
146+
147+
async for msg in iterator:
125148
msg_time = time.perf_counter()
126149
if received == 0:
127150
first_msg_time = msg_time
@@ -136,6 +159,9 @@ async def run_sub_benchmark(
136159

137160
duration = last_msg_time - first_msg_time
138161

162+
# Assert we received all expected messages
163+
assert received == msg_count, f"Message loss detected! Received {received}/{msg_count} messages"
164+
139165
# Calculate stats
140166
throughput = received / duration
141167
bytes_per_sec = total_bytes / duration
@@ -167,23 +193,26 @@ async def run_sub_benchmark(
167193

168194
async def run_pubsub_benchmark(
169195
*,
196+
client_type: str = "client",
170197
url: str = "nats://localhost:4222",
171198
msg_count: int = 100_000,
172199
msg_size: int = 128,
173200
subject: str = "test",
174-
headers: Headers | None = None,
201+
headers: dict[str, str] | Any | None = None,
175202
) -> tuple[BenchmarkResults, BenchmarkResults]:
176203
"""Run combined publisher/subscriber benchmark."""
177204

178205
# Start subscriber first
179-
sub_task = asyncio.create_task(run_sub_benchmark(url=url, msg_count=msg_count, sub_subject=subject))
206+
sub_task = asyncio.create_task(
207+
run_sub_benchmark(client_type=client_type, url=url, msg_count=msg_count, sub_subject=subject)
208+
)
180209

181210
# Small delay to ensure subscriber is ready
182211
await asyncio.sleep(0.1)
183212

184213
# Run publisher
185214
pub_results = await run_pub_benchmark(
186-
url=url, msg_count=msg_count, msg_size=msg_size, pub_subject=subject, headers=headers
215+
client_type=client_type, url=url, msg_count=msg_count, msg_size=msg_size, pub_subject=subject, headers=headers
187216
)
188217

189218
# Wait for subscriber to finish
@@ -195,6 +224,12 @@ async def run_pubsub_benchmark(
195224
def main():
196225
"""Main entry point."""
197226
parser = argparse.ArgumentParser(description="NATS benchmarking tool")
227+
parser.add_argument(
228+
"--client",
229+
choices=["client", "aio"],
230+
default="client",
231+
help="Client type to use: 'client' (nats-client) or 'aio' (nats.aio)",
232+
)
198233
parser.add_argument("--url", default="nats://localhost:4222", help="NATS server URL")
199234
parser.add_argument("--msgs", type=int, default=100_000, help="Number of messages to publish")
200235
parser.add_argument("--size", type=int, default=128, help="Size of the message payload")
@@ -213,27 +248,49 @@ def main():
213248
# Create headers if requested
214249
headers = None
215250
if args.headers:
216-
headers = Headers({f"key{i}": f"value{i}" for i in range(args.headers)})
251+
if args.client == "client":
252+
from nats.client import Headers
253+
254+
headers = Headers({f"key{i}": f"value{i}" for i in range(args.headers)})
255+
else:
256+
headers = {f"key{i}": f"value{i}" for i in range(args.headers)}
217257

218258
async def run():
259+
client_name = "nats-client" if args.client == "client" else "nats.aio"
219260
if args.pub and args.sub:
220-
sys.stdout.write(f"\nStarting pub/sub benchmark [msgs={args.msgs:,}, size={args.size:,} B]\n")
261+
sys.stdout.write(
262+
f"\nStarting pub/sub benchmark with {client_name} [msgs={args.msgs:,}, size={args.size:,} B]\n"
263+
)
221264
pub_results, sub_results = await run_pubsub_benchmark(
222-
url=args.url, msg_count=args.msgs, msg_size=args.size, subject=args.subject, headers=headers
265+
client_type=args.client,
266+
url=args.url,
267+
msg_count=args.msgs,
268+
msg_size=args.size,
269+
subject=args.subject,
270+
headers=headers,
223271
)
224272
sys.stdout.write(f"\nPublisher results: {pub_results}\n")
225273
sys.stdout.write(f"\nSubscriber results: {sub_results}\n")
226274

227275
elif args.pub:
228-
sys.stdout.write(f"\nStarting publisher benchmark [msgs={args.msgs:,}, size={args.size:,} B]\n")
276+
sys.stdout.write(
277+
f"\nStarting publisher benchmark with {client_name} [msgs={args.msgs:,}, size={args.size:,} B]\n"
278+
)
229279
results = await run_pub_benchmark(
230-
url=args.url, msg_count=args.msgs, msg_size=args.size, pub_subject=args.subject, headers=headers
280+
client_type=args.client,
281+
url=args.url,
282+
msg_count=args.msgs,
283+
msg_size=args.size,
284+
pub_subject=args.subject,
285+
headers=headers,
231286
)
232287
sys.stdout.write(f"\nResults: {results}\n")
233288

234289
elif args.sub:
235-
sys.stdout.write(f"\nStarting subscriber benchmark [msgs={args.msgs:,}]\n")
236-
results = await run_sub_benchmark(url=args.url, msg_count=args.msgs, sub_subject=args.subject)
290+
sys.stdout.write(f"\nStarting subscriber benchmark with {client_name} [msgs={args.msgs:,}]\n")
291+
results = await run_sub_benchmark(
292+
client_type=args.client, url=args.url, msg_count=args.msgs, sub_subject=args.subject
293+
)
237294
sys.stdout.write(f"\nResults: {results}\n")
238295

239296
asyncio.run(run())

0 commit comments

Comments
 (0)