Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions nats/src/nats/js/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class StatusCode(str, Enum):
NO_MESSAGES = "404"
REQUEST_TIMEOUT = "408"
CONFLICT = "409"
PIN_ID_MISMATCH = "423"
CONTROL_MESSAGE = "100"


Expand Down Expand Up @@ -491,6 +492,31 @@ class ReplayPolicy(str, Enum):
ORIGINAL = "original"


class PriorityPolicy(str, Enum):
"""Priority policy for priority groups.

Enables flexible failover and priority management when multiple clients are
pulling from the same consumer

Introduced in nats-server 2.12.0.

References:
* `Consumers, Pull consumer priority groups <https://docs.nats.io/release-notes/whats_new/whats_new_211#consumers>`
* `Consumers, Prioritized pull consumer policy <https://docs.nats.io/release-notes/whats_new/whats_new_212#consumers>`
""" # noqa: E501

NONE = ""
"default"
PINNED = "pinned_client"
"pins a consumer to a specific client"
OVERFLOW = "overflow"
"allows for restricting when a consumer will receive messages based on the number of pending messages or acks"
PRIORITIZED = "prioritized"
"""allows for restricting when a consumer will receive messages based on a priority from 0-9 (0 is highest priority & default)
Introduced in nats-server 2.12.0.
"""


@dataclass
class ConsumerConfig(Base):
"""Consumer configuration.
Expand Down Expand Up @@ -543,11 +569,25 @@ class ConsumerConfig(Base):
# Introduced in nats-server 2.11.0.
pause_until: Optional[str] = None

# Priority policy.
# Introduced in nats-server 2.11.0.
priority_policy: Optional[PriorityPolicy] = None

# The duration (seconds) after which the client will be unpinned if no new
# pull requests are sent.Used with PriorityPolicy.PINNED.
# Introduced in nats-server 2.11.0.
priority_timeout: Optional[float] = None

# Priority groups this consumer supports.
# Introduced in nats-server 2.11.0.
priority_groups: Optional[list[str]] = None

@classmethod
def from_response(cls, resp: Dict[str, Any]):
cls._convert_nanoseconds(resp, "ack_wait")
cls._convert_nanoseconds(resp, "idle_heartbeat")
cls._convert_nanoseconds(resp, "inactive_threshold")
cls._convert_nanoseconds(resp, "priority_timeout")
if "backoff" in resp:
resp["backoff"] = [val / _NANOSECOND for val in resp["backoff"]]
return super().from_response(resp)
Expand All @@ -557,6 +597,7 @@ def as_dict(self) -> Dict[str, object]:
result["ack_wait"] = self._to_nanoseconds(self.ack_wait)
result["idle_heartbeat"] = self._to_nanoseconds(self.idle_heartbeat)
result["inactive_threshold"] = self._to_nanoseconds(self.inactive_threshold)
result["priority_timeout"] = self._to_nanoseconds(self.priority_timeout)
if self.backoff:
result["backoff"] = [self._to_nanoseconds(i) for i in self.backoff]
return result
Expand All @@ -570,6 +611,14 @@ class SequenceInfo(Base):
# last_active: Optional[datetime]


@dataclass
class PriorityGroupState(Base):
group: str
pinned_client_id: str
# FIXME: Do not handle dates for now.
# pinned_ts: datetime


@dataclass
class ConsumerInfo(Base):
"""
Expand All @@ -595,6 +644,8 @@ class ConsumerInfo(Base):
# RFC 3339 timestamp until which the consumer is paused.
# Introduced in nats-server 2.11.0.
pause_remaining: Optional[str] = None
# Introduced in nats-server 2.11.0.
priority_groups: Optional[list[PriorityGroupState]] = None

