Skip to content

Commit 92f0085

Browse files
author
Jingchun Gao
committed
fix num_q_head && add UT
Signed-off-by: Jingchun Gao <[email protected]>
1 parent d9d342d commit 92f0085

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

tests/distributed/test_context_parallel.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,10 @@ def _compare_cp_with_tp(
232232
"bigcode/gpt_bigcode-santacoder": [
233233
CPTestSettings.detailed(),
234234
CPTestSettings.detailed(tp_base=2),
235+
CPTestSettings.detailed(attn_backend="FLASHINFER"),
236+
CPTestSettings.detailed(
237+
attn_backend="FLASHINFER", cp_kv_cache_interleave_size=16
238+
),
235239
],
236240
}
237241

vllm/v1/attention/backends/flashinfer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -478,9 +478,8 @@ def __init__(
478478
self.dcp_rank = 0
479479
self.dcp_kv_cache_interleave_size = 1
480480

481-
self.num_qo_heads = (
482-
self.model_config.get_num_attention_heads(self.vllm_config.parallel_config)
483-
* self.dcp_world_size
481+
self.num_qo_heads = self.model_config.get_num_attention_heads(
482+
self.vllm_config.parallel_config
484483
)
485484

486485
self.num_kv_heads = self.kv_cache_spec.num_kv_heads

0 commit comments

Comments
 (0)