77finding the first match across all sequences in parallel.
88"""
99
10- import numpy as np
1110import torch
1211from torch import nn
1312
1716 VllmConfig ,
1817)
1918from vllm .forward_context import set_forward_context
20- from vllm .utils .platform_utils import is_pin_memory_available
21- from vllm .v1 .attention .backends .utils import (
22- CommonAttentionMetadata ,
23- )
24- from vllm .v1 .utils import CpuGpuBuffer
25- from vllm .v1 .worker .gpu_input_batch import CachedRequestState , InputBatch
19+ from vllm .v1 .worker .gpu_input_batch import InputBatch
2620
2721
2822@support_torch_compile (
@@ -295,17 +289,6 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device, runner=None):
295289 self .device = device
296290 self .kernel .to (device )
297291 self .kernel .eval ()
298- max_batch_size = vllm_config .scheduler_config .max_num_seqs
299-
300- # TODO(patchy): Remove this buffer, use
301- # token_ids_gpu_tensor in gpu_model_runner.py instead.
302- self .backup_next_token_ids = CpuGpuBuffer (
303- max_batch_size ,
304- dtype = torch .int32 ,
305- pin_memory = is_pin_memory_available (),
306- device = device ,
307- with_numpy = True ,
308- )
309292
310293 self ._dummy_run ()
311294
@@ -400,14 +383,14 @@ def propose(
400383
401384 return draft_tokens , is_empty_draft_tokens
402385
403- def prepare_next_token_ids_padded (
386+ def update_token_ids_ngram (
404387 self ,
405- common_attn_metadata : CommonAttentionMetadata ,
406388 sampled_token_ids : torch .Tensor ,
407- requests : dict [str , CachedRequestState ],
408389 gpu_input_batch : InputBatch ,
409390 discard_request_indices : torch .Tensor ,
410391 num_discarded_requests : int ,
392+ token_ids_gpu : torch .Tensor ,
393+ num_tokens_no_spec : torch .Tensor ,
411394 ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
412395 """
413396 This function is used to prepare the inputs for speculative decoding.
@@ -419,17 +402,14 @@ def prepare_next_token_ids_padded(
419402 should not introduce any blocking CPU-GPU synchronization.
420403 """
421404 num_reqs = gpu_input_batch .num_reqs
422- # Batch convert seq_lens to avoid multiple .item() calls
423- seq_lens_list = common_attn_metadata .seq_lens_cpu [:num_reqs ].tolist ()
424-
425- # Now use the pre-converted list to avoid .item() calls in the loop
426- self .backup_next_token_ids .np [:num_reqs ] = np .array (
427- [
428- requests [gpu_input_batch .req_ids [i ]].get_token_id (seq_lens_list [i ])
429- for i in range (num_reqs )
430- ]
431- )
432- self .backup_next_token_ids .copy_to_gpu (num_reqs )
405+
406+ # Extract backup_next_token_ids from token_ids_gpu using vectorized gather
407+ # For each request i, get token_ids_gpu[i, num_tokens_no_spec[i] - 1]
408+ # This is the last valid token before speculative tokens
409+ backup_indices = (num_tokens_no_spec [:num_reqs ] - 1 ).clamp (min = 0 ).long ()
410+ backup_next_token_ids = torch .gather (
411+ token_ids_gpu [:num_reqs ], dim = 1 , index = backup_indices .unsqueeze (1 )
412+ ).squeeze (1 )
433413
434414 # Mask out the sampled tokens indices that should not be sampled.
435415 discard_sampled_tokens_req_indices = discard_request_indices [
@@ -459,12 +439,11 @@ def prepare_next_token_ids_padded(
459439 valid_sampled_token_ids_gpu , 1 , last_valid_indices_safe .unsqueeze (1 )
460440 ).squeeze (1 )
461441
462- # Use last token if valid, pre-computed backup if not
463- batch_size = valid_sampled_token_ids_gpu .shape [0 ]
442+ # Use last token if valid, vectorized backup from token_ids_gpu if not
464443 next_token_ids = torch .where (
465444 last_valid_indices != - 1 ,
466445 selected_tokens ,
467- self . backup_next_token_ids . gpu [: batch_size ] ,
446+ backup_next_token_ids ,
468447 )
469448
470449 return next_token_ids , valid_sampled_tokens_count , valid_sampled_token_ids_gpu
0 commit comments