@classmethod
def from_response(cls, resp: Dict[str, Any]):
Expand Down
107 changes: 103 additions & 4 deletions nats/src/nats/js/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,7 @@ async def pull_subscribe(
pending_msgs_limit: int = DEFAULT_JS_SUB_PENDING_MSGS_LIMIT,
pending_bytes_limit: int = DEFAULT_JS_SUB_PENDING_BYTES_LIMIT,
inbox_prefix: Optional[bytes] = None,
priority_group: Optional[str] = None,
) -> JetStreamContext.PullSubscription:
"""Create consumer and pull subscription.

Expand Down Expand Up @@ -580,6 +581,9 @@ async def main():
if stream is None:
stream = await self._jsm.find_stream_name_by_subject(subject)

if config and config.priority_groups and priority_group is None:
raise ValueError("nats: priority_group is required when consumer has priority_groups configured")

should_create = True
try:
if durable:
Expand All @@ -605,6 +609,10 @@ async def main():
consumer_name = self._nc._nuid.next().decode()
config.name = consumer_name

# Auto created consumers use the priority group, unless priority_groups is set.
if not config.priority_groups and priority_group:
config.priority_groups = [priority_group]

await self._jsm.add_consumer(stream, config=config)

return await self.pull_subscribe_bind(
Expand All @@ -614,6 +622,7 @@ async def main():
pending_bytes_limit=pending_bytes_limit,
pending_msgs_limit=pending_msgs_limit,
name=consumer_name,
priority_group=priority_group,
)

async def pull_subscribe_bind(
Expand All @@ -625,6 +634,7 @@ async def pull_subscribe_bind(
pending_bytes_limit: int = DEFAULT_JS_SUB_PENDING_BYTES_LIMIT,
name: Optional[str] = None,
durable: Optional[str] = None,
priority_group: Optional[str] = None,
) -> JetStreamContext.PullSubscription:
"""
pull_subscribe returns a `PullSubscription` that can be delivered messages
Expand Down Expand Up @@ -680,6 +690,7 @@ async def main():
stream=stream,
consumer=consumer_name,
deliver=deliver,
group=priority_group,
)

@classmethod
Expand All @@ -703,11 +714,16 @@ def _is_temporary_error(cls, status: Optional[str]) -> bool:
status == api.StatusCode.NO_MESSAGES
or status == api.StatusCode.CONFLICT
or status == api.StatusCode.REQUEST_TIMEOUT
or status == api.StatusCode.PIN_ID_MISMATCH
):
return True
else:
return False

@classmethod
def _is_pin_id_mismatch_error(cls, status: Optional[str]) -> bool:
return status == api.StatusCode.PIN_ID_MISMATCH

