Skip to content

Commit 424c353

Browse files
committed
Add jetstream consumer priority groups
1 parent 170ae2c commit 424c353

File tree

5 files changed

+498
-4
lines changed

5 files changed

+498
-4
lines changed

nats/src/nats/js/api.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class StatusCode(str, Enum):
4646
NO_MESSAGES = "404"
4747
REQUEST_TIMEOUT = "408"
4848
CONFLICT = "409"
49+
PIN_ID_MISMATCH = "423"
4950
CONTROL_MESSAGE = "100"
5051

5152

@@ -491,6 +492,31 @@ class ReplayPolicy(str, Enum):
491492
ORIGINAL = "original"
492493

493494

495+
class PriorityPolicy(str, Enum):
496+
"""Priority policy for priority groups.
497+
498+
Enables flexible failover and priority management when multiple clients are
499+
pulling from the same consumer
500+
501+
Introduced in nats-server 2.12.0.
502+
503+
References:
504+
* `Consumers, Pull consumer priority groups <https://docs.nats.io/release-notes/whats_new/whats_new_211#consumers>`
505+
* `Consumers, Prioritized pull consumer policy <https://docs.nats.io/release-notes/whats_new/whats_new_212#consumers>`
506+
""" # noqa: E501
507+
508+
NONE = ""
509+
"default"
510+
PINNED = "pinned_client"
511+
"pins a consumer to a specific client"
512+
OVERFLOW = "overflow"
513+
"allows for restricting when a consumer will receive messages based on the number of pending messages or acks"
514+
PRIORITIZED = "prioritized"
515+
"""allows for restricting when a consumer will receive messages based on a priority from 0-9 (0 is highest priority & default)
516+
Introduced in nats-server 2.12.0.
517+
"""
518+
519+
494520
@dataclass
495521
class ConsumerConfig(Base):
496522
"""Consumer configuration.
@@ -543,11 +569,25 @@ class ConsumerConfig(Base):
543569
# Introduced in nats-server 2.11.0.
544570
pause_until: Optional[str] = None
545571

572+
# Priority policy.
573+
# Introduced in nats-server 2.11.0.
574+
priority_policy: Optional[PriorityPolicy] = None
575+
576+
# The duration (seconds) after which the client will be unpinned if no new
577+
# pull requests are sent.Used with PriorityPolicy.PINNED.
578+
# Introduced in nats-server 2.11.0.
579+
priority_timeout: Optional[float] = None
580+
581+
# Priority groups this consumer supports.
582+
# Introduced in nats-server 2.11.0.
583+
priority_groups: Optional[list[str]] = None
584+
546585
@classmethod
547586
def from_response(cls, resp: Dict[str, Any]):
548587
cls._convert_nanoseconds(resp, "ack_wait")
549588
cls._convert_nanoseconds(resp, "idle_heartbeat")
550589
cls._convert_nanoseconds(resp, "inactive_threshold")
590+
cls._convert_nanoseconds(resp, "priority_timeout")
551591
if "backoff" in resp:
552592
resp["backoff"] = [val / _NANOSECOND for val in resp["backoff"]]
553593
return super().from_response(resp)
@@ -557,6 +597,7 @@ def as_dict(self) -> Dict[str, object]:
557597
result["ack_wait"] = self._to_nanoseconds(self.ack_wait)
558598
result["idle_heartbeat"] = self._to_nanoseconds(self.idle_heartbeat)
559599
result["inactive_threshold"] = self._to_nanoseconds(self.inactive_threshold)
600+
result["priority_timeout"] = self._to_nanoseconds(self.priority_timeout)
560601
if self.backoff:
561602
result["backoff"] = [self._to_nanoseconds(i) for i in self.backoff]
562603
return result
@@ -570,6 +611,14 @@ class SequenceInfo(Base):
570611
# last_active: Optional[datetime]
571612

572613

614+
@dataclass
615+
class PriorityGroupState(Base):
616+
group: str
617+
pinned_client_id: str
618+
# FIXME: Do not handle dates for now.
619+
# pinned_ts: datetime
620+
621+
573622
@dataclass
574623
class ConsumerInfo(Base):
575624
"""
@@ -595,6 +644,8 @@ class ConsumerInfo(Base):
595644
# RFC 3339 timestamp until which the consumer is paused.
596645
# Introduced in nats-server 2.11.0.
597646
pause_remaining: Optional[str] = None
647+
# Introduced in nats-server 2.11.0.
648+
priority_groups: Optional[list[PriorityGroupState]] = None
598649

