Skip to content

Commit cffc19c

Browse files
FENPpisceskkkzhenwenqi2024
committed
model runner support PCP.
Co-authored-by: QiuChunshuo <[email protected]> Co-authored-by: zhenwenqi2024 <[email protected]> Signed-off-by: FENP <[email protected]>
1 parent fa8804a commit cffc19c

File tree

2 files changed

+398
-49
lines changed

2 files changed

+398
-49
lines changed

vllm/v1/worker/gpu_model_runner.py

Lines changed: 157 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
4343
from vllm.distributed.parallel_state import (
4444
get_dcp_group,
45+
get_pcp_group,
4546
get_pp_group,
4647
get_tp_group,
4748
graph_capture,
@@ -100,7 +101,7 @@
100101
AttentionMetadataBuilder,
101102
CommonAttentionMetadata,
102103
create_fast_prefill_custom_backend,
103-
get_dcp_local_seq_lens,
104+
get_cp_local_seq_lens,
104105
reorder_batch_to_split_decodes_and_prefills,
105106
split_attn_metadata,
106107
)
@@ -154,7 +155,7 @@
154155
UBatchSlices,
155156
check_ubatch_thresholds,
156157
)
157-
from vllm.v1.worker.utils import is_residual_scattered_for_sp
158+
from vllm.v1.worker.utils import PCPManager, is_residual_scattered_for_sp
158159

