Skip to content

Commit 1fbf296

Browse files
author
PatchouliTaisa
committed
fix large batch performance.
Signed-off-by: PatchouliTaisa <[email protected]>
1 parent f6f871f commit 1fbf296

File tree

5 files changed

+179
-283
lines changed

5 files changed

+179
-283
lines changed

vllm/v1/core/sched/scheduler.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -749,17 +749,8 @@ def _update_after_schedule(
749749
# 3. If some tokens (e.g. spec tokens) are rejected later, the number of
750750
# computed tokens will be adjusted in update_from_output.
751751
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
752-
spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens
753752
for req_id, num_scheduled_token in num_scheduled_tokens.items():
754753
request = self.requests[req_id]
755-
# DEBUG LOG: Track num_computed_tokens update in scheduler
756-
spec_tokens = spec_decode_tokens.get(req_id, [])
757-
logger.info(f"[DEBUG-SCHED] _update_after_schedule: "
758-
f"req_id={req_id}, "
759-
f"num_computed_tokens_before={request.num_computed_tokens}, "
760-
f"num_scheduled_token={num_scheduled_token}, "
761-
f"spec_decode_tokens={spec_tokens}, "
762-
f"num_computed_tokens_after={request.num_computed_tokens + num_scheduled_token}")
763754
request.num_computed_tokens += num_scheduled_token
764755

765756
# NOTE: _free_encoder_inputs relies on num_computed_tokens, which
@@ -1005,6 +996,7 @@ def update_from_output(
1005996
pooler_outputs = model_runner_output.pooler_output
1006997
num_nans_in_logits = model_runner_output.num_nans_in_logits
1007998
kv_connector_output = model_runner_output.kv_connector_output
999+
is_empty_draft_tokens = model_runner_output.is_empty_draft_tokens
10081000

10091001
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
10101002
spec_decoding_stats: SpecDecodingStats | None = None
@@ -1047,23 +1039,16 @@ def update_from_output(
10471039
sampled_token_ids[req_index] if sampled_token_ids else []
10481040
)
10491041

1042+
req_is_empty_draft_tokens = (
1043+
is_empty_draft_tokens[req_index] if is_empty_draft_tokens else False
1044+
)
10501045
scheduled_spec_token_ids = (
10511046
scheduler_output.scheduled_spec_decode_tokens.get(req_id)
10521047
)
10531048
if scheduled_spec_token_ids:
10541049
num_draft_tokens = len(scheduled_spec_token_ids)
10551050
num_accepted = len(generated_token_ids) - 1
10561051
num_rejected = num_draft_tokens - num_accepted
1057-
# DEBUG LOG: Track scheduler adjustment
1058-
logger.info(f"[DEBUG-SCHED] Adjusting in update_from_output: "
1059-
f"req_id={req_id}, "
1060-
f"scheduled_spec_token_ids={scheduled_spec_token_ids}, "
1061-
f"num_draft_tokens={num_draft_tokens}, "
1062-
f"generated_token_ids_len={len(generated_token_ids)}, "
1063-
f"num_accepted={num_accepted}, "
1064-
f"num_rejected={num_rejected}, "
1065-
f"num_computed_tokens_before={request.num_computed_tokens}, "
1066-
f"num_computed_tokens_after={request.num_computed_tokens - num_rejected if request.num_computed_tokens > 0 else request.num_computed_tokens}")
10671052
# num_computed_tokens represents the number of tokens
10681053
# processed in the current step, considering scheduled
10691054
# tokens and rejections. If some tokens are rejected,
@@ -1090,6 +1075,10 @@ def update_from_output(
10901075
# Check for stop and update request status.
10911076
# logger.info(f"In Scheduler::_update_request_with_output inside loop")
10921077
# from fpdb import ForkedPdb; ForkedPdb().set_trace()
1078+
1079+
if req_is_empty_draft_tokens:
1080+
request.spec_token_ids = []
1081+
10931082
if new_token_ids:
10941083
new_token_ids, stopped = self._update_request_with_output(
10951084
request, new_token_ids

vllm/v1/outputs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,9 @@ class ModelRunnerOutput:
181181
# req_id -> num_nans_in_logits
182182
num_nans_in_logits: dict[str, int] | None = None
183183

184+
# [num_reqs]
185+
is_empty_draft_tokens: list[bool] | None = None
186+
184187

185188
# ModelRunnerOutput wrapper for async scheduling.
186189
class AsyncModelRunnerOutput(ABC):

vllm/v1/spec_decode/ngram_proposer_gpu.py

Lines changed: 14 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
@support_torch_compile(
2929
dynamic_arg_dims={
3030
"num_tokens_no_spec": 0,
31-
"token_ids_gpu": [0, 1],
31+
"token_ids_gpu": 0,
3232
"combined_mask": 0,
3333
}
3434
)
@@ -196,7 +196,7 @@ def _find_first_and_extract_all_n_parallel(
196196
results = torch.where(
197197
has_any_match.unsqueeze(1),
198198
extracted_sequences,
199-
torch.full_like(extracted_sequences, 0), # TODO:(patchy): Use -1 instead of 0.
199+
torch.full_like(extracted_sequences, 0),
200200
)
201201

202202
return results
@@ -248,7 +248,7 @@ def forward(
248248
mask = combined_mask.unsqueeze(1).expand(-1, self.k)
249249
draft_tokens = torch.where(mask, results, draft_tokens)
250250

251-
is_empty_draft_tokens = (draft_tokens == 0).all(dim=1) # TODO:(patchy): Use -1 instead of 0.
251+
is_empty_draft_tokens = (draft_tokens == 0).all(dim=1)
252252

253253
return draft_tokens, is_empty_draft_tokens
254254

@@ -296,6 +296,9 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device, runner=None):
296296
self.kernel.to(device)
297297
self.kernel.eval()
298298
max_batch_size = vllm_config.scheduler_config.max_num_seqs
299+
300+
# TODO(patchy): Remove this buffer, use
301+
# token_ids_gpu_tensor in gpu_model_runner.py instead.
299302
self.backup_next_token_ids = CpuGpuBuffer(
300303
max_batch_size,
301304
dtype=torch.int32,
@@ -309,7 +312,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device, runner=None):
309312
def _dummy_run(self):
310313
token_ids, num_tokens, sampled_flags, valid_mask = self._generate_dummy_data(
311314
batch_size=self.max_num_seqs,
312-
max_seq_len=min(self.max_model_len, 1024),
315+
max_seq_len=self.max_model_len,
313316
vocab_size=self.vocab_size,
314317
pattern_len=self.k,
315318
repetition_rate=0.5,
@@ -354,33 +357,18 @@ def _generate_dummy_data(
354357
valid_mask: [batch_size] bool tensor
355358
"""
356359
# Generate random token IDs
357-
token_ids = torch.randint(
358-
0, vocab_size, (batch_size, max_seq_len), dtype=torch.int32, device=device
360+
token_ids = torch.zeros(
361+
batch_size,
362+
max_seq_len,
363+
dtype=torch.int32,
364+
device=device,
359365
)
360366

361367
# Generate random sequence lengths
362-
min_len = max(pattern_len * 2 + 3, max_seq_len // 2)
363368
num_tokens = torch.randint(
364-
min_len, max_seq_len, (batch_size,), dtype=torch.int32, device=device
369+
pattern_len, max_seq_len, (batch_size,), dtype=torch.int32, device=device
365370
)
366371

367-
# Inject n-gram repetitions using the tail pattern of each sequence
368-
for i in range(batch_size):
369-
seq_len = num_tokens[i].item()
370-
if seq_len > pattern_len * 2:
371-
# Pattern is the last pattern_len tokens of the valid sequence
372-
src_pos = seq_len - pattern_len
373-
num_reps = int(seq_len * repetition_rate / pattern_len)
374-
for _ in range(num_reps):
375-
# Place the copied tail pattern somewhere before the tail
376-
tgt_pos = torch.randint(0, seq_len - pattern_len, (1,)).item()
377-
if tgt_pos == src_pos:
378-
continue
379-
380-
token_ids[i, tgt_pos : tgt_pos + pattern_len] = token_ids[
381-
i, src_pos : src_pos + pattern_len
382-
].clone()
383-
384372
# All sequences have sampled tokens and are valid
385373
sampled_flags = torch.ones(batch_size, dtype=torch.bool, device=device)
386374
valid_mask = torch.ones(batch_size, dtype=torch.bool, device=device)
@@ -401,10 +389,7 @@ def propose(
401389

402390
with set_forward_context(None, self.vllm_config):
403391
combined_mask = (
404-
sampled_flags
405-
& valid_mask
406-
& (num_tokens_no_spec < self.max_model_len)
407-
& (num_tokens_no_spec >= self.min_n)
392+
sampled_flags & valid_mask & (num_tokens_no_spec >= self.min_n)
408393
)
409394

410395
draft_tokens, is_empty_draft_tokens = self.kernel(
@@ -415,36 +400,6 @@ def propose(
415400

416401
return draft_tokens, is_empty_draft_tokens
417402

418-
def prepare_next_token_ids_cpu(
419-
self,
420-
sampled_token_ids: list[np.ndarray],
421-
requests: dict[str, CachedRequestState],
422-
gpu_input_batch: InputBatch,
423-
num_scheduled_tokens: dict[str, int],
424-
) -> torch.Tensor:
425-
"""
426-
This function is used to prepare the inputs for speculative decoding.
427-
It calculates the next token ids for each request based on the sampled
428-
token ids from the CPU. If a request has no sampled token ids (e.g.,
429-
during the initial decoding steps), it falls back to using the request
430-
state to get the next token id.
431-
"""
432-
req_ids = gpu_input_batch.req_ids
433-
next_token_ids: list[int] = []
434-
for i, token_ids in enumerate(sampled_token_ids):
435-
if token_ids.shape[0] > 0:
436-
# Common case.
437-
next_token_id = token_ids[-1]
438-
else:
439-
# Partial prefill (rare case).
440-
# Get the next token id from the request state.
441-
req_id = req_ids[i]
442-
req_state = requests[req_id]
443-
seq_len = req_state.num_computed_tokens + num_scheduled_tokens[req_id]
444-
next_token_id = req_state.get_token_id(seq_len)
445-
next_token_ids.append(next_token_id)
446-
return torch.tensor(next_token_ids, dtype=torch.int32, device=self.device)
447-
448403
def prepare_next_token_ids_padded(
449404
self,
450405
common_attn_metadata: CommonAttentionMetadata,
@@ -463,8 +418,6 @@ def prepare_next_token_ids_padded(
463418
This function must use device functions to operate on the inputs, and
464419
should not introduce any blocking CPU-GPU synchronization.
465420
"""
466-
# TODO(Ben): Combine this into a custom fused kernel
467-
# Precompute get_token_id for when there is no valid next token
468421
num_reqs = gpu_input_batch.num_reqs
469422
# Batch convert seq_lens to avoid multiple .item() calls
470423
seq_lens_list = common_attn_metadata.seq_lens_cpu[:num_reqs].tolist()

vllm/v1/worker/gpu_input_batch.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,6 @@ def __init__(
114114
pin_memory=False,
115115
)
116116
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
117-
self.token_ids_gpu_tensor = torch.zeros(
118-
max_num_reqs, max_model_len, dtype=torch.int32, device=device
119-
)
120117
self.is_token_ids_tensor = torch.zeros(
121118
(max_num_reqs, max_model_len), device="cpu", dtype=bool, pin_memory=False
122119
)
@@ -127,9 +124,6 @@ def __init__(
127124
self.req_prompt_embeds: dict[int, torch.Tensor] = {}
128125
self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
129126
self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
130-
self.num_tokens_no_spec_gpu = torch.zeros(
131-
max_num_reqs, dtype=torch.int32, device=device
132-
)
133127
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
134128
self.num_computed_tokens_cpu_tensor = torch.zeros(
135129
(max_num_reqs,),

0 commit comments

Comments
 (0)