Skip to content

Commit c4d35e2

Browse files
committed
Fix cp variable naming & Support prefill context parallel with MLA backend.
Signed-off-by: FENP <[email protected]>
1 parent cffc19c commit c4d35e2

File tree

15 files changed

+439
-144
lines changed

15 files changed

+439
-144
lines changed

tests/distributed/test_context_parallel.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class ParallelSetup(NamedTuple):
3131
tp_size: int
3232
pp_size: int
3333
dcp_size: int
34+
pcp_size: int
3435
cp_kv_cache_interleave_size: int
3536
eager_mode: bool
3637
chunked_prefill: bool
@@ -55,6 +56,7 @@ def detailed(
5556
tp_base: int = 4,
5657
pp_base: int = 1,
5758
dcp_base: int = 1,
59+
pcp_base: int = 1,
5860
cp_kv_cache_interleave_size: int = 1,
5961
multi_node_only: bool = False,
6062
runner: RunnerOption = "auto",
@@ -70,7 +72,8 @@ def detailed(
7072
ParallelSetup(
7173
tp_size=tp_base,
7274
pp_size=pp_multiplier * pp_base,
73-
dcp_size=int(dcp_multiplier * tp_base),
75+
dcp_size=max(1, int(dcp_multiplier * tp_base)),
76+
pcp_size=pcp_base,
7477
cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
7578
eager_mode=eager_mode_val,
7679
chunked_prefill=chunked_prefill_val,
@@ -116,6 +119,7 @@ def _compare_cp_with_tp(
116119
tp_size,
117120
pp_size,
118121
dcp_size,
122+
pcp_size,
119123
cp_kv_cache_interleave_size,
120124
eager_mode,
121125
chunked_prefill,
@@ -196,7 +200,9 @@ def _compare_cp_with_tp(
196200
str(pp_size),
197201
"--decode-context-parallel-size",
198202
str(dcp_size),
199-
"--dcp-kv-cache-interleave-size",
203+
"--prefill-context-parallel-size",
204+
str(pcp_size),
205+
"--cp-kv-cache-interleave-size",
200206
str(cp_kv_cache_interleave_size),
201207
"--distributed-executor-backend",
202208
distributed_backend,
@@ -228,6 +234,8 @@ def _compare_cp_with_tp(
228234
CPTestSettings.detailed(),
229235
CPTestSettings.detailed(tp_base=2),
230236
CPTestSettings.detailed(tp_base=2, cp_kv_cache_interleave_size=64),
237+
CPTestSettings.detailed(tp_base=1, pcp_base=2),
238+
CPTestSettings.detailed(tp_base=1, pcp_base=2, cp_kv_cache_interleave_size=64),
231239
],
232240
"bigcode/gpt_bigcode-santacoder": [
233241
CPTestSettings.detailed(),

vllm/attention/backends/abstract.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -295,8 +295,8 @@ class AttentionImpl(ABC, Generic[T]):
295295
pcp_world_size: int
296296
pcp_rank: int
297297

298-
total_cp_world_size: int
299-
total_cp_rank: int
298+
cp_world_size: int
299+
cp_rank: int
300300

301301
def __new__(cls, *args, **kwargs):
302302
# use __new__ so that all subclasses will call this
@@ -318,11 +318,11 @@ def __new__(cls, *args, **kwargs):
318318
except AssertionError:
319319
self.pcp_world_size = 1
320320
self.pcp_rank = 0
321-
self.total_cp_world_size = self.pcp_world_size * self.dcp_world_size
322-
self.total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank
321+
self.cp_world_size = self.pcp_world_size * self.dcp_world_size
322+
self.cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank
323323

324324
self.need_to_return_lse_for_decode = (
325-
self.dcp_world_size > 1 and self.can_return_lse_for_decode
325+
self.cp_world_size > 1 and self.can_return_lse_for_decode
326326
)
327327
return self
328328

vllm/config/parallel.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -237,12 +237,12 @@ class is dynamically inherited by the worker class. This is used to inject
237237
"""
238238
cp_kv_cache_interleave_size: int = 1
239239
"""Interleave size of kv_cache storage while using DCP or PCP.
240-
For `total_cp_rank = pcp_rank * dcp_world_size + dcp_rank`,
241-
and `total_cp_world_size = pcp_world_size * dcp_world_size`.
242-
store interleave_size tokens on total_cp_rank i,
243-
then store next interleave_size tokens on total_cp_rank i+1.
240+
For `cp_rank = pcp_rank * dcp_world_size + dcp_rank`,
241+
and `cp_world_size = pcp_world_size * dcp_world_size`.
242+
store interleave_size tokens on cp_rank i,
243+
then store next interleave_size tokens on cp_rank i+1.
244244
Interleave_size=1: token-level alignment, where token `i` is stored on
245-
total_cp_rank `i % total_cp_world_size`.
245+
cp_rank `i % cp_world_size`.
246246
Interleave_size=block_size: block-level alignment, where tokens are
247247
first populated to the preceding ranks. Tokens are then stored
248248
in (rank i+1, block j) only after (rank i, block j) is fully occupied.
@@ -312,11 +312,6 @@ def _validate_parallel_config(self) -> Self:
312312
"num_redundant_experts."
313313
)
314314

315-
if self.prefill_context_parallel_size > 1:
316-
raise ValueError(
317-
"Prefill context parallelism is not fully supported. "
318-
"Please set prefill_context_parallel_size to 1."
319-
)
320315
return self
321316

322317
@property

vllm/distributed/parallel_state.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1085,9 +1085,6 @@ def get_dcp_group() -> GroupCoordinator:
10851085
return _DCP
10861086

10871087

1088-
# kept for backward compatibility
1089-
get_context_model_parallel_group = get_dcp_group
1090-
10911088
_PP: GroupCoordinator | None = None
10921089

10931090

vllm/engine/arg_utils.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1898,16 +1898,6 @@ def _set_default_chunked_prefill_and_prefix_caching_args(
18981898
default_chunked_prefill = model_config.is_chunked_prefill_supported
18991899
default_prefix_caching = model_config.is_prefix_caching_supported
19001900

1901-
if self.prefill_context_parallel_size > 1:
1902-
default_chunked_prefill = False
1903-
default_prefix_caching = False
1904-
logger.warning_once(
1905-
"--prefill-context-parallel-size > 1 is not compatible with "
1906-
"chunked prefill and prefix caching now. Chunked prefill "
1907-
"and prefix caching have been disabled by default.",
1908-
scope="local",
1909-
)
1910-
19111901
if self.enable_chunked_prefill is None:
19121902
self.enable_chunked_prefill = default_chunked_prefill
19131903

vllm/model_executor/layers/fused_moe/config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -894,6 +894,10 @@ def tp_size(self):
894894
def dp_size(self):
895895
return self.moe_parallel_config.dp_size
896896

897+
@property
898+
def pcp_size(self):
899+
return self.moe_parallel_config.pcp_size
900+
897901
@property
898902
def ep_size(self):
899903
return self.moe_parallel_config.ep_size
@@ -906,6 +910,10 @@ def tp_rank(self):
906910
def dp_rank(self):
907911
return self.moe_parallel_config.dp_rank
908912

913+
@property
914+
def pcp_rank(self):
915+
return self.moe_parallel_config.pcp_rank
916+
909917
@property
910918
def ep_rank(self):
911919
return self.moe_parallel_config.ep_rank

vllm/v1/attention/backends/flash_attn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
AttentionCGSupport,
4646
AttentionMetadataBuilder,
4747
CommonAttentionMetadata,
48-
get_dcp_local_seq_lens,
48+
get_cp_local_seq_lens,
4949
get_kv_cache_layout,
5050
)
5151
from vllm.v1.kv_cache_interface import AttentionSpec
@@ -401,7 +401,7 @@ def schedule(
401401
query_kv_lens = query_start_loc[1:] - query_start_loc[:-1]
402402
dcp_context_kv_lens = seq_lens - query_kv_lens
403403

404-
dcp_context_kv_lens = get_dcp_local_seq_lens(
404+
dcp_context_kv_lens = get_cp_local_seq_lens(
405405
dcp_context_kv_lens,
406406
self.dcp_world_size,
407407
self.dcp_rank,

vllm/v1/attention/backends/flashinfer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
AttentionMetadataBuilder,
5454
CommonAttentionMetadata,
5555
KVCacheLayoutType,
56-
get_dcp_local_seq_lens,
56+
get_cp_local_seq_lens,
5757
get_kv_cache_layout,
5858
get_per_layer_parameters,
5959
infer_global_hyperparameters,
@@ -694,7 +694,7 @@ def build(
694694
seq_lens_cpu[num_decodes:] - query_lens_prefill_cpu
695695
)
696696

697-
seq_lens_cpu = get_dcp_local_seq_lens(
697+
seq_lens_cpu = get_cp_local_seq_lens(
698698
seq_lens_cpu,
699699
self.dcp_world_size,
700700
self.dcp_rank,

0 commit comments

Comments
 (0)