@classmethod
def _is_heartbeat(cls, status: Optional[str]) -> bool:
if status == api.StatusCode.CONTROL_MESSAGE:
Expand Down Expand Up @@ -997,6 +1013,7 @@ def __init__(
stream: str,
consumer: str,
deliver: bytes,
group: Optional[str] = None,
) -> None:
# JS/JSM context
self._js = js
Expand All @@ -1009,6 +1026,8 @@ def __init__(
prefix = self._js._prefix
self._nms = f"{prefix}.CONSUMER.MSG.NEXT.{stream}.{consumer}"
self._deliver = deliver.decode()
self._pin_id: Optional[str] = None
self._group = group

@property
def pending_msgs(self) -> int:
Expand Down Expand Up @@ -1055,6 +1074,9 @@ async def fetch(
batch: int = 1,
timeout: Optional[float] = 5,
heartbeat: Optional[float] = None,
min_pending: Optional[int] = None,
min_ack_pending: Optional[int] = None,
priority: Optional[int] = None,
) -> List[Msg]:
"""
fetch makes a request to JetStream to be delivered a set of messages.
Expand Down Expand Up @@ -1095,17 +1117,26 @@ async def main():

expires = int(timeout * 1_000_000_000) - 100_000 if timeout else None
if batch == 1:
msg = await self._fetch_one(expires, timeout, heartbeat)
msg = await self._fetch_one(expires, timeout, heartbeat, min_pending, min_ack_pending, priority)
return [msg]
msgs = await self._fetch_n(batch, expires, timeout, heartbeat)
msgs = await self._fetch_n(batch, expires, timeout, heartbeat, min_pending, min_ack_pending, priority)
return msgs

async def _fetch_one(
self,
expires: Optional[int],
timeout: Optional[float],
heartbeat: Optional[float] = None,
min_pending: Optional[int] = None,
min_ack_pending: Optional[int] = None,
priority: Optional[int] = None,
) -> Msg:
if min_pending is not None and not (min_pending > 0):
raise ValueError("nats: min_pending must be more than 0")
if min_ack_pending is not None and not (min_ack_pending > 0):
raise ValueError("nats: min_ack_pending must be more than 0")
if priority is not None and not (0 <= priority <= 9):
raise ValueError("nats: priority must be 0-9")
queue = self._sub._pending_queue

# Check the next message in case there are any.
Expand All @@ -1130,7 +1161,17 @@ async def _fetch_one(
next_req["expires"] = int(expires)
if heartbeat:
next_req["idle_heartbeat"] = int(heartbeat * 1_000_000_000) # to nanoseconds

if self._group:
next_req["group"] = self._group
pin_id = self.pin_id
if pin_id:
next_req["id"] = pin_id
if min_pending:
next_req["min_pending"] = min_pending
if min_ack_pending:
next_req["min_ack_pending"] = min_ack_pending
if priority:
next_req["priority"] = priority
await self._nc.publish(
self._nms,
json.dumps(next_req).encode(),
Expand All @@ -1152,13 +1193,19 @@ async def _fetch_one(
got_any_response = True
continue

if JetStreamContext._is_pin_id_mismatch_error(status):
self.pin_id = ""

# In case of a temporary error, treat it as a timeout to retry.
if JetStreamContext._is_temporary_error(status):
raise nats.errors.TimeoutError
else:
# Any other type of status message is an error.
raise nats.js.errors.APIError.from_msg(msg)
else:
pin_id = msg.headers.get("Nats-Pin-Id") if msg.headers else None
if pin_id:
self.pin_id = pin_id
return msg
except asyncio.TimeoutError:
deadline = JetStreamContext._time_until(timeout, start_time)
Expand All @@ -1177,6 +1224,9 @@ async def _fetch_n(
expires: Optional[int],
timeout: Optional[float],
heartbeat: Optional[float] = None,
min_pending: Optional[int] = None,
min_ack_pending: Optional[int] = None,
priority: Optional[int] = None,
) -> List[Msg]:
msgs = []
queue = self._sub._pending_queue
Expand Down Expand Up @@ -1210,6 +1260,17 @@ async def _fetch_n(
if heartbeat:
next_req["idle_heartbeat"] = int(heartbeat * 1_000_000_000) # to nanoseconds
next_req["no_wait"] = True
if self._group:
next_req["group"] = self._group
pin_id = self.pin_id
if pin_id:
next_req["id"] = pin_id
if min_pending:
next_req["min_pending"] = min_pending
if min_ack_pending:
next_req["min_ack_pending"] = min_ack_pending
if priority:
next_req["priority"] = priority
await self._nc.publish(
self._nms,
json.dumps(next_req).encode(),
Expand All @@ -1233,8 +1294,13 @@ async def _fetch_n(
# a possible i/o timeout error or due to a disconnection.
got_any_response = True
pass
elif JetStreamContext._is_pin_id_mismatch_error(status):
self.pin_id = ""
elif JetStreamContext._is_processable_msg(status, msg):
# First processable message received, do not raise error from now.
pin_id = msg.headers.get("Nats-Pin-Id") if msg.headers else None
if pin_id:
self.pin_id = pin_id
msgs.append(msg)
needed -= 1

Expand All @@ -1251,7 +1317,12 @@ async def _fetch_n(
# Skip heartbeats.
got_any_response = True
continue
elif JetStreamContext._is_pin_id_mismatch_error(status):
self.pin_id = ""
elif JetStreamContext._is_processable_msg(status, msg):
pin_id = msg.headers.get("Nats-Pin-Id") if msg.headers else None
if pin_id:
self.pin_id = pin_id
needed -= 1
msgs.append(msg)
except asyncio.TimeoutError:
Expand All @@ -1271,7 +1342,17 @@ async def _fetch_n(
next_req["expires"] = expires
if heartbeat:
next_req["idle_heartbeat"] = int(heartbeat * 1_000_000_000) # to nanoseconds

if self._group:
next_req["group"] = self._group
pin_id = self.pin_id
if pin_id:
next_req["id"] = pin_id
if min_pending:
next_req["min_pending"] = min_pending
if min_ack_pending:
next_req["min_ack_pending"] = min_ack_pending
if priority:
next_req["priority"] = priority
await self._nc.publish(
self._nms,
json.dumps(next_req).encode(),
Expand Down Expand Up @@ -1310,8 +1391,13 @@ async def _fetch_n(
if JetStreamContext._is_heartbeat(status):
got_any_response = True
continue
if JetStreamContext._is_pin_id_mismatch_error(status):
self.pin_id = ""

if not status:
pin_id = msg.headers.get("Nats-Pin-Id") if msg.headers else None
if pin_id:
self.pin_id = pin_id
needed -= 1
msgs.append(msg)
break
Expand All @@ -1335,7 +1421,12 @@ async def _fetch_n(
if JetStreamContext._is_heartbeat(status):
got_any_response = True
continue
if JetStreamContext._is_pin_id_mismatch_error(status):
self.pin_id = ""
if JetStreamContext._is_processable_msg(status, msg):
pin_id = msg.headers.get("Nats-Pin-Id") if msg.headers else None
if pin_id:
self.pin_id = pin_id
needed -= 1
msgs.append(msg)
except asyncio.TimeoutError:
Expand All @@ -1348,6 +1439,14 @@ async def _fetch_n(

return msgs

@property
def pin_id(self) -> Optional[str]:
return self._pin_id

@pin_id.setter
def pin_id(self, pin_id: str) -> None:
self._pin_id = pin_id

######################
# #
# KeyValue Context #
Expand Down
Loading
Loading