Skip to content

Commit bc5924c

Browse files
committed
Add --client option to bench script
1 parent 6b3edd9 commit bc5924c

File tree

1 file changed

+52
-20
lines changed

1 file changed

+52
-20
lines changed

nats-client/tools/bench.py

Lines changed: 52 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
import sys
99
import time
1010
from dataclasses import dataclass
11-
12-
from nats.client import Headers, connect
11+
from typing import Any
1312

1413

1514
@dataclass
@@ -43,16 +42,22 @@ def __str__(self) -> str:
4342

4443
async def run_pub_benchmark(
4544
*,
45+
client_type: str = "client",
4646
url: str = "nats://localhost:4222",
4747
msg_count: int = 100_000,
4848
msg_size: int = 128,
4949
pub_subject: str = "test",
50-
headers: Headers | None = None,
50+
headers: dict[str, str] | Any | None = None,
5151
) -> BenchmarkResults:
5252
"""Run publisher benchmark."""
5353

54-
# Connect to server
55-
nc = await connect(url)
54+
# Connect to server based on client type
55+
if client_type == "aio":
56+
import nats
57+
nc = await nats.connect(url)
58+
else:
59+
from nats.client import connect
60+
nc = await connect(url)
5661

5762
try:
5863
# Prepare payload
@@ -103,12 +108,21 @@ async def run_pub_benchmark(
103108

104109

105110
async def run_sub_benchmark(
106-
*, url: str = "nats://localhost:4222", msg_count: int = 100_000, sub_subject: str = "test"
111+
*,
112+
client_type: str = "client",
113+
url: str = "nats://localhost:4222",
114+
msg_count: int = 100_000,
115+
sub_subject: str = "test"
107116
) -> BenchmarkResults:
108117
"""Run subscriber benchmark."""
109118

110-
# Connect to server
111-
nc = await connect(url)
119+
# Connect to server based on client type
120+
if client_type == "aio":
121+
import nats
122+
nc = await nats.connect(url)
123+
else:
124+
from nats.client import connect
125+
nc = await connect(url)
112126
received = 0
113127
first_msg_time = 0.0
114128
last_msg_time = 0.0
@@ -120,8 +134,13 @@ async def run_sub_benchmark(
120134
sub = await nc.subscribe(sub_subject)
121135
start_time = time.perf_counter()
122136

123-
# Receive messages
124-
async for msg in sub:
137+
# Receive messages - handle different iterator styles
138+
if client_type == "aio":
139+
iterator = sub.messages
140+
else:
141+
iterator = sub
142+
143+
async for msg in iterator:
125144
msg_time = time.perf_counter()
126145
if received == 0:
127146
first_msg_time = msg_time
@@ -136,6 +155,9 @@ async def run_sub_benchmark(
136155

137156
duration = last_msg_time - first_msg_time
138157

158+
# Assert we received all expected messages
159+
assert received == msg_count, f"Message loss detected! Received {received}/{msg_count} messages"
160+
139161
# Calculate stats
140162
throughput = received / duration
141163
bytes_per_sec = total_bytes / duration
@@ -167,23 +189,26 @@ async def run_sub_benchmark(
167189

168190
async def run_pubsub_benchmark(
169191
*,
192+
client_type: str = "client",
170193
url: str = "nats://localhost:4222",
171194
msg_count: int = 100_000,
172195
msg_size: int = 128,
173196
subject: str = "test",
174-
headers: Headers | None = None,
197+
headers: dict[str, str] | Any | None = None,
175198
) -> tuple[BenchmarkResults, BenchmarkResults]:
176199
"""Run combined publisher/subscriber benchmark."""
177200

178201
# Start subscriber first
179-
sub_task = asyncio.create_task(run_sub_benchmark(url=url, msg_count=msg_count, sub_subject=subject))
202+
sub_task = asyncio.create_task(
203+
run_sub_benchmark(client_type=client_type, url=url, msg_count=msg_count, sub_subject=subject)
204+
)
180205

181206
# Small delay to ensure subscriber is ready
182207
await asyncio.sleep(0.1)
183208

184209
# Run publisher
185210
pub_results = await run_pub_benchmark(
186-
url=url, msg_count=msg_count, msg_size=msg_size, pub_subject=subject, headers=headers
211+
client_type=client_type, url=url, msg_count=msg_count, msg_size=msg_size, pub_subject=subject, headers=headers
187212
)
188213

189214
# Wait for subscriber to finish
@@ -195,6 +220,8 @@ async def run_pubsub_benchmark(
195220
def main():
196221
"""Main entry point."""
197222
parser = argparse.ArgumentParser(description="NATS benchmarking tool")
223+
parser.add_argument("--client", choices=["client", "aio"], default="client",
224+
help="Client type to use: 'client' (nats-client) or 'aio' (nats.aio)")
198225
parser.add_argument("--url", default="nats://localhost:4222", help="NATS server URL")
199226
parser.add_argument("--msgs", type=int, default=100_000, help="Number of messages to publish")
200227
parser.add_argument("--size", type=int, default=128, help="Size of the message payload")
@@ -213,27 +240,32 @@ def main():
213240
# Create headers if requested
214241
headers = None
215242
if args.headers:
216-
headers = Headers({f"key{i}": f"value{i}" for i in range(args.headers)})
243+
if args.client == "client":
244+
from nats.client import Headers
245+
headers = Headers({f"key{i}": f"value{i}" for i in range(args.headers)})
246+
else:
247+
headers = {f"key{i}": f"value{i}" for i in range(args.headers)}
217248

218249
async def run():
250+
client_name = "nats-client" if args.client == "client" else "nats.aio"
219251
if args.pub and args.sub:
220-
sys.stdout.write(f"\nStarting pub/sub benchmark [msgs={args.msgs:,}, size={args.size:,} B]\n")
252+
sys.stdout.write(f"\nStarting pub/sub benchmark with {client_name} [msgs={args.msgs:,}, size={args.size:,} B]\n")
221253
pub_results, sub_results = await run_pubsub_benchmark(
222-
url=args.url, msg_count=args.msgs, msg_size=args.size, subject=args.subject, headers=headers
254+
client_type=args.client, url=args.url, msg_count=args.msgs, msg_size=args.size, subject=args.subject, headers=headers
223255
)
224256
sys.stdout.write(f"\nPublisher results: {pub_results}\n")
225257
sys.stdout.write(f"\nSubscriber results: {sub_results}\n")
226258

227259
elif args.pub:
228-
sys.stdout.write(f"\nStarting publisher benchmark [msgs={args.msgs:,}, size={args.size:,} B]\n")
260+
sys.stdout.write(f"\nStarting publisher benchmark with {client_name} [msgs={args.msgs:,}, size={args.size:,} B]\n")
229261
results = await run_pub_benchmark(
230-
url=args.url, msg_count=args.msgs, msg_size=args.size, pub_subject=args.subject, headers=headers
262+
client_type=args.client, url=args.url, msg_count=args.msgs, msg_size=args.size, pub_subject=args.subject, headers=headers
231263
)
232264
sys.stdout.write(f"\nResults: {results}\n")
233265

234266
elif args.sub:
235-
sys.stdout.write(f"\nStarting subscriber benchmark [msgs={args.msgs:,}]\n")
236-
results = await run_sub_benchmark(url=args.url, msg_count=args.msgs, sub_subject=args.subject)
267+
sys.stdout.write(f"\nStarting subscriber benchmark with {client_name} [msgs={args.msgs:,}]\n")
268+
results = await run_sub_benchmark(client_type=args.client, url=args.url, msg_count=args.msgs, sub_subject=args.subject)
237269
sys.stdout.write(f"\nResults: {results}\n")
238270

239271
asyncio.run(run())

0 commit comments

Comments
 (0)