diff --git a/tests/v1/core/test_priority_scheduler_random.py b/tests/v1/core/test_priority_scheduler_random.py index b4805be80272..429b179b61dc 100644 --- a/tests/v1/core/test_priority_scheduler_random.py +++ b/tests/v1/core/test_priority_scheduler_random.py @@ -219,7 +219,17 @@ def test_priority_scheduling_blast( vllm_config=scheduler.vllm_config, ) scheduler.add_request(req) - + num_initial_requests = 2 + for _ in range(num_initial_requests): + req = _create_random_request( + max_tokens_range=(1, max_output_tokens), + num_tokens_range=(1, max_input_tokens), + arrival_time_range=(0, 0), + priority_range=(4, 4), + num_mm_item_range=(0, 2), + vllm_config=scheduler.vllm_config, + ) + scheduler.add_request(req) for _ in range(20000): if len(scheduler.waiting) == 0: num_new_requests = random.randint(0, 2) diff --git a/vllm/v1/core/sched/request_queue.py b/vllm/v1/core/sched/request_queue.py index 7bc1010db23a..a00ca1912b0f 100644 --- a/vllm/v1/core/sched/request_queue.py +++ b/vllm/v1/core/sched/request_queue.py @@ -137,31 +137,30 @@ class PriorityRequestQueue(RequestQueue): """ A priority queue that supports heap operations. - Requests with a smaller value of `priority` are processed first. + Respects the ordering defined in the Request class, where + requests with a smaller value of `priority` are processed first. If multiple requests have the same priority, the one with the earlier `arrival_time` is processed first. """ def __init__(self) -> None: - self._heap: list[tuple[int, float, Request]] = [] + self._heap: list[Request] = [] def add_request(self, request: Request) -> None: """Add a request to the queue according to priority policy.""" - heapq.heappush(self._heap, (request.priority, request.arrival_time, request)) + heapq.heappush(self._heap, request) def pop_request(self) -> Request: """Pop a request from the queue according to priority policy.""" if not self._heap: raise IndexError("pop from empty heap") - _, _, request = heapq.heappop(self._heap) - return request + return heapq.heappop(self._heap) def peek_request(self) -> Request: """Peek at the next request in the queue without removing it.""" if not self._heap: raise IndexError("peek from empty heap") - _, _, request = self._heap[0] - return request + return self._heap[0] def prepend_request(self, request: Request) -> None: """Add a request to the queue according to priority policy. @@ -180,15 +179,13 @@ def prepend_requests(self, requests: RequestQueue) -> None: def remove_request(self, request: Request) -> None: """Remove a specific request from the queue.""" - self._heap = [(p, t, r) for p, t, r in self._heap if r != request] + self._heap.remove(request) heapq.heapify(self._heap) def remove_requests(self, requests: Iterable[Request]) -> None: """Remove multiple specific requests from the queue.""" - requests_to_remove = set(requests) - self._heap = [ - (p, t, r) for p, t, r in self._heap if r not in requests_to_remove - ] + requests_to_remove = requests if isinstance(requests, set) else set(requests) + self._heap = [r for r in self._heap if r not in requests_to_remove] heapq.heapify(self._heap) def __bool__(self) -> bool: @@ -203,8 +200,7 @@ def __iter__(self) -> Iterator[Request]: """Iterate over the queue according to priority policy.""" heap_copy = self._heap[:] while heap_copy: - _, _, request = heapq.heappop(heap_copy) - yield request + yield heapq.heappop(heap_copy) def __reversed__(self) -> Iterator[Request]: """Iterate over the queue in reverse priority order.""" diff --git a/vllm/v1/request.py b/vllm/v1/request.py index f2dfd2eed03c..33762fe34e64 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -227,6 +227,19 @@ def take_events(self) -> list[EngineCoreEvent] | None: events, self.events = self.events, [] return events + def __lt__(self, other: "Request") -> bool: + """ + Compare two requests based on priority, arrival time, and request ID. + Used in priority scheduling. + """ + if self.priority != other.priority: + return self.priority < other.priority + if self.arrival_time != other.arrival_time: + return self.arrival_time < other.arrival_time + if self.request_id != other.request_id: + return self.request_id < other.request_id + return id(self) < id(other) + class RequestStatus(enum.IntEnum): """Status of a request."""