Skip to content

Commit d0cd728

Browse files
zhuohan123njhill
andauthored
[Core] Support reseting all running requests' KV while calling reset_prefix_cache (#28827)
Signed-off-by: Zhuohan Li <[email protected]> Signed-off-by: Nick Hill <[email protected]> Co-authored-by: Nick Hill <[email protected]>
1 parent fa8804a commit d0cd728

File tree

16 files changed

+315
-35
lines changed

16 files changed

+315
-35
lines changed
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
This file demonstrates preempt requests when using the `LLMEngine`
5+
for processing prompts with various sampling parameters.
6+
"""
7+
8+
import argparse
9+
10+
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
11+
from vllm.utils.argparse_utils import FlexibleArgumentParser
12+
13+
14+
def create_test_prompts() -> list[tuple[str, SamplingParams]]:
15+
"""Create a list of test prompts with their sampling parameters."""
16+
return [
17+
(
18+
"A robot may not injure a human being " * 50,
19+
SamplingParams(
20+
temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=16
21+
),
22+
),
23+
(
24+
"A robot may not injure a human being " * 50,
25+
SamplingParams(
26+
temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=16
27+
),
28+
),
29+
(
30+
"To be or not to be,",
31+
SamplingParams(
32+
temperature=0.8, top_k=5, presence_penalty=0.2, max_tokens=128
33+
),
34+
),
35+
(
36+
"What is the meaning of life?",
37+
SamplingParams(
38+
n=2, temperature=0.8, top_p=0.95, frequency_penalty=0.1, max_tokens=128
39+
),
40+
),
41+
]
42+
43+
44+
def process_requests(engine: LLMEngine, test_prompts: list[tuple[str, SamplingParams]]):
45+
"""Continuously process a list of prompts and handle the outputs."""
46+
request_id = 0
47+
48+
print("-" * 50)
49+
step_id = 0
50+
while test_prompts or engine.has_unfinished_requests():
51+
print("-" * 50)
52+
import os
53+
54+
print(f"Step {step_id} (pid={os.getpid()})")
55+
56+
if test_prompts:
57+
prompt, sampling_params = test_prompts.pop(0)
58+
engine.add_request(str(request_id), prompt, sampling_params)
59+
request_id += 1
60+
61+
if step_id == 10:
62+
print(f"Resetting prefix cache at {step_id}")
63+
engine.reset_prefix_cache(reset_running_requests=True)
64+
65+
request_outputs: list[RequestOutput] = engine.step()
66+
67+
for request_output in request_outputs:
68+
if request_output.finished:
69+
print("-" * 50)
70+
print(request_output)
71+
print("-" * 50)
72+
step_id += 1
73+
74+
75+
def initialize_engine(args: argparse.Namespace) -> LLMEngine:
76+
"""Initialize the LLMEngine from the command line arguments."""
77+
engine_args = EngineArgs.from_cli_args(args)
78+
return LLMEngine.from_engine_args(engine_args)
79+
80+
81+
def parse_args():
82+
parser = FlexibleArgumentParser(
83+
description="Demo on using the LLMEngine class directly"
84+
)
85+
parser = EngineArgs.add_cli_args(parser)
86+
return parser.parse_args()
87+
88+
89+
def main(args: argparse.Namespace):
90+
"""Main function that sets up and runs the prompt processing."""
91+
engine = initialize_engine(args)
92+
test_prompts = create_test_prompts()
93+
process_requests(engine, test_prompts)
94+
95+
96+
if __name__ == "__main__":
97+
args = parse_args()
98+
main(args)
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from vllm import EngineArgs, LLMEngine, SamplingParams
5+
6+
PROMPTS = [
7+
"A robot may not injure a human being ",
8+
"To be or not to be,",
9+
"What is the meaning of life?",
10+
"What does the fox say? " * 20, # Test long prompt
11+
]
12+
13+
14+
def test_reset_prefix_cache_e2e():
15+
engine_args = EngineArgs(
16+
model="Qwen/Qwen3-0.6B",
17+
gpu_memory_utilization=0.2,
18+
async_scheduling=True,
19+
max_num_batched_tokens=32,
20+
max_model_len=2048,
21+
compilation_config={"mode": 0},
22+
)
23+
engine = LLMEngine.from_engine_args(engine_args)
24+
sampling_params = SamplingParams(
25+
temperature=0.0,
26+
max_tokens=16,
27+
)
28+
29+
# No preempt case:
30+
for i, prompt in enumerate(PROMPTS):
31+
engine.add_request("ground_truth_" + str(i), prompt, sampling_params)
32+
33+
ground_truth_results = {}
34+
while engine.has_unfinished_requests():
35+
request_outputs = engine.step()
36+
for request_output in request_outputs:
37+
if request_output.finished:
38+
ground_truth_results[request_output.request_id] = request_output
39+
40+
# Preempt case:
41+
for i, prompt in enumerate(PROMPTS):
42+
engine.add_request("preempted_" + str(i), prompt, sampling_params)
43+
44+
step_id = 0
45+
preempted_results = {}
46+
while engine.has_unfinished_requests():
47+
if step_id == 10:
48+
engine.reset_prefix_cache(reset_running_requests=True)
49+
50+
request_outputs = engine.step()
51+
52+
for request_output in request_outputs:
53+
if request_output.finished:
54+
preempted_results[request_output.request_id] = request_output
55+
step_id += 1
56+
57+
for i in range(len(PROMPTS)):
58+
assert (
59+
ground_truth_results["ground_truth_" + str(i)].outputs[0].text
60+
== preempted_results["preempted_" + str(i)].outputs[0].text
61+
), (
62+
f"ground_truth_results['ground_truth_{i}'].outputs[0].text="
63+
f"{ground_truth_results['ground_truth_' + str(i)].outputs[0].text} "
64+
f"preempted_results['preempted_{i}'].outputs[0].text="
65+
f"{preempted_results['preempted_' + str(i)].outputs[0].text}"
66+
)

tests/v1/core/test_scheduler.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -728,6 +728,37 @@ def test_preempt_during_execution():
728728
assert requests[1].output_token_ids[0] == 42
729729

730730

731+
def test_scheduler_reset_prefix_cache():
732+
scheduler = create_scheduler(enable_prefix_caching=True)
733+
requests = create_requests(num_requests=10)
734+
for request in requests:
735+
scheduler.add_request(request)
736+
737+
# Initial scheduling, requests should be at the running state now
738+
_ = scheduler.schedule()
739+
740+
# Verify requests moved from waiting to running
741+
assert len(scheduler.waiting) == 0
742+
assert len(scheduler.running) == len(requests)
743+
for i, request in enumerate(requests):
744+
assert scheduler.running[i] == request
745+
746+
# Reset prefix cache should fail since there are still running requests
747+
# and they are taking KV cache
748+
assert not scheduler.reset_prefix_cache()
749+
750+
# Reset prefix cache with reset_running_requests=True. All running requests
751+
# Should be pushed back to the waiting queue and kv cache should be freed
752+
assert scheduler.reset_prefix_cache(reset_running_requests=True)
753+
754+
# Verify requests moved from running to waiting
755+
assert len(scheduler.waiting) == len(requests)
756+
assert len(scheduler.running) == 0
757+
758+
for i, request in enumerate(requests):
759+
assert scheduler.waiting[i] == request
760+
761+
731762
# Note - these test cases mirror some of those in test_rejection_sampler.py
732763
@pytest.mark.parametrize(
733764
"spec_tokens,output_tokens,expected",

vllm/engine/protocol.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ async def reset_mm_cache(self) -> None:
116116
...
117117

118118
@abstractmethod
119-
async def reset_prefix_cache(self) -> None:
119+
async def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
120120
"""Reset the prefix cache"""
121121
...
122122

vllm/entrypoints/llm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1492,8 +1492,8 @@ def start_profile(self) -> None:
14921492
def stop_profile(self) -> None:
14931493
self.llm_engine.stop_profile()
14941494

1495-
def reset_prefix_cache(self) -> None:
1496-
self.llm_engine.reset_prefix_cache()
1495+
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
1496+
return self.llm_engine.reset_prefix_cache(reset_running_requests)
14971497

14981498
def sleep(self, level: int = 1):
14991499
"""