159160
from .utils import (
160161
AttentionGroup,
@@ -305,7 +306,11 @@ def __init__(
305306
# Always set to false after the first forward pass
306307
self.calculate_kv_scales = self.cache_config.calculate_kv_scales
307308
self.dcp_world_size = self.parallel_config.decode_context_parallel_size
309+
self.pcp_world_size = self.parallel_config.prefill_context_parallel_size
310+
self.cp_world_size = self.dcp_world_size * self.pcp_world_size
308311
self.dcp_rank = 0 if self.dcp_world_size <= 1 else get_dcp_group().rank_in_group
312+
self.pcp_rank = 0 if self.pcp_world_size <= 1 else get_pcp_group().rank_in_group
313+
self.cp_rank = self.dcp_world_size * self.pcp_rank + self.dcp_rank
309314
self.max_num_tokens = scheduler_config.max_num_batched_tokens
310315
self.max_num_reqs = scheduler_config.max_num_seqs
311316

@@ -469,25 +474,38 @@ def __init__(
469474
# Cache the device properties.
470475
self._init_device_properties()
471476

477+
if self.pcp_world_size > 1:
478+
# NOTE For PCP, we will pad the tokens of each request
479+
# to a multiple of 2 * pcp_size that is possible greater
480+
# than the max_num_batched_tokens.
481+
max_buffer_num_tokens = (
482+
self.max_num_tokens + self.max_num_reqs * 2 * self.pcp_world_size
483+
)
484+
else:
485+
max_buffer_num_tokens = self.max_num_tokens
486+
472487
# Persistent buffers for CUDA graphs.
473-
self.input_ids = self._make_buffer(self.max_num_tokens, dtype=torch.int32)
474-
self.positions = self._make_buffer(self.max_num_tokens, dtype=torch.int64)
488+
self.input_ids = self._make_buffer(max_buffer_num_tokens, dtype=torch.int32)
489+
self.positions = self._make_buffer(max_buffer_num_tokens, dtype=torch.int64)
475490
self.query_start_loc = self._make_buffer(
476491
self.max_num_reqs + 1, dtype=torch.int32
477492
)
478493
self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
479494
self.encoder_seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
480-
if self.dcp_world_size > 1:
481-
self.dcp_local_seq_lens = self._make_buffer(
495+
if self.cp_world_size > 1:
496+
self.cp_local_seq_lens = self._make_buffer(
482497
self.max_num_reqs, dtype=torch.int32
483498
)
484499
# Because inputs_embeds may be bfloat16 and we don't need a numpy
485500
# version of this tensor, avoid a RuntimeError by not creating a
486501
# numpy buffer.
487502
self.inputs_embeds = self._make_buffer(
488-
self.max_num_tokens, self.inputs_embeds_size, dtype=self.dtype, numpy=False
503+
max_buffer_num_tokens,
504+
self.inputs_embeds_size,
505+
dtype=self.dtype,
506+
numpy=False,
489507
)
490-
self.is_token_ids = self._make_buffer(self.max_num_tokens, dtype=torch.bool)
508+
self.is_token_ids = self._make_buffer(max_buffer_num_tokens, dtype=torch.bool)
491509
self.discard_request_mask = self._make_buffer(
492510
self.max_num_reqs, dtype=torch.bool
493511
)
@@ -500,7 +518,20 @@ def __init__(
500518

501519
# Only relevant for multimodal models
502520
if self.supports_mm_inputs:
503-
self.is_mm_embed = self._make_buffer(self.max_num_tokens, dtype=torch.bool)
521+
self.is_mm_embed = self._make_buffer(
522+
max_buffer_num_tokens, dtype=torch.bool
523+
)
524+
525+
# Manager for Prefill Context Parallism
526+
if self.pcp_world_size > 1:
527+
self.pcp_manager = PCPManager(
528+
self.pcp_world_size,
529+
self.pcp_rank,
530+
max_buffer_num_tokens,
531+
self.max_num_reqs,
532+
self.device,
533+
self.pin_memory,
534+
)
504535

505536
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
506537
if self.uses_mrope:
@@ -515,7 +546,7 @@ def __init__(
515546
# 1D-RoPE.
516547
# See page 5 of https://arxiv.org/abs/2409.12191
517548
self.mrope_positions = self._make_buffer(
518-
(3, self.max_num_tokens + 1), dtype=torch.int64
549+
(3, max_buffer_num_tokens + 1), dtype=torch.int64
519550
)
520551

521552
# Only relevant for models using XD-RoPE (e.g, HunYuan-VL)
@@ -531,7 +562,7 @@ def __init__(
531562
# OPTIMIZATION: Cache the tensors rather than creating them every step.
532563
# Keep in int64 to avoid overflow with long context
533564
self.arange_np = np.arange(
534-
max(self.max_num_reqs + 1, self.max_model_len, self.max_num_tokens),
565+
max(self.max_num_reqs + 1, self.max_model_len, max_buffer_num_tokens),
535566
dtype=np.int64,
536567
)
537568

@@ -545,7 +576,7 @@ def __init__(
545576
self.kv_sharing_fast_prefill_logits_indices = None
546577
if self.cache_config.kv_sharing_fast_prefill:
547578
self.kv_sharing_fast_prefill_logits_indices = torch.zeros(
548-
self.max_num_tokens, dtype=torch.int32, device=self.device
579+
max_buffer_num_tokens, dtype=torch.int32, device=self.device
549580
)
550581

551582
self.uniform_decode_query_len = 1 + self.num_spec_tokens
@@ -1314,6 +1345,32 @@ def _prepare_inputs(
13141345
out=positions_np,
13151346
)
13161347

1348+
self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np)
1349+
self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens)
1350+
1351+
if self.pcp_world_size > 1:
1352+
num_scheduled_tokens[:num_reqs], pcp_positions = (
1353+
self.pcp_manager.update_tokens_for_pcp(
1354+
num_scheduled_tokens[:num_reqs],
1355+
self.arange_np,
1356+
self.input_batch.num_reqs,
1357+
self.reorder_batch_threshold,
1358+
)
1359+
)
1360+
1361+
# Re-update after PCP split sequences.
1362+
total_num_scheduled_tokens = sum(num_scheduled_tokens)
1363+
scheduler_output.total_num_scheduled_tokens = total_num_scheduled_tokens
1364+
1365+
req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens)
1366+
cu_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens)
1367+
positions_np = self.positions.np[:total_num_scheduled_tokens]
1368+
np.add(
1369+
self.input_batch.num_computed_tokens_cpu[req_indices],
1370+
pcp_positions[:total_num_scheduled_tokens],
1371+
out=positions_np,
1372+
)
1373+
13171374
# Calculate M-RoPE positions.
13181375
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
13191376
if self.uses_mrope:
@@ -1389,9 +1446,6 @@ def _prepare_inputs(
13891446

13901447
output_idx += num_sched
13911448

1392-
self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np)
1393-
self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens)
1394-
13951449
# Prepare the attention metadata.
13961450
self.query_start_loc.np[0] = 0
13971451
self.query_start_loc.np[1 : num_reqs + 1] = cu_num_tokens
@@ -1413,9 +1467,16 @@ def _prepare_inputs(
14131467

14141468
# Record which requests should not be sampled,
14151469
# so that we could clear the sampled tokens before returning
1416-
self.discard_request_mask.np[:num_reqs] = (
1417-
self.seq_lens.np[:num_reqs] < num_tokens_np
1418-
)
1470+
if self.pcp_world_size > 1:
1471+
self.discard_request_mask.np[:num_reqs] = (
1472+
self.input_batch.num_computed_tokens_cpu[:num_reqs]
1473+
+ num_scheduled_tokens * self.pcp_world_size
1474+
- self.pcp_manager.num_pcp_pads_cpu[:num_reqs]
1475+
) < num_tokens_np
1476+
else:
1477+
self.discard_request_mask.np[:num_reqs] = (
1478+
self.seq_lens.np[:num_reqs] < num_tokens_np
1479+
)
14191480
self.discard_request_mask.copy_to_gpu(num_reqs)
14201481

14211482
# Copy the tensors to the GPU.
@@ -1449,10 +1510,19 @@ def _prepare_inputs(
14491510
# We will ignore the sampled tokens from the partial requests.
14501511
# TODO: Support prompt logprobs.
14511512
logits_indices = query_start_loc[1:] - 1
1513+
if self.pcp_world_size > 1:
1514+
logits_indices = (
1515+
torch.from_numpy(cu_num_tokens) * self.pcp_world_size
1516+
- self.pcp_manager.num_pcp_pads_cpu_tensor[:num_reqs]
1517+
- 1
1518+
)
1519+
else:
1520+
logits_indices = query_start_loc[1:] - 1
14521521
num_draft_tokens = None
14531522
spec_decode_metadata = None
14541523
num_sampled_tokens = np.ones(num_reqs, dtype=np.int32)
14551524
else:
1525+
assert self.pcp_world_size == 1, "PCP not support spec decode now"
14561526
# Get the number of draft tokens for each request.
14571527
# Iterate over the dictionary rather than all requests since not all
14581528
# requests have draft tokens.
@@ -1516,6 +1586,10 @@ def _build_attention_metadata(
15161586
"""
15171587
:return: tuple[attn_metadata, spec_decode_common_attn_metadata]
15181588
"""
1589+
assert num_tokens_padded is None or self.pcp_world_size == 1, (
1590+
"PCP not support pad attn now"
1591+
)
1592+
15191593
num_tokens_padded = num_tokens_padded or num_tokens
15201594
num_reqs_padded = num_reqs_padded or num_reqs
15211595

@@ -1528,16 +1602,16 @@ def _build_attention_metadata(
15281602
logits_indices
15291603
)
15301604

1531-
# update seq_lens of decode reqs under DCP.
1532-
if self.dcp_world_size > 1:
1533-
self.dcp_local_seq_lens.cpu[:num_reqs] = get_dcp_local_seq_lens(
1605+
# update seq_lens of decode reqs under CP.
1606+
if self.cp_world_size > 1:
1607+
self.cp_local_seq_lens.cpu[:num_reqs] = get_cp_local_seq_lens(
15341608
self.seq_lens.cpu[:num_reqs],
1535-
self.dcp_world_size,
1536-
self.dcp_rank,
1609+
self.cp_world_size,
1610+
self.cp_rank,
15371611
self.parallel_config.cp_kv_cache_interleave_size,
15381612
)
1539-
self.dcp_local_seq_lens.cpu[num_reqs:].fill_(0)
1540-
self.dcp_local_seq_lens.copy_to_gpu(num_reqs_padded)
1613+
self.cp_local_seq_lens.cpu[num_reqs:].fill_(0)
1614+
self.cp_local_seq_lens.copy_to_gpu(num_reqs_padded)
15411615

15421616
attn_metadata: PerLayerAttnMetadata = {}
15431617
if ubatch_slices is not None:
@@ -1567,10 +1641,10 @@ def _build_attention_metadata(
15671641
:num_reqs_padded
15681642
]
15691643

1570-
dcp_local_seq_lens, dcp_local_seq_lens_cpu = None, None
1571-
if self.dcp_world_size > 1:
1572-
dcp_local_seq_lens = self.dcp_local_seq_lens.gpu[:num_reqs_padded]
1573-
dcp_local_seq_lens_cpu = self.dcp_local_seq_lens.cpu[:num_reqs_padded]
1644+
cp_local_seq_lens, cp_local_seq_lens_cpu = None, None
1645+
if self.cp_world_size > 1:
1646+
cp_local_seq_lens = self.cp_local_seq_lens.gpu[:num_reqs_padded]
1647+
cp_local_seq_lens_cpu = self.cp_local_seq_lens.cpu[:num_reqs_padded]
15741648

15751649
spec_decode_common_attn_metadata = None
15761650

@@ -1585,11 +1659,18 @@ def _build_attention_metadata(
15851659
num_reqs_padded,
15861660
)
15871661

1662+
maybe_pcp_full_tokens = (
1663+
num_tokens_padded
1664+
if self.pcp_world_size == 1
1665+
else num_tokens * self.pcp_world_size
1666+
- sum(self.pcp_manager.num_pcp_pads_cpu[:num_reqs])
1667+
)
1668+
15881669
if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec):
15891670
# Encoder-only layers do not have KV cache, so we need to
15901671
# create a dummy block table and slot mapping for them.
15911672
blk_table_tensor = torch.zeros(
1592-
(num_reqs_padded, 1),
1673+
(num_tokens_padded, 1),
15931674
dtype=torch.int32,
15941675
device=self.device,
15951676
)
@@ -1601,12 +1682,26 @@ def _build_attention_metadata(
16011682
else:
16021683
blk_table = self.input_batch.block_table[kv_cache_gid]
16031684
blk_table_tensor = blk_table.get_device_tensor(num_reqs_padded)
1604-
slot_mapping = blk_table.slot_mapping.gpu[:num_tokens_padded]
1685+
slot_mapping = blk_table.slot_mapping.gpu[:maybe_pcp_full_tokens]
16051686

16061687
# Fill unused with -1. Needed for reshape_and_cache in full cuda
16071688
# graph mode. `blk_table_tensor` -1 to match mamba PAD_SLOT_ID
1608-
slot_mapping[num_tokens:num_tokens_padded].fill_(-1)
1609-
blk_table_tensor[num_reqs:num_reqs_padded].fill_(-1)
1689+
if self.pcp_world_size == 1:
1690+
slot_mapping[num_tokens:num_tokens_padded].fill_(-1)
1691+
blk_table_tensor[num_reqs:num_reqs_padded].fill_(-1)
1692+
1693+
if self.pcp_world_size > 1:
1694+
# After pcp allgather and restore, there are padded tokens in
1695+
# kv, so we need pad slotmapping for alignment.
1696+
pcp_padded_slot_mapping = self.pcp_manager.pcp_padded_slot_mapping[
1697+
: num_tokens * self.pcp_world_size
1698+
]
1699+
cp_unpad_mask = self.pcp_manager.pcp_unpad_mask_cpu_tensor[
1700+
: num_tokens * self.pcp_world_size
1701+
]
1702+
pcp_padded_slot_mapping.fill_(-1)
1703+
pcp_padded_slot_mapping[cp_unpad_mask] = slot_mapping
1704+
slot_mapping = pcp_padded_slot_mapping
16101705

16111706
common_attn_metadata = CommonAttentionMetadata(
16121707
query_start_loc=query_start_loc,
@@ -1625,8 +1720,13 @@ def _build_attention_metadata(
16251720
causal=True,
16261721
encoder_seq_lens=encoder_seq_lens,
16271722
encoder_seq_lens_cpu=encoder_seq_lens_cpu,
1628-
dcp_local_seq_lens=dcp_local_seq_lens,
1629-
dcp_local_seq_lens_cpu=dcp_local_seq_lens_cpu,
1723+
cp_local_seq_lens=cp_local_seq_lens,
1724+
cp_local_seq_lens_cpu=cp_local_seq_lens_cpu,
1725+
pcp_allgather_restore_idx=self.pcp_manager.pcp_allgather_restore_idx.gpu[
1726+
: num_tokens * self.pcp_world_size
1727+
]
1728+
if self.pcp_world_size > 1
1729+
else None,
16301730
)
16311731

16321732
if self.speculative_config and spec_decode_common_attn_metadata is None:
@@ -1690,16 +1790,6 @@ def _build_attention_metadata(
16901790
for layer_name in attn_group.layer_names:
16911791
attn_metadata[layer_name] = attn_metadata_i
16921792

1693-
if spec_decode_common_attn_metadata is not None and (
1694-
num_reqs != num_reqs_padded or num_tokens != num_tokens_padded
1695-
):
1696-
# Currently the drafter still only uses piecewise cudagraphs (and modifies
1697-
# the attention metadata in directly), and therefore does not want to use
1698-
# padded attention metadata.
1699-
spec_decode_common_attn_metadata = (
1700-
spec_decode_common_attn_metadata.unpadded(num_tokens, num_reqs)
1701-
)
1702-
17031793
return attn_metadata, spec_decode_common_attn_metadata
17041794

17051795
def _compute_cascade_attn_prefix_lens(
@@ -2904,6 +2994,9 @@ def execute_model(
29042994
scheduler_output,
29052995
num_scheduled_tokens_np,
29062996
)
2997+
if self.pcp_world_size > 1:
2998+
max_num_scheduled_tokens = int(num_scheduled_tokens_np.max())
2999+
num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens
29073000

29083001
cascade_attn_prefix_lens = None
29093002
# Disable cascade attention when using microbatching (DBO)
@@ -3011,6 +3104,23 @@ def execute_model(
30113104
hidden_states = model_output
30123105
aux_hidden_states = None
30133106

3107+
if self.pcp_world_size > 1:
3108+
# NOTE we must `slice` hidden_states because pcp_allgather_restore_idx
3109+
# ignores the padding from CUDA Graph.
3110+
hidden_states = get_pcp_group().all_gather(
3111+
hidden_states[:num_tokens_unpadded],
3112+
0,
3113+
)
3114+
restore_idx = self.pcp_manager.pcp_allgather_restore_idx.gpu[
3115+
: hidden_states.shape[0]
3116+
]
3117+
hidden_states = torch.index_select(
3118+
hidden_states,
3119+
0,
3120+
restore_idx,
3121+
)
3122+
# Restore total_num_scheduled_tokens.
3123+
scheduler_output.total_num_scheduled_tokens = num_scheduled_tokens
30143124
if not self.broadcast_pp_output:
30153125
# Common case.
30163126
if not get_pp_group().is_last_rank:
@@ -5274,15 +5384,15 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
52745384
kv_transfer_group.register_kv_caches(kv_caches)
52755385
kv_transfer_group.set_host_xfer_buffer_ops(copy_kv_blocks)
52765386

5277-
if self.dcp_world_size > 1:
5387+
if self.cp_world_size > 1:
52785388
layer_type = cast(type[Any], AttentionLayerBase)
52795389
layers = get_layers_from_vllm_config(self.vllm_config, layer_type)
52805390
for layer in layers.values():
52815391
layer_impl = getattr(layer, "impl", None)
52825392
if layer_impl is None:
52835393
continue
52845394
assert layer_impl.need_to_return_lse_for_decode, (
5285-
"DCP requires attention impls to return"
5395+
"PCP & DCP require attention impls to return"
52865396
" the softmax lse for decode, but the impl "
52875397
f"{layer_impl.__class__.__name__} "
52885398
"does not return the softmax lse for decode."

0 commit comments

Comments
 (0)