Skip to content

Commit b5243ec

Browse files
author
PatchouliTaisa
committed
refactor ngram gpu
Signed-off-by: PatchouliTaisa <[email protected]>
1 parent 1fbf296 commit b5243ec

File tree

2 files changed

+30
-48
lines changed

2 files changed

+30
-48
lines changed

vllm/v1/spec_decode/ngram_proposer_gpu.py

Lines changed: 14 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
finding the first match across all sequences in parallel.
88
"""
99

10-
import numpy as np
1110
import torch
1211
from torch import nn
1312

@@ -17,12 +16,7 @@
1716
VllmConfig,
1817
)
1918
from vllm.forward_context import set_forward_context
20-
from vllm.utils.platform_utils import is_pin_memory_available
21-
from vllm.v1.attention.backends.utils import (
22-
CommonAttentionMetadata,
23-
)
24-
from vllm.v1.utils import CpuGpuBuffer
25-
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
19+
from vllm.v1.worker.gpu_input_batch import InputBatch
2620

2721

2822
@support_torch_compile(
@@ -295,17 +289,6 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device, runner=None):
295289
self.device = device
296290
self.kernel.to(device)
297291
self.kernel.eval()
298-
max_batch_size = vllm_config.scheduler_config.max_num_seqs
299-
300-
# TODO(patchy): Remove this buffer, use
301-
# token_ids_gpu_tensor in gpu_model_runner.py instead.
302-
self.backup_next_token_ids = CpuGpuBuffer(
303-
max_batch_size,
304-
dtype=torch.int32,
305-
pin_memory=is_pin_memory_available(),
306-
device=device,
307-
with_numpy=True,
308-
)
309292

310293
self._dummy_run()
311294

@@ -400,14 +383,14 @@ def propose(
400383

401384
return draft_tokens, is_empty_draft_tokens
402385

403-
def prepare_next_token_ids_padded(
386+
def update_token_ids_ngram(
404387
self,
405-
common_attn_metadata: CommonAttentionMetadata,
406388
sampled_token_ids: torch.Tensor,
407-
requests: dict[str, CachedRequestState],
408389
gpu_input_batch: InputBatch,
409390
discard_request_indices: torch.Tensor,
410391
num_discarded_requests: int,
392+
token_ids_gpu: torch.Tensor,
393+
num_tokens_no_spec: torch.Tensor,
411394
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
412395
"""
413396
This function is used to prepare the inputs for speculative decoding.
@@ -419,17 +402,14 @@ def prepare_next_token_ids_padded(
419402
should not introduce any blocking CPU-GPU synchronization.
420403
"""
421404
num_reqs = gpu_input_batch.num_reqs
422-
# Batch convert seq_lens to avoid multiple .item() calls
423-
seq_lens_list = common_attn_metadata.seq_lens_cpu[:num_reqs].tolist()
424-
425-
# Now use the pre-converted list to avoid .item() calls in the loop
426-
self.backup_next_token_ids.np[:num_reqs] = np.array(
427-
[
428-
requests[gpu_input_batch.req_ids[i]].get_token_id(seq_lens_list[i])
429-
for i in range(num_reqs)
430-
]
431-
)
432-
self.backup_next_token_ids.copy_to_gpu(num_reqs)
405+
406+
# Extract backup_next_token_ids from token_ids_gpu using vectorized gather
407+
# For each request i, get token_ids_gpu[i, num_tokens_no_spec[i] - 1]
408+
# This is the last valid token before speculative tokens
409+
backup_indices = (num_tokens_no_spec[:num_reqs] - 1).clamp(min=0).long()
410+
backup_next_token_ids = torch.gather(
411+
token_ids_gpu[:num_reqs], dim=1, index=backup_indices.unsqueeze(1)
412+
).squeeze(1)
433413

434414
# Mask out the sampled tokens indices that should not be sampled.
435415
discard_sampled_tokens_req_indices = discard_request_indices[
@@ -459,12 +439,11 @@ def prepare_next_token_ids_padded(
459439
valid_sampled_token_ids_gpu, 1, last_valid_indices_safe.unsqueeze(1)
460440
).squeeze(1)
461441

462-
# Use last token if valid, pre-computed backup if not
463-
batch_size = valid_sampled_token_ids_gpu.shape[0]
442+
# Use last token if valid, vectorized backup from token_ids_gpu if not
464443
next_token_ids = torch.where(
465444
last_valid_indices != -1,
466445
selected_tokens,
467-
self.backup_next_token_ids.gpu[:batch_size],
446+
backup_next_token_ids,
468447
)
469448

470449
return next_token_ids, valid_sampled_tokens_count, valid_sampled_token_ids_gpu

vllm/v1/worker/gpu_model_runner.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,7 +1026,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
10261026

10271027
# Incrementally update ngram_gpu tensors after batch is stable
10281028
if is_ngram_gpu:
1029-
self._update_ngram_gpu_tensors_incremental(ngram_gpu_new_reqs)
1029+
with record_function_or_nullcontext("update_ngram_gpu_tensors_incremental"):
1030+
self._update_ngram_gpu_tensors_incremental(ngram_gpu_new_reqs)
10301031

10311032
def _update_ngram_gpu_tensors_incremental(
10321033
self,
@@ -3264,13 +3265,13 @@ def propose_draft_token_ids(sampled_token_ids):
32643265
elif self.valid_sampled_token_count_event is not None:
32653266
assert spec_decode_common_attn_metadata is not None
32663267
next_token_ids, valid_sampled_tokens_count, _ = (
3267-
self.drafter.prepare_next_token_ids_padded(
3268-
spec_decode_common_attn_metadata,
3268+
self.drafter.update_token_ids_ngram(
32693269
sampled_token_ids,
3270-
self.requests,
32713270
self.input_batch,
32723271
self.discard_request_indices.gpu,
32733272
self.num_discarded_requests,
3273+
self.token_ids_gpu_tensor,
3274+
self.num_tokens_no_spec_gpu,
32743275
)
32753276
)
32763277
self._copy_valid_sampled_token_count(
@@ -3426,15 +3427,17 @@ def propose_draft_token_ids(
34263427
assert isinstance(sampled_token_ids, torch.Tensor), (
34273428
"sampled_token_ids should be a torch.Tensor for ngram_gpu"
34283429
)
3429-
next_token_ids, valid_sampled_tokens_count, valid_sampled_token_ids_gpu = (
3430-
self.drafter.prepare_next_token_ids_padded(
3431-
common_attn_metadata,
3432-
sampled_token_ids,
3433-
self.requests,
3434-
self.input_batch,
3435-
self.discard_request_indices.gpu,
3436-
self.num_discarded_requests,
3437-
)
3430+
(
3431+
next_token_ids,
3432+
valid_sampled_tokens_count,
3433+
valid_sampled_token_ids_gpu,
3434+
) = self.drafter.update_token_ids_ngram(
3435+
sampled_token_ids,
3436+
self.input_batch,
3437+
self.discard_request_indices.gpu,
3438+
self.num_discarded_requests,
3439+
self.token_ids_gpu_tensor,
3440+
self.num_tokens_no_spec_gpu,
34383441
)
34393442
self._copy_valid_sampled_token_count(
34403443
next_token_ids, valid_sampled_tokens_count

0 commit comments

Comments
 (0)