66This version uses a fully vectorized approach with unfold and argmax for
77finding the first match across all sequences in parallel.
88"""
9-
10- import os
11-
12- import numpy as np
139import torch
1410from torch import nn
11+ import numpy as np
1512
1613from vllm .compilation .decorators import support_torch_compile
1714from vllm .config import CompilationConfig , CompilationMode , CUDAGraphMode , VllmConfig
1815from vllm .utils .platform_utils import is_pin_memory_available
1916from vllm .v1 .attention .backends .utils import (
2017 CommonAttentionMetadata ,
2118)
22- from vllm .v1 .utils import CpuGpuBuffer
2319from vllm .v1 .worker .gpu_input_batch import CachedRequestState , InputBatch
2420
25- if int (os .environ .get ("NVPROF" , "0" )) == 1 :
26- pass
27- else :
28- pass
29-
30- import logging
31-
3221from vllm .config import set_current_vllm_config
3322from vllm .forward_context import set_forward_context
34-
35- logger = logging .getLogger (__name__ )
36-
23+ from vllm .v1 .utils import CpuGpuBuffer
3724
3825@support_torch_compile (
3926 dynamic_arg_dims = {
@@ -107,12 +94,10 @@ def _find_first_and_extract_all_n_parallel(
10794 max_seq_len = data .shape [1 ]
10895 num_patterns = max_pattern_len - min_pattern_len + 1
10996
110- # Create sliding windows once
11197 all_windows = data .unfold (1 , max_pattern_len , 1 ) # [B, num_windows, max_n]
11298 num_windows = all_windows .shape [1 ]
11399 window_starts = torch .arange (num_windows , device = device )
114100
115- # Store the first match position for each pattern length
116101 all_first_matches = torch .full (
117102 (batch_size , num_patterns ), - 1 , dtype = torch .long , device = device
118103 )
@@ -148,7 +133,6 @@ def _find_first_and_extract_all_n_parallel(
148133 # from back to front, prioritizing longer patterns
149134 best_pattern_idx = (all_first_matches >= 0 ).int ().flip (dims = [1 ]).argmax (dim = 1 )
150135 best_pattern_idx = num_patterns - 1 - best_pattern_idx # Flip back
151- best_pattern_len = min_pattern_len + best_pattern_idx
152136
153137 # Extract corresponding results
154138 batch_idx = torch .arange (batch_size , device = device )
@@ -163,9 +147,8 @@ def _find_first_and_extract_all_n_parallel(
163147 # the result starts after the full window
164148 result_starts = torch .where (
165149 has_any_match ,
166- best_match_pos
167- + max_pattern_len , # Use max_pattern_len, not best_pattern_len
168- torch .zeros_like (best_match_pos ), # Use 0 for no match
150+ best_match_pos + max_pattern_len ,
151+ torch .zeros_like (best_match_pos ),
169152 )
170153
171154 # Create gather indices
@@ -226,13 +209,12 @@ def forward(
226209 device = token_ids_gpu .device
227210
228211 # Initialize output tensor - torch.compile will optimize this allocation
229- # NOTE: Do NOT pre-allocate this as a buffer - it would break torch.compile
212+ # NOTE(patchy): Do NOT pre-allocate this as a buffer
213+ # it would break torch.compile
230214 draft_tokens = torch .zeros (
231215 (batch_size , self .k ), dtype = torch .int32 , device = device
232216 )
233217
234- # Use the async find and extract method with max_n pattern length
235- # This will find the first match and extract k tokens
236218 results = self ._find_first_and_extract_all_n_parallel (
237219 token_ids_gpu ,
238220 num_tokens_no_spec ,
@@ -250,28 +232,31 @@ def load_model(self, *args, **kwargs):
250232 """No model to load for N-gram proposer."""
251233 pass
252234
253-
254235class NgramProposerGPU :
255236 def __init__ (self , vllm_config : VllmConfig , device : torch .device , runner = None ):
256237 assert vllm_config .speculative_config is not None
257238 assert vllm_config .speculative_config .prompt_lookup_min is not None
258239 assert vllm_config .speculative_config .prompt_lookup_max is not None
259240
260- # Create optimized compilation config for ngram kernel
261241 compilation_config = CompilationConfig (
262- level = 3 ,
242+ level = 3 ,
263243 custom_ops = ["none" ],
264244 splitting_ops = [],
265245 compile_sizes = [],
266246 inductor_compile_config = {
267- "enable_auto_functionalized_v2" : False ,
247+ "enable_auto_functionalized_v2" : False ,
248+ "max_autotune" : True ,
249+ "aggressive_fusion" : True ,
250+ "triton.autotune_pointwise" : True ,
251+ "coordinate_descent_tuning" : True ,
252+ "use_mixed_mm" : False ,
268253 },
269254 use_cudagraph = False ,
270- cudagraph_mode = CUDAGraphMode .NONE ,
271- mode = CompilationMode .VLLM_COMPILE ,
272255 )
273256
274- self .vllm_config = VllmConfig (compilation_config = compilation_config )
257+ self .vllm_config = VllmConfig (
258+ compilation_config = compilation_config
259+ )
275260
276261 self .min_n = vllm_config .speculative_config .prompt_lookup_min
277262 self .max_n = vllm_config .speculative_config .prompt_lookup_max
@@ -282,9 +267,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device, runner=None):
282267 self .device = device
283268
284269 with set_current_vllm_config (self .vllm_config , check_compile = False ):
285- self .kernel = NgramGPUKernel (
286- vllm_config = vllm_config , prefix = "ngram_gpu_kernel" , device = device
287- )
270+ self .kernel = NgramGPUKernel (vllm_config = vllm_config , prefix = "ngram_gpu_kernel" , device = device )
288271 self .device = device
289272 self .kernel .to (device )
290273 self .kernel .eval ()
@@ -300,26 +283,19 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device, runner=None):
300283 self ._dummy_run ()
301284
302285 def _dummy_run (self ):
303- # with set_current_vllm_config(self.vllm_config, check_compile=False):
304- # Get warmup iterations from config or use default
305- token_ids , num_tokens , sampled_flags , valid_mask = (
306- self ._generate_dummy_data (
286+ with set_current_vllm_config (self .vllm_config , check_compile = False ):
287+ token_ids , num_tokens , sampled_flags , valid_mask = self ._generate_dummy_data (
307288 batch_size = self .max_num_seqs ,
308- max_seq_len = min (
309- self .max_model_len , 1024
310- ), # Use reasonable seq len for warmup
289+ max_seq_len = min (self .max_model_len , 1024 ),
311290 vocab_size = self .vocab_size ,
312291 pattern_len = self .k ,
313292 repetition_rate = 0.5 ,
314- device = self .device ,
293+ device = self .device
315294 )
316- )
317295
318- for _ in range (3 ):
319- with set_forward_context (None , self .vllm_config ):
320- _ = self .kernel (
321- num_tokens , token_ids , sampled_flags , valid_mask
322- )
296+ for _ in range (3 ):
297+ with set_forward_context (None , self .vllm_config ):
298+ output = self .kernel (num_tokens , token_ids , sampled_flags , valid_mask )
323299
324300 def _generate_dummy_data (
325301 self ,
@@ -349,13 +325,15 @@ def _generate_dummy_data(
349325 """
350326 # Generate random token IDs
351327 token_ids = torch .randint (
352- 0 , vocab_size , (batch_size , max_seq_len ), dtype = torch .int32 , device = device
328+ 0 , vocab_size , (batch_size , max_seq_len ),
329+ dtype = torch .int32 , device = device
353330 )
354331
355332 # Generate random sequence lengths
356333 min_len = max (pattern_len * 2 + 3 , max_seq_len // 2 )
357334 num_tokens = torch .randint (
358- min_len , max_seq_len , (batch_size ,), dtype = torch .int32 , device = device
335+ min_len , max_seq_len , (batch_size ,),
336+ dtype = torch .int32 , device = device
359337 )
360338
361339 # Inject n-gram repetitions using the tail pattern of each sequence
@@ -371,9 +349,8 @@ def _generate_dummy_data(
371349 if tgt_pos == src_pos :
372350 continue
373351
374- token_ids [i , tgt_pos : tgt_pos + pattern_len ] = token_ids [
375- i , src_pos : src_pos + pattern_len
376- ].clone ()
352+ token_ids [i , tgt_pos :tgt_pos + pattern_len ] = \
353+ token_ids [i , src_pos :src_pos + pattern_len ].clone ()
377354
378355 # All sequences have sampled tokens and are valid
379356 sampled_flags = torch .ones (batch_size , dtype = torch .bool , device = device )
@@ -388,11 +365,9 @@ def propose(
388365 sampled_flags : torch .Tensor , # [batch_size] bool on GPU
389366 valid_mask : torch .Tensor , # [batch_size] bool on GPU
390367 ) -> torch .Tensor :
391- # with set_current_vllm_config(self.vllm_config, check_compile=False):
392- with set_forward_context (None , self .vllm_config ):
393- return self .kernel (
394- num_tokens_no_spec , token_ids_gpu , sampled_flags , valid_mask
395- )
368+ with set_current_vllm_config (self .vllm_config , check_compile = False ):
369+ with set_forward_context (None , self .vllm_config ):
370+ return self .kernel (num_tokens_no_spec , token_ids_gpu , sampled_flags , valid_mask )
396371
397372 def prepare_next_token_ids_cpu (
398373 self ,
@@ -443,12 +418,9 @@ def prepare_next_token_ids_padded(
443418 should not introduce any blocking CPU-GPU synchronization.
444419 """
445420 # TODO(Ben): Combine this into a custom fused kernel
446-
447421 # Precompute get_token_id for when there is no valid next token
448422 num_reqs = gpu_input_batch .num_reqs
449423 # Batch convert seq_lens to avoid multiple .item() calls
450- # This performs a single synchronization for all lengths
451- # instead of one per request
452424 seq_lens_list = common_attn_metadata .seq_lens_cpu [:num_reqs ].tolist ()
453425
454426 # Now use the pre-converted list to avoid .item() calls in the loop
@@ -500,4 +472,4 @@ def prepare_next_token_ids_padded(
500472
501473 def load_model (self , * args , ** kwargs ):
502474 with set_current_vllm_config (self .vllm_config , check_compile = False ):
503- self .kernel .load_model (* args , ** kwargs )
475+ self .kernel .load_model (* args , ** kwargs )
0 commit comments