diff --git a/nats/src/nats/js/api.py b/nats/src/nats/js/api.py index cdd254db..5e2a205d 100644 --- a/nats/src/nats/js/api.py +++ b/nats/src/nats/js/api.py @@ -46,6 +46,7 @@ class StatusCode(str, Enum): NO_MESSAGES = "404" REQUEST_TIMEOUT = "408" CONFLICT = "409" + PIN_ID_MISMATCH = "423" CONTROL_MESSAGE = "100" @@ -491,6 +492,31 @@ class ReplayPolicy(str, Enum): ORIGINAL = "original" +class PriorityPolicy(str, Enum): + """Priority policy for priority groups. + + Enables flexible failover and priority management when multiple clients are + pulling from the same consumer + + Introduced in nats-server 2.12.0. + + References: + * `Consumers, Pull consumer priority groups ` + * `Consumers, Prioritized pull consumer policy ` + """ # noqa: E501 + + NONE = "" + "default" + PINNED = "pinned_client" + "pins a consumer to a specific client" + OVERFLOW = "overflow" + "allows for restricting when a consumer will receive messages based on the number of pending messages or acks" + PRIORITIZED = "prioritized" + """allows for restricting when a consumer will receive messages based on a priority from 0-9 (0 is highest priority & default) + Introduced in nats-server 2.12.0. + """ + + @dataclass class ConsumerConfig(Base): """Consumer configuration. @@ -543,11 +569,25 @@ class ConsumerConfig(Base): # Introduced in nats-server 2.11.0. pause_until: Optional[str] = None + # Priority policy. + # Introduced in nats-server 2.11.0. + priority_policy: Optional[PriorityPolicy] = None + + # The duration (seconds) after which the client will be unpinned if no new + # pull requests are sent.Used with PriorityPolicy.PINNED. + # Introduced in nats-server 2.11.0. + priority_timeout: Optional[float] = None + + # Priority groups this consumer supports. + # Introduced in nats-server 2.11.0. + priority_groups: Optional[list[str]] = None + @classmethod def from_response(cls, resp: Dict[str, Any]): cls._convert_nanoseconds(resp, "ack_wait") cls._convert_nanoseconds(resp, "idle_heartbeat") cls._convert_nanoseconds(resp, "inactive_threshold") + cls._convert_nanoseconds(resp, "priority_timeout") if "backoff" in resp: resp["backoff"] = [val / _NANOSECOND for val in resp["backoff"]] return super().from_response(resp) @@ -557,6 +597,7 @@ def as_dict(self) -> Dict[str, object]: result["ack_wait"] = self._to_nanoseconds(self.ack_wait) result["idle_heartbeat"] = self._to_nanoseconds(self.idle_heartbeat) result["inactive_threshold"] = self._to_nanoseconds(self.inactive_threshold) + result["priority_timeout"] = self._to_nanoseconds(self.priority_timeout) if self.backoff: result["backoff"] = [self._to_nanoseconds(i) for i in self.backoff] return result @@ -570,6 +611,14 @@ class SequenceInfo(Base): # last_active: Optional[datetime] +@dataclass +class PriorityGroupState(Base): + group: str + pinned_client_id: str + # FIXME: Do not handle dates for now. + # pinned_ts: datetime + + @dataclass class ConsumerInfo(Base): """ @@ -595,6 +644,8 @@ class ConsumerInfo(Base): # RFC 3339 timestamp until which the consumer is paused. # Introduced in nats-server 2.11.0. pause_remaining: Optional[str] = None + # Introduced in nats-server 2.11.0. + priority_groups: Optional[list[PriorityGroupState]] = None @classmethod def from_response(cls, resp: Dict[str, Any]): diff --git a/nats/src/nats/js/client.py b/nats/src/nats/js/client.py index 66279635..3eeb5796 100644 --- a/nats/src/nats/js/client.py +++ b/nats/src/nats/js/client.py @@ -546,6 +546,7 @@ async def pull_subscribe( pending_msgs_limit: int = DEFAULT_JS_SUB_PENDING_MSGS_LIMIT, pending_bytes_limit: int = DEFAULT_JS_SUB_PENDING_BYTES_LIMIT, inbox_prefix: Optional[bytes] = None, + priority_group: Optional[str] = None, ) -> JetStreamContext.PullSubscription: """Create consumer and pull subscription. @@ -580,6 +581,9 @@ async def main(): if stream is None: stream = await self._jsm.find_stream_name_by_subject(subject) + if config and config.priority_groups and priority_group is None: + raise ValueError("nats: priority_group is required when consumer has priority_groups configured") + should_create = True try: if durable: @@ -605,6 +609,10 @@ async def main(): consumer_name = self._nc._nuid.next().decode() config.name = consumer_name + # Auto created consumers use the priority group, unless priority_groups is set. + if not config.priority_groups and priority_group: + config.priority_groups = [priority_group] + await self._jsm.add_consumer(stream, config=config) return await self.pull_subscribe_bind( @@ -614,6 +622,7 @@ async def main(): pending_bytes_limit=pending_bytes_limit, pending_msgs_limit=pending_msgs_limit, name=consumer_name, + priority_group=priority_group, ) async def pull_subscribe_bind( @@ -625,6 +634,7 @@ async def pull_subscribe_bind( pending_bytes_limit: int = DEFAULT_JS_SUB_PENDING_BYTES_LIMIT, name: Optional[str] = None, durable: Optional[str] = None, + priority_group: Optional[str] = None, ) -> JetStreamContext.PullSubscription: """ pull_subscribe returns a `PullSubscription` that can be delivered messages @@ -680,6 +690,7 @@ async def main(): stream=stream, consumer=consumer_name, deliver=deliver, + group=priority_group, ) @classmethod @@ -703,11 +714,16 @@ def _is_temporary_error(cls, status: Optional[str]) -> bool: status == api.StatusCode.NO_MESSAGES or status == api.StatusCode.CONFLICT or status == api.StatusCode.REQUEST_TIMEOUT + or status == api.StatusCode.PIN_ID_MISMATCH ): return True else: return False + @classmethod + def _is_pin_id_mismatch_error(cls, status: Optional[str]) -> bool: + return status == api.StatusCode.PIN_ID_MISMATCH + @classmethod def _is_heartbeat(cls, status: Optional[str]) -> bool: if status == api.StatusCode.CONTROL_MESSAGE: @@ -997,6 +1013,7 @@ def __init__( stream: str, consumer: str, deliver: bytes, + group: Optional[str] = None, ) -> None: # JS/JSM context self._js = js @@ -1009,6 +1026,8 @@ def __init__( prefix = self._js._prefix self._nms = f"{prefix}.CONSUMER.MSG.NEXT.{stream}.{consumer}" self._deliver = deliver.decode() + self._pin_id: Optional[str] = None + self._group = group @property def pending_msgs(self) -> int: @@ -1055,6 +1074,9 @@ async def fetch( batch: int = 1, timeout: Optional[float] = 5, heartbeat: Optional[float] = None, + min_pending: Optional[int] = None, + min_ack_pending: Optional[int] = None, + priority: Optional[int] = None, ) -> List[Msg]: """ fetch makes a request to JetStream to be delivered a set of messages. @@ -1095,9 +1117,9 @@ async def main(): expires = int(timeout * 1_000_000_000) - 100_000 if timeout else None if batch == 1: - msg = await self._fetch_one(expires, timeout, heartbeat) + msg = await self._fetch_one(expires, timeout, heartbeat, min_pending, min_ack_pending, priority) return [msg] - msgs = await self._fetch_n(batch, expires, timeout, heartbeat) + msgs = await self._fetch_n(batch, expires, timeout, heartbeat, min_pending, min_ack_pending, priority) return msgs async def _fetch_one( @@ -1105,7 +1127,16 @@ async def _fetch_one( expires: Optional[int], timeout: Optional[float], heartbeat: Optional[float] = None, + min_pending: Optional[int] = None, + min_ack_pending: Optional[int] = None, + priority: Optional[int] = None, ) -> Msg: + if min_pending is not None and not (min_pending > 0): + raise ValueError("nats: min_pending must be more than 0") + if min_ack_pending is not None and not (min_ack_pending > 0): + raise ValueError("nats: min_ack_pending must be more than 0") + if priority is not None and not (0 <= priority <= 9): + raise ValueError("nats: priority must be 0-9") queue = self._sub._pending_queue # Check the next message in case there are any. @@ -1130,7 +1161,17 @@ async def _fetch_one( next_req["expires"] = int(expires) if heartbeat: next_req["idle_heartbeat"] = int(heartbeat * 1_000_000_000) # to nanoseconds - + if self._group: + next_req["group"] = self._group + pin_id = self.pin_id + if pin_id: + next_req["id"] = pin_id + if min_pending: + next_req["min_pending"] = min_pending + if min_ack_pending: + next_req["min_ack_pending"] = min_ack_pending + if priority: + next_req["priority"] = priority await self._nc.publish( self._nms, json.dumps(next_req).encode(), @@ -1152,6 +1193,9 @@ async def _fetch_one( got_any_response = True continue + if JetStreamContext._is_pin_id_mismatch_error(status): + self.pin_id = "" + # In case of a temporary error, treat it as a timeout to retry. if JetStreamContext._is_temporary_error(status): raise nats.errors.TimeoutError @@ -1159,6 +1203,9 @@ async def _fetch_one( # Any other type of status message is an error. raise nats.js.errors.APIError.from_msg(msg) else: + pin_id = msg.headers.get("Nats-Pin-Id") if msg.headers else None + if pin_id: + self.pin_id = pin_id return msg except asyncio.TimeoutError: deadline = JetStreamContext._time_until(timeout, start_time) @@ -1177,6 +1224,9 @@ async def _fetch_n( expires: Optional[int], timeout: Optional[float], heartbeat: Optional[float] = None, + min_pending: Optional[int] = None, + min_ack_pending: Optional[int] = None, + priority: Optional[int] = None, ) -> List[Msg]: msgs = [] queue = self._sub._pending_queue @@ -1210,6 +1260,17 @@ async def _fetch_n( if heartbeat: next_req["idle_heartbeat"] = int(heartbeat * 1_000_000_000) # to nanoseconds next_req["no_wait"] = True + if self._group: + next_req["group"] = self._group + pin_id = self.pin_id + if pin_id: + next_req["id"] = pin_id + if min_pending: + next_req["min_pending"] = min_pending + if min_ack_pending: + next_req["min_ack_pending"] = min_ack_pending + if priority: + next_req["priority"] = priority await self._nc.publish( self._nms, json.dumps(next_req).encode(), @@ -1233,8 +1294,13 @@ async def _fetch_n( # a possible i/o timeout error or due to a disconnection. got_any_response = True pass + elif JetStreamContext._is_pin_id_mismatch_error(status): + self.pin_id = "" elif JetStreamContext._is_processable_msg(status, msg): # First processable message received, do not raise error from now. + pin_id = msg.headers.get("Nats-Pin-Id") if msg.headers else None + if pin_id: + self.pin_id = pin_id msgs.append(msg) needed -= 1 @@ -1251,7 +1317,12 @@ async def _fetch_n( # Skip heartbeats. got_any_response = True continue + elif JetStreamContext._is_pin_id_mismatch_error(status): + self.pin_id = "" elif JetStreamContext._is_processable_msg(status, msg): + pin_id = msg.headers.get("Nats-Pin-Id") if msg.headers else None + if pin_id: + self.pin_id = pin_id needed -= 1 msgs.append(msg) except asyncio.TimeoutError: @@ -1271,7 +1342,17 @@ async def _fetch_n( next_req["expires"] = expires if heartbeat: next_req["idle_heartbeat"] = int(heartbeat * 1_000_000_000) # to nanoseconds - + if self._group: + next_req["group"] = self._group + pin_id = self.pin_id + if pin_id: + next_req["id"] = pin_id + if min_pending: + next_req["min_pending"] = min_pending + if min_ack_pending: + next_req["min_ack_pending"] = min_ack_pending + if priority: + next_req["priority"] = priority await self._nc.publish( self._nms, json.dumps(next_req).encode(), @@ -1310,8 +1391,13 @@ async def _fetch_n( if JetStreamContext._is_heartbeat(status): got_any_response = True continue + if JetStreamContext._is_pin_id_mismatch_error(status): + self.pin_id = "" if not status: + pin_id = msg.headers.get("Nats-Pin-Id") if msg.headers else None + if pin_id: + self.pin_id = pin_id needed -= 1 msgs.append(msg) break @@ -1335,7 +1421,12 @@ async def _fetch_n( if JetStreamContext._is_heartbeat(status): got_any_response = True continue + if JetStreamContext._is_pin_id_mismatch_error(status): + self.pin_id = "" if JetStreamContext._is_processable_msg(status, msg): + pin_id = msg.headers.get("Nats-Pin-Id") if msg.headers else None + if pin_id: + self.pin_id = pin_id needed -= 1 msgs.append(msg) except asyncio.TimeoutError: @@ -1348,6 +1439,14 @@ async def _fetch_n( return msgs + @property + def pin_id(self) -> Optional[str]: + return self._pin_id + + @pin_id.setter + def pin_id(self, pin_id: str) -> None: + self._pin_id = pin_id + ###################### # # # KeyValue Context # diff --git a/nats/src/nats/js/errors.py b/nats/src/nats/js/errors.py index d594587d..cbb55d36 100644 --- a/nats/src/nats/js/errors.py +++ b/nats/src/nats/js/errors.py @@ -83,6 +83,8 @@ def from_error(cls, err: Dict[str, Any]): raise ServiceUnavailableError(**err) elif code == 500: raise ServerError(**err) + elif code == 423: + raise PinIdMismatchError(**err) elif code == 404: raise NotFoundError(**err) elif code == 400: @@ -112,6 +114,17 @@ class ServerError(APIError): pass +class PinIdMismatchError(APIError): + """ + A 423 error + + PinIdMismatchError is returned when Pin ID sent in the request does not match + the currently pinned consumer subscriber ID on the server. + """ + + pass + + class NotFoundError(APIError): """ A 404 error diff --git a/nats/src/nats/js/manager.py b/nats/src/nats/js/manager.py index 59f6e5f8..19df91e4 100644 --- a/nats/src/nats/js/manager.py +++ b/nats/src/nats/js/manager.py @@ -420,6 +420,15 @@ async def get_last_msg( """ return await self.get_msg(stream_name, subject=subject, direct=direct) + async def unpin_consumer(self, stream_name: str, consumer_name: str, group: str) -> None: + """ + unpin_consumer unpins a pinned consumer. + """ + req_subject = f"{self._prefix}.CONSUMER.UNPIN.{stream_name}.{consumer_name}" + req = {"group": group} + data = json.dumps(req) + _ = await self._api_request(req_subject, data.encode()) + async def _api_request( self, req_subject: str, diff --git a/nats/tests/test_js.py b/nats/tests/test_js.py index 10c5f329..b9d7946d 100644 --- a/nats/tests/test_js.py +++ b/nats/tests/test_js.py @@ -3100,7 +3100,7 @@ async def error_handler(e): assert config.template_owner == None version = nc.connected_server_version - if version.major == 2 and version.minor < 9: + if version.major == 2 and (version.minor < 9 or version.minor > 12): assert config.allow_direct == None else: assert config.allow_direct == False @@ -4040,14 +4040,21 @@ async def error_handler(e): assert sinfo.config.max_msgs == -1 assert sinfo.config.max_bytes == -1 assert sinfo.config.discard == "new" - assert sinfo.config.max_age == 0 + version = nc.connected_server_version + if version.major == 2 and version.minor > 12: + assert sinfo.config.max_age is None + else: + assert sinfo.config.max_age == 0 assert sinfo.config.max_msgs_per_subject == -1 assert sinfo.config.max_msg_size == -1 assert sinfo.config.storage == "file" assert sinfo.config.num_replicas == 1 assert sinfo.config.allow_rollup_hdrs == True assert sinfo.config.allow_direct == True - assert sinfo.config.mirror_direct == False + if version.major == 2 and version.minor > 12: + assert sinfo.config.mirror_direct is None + else: + assert sinfo.config.mirror_direct == False bucketname = "".join(random.SystemRandom().choice(string.ascii_letters) for _ in range(10)) obs = await js.create_object_store(bucket=bucketname) @@ -4844,7 +4851,11 @@ async def test_stream_compression(self): compression="none", ) sinfo = await js.stream_info("NONE") - assert sinfo.config.compression == nats.js.api.StoreCompression.NONE + version = nc.connected_server_version + if version.major == 2 and version.minor > 12: + assert sinfo.config.compression is None + else: + assert sinfo.config.compression == nats.js.api.StoreCompression.NONE # By default it should be using 'none' as the configured compression value. js = nc.jetstream() @@ -4853,7 +4864,10 @@ async def test_stream_compression(self): subjects=["quux"], ) sinfo = await js.stream_info("NONE2") - assert sinfo.config.compression == nats.js.api.StoreCompression.NONE + if version.major == 2 and version.minor > 12: + assert sinfo.config.compression is None + else: + assert sinfo.config.compression == nats.js.api.StoreCompression.NONE await nc.close() @async_test @@ -5146,3 +5160,344 @@ async def test_add_stream_invalid_names(self): ), ): await js.add_stream(name=name) + + +class PriorityGroupsFeaturesTest(SingleJetStreamServerTestCase): + @async_test + async def test_consumer_overflow(self): + nc = await nats.connect() + + server_version = nc.connected_server_version + if server_version.major == 2 and server_version.minor < 12: + pytest.skip("consumer group overflow requires nats-server v2.11.0 or later") + + js = nc.jetstream() + + # create stream + await js.add_stream( + name="PRIORITIES", + subjects=["foo"], + ) + + # create consumer with overflow priority policy + cinfo = await js.add_consumer( + "PRIORITIES", + nats.js.api.ConsumerConfig( + priority_policy=nats.js.api.PriorityPolicy.OVERFLOW, + priority_groups=["A"], + ), + ) + assert cinfo.config.priority_policy == nats.js.api.PriorityPolicy.OVERFLOW + + # 1. Below threshold - no messages delivered + # - publish 100 msgs + # - fetch 10 msgs with min_pending 110 + # - should not get any msgs since 100<110 + psub = await js.pull_subscribe_bind( + cinfo.name, + cinfo.stream_name, + priority_group="A", + ) + for i in range(0, 100): + await js.publish("foo", f"{i}".encode()) + with pytest.raises(TimeoutError): + msgs = await psub.fetch(10, timeout=0.5, min_pending=110) + await psub.unsubscribe() + + # 2. Above threshold - messages delivered + # - publish 100 more msgs + # - fetch 10 msgs with min_pending 110 + # - should get 10 msgs since (200-10)>110 + psub = await js.pull_subscribe_bind( + cinfo.name, + cinfo.stream_name, + priority_group="A", + ) + for i in range(0, 100): + await js.publish("foo", f"{i}".encode()) + msgs = await psub.fetch(10, timeout=0.5, min_pending=110) + assert len(msgs) == 10 + for msg in msgs: # clean up + await msg.ack_sync() + await psub.unsubscribe() + + # 3: MinAckPending - no unacked messages yet + # - fetch 10 msgs with min_ack_pending 10 + # - should get 0 msgs since no pending acks currently + psub = await js.pull_subscribe_bind( + cinfo.name, + cinfo.stream_name, + priority_group="A", + ) + with pytest.raises(TimeoutError): + msgs = await psub.fetch(10, timeout=0.5, min_ack_pending=10) + await psub.unsubscribe() + + # 4: MinAckPending threshold met + # - create 10 pending acks + # - fetch 10 msgs with min_ack_pending 10 + # - should get 10 msgs since 10 pending acks >=10 + # NOTE: the psub's buffer queue can get filled with extra messages which + # leak into subsequent fetch calls, so to check unbuffered behavior we + # use separate subs + psub1 = await js.pull_subscribe_bind( + cinfo.name, + cinfo.stream_name, + priority_group="A", + ) + psub2 = await js.pull_subscribe_bind( + cinfo.name, + cinfo.stream_name, + priority_group="A", + ) + unacked_msgs = await psub1.fetch(10, timeout=0.5) + msgs = await psub2.fetch(10, timeout=0.5, min_ack_pending=10) + assert len(msgs) == 10 + for msg in unacked_msgs + msgs: # clean up + await msg.ack_sync() + + await nc.close() + + @async_test + async def test_consumer_pinned(self): + nc = await nats.connect() + + server_version = nc.connected_server_version + if server_version.major == 2 and server_version.minor < 12: + pytest.skip("consumer group pinning requires nats-server v2.11.0 or later") + + js = nc.jetstream() + + # create stream + await js.add_stream( + name="PRIORITIES", + subjects=["foo"], + ) + + # create consumer with pinned priority policy + cinfo = await js.add_consumer( + "PRIORITIES", + nats.js.api.ConsumerConfig( + priority_policy=nats.js.api.PriorityPolicy.PINNED, + priority_timeout=1.0, + priority_groups=["A"], + ), + ) + assert cinfo.config.priority_policy == nats.js.api.PriorityPolicy.PINNED + + # publish messages + for i in range(100): + await js.publish("foo", f"{i}".encode()) + + # 1. Priority group validation - invalid group + psub = await js.pull_subscribe_bind( + cinfo.name, + cinfo.stream_name, + priority_group="BAD", + ) + with pytest.raises(nats.js.errors.APIError, match="Invalid Priority Group"): + await psub.fetch(10, timeout=0.5) + await psub.unsubscribe() + + # 2. Priority group validation - no group + psub = await js.pull_subscribe_bind( + cinfo.name, + cinfo.stream_name, + ) + with pytest.raises(nats.js.errors.APIError, match="Priority Group missing"): + await psub.fetch(10, timeout=0.5) + await psub.unsubscribe() + + # 3. First consumer gets pinned + psub1 = await js.pull_subscribe_bind( + cinfo.name, + cinfo.stream_name, + priority_group="A", + ) + msgs = await psub1.fetch(10, timeout=0.5) + assert len(msgs) == 10 + first_pin_id = msgs[0].headers.get("Nats-Pin-Id") if msgs[0].headers else None + assert first_pin_id is not None + # all messages should have same pin id + for msg in msgs: + assert msg.headers.get("Nats-Pin-Id") == first_pin_id + await msg.ack_sync() + + # 4. Different consumer instance can't fetch while pinned + psub2 = await js.pull_subscribe_bind( + cinfo.name, + cinfo.stream_name, + priority_group="A", + ) + with pytest.raises(TimeoutError): + await psub2.fetch(10, timeout=0.5) + + # 5. Original consumer continues to work + msgs = await psub1.fetch(10, timeout=0.5) + assert len(msgs) == 10 + for msg in msgs: + assert msg.headers.get("Nats-Pin-Id") == first_pin_id + await msg.ack_sync() + + # 6. After TTL expires, pin ID changes + await asyncio.sleep(1.5) # longer than priority_timeout (1s) + msgs = await psub1.fetch(10, timeout=0.5) + assert len(msgs) == 10 + new_pin_id = msgs[0].headers.get("Nats-Pin-Id") if msgs[0].headers else None + assert new_pin_id is not None + assert new_pin_id != first_pin_id + for msg in msgs: # clean up + await msg.ack_sync() + + await psub1.unsubscribe() + await psub2.unsubscribe() + await nc.close() + + @async_test + async def test_consumer_unpin(self): + nc = await nats.connect() + + server_version = nc.connected_server_version + if server_version.major == 2 and server_version.minor < 12: + pytest.skip("consumer group unpinning requires nats-server v2.11.0 or later") + + js = nc.jetstream() + jsm = js._jsm + + # create stream + await js.add_stream( + name="PRIORITIES", + subjects=["foo"], + ) + + # create consumer with pinned priority policy and long TTL + cinfo = await js.add_consumer( + "PRIORITIES", + nats.js.api.ConsumerConfig( + priority_policy=nats.js.api.PriorityPolicy.PINNED, + priority_timeout=50.0, + priority_groups=["A"], + ), + ) + assert cinfo.config.priority_policy == nats.js.api.PriorityPolicy.PINNED + + # publish messages + for i in range(100): + await js.publish("foo", f"{i}".encode()) + + # 1. First consumer gets pinned + psub1 = await js.pull_subscribe_bind( + cinfo.name, + cinfo.stream_name, + priority_group="A", + ) + msgs = await psub1.fetch(1, timeout=0.5) + assert len(msgs) == 1 + first_pin_id = msgs[0].headers.get("Nats-Pin-Id") if msgs[0].headers else None + assert first_pin_id is not None + await msgs[0].ack_sync() + + # 2. Second consumer can't get messages while first is pinned + psub2 = await js.pull_subscribe_bind( + cinfo.name, + cinfo.stream_name, + priority_group="A", + ) + with pytest.raises(TimeoutError): + await psub2.fetch(1, timeout=0.5) + await psub2.unsubscribe() + + # 3. Manual unpin allows third consumer + psub3 = await js.pull_subscribe_bind( + cinfo.name, + cinfo.stream_name, + priority_group="A", + ) + + # Unpin the consumer + await jsm.unpin_consumer(cinfo.stream_name, cinfo.name, "A") + + # Third consumer should now receive message with new pin ID + msgs = await psub3.fetch(1, timeout=0.5) + assert len(msgs) == 1 + new_pin_id = msgs[0].headers.get("Nats-Pin-Id") if msgs[0].headers else None + assert new_pin_id is not None + assert new_pin_id != first_pin_id + + await psub1.unsubscribe() + await psub3.unsubscribe() + + # 4. Test unpin on non-existent consumer + with pytest.raises(nats.js.errors.NotFoundError): + await jsm.unpin_consumer("PRIORITIES", "nonexistent", "A") + + await nc.close() + + @async_test + async def test_consumer_prioritized(self): + nc = await nats.connect() + + server_version = nc.connected_server_version + if server_version.major == 2 and server_version.minor < 12: + pytest.skip("consumer group priority requires nats-server v2.12.0 or later") + + js = nc.jetstream() + + # create stream + await js.add_stream( + name="PRIORITIES", + subjects=["foo"], + ) + + # create consumer with prioritized priority policy + cinfo = await js.add_consumer( + "PRIORITIES", + nats.js.api.ConsumerConfig( + priority_policy=nats.js.api.PriorityPolicy.PRIORITIZED, + priority_groups=["A"], + ), + ) + assert cinfo.config.priority_policy == nats.js.api.PriorityPolicy.PRIORITIZED + + # Test: Messages distributed based on priority + # Higher priority (lower number) consumers get messages first + + # Create two consumer instances + psub1 = await js.pull_subscribe_bind( + cinfo.name, + cinfo.stream_name, + priority_group="A", + ) + psub2 = await js.pull_subscribe_bind( + cinfo.name, + cinfo.stream_name, + priority_group="A", + ) + + # publish 100 messages + for i in range(100): + await js.publish("foo", f"{i}".encode()) + + # Start concurrent fetches: + # psub1 with priority=1 (lower priority) requesting 100 messages + # psub2 with priority=0 (higher priority) requesting 75 messages + # Expected: psub2 gets 75 first, psub1 gets remaining 25 + + fetch1_task = asyncio.create_task(psub1.fetch(100, timeout=2.0, priority=1)) + fetch2_task = asyncio.create_task(psub2.fetch(75, timeout=2.0, priority=0)) + + # Wait for both fetches + msgs1, msgs2 = await asyncio.gather(fetch1_task, fetch2_task) + + # psub2 (priority 0) should get 75 messages + assert len(msgs2) == 75 + + # psub1 (priority 1) should get remaining 25 messages + assert len(msgs1) == 25 + + for msg in msgs1 + msgs2: # clean up + await msg.ack_sync() + + await psub1.unsubscribe() + await psub2.unsubscribe() + await nc.close()