vllm/entrypoints/openai/api_server.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -877,13 +877,15 @@ async def show_server_info(
877877
return JSONResponse(content=server_info)
878878

879879
@router.post("/reset_prefix_cache")
880-
async def reset_prefix_cache(raw_request: Request):
880+
async def reset_prefix_cache(
881+
raw_request: Request, reset_running_requests: bool = Query(default=False)
882+
):
881883
"""
882884
Reset the prefix cache. Note that we currently do not check if the
883885
prefix cache is successfully reset in the API server.
884886
"""
885887
logger.info("Resetting prefix cache...")
886-
await engine_client(raw_request).reset_prefix_cache()
888+
await engine_client(raw_request).reset_prefix_cache(reset_running_requests)
887889
return Response(status_code=200)
888890

889891
@router.post("/reset_mm_cache")

vllm/v1/core/sched/async_scheduler.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ def _update_request_with_output(
4545
request: Request,
4646
new_token_ids: list[int],
4747
) -> tuple[list[int], bool]:
48+
if request.discard_latest_async_tokens:
49+
# If the request is force preempted in reset_prefix_cache, we
50+
# should discard the latest async token.
51+
request.discard_latest_async_tokens = False
52+
return [], False
53+
4854
status_before_update = request.status
4955
new_token_ids, stopped = super()._update_request_with_output(
5056
request, new_token_ids

vllm/v1/core/sched/interface.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,16 @@ def has_requests(self) -> bool:
152152
return self.has_unfinished_requests() or self.has_finished_requests()
153153

154154
@abstractmethod
155-
def reset_prefix_cache(self) -> bool:
155+
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
156156
"""Reset the prefix cache for KV cache.
157157
158158
This is particularly required when the model weights are live-updated.
159+
160+
Args:
161+
reset_running_requests: If True, all the running requests will be
162+
preempted and moved to the waiting queue. Otherwise, this method
163+
will only reset the KV prefix cache when there is no running request
164+
taking KV cache.
159165
"""
160166
raise NotImplementedError
161167

vllm/v1/core/sched/scheduler.py

Lines changed: 64 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -347,17 +347,7 @@ def schedule(self) -> SchedulerOutput:
347347
else:
348348
preempted_req = self.running.pop()
349349

350-
self.kv_cache_manager.free(preempted_req)
351-
self.encoder_cache_manager.free(preempted_req)
352-
preempted_req.status = RequestStatus.PREEMPTED
353-
preempted_req.num_computed_tokens = 0
354-
preempted_req.num_preemptions += 1
355-
if self.log_stats:
356-
preempted_req.record_event(
357-
EngineCoreEventType.PREEMPTED, scheduled_timestamp
358-
)
359-
360-
self.waiting.prepend_request(preempted_req)
350+
self._preempt_request(preempted_req, scheduled_timestamp)
361351
preempted_reqs.append(preempted_req)
362352
if preempted_req == request:
363353
# No more request to preempt. Cannot schedule this request.
@@ -756,6 +746,30 @@ def schedule(self) -> SchedulerOutput:
756746
self._update_after_schedule(scheduler_output)
757747
return scheduler_output
758748

749+
def _preempt_request(
750+
self,
751+
request: Request,
752+
timestamp: float,
753+
) -> None:
754+
"""Preempt a request and put it back to the waiting queue.
755+
756+
NOTE: The request should be popped from the running queue outside of this
757+
method.
758+
"""
759+
assert request.status == RequestStatus.RUNNING, (
760+
"Only running requests can be preempted"
761+
)
762+
self.kv_cache_manager.free(request)
763+
self.encoder_cache_manager.free(request)
764+
request.status = RequestStatus.PREEMPTED
765+
request.num_computed_tokens = 0
766+
request.num_preemptions += 1
767+
if self.log_stats:
768+
request.record_event(EngineCoreEventType.PREEMPTED, timestamp)
769+
770+
# Put the request back to the waiting queue.
771+
self.waiting.prepend_request(request)
772+
759773
def _update_after_schedule(
760774
self,
761775
scheduler_output: SchedulerOutput,
@@ -1362,8 +1376,45 @@ def get_num_unfinished_requests(self) -> int:
13621376
def has_finished_requests(self) -> bool:
13631377
return len(self.finished_req_ids) > 0
13641378

1365-
def reset_prefix_cache(self) -> bool:
1366-
return self.kv_cache_manager.reset_prefix_cache()
1379+
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
1380+
"""Reset the KV prefix cache.
1381+
1382+
If reset_running_requests is True, all the running requests will be
1383+
preempted and moved to the waiting queue.
1384+
Otherwise, this method will only reset the KV prefix cache when there
1385+
is no running requests taking KV cache.
1386+
"""
1387+
if reset_running_requests:
1388+
# For logging.
1389+
timestamp = time.monotonic()
1390+
# Invalidate all the current running requests KV's by pushing them to
1391+
# the waiting queue. In this case, we can reduce the ref count of all
1392+
# the kv blocks to 0 and thus we can make sure the reset is successful.
1393+
# Preempt in reverse order so the requests will be added back to the
1394+
# running queue in FIFO order.
1395+
while self.running:
1396+
request = self.running.pop()
1397+
self._preempt_request(request, timestamp)
1398+
# NOTE(zhuohan): For async scheduling, we need to discard the latest
1399+
# output token on the fly to avoid a redundant repetitive output token.
1400+
request.num_output_placeholders = 0
1401+
request.discard_latest_async_tokens = True
1402+
1403+
# Clear scheduled request ids cache. Since we are forcing preemption
1404+
# + resumption in the same step, we must act as if these requests were
1405+
# not scheduled in the prior step. They will be flushed from the
1406+
# persistent batch in the model runner.
1407+
self.prev_step_scheduled_req_ids.clear()
1408+
1409+
reset_successful = self.kv_cache_manager.reset_prefix_cache()
1410+
if reset_running_requests and not reset_successful:
1411+
raise RuntimeError(
1412+
"Failed to reset KV cache even when all the running requests are "
1413+
"preempted and moved to the waiting queue. This is likely due to "
1414+
"the presence of running requests waiting for remote KV transfer, "
1415+
"which is not supported yet."
1416+
)
1417+
return reset_successful
13671418

13681419
def make_stats(
13691420
self,

vllm/v1/engine/async_llm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -750,8 +750,8 @@ async def reset_mm_cache(self) -> None:
750750
self.input_processor.clear_mm_cache()
751751
await self.engine_core.reset_mm_cache_async()
752752

753-
async def reset_prefix_cache(self) -> None:
754-
await self.engine_core.reset_prefix_cache_async()
753+
async def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
754+
return await self.engine_core.reset_prefix_cache_async(reset_running_requests)
755755

756756
async def sleep(self, level: int = 1) -> None:
757757
await self.reset_prefix_cache()

0 commit comments

Comments
 (0)