599650
@classmethod
600651
def from_response(cls, resp: Dict[str, Any]):

nats/src/nats/js/client.py

Lines changed: 103 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,7 @@ async def pull_subscribe(
546546
pending_msgs_limit: int = DEFAULT_JS_SUB_PENDING_MSGS_LIMIT,
547547
pending_bytes_limit: int = DEFAULT_JS_SUB_PENDING_BYTES_LIMIT,
548548
inbox_prefix: Optional[bytes] = None,
549+
priority_group: Optional[str] = None,
549550
) -> JetStreamContext.PullSubscription:
550551
"""Create consumer and pull subscription.
551552
@@ -580,6 +581,9 @@ async def main():
580581
if stream is None:
581582
stream = await self._jsm.find_stream_name_by_subject(subject)
582583

584+
if config and config.priority_groups and priority_group is None:
585+
raise ValueError("nats: priority_group is required when consumer has priority_groups configured")
586+
583587
should_create = True
584588
try:
585589
if durable:
@@ -605,6 +609,10 @@ async def main():
605609
consumer_name = self._nc._nuid.next().decode()
606610
config.name = consumer_name
607611

612+
# Auto created consumers use the priority group, unless priority_groups is set.
613+
if not config.priority_groups and priority_group:
614+
config.priority_groups = [priority_group]
615+
608616
await self._jsm.add_consumer(stream, config=config)
609617

610618
return await self.pull_subscribe_bind(
@@ -614,6 +622,7 @@ async def main():
614622
pending_bytes_limit=pending_bytes_limit,
615623
pending_msgs_limit=pending_msgs_limit,
616624
name=consumer_name,
625+
priority_group=priority_group,
617626
)
618627

619628
async def pull_subscribe_bind(
@@ -625,6 +634,7 @@ async def pull_subscribe_bind(
625634
pending_bytes_limit: int = DEFAULT_JS_SUB_PENDING_BYTES_LIMIT,
626635
name: Optional[str] = None,
627636
durable: Optional[str] = None,
637+
priority_group: Optional[str] = None,
628638
) -> JetStreamContext.PullSubscription:
629639
"""
630640
pull_subscribe returns a `PullSubscription` that can be delivered messages
@@ -680,6 +690,7 @@ async def main():
680690
stream=stream,
681691
consumer=consumer_name,
682692
deliver=deliver,
693+
group=priority_group,
683694
)
684695

685696
@classmethod
@@ -703,11 +714,16 @@ def _is_temporary_error(cls, status: Optional[str]) -> bool:
703714
status == api.StatusCode.NO_MESSAGES
704715
or status == api.StatusCode.CONFLICT
705716
or status == api.StatusCode.REQUEST_TIMEOUT
717+
or status == api.StatusCode.PIN_ID_MISMATCH
706718
):
707719
return True
708720
else:
709721
return False
710722

