4242from vllm .distributed .kv_transfer .kv_connector .utils import copy_kv_blocks
4343from vllm .distributed .parallel_state import (
4444 get_dcp_group ,
45+ get_pcp_group ,
4546 get_pp_group ,
4647 get_tp_group ,
4748 graph_capture ,
100101 AttentionMetadataBuilder ,
101102 CommonAttentionMetadata ,
102103 create_fast_prefill_custom_backend ,
103- get_dcp_local_seq_lens ,
104+ get_cp_local_seq_lens ,
104105 reorder_batch_to_split_decodes_and_prefills ,
105106 split_attn_metadata ,
106107)
154155 UBatchSlices ,
155156 check_ubatch_thresholds ,
156157)
157- from vllm .v1 .worker .utils import is_residual_scattered_for_sp
158+ from vllm .v1 .worker .utils import PCPManager , is_residual_scattered_for_sp
158159
159160from .utils import (
160161 AttentionGroup ,
@@ -305,7 +306,11 @@ def __init__(
305306 # Always set to false after the first forward pass
306307 self .calculate_kv_scales = self .cache_config .calculate_kv_scales
307308 self .dcp_world_size = self .parallel_config .decode_context_parallel_size
309+ self .pcp_world_size = self .parallel_config .prefill_context_parallel_size
310+ self .cp_world_size = self .dcp_world_size * self .pcp_world_size
308311 self .dcp_rank = 0 if self .dcp_world_size <= 1 else get_dcp_group ().rank_in_group
312+ self .pcp_rank = 0 if self .pcp_world_size <= 1 else get_pcp_group ().rank_in_group
313+ self .cp_rank = self .dcp_world_size * self .pcp_rank + self .dcp_rank
309314 self .max_num_tokens = scheduler_config .max_num_batched_tokens
310315 self .max_num_reqs = scheduler_config .max_num_seqs
311316
@@ -469,25 +474,38 @@ def __init__(
469474 # Cache the device properties.
470475 self ._init_device_properties ()
471476
477+ if self .pcp_world_size > 1 :
478+ # NOTE For PCP, we will pad the tokens of each request
479+ # to a multiple of 2 * pcp_size that is possible greater
480+ # than the max_num_batched_tokens.
481+ max_buffer_num_tokens = (
482+ self .max_num_tokens + self .max_num_reqs * 2 * self .pcp_world_size
483+ )
484+ else :
485+ max_buffer_num_tokens = self .max_num_tokens
486+
472487 # Persistent buffers for CUDA graphs.
473- self .input_ids = self ._make_buffer (self . max_num_tokens , dtype = torch .int32 )
474- self .positions = self ._make_buffer (self . max_num_tokens , dtype = torch .int64 )
488+ self .input_ids = self ._make_buffer (max_buffer_num_tokens , dtype = torch .int32 )
489+ self .positions = self ._make_buffer (max_buffer_num_tokens , dtype = torch .int64 )
475490 self .query_start_loc = self ._make_buffer (
476491 self .max_num_reqs + 1 , dtype = torch .int32
477492 )
478493 self .seq_lens = self ._make_buffer (self .max_num_reqs , dtype = torch .int32 )
479494 self .encoder_seq_lens = self ._make_buffer (self .max_num_reqs , dtype = torch .int32 )
480- if self .dcp_world_size > 1 :
481- self .dcp_local_seq_lens = self ._make_buffer (
495+ if self .cp_world_size > 1 :
496+ self .cp_local_seq_lens = self ._make_buffer (
482497 self .max_num_reqs , dtype = torch .int32
483498 )
484499 # Because inputs_embeds may be bfloat16 and we don't need a numpy
485500 # version of this tensor, avoid a RuntimeError by not creating a
486501 # numpy buffer.
487502 self .inputs_embeds = self ._make_buffer (
488- self .max_num_tokens , self .inputs_embeds_size , dtype = self .dtype , numpy = False
503+ max_buffer_num_tokens ,
504+ self .inputs_embeds_size ,
505+ dtype = self .dtype ,
506+ numpy = False ,
489507 )
490- self .is_token_ids = self ._make_buffer (self . max_num_tokens , dtype = torch .bool )
508+ self .is_token_ids = self ._make_buffer (max_buffer_num_tokens , dtype = torch .bool )
491509 self .discard_request_mask = self ._make_buffer (
492510 self .max_num_reqs , dtype = torch .bool
493511 )
@@ -500,7 +518,20 @@ def __init__(
500518
501519 # Only relevant for multimodal models
502520 if self .supports_mm_inputs :
503- self .is_mm_embed = self ._make_buffer (self .max_num_tokens , dtype = torch .bool )
521+ self .is_mm_embed = self ._make_buffer (
522+ max_buffer_num_tokens , dtype = torch .bool
523+ )
524+
525+ # Manager for Prefill Context Parallism
526+ if self .pcp_world_size > 1 :
527+ self .pcp_manager = PCPManager (
528+ self .pcp_world_size ,
529+ self .pcp_rank ,
530+ max_buffer_num_tokens ,
531+ self .max_num_reqs ,
532+ self .device ,
533+ self .pin_memory ,
534+ )
504535
505536 # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
506537 if self .uses_mrope :
@@ -515,7 +546,7 @@ def __init__(
515546 # 1D-RoPE.
516547 # See page 5 of https://arxiv.org/abs/2409.12191
517548 self .mrope_positions = self ._make_buffer (
518- (3 , self . max_num_tokens + 1 ), dtype = torch .int64
549+ (3 , max_buffer_num_tokens + 1 ), dtype = torch .int64
519550 )
520551
521552 # Only relevant for models using XD-RoPE (e.g, HunYuan-VL)
@@ -531,7 +562,7 @@ def __init__(
531562 # OPTIMIZATION: Cache the tensors rather than creating them every step.
532563 # Keep in int64 to avoid overflow with long context
533564 self .arange_np = np .arange (
534- max (self .max_num_reqs + 1 , self .max_model_len , self . max_num_tokens ),
565+ max (self .max_num_reqs + 1 , self .max_model_len , max_buffer_num_tokens ),
535566 dtype = np .int64 ,
536567 )
537568
@@ -545,7 +576,7 @@ def __init__(
545576 self .kv_sharing_fast_prefill_logits_indices = None
546577 if self .cache_config .kv_sharing_fast_prefill :
547578 self .kv_sharing_fast_prefill_logits_indices = torch .zeros (
548- self . max_num_tokens , dtype = torch .int32 , device = self .device
579+ max_buffer_num_tokens , dtype = torch .int32 , device = self .device
549580 )
550581
551582 self .uniform_decode_query_len = 1 + self .num_spec_tokens
@@ -1314,6 +1345,32 @@ def _prepare_inputs(
13141345 out = positions_np ,
13151346 )
13161347
1348+ self .input_batch .block_table .compute_slot_mapping (req_indices , positions_np )
1349+ self .input_batch .block_table .commit_slot_mapping (total_num_scheduled_tokens )
1350+
1351+ if self .pcp_world_size > 1 :
1352+ num_scheduled_tokens [:num_reqs ], pcp_positions = (
1353+ self .pcp_manager .update_tokens_for_pcp (
1354+ num_scheduled_tokens [:num_reqs ],
1355+ self .arange_np ,
1356+ self .input_batch .num_reqs ,
1357+ self .reorder_batch_threshold ,
1358+ )
1359+ )
1360+
1361+ # Re-update after PCP split sequences.
1362+ total_num_scheduled_tokens = sum (num_scheduled_tokens )
1363+ scheduler_output .total_num_scheduled_tokens = total_num_scheduled_tokens
1364+
1365+ req_indices = np .repeat (self .arange_np [:num_reqs ], num_scheduled_tokens )
1366+ cu_num_tokens , _ = self ._get_cumsum_and_arange (num_scheduled_tokens )
1367+ positions_np = self .positions .np [:total_num_scheduled_tokens ]
1368+ np .add (
1369+ self .input_batch .num_computed_tokens_cpu [req_indices ],
1370+ pcp_positions [:total_num_scheduled_tokens ],
1371+ out = positions_np ,
1372+ )
1373+
13171374 # Calculate M-RoPE positions.
13181375 # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
13191376 if self .uses_mrope :
@@ -1389,9 +1446,6 @@ def _prepare_inputs(
13891446
13901447 output_idx += num_sched
13911448
1392- self .input_batch .block_table .compute_slot_mapping (req_indices , positions_np )
1393- self .input_batch .block_table .commit_slot_mapping (total_num_scheduled_tokens )
1394-
13951449 # Prepare the attention metadata.
13961450 self .query_start_loc .np [0 ] = 0
13971451 self .query_start_loc .np [1 : num_reqs + 1 ] = cu_num_tokens
@@ -1413,9 +1467,16 @@ def _prepare_inputs(
14131467
14141468 # Record which requests should not be sampled,
14151469 # so that we could clear the sampled tokens before returning
1416- self .discard_request_mask .np [:num_reqs ] = (
1417- self .seq_lens .np [:num_reqs ] < num_tokens_np
1418- )
1470+ if self .pcp_world_size > 1 :
1471+ self .discard_request_mask .np [:num_reqs ] = (
1472+ self .input_batch .num_computed_tokens_cpu [:num_reqs ]
1473+ + num_scheduled_tokens * self .pcp_world_size
1474+ - self .pcp_manager .num_pcp_pads_cpu [:num_reqs ]
1475+ ) < num_tokens_np
1476+ else :
1477+ self .discard_request_mask .np [:num_reqs ] = (
1478+ self .seq_lens .np [:num_reqs ] < num_tokens_np
1479+ )
14191480 self .discard_request_mask .copy_to_gpu (num_reqs )
14201481
14211482 # Copy the tensors to the GPU.
@@ -1449,10 +1510,19 @@ def _prepare_inputs(
14491510 # We will ignore the sampled tokens from the partial requests.
14501511 # TODO: Support prompt logprobs.
14511512 logits_indices = query_start_loc [1 :] - 1
1513+ if self .pcp_world_size > 1 :
1514+ logits_indices = (
1515+ torch .from_numpy (cu_num_tokens ) * self .pcp_world_size
1516+ - self .pcp_manager .num_pcp_pads_cpu_tensor [:num_reqs ]
1517+ - 1
1518+ )
1519+ else :
1520+ logits_indices = query_start_loc [1 :] - 1
14521521 num_draft_tokens = None
14531522 spec_decode_metadata = None
14541523 num_sampled_tokens = np .ones (num_reqs , dtype = np .int32 )
14551524 else :
1525+ assert self .pcp_world_size == 1 , "PCP not support spec decode now"
14561526 # Get the number of draft tokens for each request.
14571527 # Iterate over the dictionary rather than all requests since not all
14581528 # requests have draft tokens.
@@ -1516,6 +1586,10 @@ def _build_attention_metadata(
15161586 """
15171587 :return: tuple[attn_metadata, spec_decode_common_attn_metadata]
15181588 """
1589+ assert num_tokens_padded is None or self .pcp_world_size == 1 , (
1590+ "PCP not support pad attn now"
1591+ )
1592+
15191593 num_tokens_padded = num_tokens_padded or num_tokens
15201594 num_reqs_padded = num_reqs_padded or num_reqs
15211595
@@ -1528,16 +1602,16 @@ def _build_attention_metadata(
15281602 logits_indices
15291603 )
15301604
1531- # update seq_lens of decode reqs under DCP .
1532- if self .dcp_world_size > 1 :
1533- self .dcp_local_seq_lens .cpu [:num_reqs ] = get_dcp_local_seq_lens (
1605+ # update seq_lens of decode reqs under CP .
1606+ if self .cp_world_size > 1 :
1607+ self .cp_local_seq_lens .cpu [:num_reqs ] = get_cp_local_seq_lens (
15341608 self .seq_lens .cpu [:num_reqs ],
1535- self .dcp_world_size ,
1536- self .dcp_rank ,
1609+ self .cp_world_size ,
1610+ self .cp_rank ,
15371611 self .parallel_config .cp_kv_cache_interleave_size ,
15381612 )
1539- self .dcp_local_seq_lens .cpu [num_reqs :].fill_ (0 )
1540- self .dcp_local_seq_lens .copy_to_gpu (num_reqs_padded )
1613+ self .cp_local_seq_lens .cpu [num_reqs :].fill_ (0 )
1614+ self .cp_local_seq_lens .copy_to_gpu (num_reqs_padded )
15411615
15421616 attn_metadata : PerLayerAttnMetadata = {}
15431617 if ubatch_slices is not None :
@@ -1567,10 +1641,10 @@ def _build_attention_metadata(
15671641 :num_reqs_padded
15681642 ]
15691643
1570- dcp_local_seq_lens , dcp_local_seq_lens_cpu = None , None
1571- if self .dcp_world_size > 1 :
1572- dcp_local_seq_lens = self .dcp_local_seq_lens .gpu [:num_reqs_padded ]
1573- dcp_local_seq_lens_cpu = self .dcp_local_seq_lens .cpu [:num_reqs_padded ]
1644+ cp_local_seq_lens , cp_local_seq_lens_cpu = None , None
1645+ if self .cp_world_size > 1 :
1646+ cp_local_seq_lens = self .cp_local_seq_lens .gpu [:num_reqs_padded ]
1647+ cp_local_seq_lens_cpu = self .cp_local_seq_lens .cpu [:num_reqs_padded ]
15741648
15751649 spec_decode_common_attn_metadata = None
15761650
@@ -1585,11 +1659,18 @@ def _build_attention_metadata(
15851659 num_reqs_padded ,
15861660 )
15871661
1662+ maybe_pcp_full_tokens = (
1663+ num_tokens_padded
1664+ if self .pcp_world_size == 1
1665+ else num_tokens * self .pcp_world_size
1666+ - sum (self .pcp_manager .num_pcp_pads_cpu [:num_reqs ])
1667+ )
1668+
15881669 if isinstance (kv_cache_group .kv_cache_spec , EncoderOnlyAttentionSpec ):
15891670 # Encoder-only layers do not have KV cache, so we need to
15901671 # create a dummy block table and slot mapping for them.
15911672 blk_table_tensor = torch .zeros (
1592- (num_reqs_padded , 1 ),
1673+ (num_tokens_padded , 1 ),
15931674 dtype = torch .int32 ,
15941675 device = self .device ,
15951676 )
@@ -1601,12 +1682,26 @@ def _build_attention_metadata(
16011682 else :
16021683 blk_table = self .input_batch .block_table [kv_cache_gid ]
16031684 blk_table_tensor = blk_table .get_device_tensor (num_reqs_padded )
1604- slot_mapping = blk_table .slot_mapping .gpu [:num_tokens_padded ]
1685+ slot_mapping = blk_table .slot_mapping .gpu [:maybe_pcp_full_tokens ]
16051686
16061687 # Fill unused with -1. Needed for reshape_and_cache in full cuda
16071688 # graph mode. `blk_table_tensor` -1 to match mamba PAD_SLOT_ID
1608- slot_mapping [num_tokens :num_tokens_padded ].fill_ (- 1 )
1609- blk_table_tensor [num_reqs :num_reqs_padded ].fill_ (- 1 )
1689+ if self .pcp_world_size == 1 :
1690+ slot_mapping [num_tokens :num_tokens_padded ].fill_ (- 1 )
1691+ blk_table_tensor [num_reqs :num_reqs_padded ].fill_ (- 1 )
1692+
1693+ if self .pcp_world_size > 1 :
1694+ # After pcp allgather and restore, there are padded tokens in
1695+ # kv, so we need pad slotmapping for alignment.
1696+ pcp_padded_slot_mapping = self .pcp_manager .pcp_padded_slot_mapping [
1697+ : num_tokens * self .pcp_world_size
1698+ ]
1699+ cp_unpad_mask = self .pcp_manager .pcp_unpad_mask_cpu_tensor [
1700+ : num_tokens * self .pcp_world_size
1701+ ]
1702+ pcp_padded_slot_mapping .fill_ (- 1 )
1703+ pcp_padded_slot_mapping [cp_unpad_mask ] = slot_mapping
1704+ slot_mapping = pcp_padded_slot_mapping
16101705
16111706 common_attn_metadata = CommonAttentionMetadata (
16121707 query_start_loc = query_start_loc ,
@@ -1625,8 +1720,13 @@ def _build_attention_metadata(
16251720 causal = True ,
16261721 encoder_seq_lens = encoder_seq_lens ,
16271722 encoder_seq_lens_cpu = encoder_seq_lens_cpu ,
1628- dcp_local_seq_lens = dcp_local_seq_lens ,
1629- dcp_local_seq_lens_cpu = dcp_local_seq_lens_cpu ,
1723+ cp_local_seq_lens = cp_local_seq_lens ,
1724+ cp_local_seq_lens_cpu = cp_local_seq_lens_cpu ,
1725+ pcp_allgather_restore_idx = self .pcp_manager .pcp_allgather_restore_idx .gpu [
1726+ : num_tokens * self .pcp_world_size
1727+ ]
1728+ if self .pcp_world_size > 1
1729+ else None ,
16301730 )
16311731
16321732 if self .speculative_config and spec_decode_common_attn_metadata is None :
@@ -1690,16 +1790,6 @@ def _build_attention_metadata(
16901790 for layer_name in attn_group .layer_names :
16911791 attn_metadata [layer_name ] = attn_metadata_i
16921792
1693- if spec_decode_common_attn_metadata is not None and (
1694- num_reqs != num_reqs_padded or num_tokens != num_tokens_padded
1695- ):
1696- # Currently the drafter still only uses piecewise cudagraphs (and modifies
1697- # the attention metadata in directly), and therefore does not want to use
1698- # padded attention metadata.
1699- spec_decode_common_attn_metadata = (
1700- spec_decode_common_attn_metadata .unpadded (num_tokens , num_reqs )
1701- )
1702-
17031793 return attn_metadata , spec_decode_common_attn_metadata
17041794
17051795 def _compute_cascade_attn_prefix_lens (
@@ -2904,6 +2994,9 @@ def execute_model(
29042994 scheduler_output ,
29052995 num_scheduled_tokens_np ,
29062996 )
2997+ if self .pcp_world_size > 1 :
2998+ max_num_scheduled_tokens = int (num_scheduled_tokens_np .max ())
2999+ num_tokens_unpadded = scheduler_output .total_num_scheduled_tokens
29073000
29083001 cascade_attn_prefix_lens = None
29093002 # Disable cascade attention when using microbatching (DBO)
@@ -3011,6 +3104,23 @@ def execute_model(
30113104 hidden_states = model_output
30123105 aux_hidden_states = None
30133106
3107+ if self .pcp_world_size > 1 :
3108+ # NOTE we must `slice` hidden_states because pcp_allgather_restore_idx
3109+ # ignores the padding from CUDA Graph.
3110+ hidden_states = get_pcp_group ().all_gather (
3111+ hidden_states [:num_tokens_unpadded ],
3112+ 0 ,
3113+ )
3114+ restore_idx = self .pcp_manager .pcp_allgather_restore_idx .gpu [
3115+ : hidden_states .shape [0 ]
3116+ ]
3117+ hidden_states = torch .index_select (
3118+ hidden_states ,
3119+ 0 ,
3120+ restore_idx ,
3121+ )
3122+ # Restore total_num_scheduled_tokens.
3123+ scheduler_output .total_num_scheduled_tokens = num_scheduled_tokens
30143124 if not self .broadcast_pp_output :
30153125 # Common case.
30163126 if not get_pp_group ().is_last_rank :
@@ -5274,15 +5384,15 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
52745384 kv_transfer_group .register_kv_caches (kv_caches )
52755385 kv_transfer_group .set_host_xfer_buffer_ops (copy_kv_blocks )
52765386
5277- if self .dcp_world_size > 1 :
5387+ if self .cp_world_size > 1 :
52785388 layer_type = cast (type [Any ], AttentionLayerBase )
52795389 layers = get_layers_from_vllm_config (self .vllm_config , layer_type )
52805390 for layer in layers .values ():
52815391 layer_impl = getattr (layer , "impl" , None )
52825392 if layer_impl is None :
52835393 continue
52845394 assert layer_impl .need_to_return_lse_for_decode , (
5285- "DCP requires attention impls to return"
5395+ "PCP & DCP require attention impls to return"
52865396 " the softmax lse for decode, but the impl "
52875397 f"{ layer_impl .__class__ .__name__ } "
52885398 "does not return the softmax lse for decode."
0 commit comments