Skip to content

Commit 293e3ae

Browse files
author
PatchouliTaisa
committed
fix return values in ngram gpu
Signed-off-by: PatchouliTaisa <[email protected]>
1 parent d183dcb commit 293e3ae

File tree

1 file changed

+34
-62
lines changed

1 file changed

+34
-62
lines changed

vllm/v1/spec_decode/ngram_proposer_gpu.py

Lines changed: 34 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -6,34 +6,21 @@
66
This version uses a fully vectorized approach with unfold and argmax for
77
finding the first match across all sequences in parallel.
88
"""
9-
10-
import os
11-
12-
import numpy as np
139
import torch
1410
from torch import nn
11+
import numpy as np
1512

1613
from vllm.compilation.decorators import support_torch_compile
1714
from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, VllmConfig
1815
from vllm.utils.platform_utils import is_pin_memory_available
1916
from vllm.v1.attention.backends.utils import (
2017
CommonAttentionMetadata,
2118
)
22-
from vllm.v1.utils import CpuGpuBuffer
2319
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
2420

25-
if int(os.environ.get("NVPROF", "0")) == 1:
26-
pass
27-
else:
28-
pass
29-
30-
import logging
31-
3221
from vllm.config import set_current_vllm_config
3322
from vllm.forward_context import set_forward_context
34-
35-
logger = logging.getLogger(__name__)
36-
23+
from vllm.v1.utils import CpuGpuBuffer
3724

3825
@support_torch_compile(
3926
dynamic_arg_dims={
@@ -107,12 +94,10 @@ def _find_first_and_extract_all_n_parallel(
10794
max_seq_len = data.shape[1]
10895
num_patterns = max_pattern_len - min_pattern_len + 1
10996

110-
# Create sliding windows once
11197
all_windows = data.unfold(1, max_pattern_len, 1) # [B, num_windows, max_n]
11298
num_windows = all_windows.shape[1]
11399
window_starts = torch.arange(num_windows, device=device)
114100

115-
# Store the first match position for each pattern length
116101
all_first_matches = torch.full(
117102
(batch_size, num_patterns), -1, dtype=torch.long, device=device
118103
)
@@ -148,7 +133,6 @@ def _find_first_and_extract_all_n_parallel(
148133
# from back to front, prioritizing longer patterns
149134
best_pattern_idx = (all_first_matches >= 0).int().flip(dims=[1]).argmax(dim=1)
150135
best_pattern_idx = num_patterns - 1 - best_pattern_idx # Flip back
151-
best_pattern_len = min_pattern_len + best_pattern_idx
152136

153137
# Extract corresponding results
154138
batch_idx = torch.arange(batch_size, device=device)
@@ -163,9 +147,8 @@ def _find_first_and_extract_all_n_parallel(
163147
# the result starts after the full window
164148
result_starts = torch.where(
165149
has_any_match,
166-
best_match_pos
167-
+ max_pattern_len, # Use max_pattern_len, not best_pattern_len
168-
torch.zeros_like(best_match_pos), # Use 0 for no match
150+
best_match_pos + max_pattern_len,
151+
torch.zeros_like(best_match_pos),
169152
)
170153

171154
# Create gather indices
@@ -226,13 +209,12 @@ def forward(
226209
device = token_ids_gpu.device
227210

228211
# Initialize output tensor - torch.compile will optimize this allocation
229-
# NOTE: Do NOT pre-allocate this as a buffer - it would break torch.compile
212+
# NOTE(patchy): Do NOT pre-allocate this as a buffer
213+
# it would break torch.compile
230214
draft_tokens = torch.zeros(
231215
(batch_size, self.k), dtype=torch.int32, device=device
232216
)
233217

234-
# Use the async find and extract method with max_n pattern length
235-
# This will find the first match and extract k tokens
236218
results = self._find_first_and_extract_all_n_parallel(
237219
token_ids_gpu,
238220
num_tokens_no_spec,
@@ -250,28 +232,31 @@ def load_model(self, *args, **kwargs):
250232
"""No model to load for N-gram proposer."""
251233
pass
252234

253-
254235
class NgramProposerGPU:
255236
def __init__(self, vllm_config: VllmConfig, device: torch.device, runner=None):
256237
assert vllm_config.speculative_config is not None
257238
assert vllm_config.speculative_config.prompt_lookup_min is not None
258239
assert vllm_config.speculative_config.prompt_lookup_max is not None
259240

260-
# Create optimized compilation config for ngram kernel
261241
compilation_config = CompilationConfig(
262-
level=3,
242+
level=3,
263243
custom_ops=["none"],
264244
splitting_ops=[],
265245
compile_sizes=[],
266246
inductor_compile_config={
267-
"enable_auto_functionalized_v2": False,
247+
"enable_auto_functionalized_v2": False,
248+
"max_autotune": True,
249+
"aggressive_fusion": True,
250+
"triton.autotune_pointwise": True,
251+
"coordinate_descent_tuning": True,
252+
"use_mixed_mm": False,
268253
},
269254
use_cudagraph=False,
270-
cudagraph_mode=CUDAGraphMode.NONE,
271-
mode=CompilationMode.VLLM_COMPILE,
272255
)
273256

274-
self.vllm_config = VllmConfig(compilation_config=compilation_config)
257+
self.vllm_config = VllmConfig(
258+
compilation_config=compilation_config
259+
)
275260

276261
self.min_n = vllm_config.speculative_config.prompt_lookup_min
277262
self.max_n = vllm_config.speculative_config.prompt_lookup_max
@@ -282,9 +267,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device, runner=None):
282267
self.device = device
283268

284269
with set_current_vllm_config(self.vllm_config, check_compile=False):
285-
self.kernel = NgramGPUKernel(
286-
vllm_config=vllm_config, prefix="ngram_gpu_kernel", device=device
287-
)
270+
self.kernel = NgramGPUKernel(vllm_config=vllm_config, prefix="ngram_gpu_kernel", device=device)
288271
self.device = device
289272
self.kernel.to(device)
290273
self.kernel.eval()
@@ -300,26 +283,19 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device, runner=None):
300283
self._dummy_run()
301284

302285
def _dummy_run(self):
303-
# with set_current_vllm_config(self.vllm_config, check_compile=False):
304-
# Get warmup iterations from config or use default
305-
token_ids, num_tokens, sampled_flags, valid_mask = (
306-
self._generate_dummy_data(
286+
with set_current_vllm_config(self.vllm_config, check_compile=False):
287+
token_ids, num_tokens, sampled_flags, valid_mask = self._generate_dummy_data(
307288
batch_size=self.max_num_seqs,
308-
max_seq_len=min(
309-
self.max_model_len, 1024
310-
), # Use reasonable seq len for warmup
289+
max_seq_len=min(self.max_model_len, 1024),
311290
vocab_size=self.vocab_size,
312291
pattern_len=self.k,
313292
repetition_rate=0.5,
314-
device=self.device,
293+
device=self.device
315294
)
316-
)
317295

318-
for _ in range(3):
319-
with set_forward_context(None, self.vllm_config):
320-
_ = self.kernel(
321-
num_tokens, token_ids, sampled_flags, valid_mask
322-
)
296+
for _ in range(3):
297+
with set_forward_context(None, self.vllm_config):
298+
output = self.kernel(num_tokens, token_ids, sampled_flags, valid_mask)
323299

324300
def _generate_dummy_data(
325301
self,
@@ -349,13 +325,15 @@ def _generate_dummy_data(
349325
"""
350326
# Generate random token IDs
351327
token_ids = torch.randint(
352-
0, vocab_size, (batch_size, max_seq_len), dtype=torch.int32, device=device
328+
0, vocab_size, (batch_size, max_seq_len),
329+
dtype=torch.int32, device=device
353330
)
354331

