2828@support_torch_compile (
2929 dynamic_arg_dims = {
3030 "num_tokens_no_spec" : 0 ,
31- "token_ids_gpu" : [ 0 , 1 ] ,
31+ "token_ids_gpu" : 0 ,
3232 "combined_mask" : 0 ,
3333 }
3434)
@@ -196,7 +196,7 @@ def _find_first_and_extract_all_n_parallel(
196196 results = torch .where (
197197 has_any_match .unsqueeze (1 ),
198198 extracted_sequences ,
199- torch .full_like (extracted_sequences , 0 ), # TODO:(patchy): Use -1 instead of 0.
199+ torch .full_like (extracted_sequences , 0 ),
200200 )
201201
202202 return results
@@ -248,7 +248,7 @@ def forward(
248248 mask = combined_mask .unsqueeze (1 ).expand (- 1 , self .k )
249249 draft_tokens = torch .where (mask , results , draft_tokens )
250250
251- is_empty_draft_tokens = (draft_tokens == 0 ).all (dim = 1 ) # TODO:(patchy): Use -1 instead of 0.
251+ is_empty_draft_tokens = (draft_tokens == 0 ).all (dim = 1 )
252252
253253 return draft_tokens , is_empty_draft_tokens
254254
@@ -296,6 +296,9 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device, runner=None):
296296 self .kernel .to (device )
297297 self .kernel .eval ()
298298 max_batch_size = vllm_config .scheduler_config .max_num_seqs
299+
300+ # TODO(patchy): Remove this buffer, use
301+ # token_ids_gpu_tensor in gpu_model_runner.py instead.
299302 self .backup_next_token_ids = CpuGpuBuffer (
300303 max_batch_size ,
301304 dtype = torch .int32 ,
@@ -309,7 +312,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device, runner=None):
309312 def _dummy_run (self ):
310313 token_ids , num_tokens , sampled_flags , valid_mask = self ._generate_dummy_data (
311314 batch_size = self .max_num_seqs ,
312- max_seq_len = min ( self .max_model_len , 1024 ) ,
315+ max_seq_len = self .max_model_len ,
313316 vocab_size = self .vocab_size ,
314317 pattern_len = self .k ,
315318 repetition_rate = 0.5 ,
@@ -354,33 +357,18 @@ def _generate_dummy_data(
354357 valid_mask: [batch_size] bool tensor
355358 """
356359 # Generate random token IDs
357- token_ids = torch .randint (
358- 0 , vocab_size , (batch_size , max_seq_len ), dtype = torch .int32 , device = device
360+ token_ids = torch .zeros (
361+ batch_size ,
362+ max_seq_len ,
363+ dtype = torch .int32 ,
364+ device = device ,
359365 )
360366
361367 # Generate random sequence lengths
362- min_len = max (pattern_len * 2 + 3 , max_seq_len // 2 )
363368 num_tokens = torch .randint (
364- min_len , max_seq_len , (batch_size ,), dtype = torch .int32 , device = device
369+ pattern_len , max_seq_len , (batch_size ,), dtype = torch .int32 , device = device
365370 )
366371
367- # Inject n-gram repetitions using the tail pattern of each sequence
368- for i in range (batch_size ):
369- seq_len = num_tokens [i ].item ()
370- if seq_len > pattern_len * 2 :
371- # Pattern is the last pattern_len tokens of the valid sequence
372- src_pos = seq_len - pattern_len
373- num_reps = int (seq_len * repetition_rate / pattern_len )
374- for _ in range (num_reps ):
375- # Place the copied tail pattern somewhere before the tail
376- tgt_pos = torch .randint (0 , seq_len - pattern_len , (1 ,)).item ()
377- if tgt_pos == src_pos :
378- continue
379-
380- token_ids [i , tgt_pos : tgt_pos + pattern_len ] = token_ids [
381- i , src_pos : src_pos + pattern_len
382- ].clone ()
383-
384372 # All sequences have sampled tokens and are valid
385373 sampled_flags = torch .ones (batch_size , dtype = torch .bool , device = device )
386374 valid_mask = torch .ones (batch_size , dtype = torch .bool , device = device )
@@ -401,10 +389,7 @@ def propose(
401389
402390 with set_forward_context (None , self .vllm_config ):
403391 combined_mask = (
404- sampled_flags
405- & valid_mask
406- & (num_tokens_no_spec < self .max_model_len )
407- & (num_tokens_no_spec >= self .min_n )
392+ sampled_flags & valid_mask & (num_tokens_no_spec >= self .min_n )
408393 )
409394
410395 draft_tokens , is_empty_draft_tokens = self .kernel (
@@ -415,36 +400,6 @@ def propose(
415400
416401 return draft_tokens , is_empty_draft_tokens
417402
418- def prepare_next_token_ids_cpu (
419- self ,
420- sampled_token_ids : list [np .ndarray ],
421- requests : dict [str , CachedRequestState ],
422- gpu_input_batch : InputBatch ,
423- num_scheduled_tokens : dict [str , int ],
424- ) -> torch .Tensor :
425- """
426- This function is used to prepare the inputs for speculative decoding.
427- It calculates the next token ids for each request based on the sampled
428- token ids from the CPU. If a request has no sampled token ids (e.g.,
429- during the initial decoding steps), it falls back to using the request
430- state to get the next token id.
431- """
432- req_ids = gpu_input_batch .req_ids
433- next_token_ids : list [int ] = []
434- for i , token_ids in enumerate (sampled_token_ids ):
435- if token_ids .shape [0 ] > 0 :
436- # Common case.
437- next_token_id = token_ids [- 1 ]
438- else :
439- # Partial prefill (rare case).
440- # Get the next token id from the request state.
441- req_id = req_ids [i ]
442- req_state = requests [req_id ]
443- seq_len = req_state .num_computed_tokens + num_scheduled_tokens [req_id ]
444- next_token_id = req_state .get_token_id (seq_len )
445- next_token_ids .append (next_token_id )
446- return torch .tensor (next_token_ids , dtype = torch .int32 , device = self .device )
447-
448403 def prepare_next_token_ids_padded (
449404 self ,
450405 common_attn_metadata : CommonAttentionMetadata ,
@@ -463,8 +418,6 @@ def prepare_next_token_ids_padded(
463418 This function must use device functions to operate on the inputs, and
464419 should not introduce any blocking CPU-GPU synchronization.
465420 """
466- # TODO(Ben): Combine this into a custom fused kernel
467- # Precompute get_token_id for when there is no valid next token
468421 num_reqs = gpu_input_batch .num_reqs
469422 # Batch convert seq_lens to avoid multiple .item() calls
470423 seq_lens_list = common_attn_metadata .seq_lens_cpu [:num_reqs ].tolist ()
0 commit comments