723+
@classmethod
724+
def _is_pin_id_mismatch_error(cls, status: Optional[str]) -> bool:
725+
return status == api.StatusCode.PIN_ID_MISMATCH
726+
711727
@classmethod
712728
def _is_heartbeat(cls, status: Optional[str]) -> bool:
713729
if status == api.StatusCode.CONTROL_MESSAGE:
@@ -997,6 +1013,7 @@ def __init__(
9971013
stream: str,
9981014
consumer: str,
9991015
deliver: bytes,
1016+
group: Optional[str] = None,
10001017
) -> None:
10011018
# JS/JSM context
10021019
self._js = js
@@ -1009,6 +1026,8 @@ def __init__(
10091026
prefix = self._js._prefix
10101027
self._nms = f"{prefix}.CONSUMER.MSG.NEXT.{stream}.{consumer}"
10111028
self._deliver = deliver.decode()
1029+
self._pin_id: Optional[str] = None
1030+
self._group = group
10121031

10131032
@property
10141033
def pending_msgs(self) -> int:
@@ -1055,6 +1074,9 @@ async def fetch(
10551074
batch: int = 1,
10561075
timeout: Optional[float] = 5,
10571076
heartbeat: Optional[float] = None,
1077+
min_pending: Optional[int] = None,
1078+
min_ack_pending: Optional[int] = None,
1079+
priority: Optional[int] = None,
10581080
) -> List[Msg]:
10591081
"""
10601082
fetch makes a request to JetStream to be delivered a set of messages.
@@ -1095,17 +1117,26 @@ async def main():
10951117

10961118
expires = int(timeout * 1_000_000_000) - 100_000 if timeout else None
10971119
if batch == 1:
1098-
msg = await self._fetch_one(expires, timeout, heartbeat)
1120+
msg = await self._fetch_one(expires, timeout, heartbeat, min_pending, min_ack_pending, priority)
10991121
return [msg]
1100-
msgs = await self._fetch_n(batch, expires, timeout, heartbeat)
1122+
msgs = await self._fetch_n(batch, expires, timeout, heartbeat, min_pending, min_ack_pending, priority)
11011123
return msgs
11021124

11031125
async def _fetch_one(
11041126
self,
11051127
expires: Optional[int],
11061128
timeout: Optional[float],
11071129
heartbeat: Optional[float] = None,
1130+
min_pending: Optional[int] = None,
1131+
min_ack_pending: Optional[int] = None,
1132+
priority: Optional[int] = None,
11081133
) -> Msg:
1134+
if min_pending is not None and not (min_pending > 0):
1135+
raise ValueError("nats: min_pending must be more than 0")
1136+
if min_ack_pending is not None and not (min_ack_pending > 0):
1137+
raise ValueError("nats: min_ack_pending must be more than 0")
1138+
if priority is not None and not (0 <= priority <= 9):
1139+
raise ValueError("nats: priority must be 0-9")
11091140
queue = self._sub._pending_queue
11101141

11111142
# Check the next message in case there are any.
@@ -1130,7 +1161,17 @@ async def _fetch_one(
11301161
next_req["expires"] = int(expires)
11311162
if heartbeat:
11321163
next_req["idle_heartbeat"] = int(heartbeat * 1_000_000_000) # to nanoseconds
1133-
1164+
if self._group:
1165+
next_req["group"] = self._group
1166+
pin_id = self.pin_id
1167+
if pin_id:
1168+
next_req["id"] = pin_id
1169+
if min_pending:
1170+
next_req["min_pending"] = min_pending
1171+
if min_ack_pending:
1172+
next_req["min_ack_pending"] = min_ack_pending
1173+
if priority:
1174+
next_req["priority"] = priority
11341175
await self._nc.publish(
11351176
self._nms,
11361177
json.dumps(next_req).encode(),
@@ -1152,13 +1193,19 @@ async def _fetch_one(
11521193
got_any_response = True
11531194
continue
11541195

1196+
if JetStreamContext._is_pin_id_mismatch_error(status):
1197+
self.pin_id = ""
1198+
11551199
# In case of a temporary error, treat it as a timeout to retry.
11561200
if JetStreamContext._is_temporary_error(status):
11571201
raise nats.errors.TimeoutError
11581202
else:
11591203
# Any other type of status message is an error.
11601204
raise nats.js.errors.APIError.from_msg(msg)
11611205
else:
1206+
pin_id = msg.headers.get("Nats-Pin-Id") if msg.headers else None
1207+
if pin_id:
1208+
self.pin_id = pin_id
11621209
return msg
11631210
except asyncio.TimeoutError:
11641211
deadline = JetStreamContext._time_until(timeout, start_time)
@@ -1177,6 +1224,9 @@ async def _fetch_n(
11771224
expires: Optional[int],
11781225
timeout: Optional[float],
11791226
heartbeat: Optional[float] = None,
1227+
min_pending: Optional[int] = None,
1228+
min_ack_pending: Optional[int] = None,
1229+
priority: Optional[int] = None,
11801230
) -> List[Msg]:
11811231
msgs = []
11821232
queue = self._sub._pending_queue
@@ -1210,6 +1260,17 @@ async def _fetch_n(
12101260
if heartbeat:
12111261
next_req["idle_heartbeat"] = int(heartbeat * 1_000_000_000) # to nanoseconds
12121262
next_req["no_wait"] = True
1263+
if self._group:
1264+
next_req["group"] = self._group
1265+
pin_id = self.pin_id
1266+
if pin_id:
1267+
next_req["id"] = pin_id
1268+
if min_pending:
1269+
next_req["min_pending"] = min_pending
1270+
if min_ack_pending:
1271+
next_req["min_ack_pending"] = min_ack_pending
1272+
if priority:
1273+
next_req["priority"] = priority
12131274
await self._nc.publish(
12141275
self._nms,
12151276
json.dumps(next_req).encode(),
@@ -1233,8 +1294,13 @@ async def _fetch_n(
12331294
# a possible i/o timeout error or due to a disconnection.
12341295
got_any_response = True
12351296
pass
1297+
elif JetStreamContext._is_pin_id_mismatch_error(status):
1298+
self.pin_id = ""
12361299
elif JetStreamContext._is_processable_msg(status, msg):
12371300
# First processable message received, do not raise error from now.
1301+
pin_id = msg.headers.get("Nats-Pin-Id") if msg.headers else None
1302+
if pin_id:
1303+
self.pin_id = pin_id
12381304
msgs.append(msg)
12391305
needed -= 1
12401306

@@ -1251,7 +1317,12 @@ async def _fetch_n(
12511317
# Skip heartbeats.
12521318
got_any_response = True
12531319
continue
1320+
elif JetStreamContext._is_pin_id_mismatch_error(status):
1321+
self.pin_id = ""
12541322
elif JetStreamContext._is_processable_msg(status, msg):
1323+
pin_id = msg.headers.get("Nats-Pin-Id") if msg.headers else None
1324+
if pin_id:
1325+
self.pin_id = pin_id
12551326
needed -= 1
12561327
msgs.append(msg)
12571328
except asyncio.TimeoutError:
@@ -1271,7 +1342,17 @@ async def _fetch_n(
12711342
next_req["expires"] = expires
12721343
if heartbeat:
12731344
next_req["idle_heartbeat"] = int(heartbeat * 1_000_000_000) # to nanoseconds
1274-
1345+
if self._group:
1346+
next_req["group"] = self._group
1347+
pin_id = self.pin_id
1348+
if pin_id:
1349+
next_req["id"] = pin_id
1350+
if min_pending:
1351+
next_req["min_pending"] = min_pending
1352+
if min_ack_pending:
1353+
next_req["min_ack_pending"] = min_ack_pending
1354+
if priority:
1355+
next_req["priority"] = priority
12751356
await self._nc.publish(
12761357
self._nms,
12771358
json.dumps(next_req).encode(),
@@ -1310,8 +1391,13 @@ async def _fetch_n(
13101391
if JetStreamContext._is_heartbeat(status):
13111392
got_any_response = True
13121393
continue
1394+
if JetStreamContext._is_pin_id_mismatch_error(status):
1395+
self.pin_id = ""
13131396

13141397
if not status:
1398+
pin_id = msg.headers.get("Nats-Pin-Id") if msg.headers else None
1399+
if pin_id:
1400+
self.pin_id = pin_id
13151401
needed -= 1
13161402
msgs.append(msg)
13171403
break
@@ -1335,7 +1421,12 @@ async def _fetch_n(
13351421
if JetStreamContext._is_heartbeat(status):
13361422
got_any_response = True
13371423
continue
1424+
if JetStreamContext._is_pin_id_mismatch_error(status):
1425+
self.pin_id = ""
13381426
if JetStreamContext._is_processable_msg(status, msg):
1427+
pin_id = msg.headers.get("Nats-Pin-Id") if msg.headers else None
1428+
if pin_id:
1429+
self.pin_id = pin_id
13391430
needed -= 1
13401431
msgs.append(msg)
13411432
except asyncio.TimeoutError:
@@ -1348,6 +1439,14 @@ async def _fetch_n(
13481439

13491440
return msgs
13501441

1442+
@property
1443+
def pin_id(self) -> Optional[str]:
1444+
return self._pin_id
1445+
1446+
@pin_id.setter
1447+
def pin_id(self, pin_id: str) -> None:
1448+
self._pin_id = pin_id
1449+
13511450
######################
13521451
# #
13531452
# KeyValue Context #

0 commit comments

Comments
 (0)