355332
# Generate random sequence lengths
356333
min_len = max(pattern_len * 2 + 3, max_seq_len // 2)
357334
num_tokens = torch.randint(
358-
min_len, max_seq_len, (batch_size,), dtype=torch.int32, device=device
335+
min_len, max_seq_len, (batch_size,),
336+
dtype=torch.int32, device=device
359337
)
360338

361339
# Inject n-gram repetitions using the tail pattern of each sequence
@@ -371,9 +349,8 @@ def _generate_dummy_data(
371349
if tgt_pos == src_pos:
372350
continue
373351

374-
token_ids[i, tgt_pos : tgt_pos + pattern_len] = token_ids[
375-
i, src_pos : src_pos + pattern_len
376-
].clone()
352+
token_ids[i, tgt_pos:tgt_pos + pattern_len] = \
353+
token_ids[i, src_pos:src_pos + pattern_len].clone()
377354

378355
# All sequences have sampled tokens and are valid
379356
sampled_flags = torch.ones(batch_size, dtype=torch.bool, device=device)
@@ -388,11 +365,9 @@ def propose(
388365
sampled_flags: torch.Tensor, # [batch_size] bool on GPU
389366
valid_mask: torch.Tensor, # [batch_size] bool on GPU
390367
) -> torch.Tensor:
391-
# with set_current_vllm_config(self.vllm_config, check_compile=False):
392-
with set_forward_context(None, self.vllm_config):
393-
return self.kernel(
394-
num_tokens_no_spec, token_ids_gpu, sampled_flags, valid_mask
395-
)
368+
with set_current_vllm_config(self.vllm_config, check_compile=False):
369+
with set_forward_context(None, self.vllm_config):
370+
return self.kernel(num_tokens_no_spec, token_ids_gpu, sampled_flags, valid_mask)
396371

397372
def prepare_next_token_ids_cpu(
398373
self,
@@ -443,12 +418,9 @@ def prepare_next_token_ids_padded(
443418
should not introduce any blocking CPU-GPU synchronization.
444419
"""
445420
# TODO(Ben): Combine this into a custom fused kernel
446-
447421
# Precompute get_token_id for when there is no valid next token
448422
num_reqs = gpu_input_batch.num_reqs
449423
# Batch convert seq_lens to avoid multiple .item() calls
450-
# This performs a single synchronization for all lengths
451-
# instead of one per request
452424
seq_lens_list = common_attn_metadata.seq_lens_cpu[:num_reqs].tolist()
453425

454426
# Now use the pre-converted list to avoid .item() calls in the loop
@@ -500,4 +472,4 @@ def prepare_next_token_ids_padded(
500472

501473
def load_model(self, *args, **kwargs):
502474
with set_current_vllm_config(self.vllm_config, check_compile=False):
503-
self.kernel.load_model(*args, **kwargs)
475+
self.kernel.load_model(*args, **kwargs)

0 commit comments

Comments
 (0)