@@ -347,17 +347,7 @@ def schedule(self) -> SchedulerOutput:
347347 else :
348348 preempted_req = self .running .pop ()
349349
350- self .kv_cache_manager .free (preempted_req )
351- self .encoder_cache_manager .free (preempted_req )
352- preempted_req .status = RequestStatus .PREEMPTED
353- preempted_req .num_computed_tokens = 0
354- preempted_req .num_preemptions += 1
355- if self .log_stats :
356- preempted_req .record_event (
357- EngineCoreEventType .PREEMPTED , scheduled_timestamp
358- )
359-
360- self .waiting .prepend_request (preempted_req )
350+ self ._preempt_request (preempted_req , scheduled_timestamp )
361351 preempted_reqs .append (preempted_req )
362352 if preempted_req == request :
363353 # No more request to preempt. Cannot schedule this request.
@@ -756,6 +746,30 @@ def schedule(self) -> SchedulerOutput:
756746 self ._update_after_schedule (scheduler_output )
757747 return scheduler_output
758748
749+ def _preempt_request (
750+ self ,
751+ request : Request ,
752+ timestamp : float ,
753+ ) -> None :
754+ """Preempt a request and put it back to the waiting queue.
755+
756+ NOTE: The request should be popped from the running queue outside of this
757+ method.
758+ """
759+ assert request .status == RequestStatus .RUNNING , (
760+ "Only running requests can be preempted"
761+ )
762+ self .kv_cache_manager .free (request )
763+ self .encoder_cache_manager .free (request )
764+ request .status = RequestStatus .PREEMPTED
765+ request .num_computed_tokens = 0
766+ request .num_preemptions += 1
767+ if self .log_stats :
768+ request .record_event (EngineCoreEventType .PREEMPTED , timestamp )
769+
770+ # Put the request back to the waiting queue.
771+ self .waiting .prepend_request (request )
772+
759773 def _update_after_schedule (
760774 self ,
761775 scheduler_output : SchedulerOutput ,
@@ -1362,8 +1376,45 @@ def get_num_unfinished_requests(self) -> int:
13621376 def has_finished_requests (self ) -> bool :
13631377 return len (self .finished_req_ids ) > 0
13641378
1365- def reset_prefix_cache (self ) -> bool :
1366- return self .kv_cache_manager .reset_prefix_cache ()
1379+ def reset_prefix_cache (self , reset_running_requests : bool = False ) -> bool :
1380+ """Reset the KV prefix cache.
1381+
1382+ If reset_running_requests is True, all the running requests will be
1383+ preempted and moved to the waiting queue.
1384+ Otherwise, this method will only reset the KV prefix cache when there
1385+ is no running requests taking KV cache.
1386+ """
1387+ if reset_running_requests :
1388+ # For logging.
1389+ timestamp = time .monotonic ()
1390+ # Invalidate all the current running requests KV's by pushing them to
1391+ # the waiting queue. In this case, we can reduce the ref count of all
1392+ # the kv blocks to 0 and thus we can make sure the reset is successful.
1393+ # Preempt in reverse order so the requests will be added back to the
1394+ # running queue in FIFO order.
1395+ while self .running :
1396+ request = self .running .pop ()
1397+ self ._preempt_request (request , timestamp )
1398+ # NOTE(zhuohan): For async scheduling, we need to discard the latest
1399+ # output token on the fly to avoid a redundant repetitive output token.
1400+ request .num_output_placeholders = 0
1401+ request .discard_latest_async_tokens = True
1402+
1403+ # Clear scheduled request ids cache. Since we are forcing preemption
1404+ # + resumption in the same step, we must act as if these requests were
1405+ # not scheduled in the prior step. They will be flushed from the
1406+ # persistent batch in the model runner.
1407+ self .prev_step_scheduled_req_ids .clear ()
1408+
1409+ reset_successful = self .kv_cache_manager .reset_prefix_cache ()
1410+ if reset_running_requests and not reset_successful :
1411+ raise RuntimeError (
1412+ "Failed to reset KV cache even when all the running requests are "
1413+ "preempted and moved to the waiting queue. This is likely due to "
1414+ "the presence of running requests waiting for remote KV transfer, "
1415+ "which is not supported yet."
1416+ )
1417+ return reset_successful
13671418
13681419 def make_stats (
13691420 self ,
0 commit comments