Skip to content

Commit d9d342d

Browse files
authored
[Performance][MLA][ROCm] Remove redundant D2D copy in deepseek (#27457)
Signed-off-by: ganyi <[email protected]>
1 parent 53d7f1f commit d9d342d

File tree

5 files changed

+49
-41
lines changed

5 files changed

+49
-41
lines changed

csrc/attention/merge_attn_states.cu

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ __global__ void merge_attn_states_kernel(
1616
scalar_t* output, float* output_lse, const scalar_t* prefix_output,
1717
const float* prefix_lse, const scalar_t* suffix_output,
1818
const float* suffix_lse, const uint num_tokens, const uint num_heads,
19-
const uint head_size) {
19+
const uint head_size, const uint prefix_head_stride,
20+
const uint output_head_stride) {
2021
using pack_128b_t = uint4;
2122
const uint pack_size = 16 / sizeof(scalar_t);
2223
const uint threads_per_head = head_size / pack_size;
@@ -34,11 +35,13 @@ __global__ void merge_attn_states_kernel(
3435
const uint head_idx = token_head_idx % num_heads;
3536

3637
const uint pack_offset = pack_idx * pack_size; // (0~15)*8, etc.
37-
const uint head_offset =
38-
token_idx * num_heads * head_size + head_idx * head_size;
39-
const scalar_t* prefix_head_ptr = prefix_output + head_offset;
40-
const scalar_t* suffix_head_ptr = suffix_output + head_offset;
41-
scalar_t* output_head_ptr = output + head_offset;
38+
const uint src_head_offset = token_idx * num_heads * prefix_head_stride +
39+
head_idx * prefix_head_stride;
40+
const uint dst_head_offset = token_idx * num_heads * output_head_stride +
41+
head_idx * output_head_stride;
42+
const scalar_t* prefix_head_ptr = prefix_output + src_head_offset;
43+
const scalar_t* suffix_head_ptr = suffix_output + src_head_offset;
44+
scalar_t* output_head_ptr = output + dst_head_offset;
4245

4346
float p_lse = prefix_lse[head_idx * num_tokens + token_idx];
4447
float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
@@ -140,7 +143,7 @@ __global__ void merge_attn_states_kernel(
140143
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
141144
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
142145
reinterpret_cast<float*>(suffix_lse.data_ptr()), num_tokens, \
143-
num_heads, head_size); \
146+
num_heads, head_size, prefix_head_stride, output_head_stride); \
144147
}
145148

146149
/*@brief Merges the attention states from prefix and suffix
@@ -166,17 +169,11 @@ void merge_attn_states_launcher(torch::Tensor& output,
166169
const uint num_tokens = output.size(0);
167170
const uint num_heads = output.size(1);
168171
const uint head_size = output.size(2);
172+
const uint prefix_head_stride = prefix_output.stride(1);
173+
const uint output_head_stride = output.stride(1);
169174
const uint pack_size = 16 / sizeof(scalar_t);
170175
TORCH_CHECK(head_size % pack_size == 0,
171176
"headsize must be multiple of pack_size:", pack_size);
172-
TORCH_CHECK(output.stride(-2) == head_size && output.stride(-1) == 1,
173-
"output heads must be contiguous in memory");
174-
TORCH_CHECK(
175-
prefix_output.stride(-2) == head_size && prefix_output.stride(-1) == 1,
176-
"prefix_output heads must be contiguous in memory");
177-
TORCH_CHECK(
178-
suffix_output.stride(-2) == head_size && suffix_output.stride(-1) == 1,
179-
"suffix_output heads must be contiguous in memory");
180177
float* output_lse_ptr = nullptr;
181178
if (output_lse.has_value()) {
182179
output_lse_ptr = output_lse.value().data_ptr<float>();

csrc/ops.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,13 @@ void paged_attention_v2(
5252
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
5353
const int64_t blocksparse_head_sliding_step);
5454

55-
#ifndef USE_ROCM
5655
void merge_attn_states(torch::Tensor& output,
5756
std::optional<torch::Tensor> output_lse,
5857
const torch::Tensor& prefix_output,
5958
const torch::Tensor& prefix_lse,
6059
const torch::Tensor& suffix_output,
6160
const torch::Tensor& suffix_lse);
62-
61+
#ifndef USE_ROCM
6362
void convert_vertical_slash_indexes(
6463
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
6564
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]

csrc/torch_bindings.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
6363
" int blocksparse_head_sliding_step) -> ()");
6464
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
6565

66-
#ifndef USE_ROCM
6766
// Merge attn states
6867
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
6968
// can be used to combine partial attention results (in the split-KV case)
@@ -76,7 +75,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
7675
" Tensor suffix_output,"
7776
" Tensor suffix_lse) -> ()");
7877
ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states);
79-
78+
#ifndef USE_ROCM
8079
ops.def(
8180
"convert_vertical_slash_indexes("
8281
" Tensor! block_count, Tensor! block_offset, "

vllm/attention/ops/triton_merge_attn_states.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@ def merge_attn_states(
2020
num_query_heads = output.shape[1]
2121
head_size = output.shape[2]
2222
padded_head_size = triton.next_power_of_2(head_size)
23-
23+
# We assume the output stride on num_head is not always as same as the
24+
# `suffix_output` and `prefix_output`, as them might be padded by the attention
25+
# backend.
26+
prefix_head_stride = prefix_output.stride(1)
27+
output_head_stride = output.stride(1)
2428
# TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead.
2529
merge_attn_states_kernel[(num_tokens, num_query_heads)](
2630
output,
@@ -29,6 +33,8 @@ def merge_attn_states(
2933
prefix_lse,
3034
suffix_output,
3135
suffix_lse,
36+
prefix_head_stride,
37+
output_head_stride,
3238
head_size,
3339
padded_head_size,
3440
output_lse is not None,
@@ -43,6 +49,8 @@ def merge_attn_states_kernel(
4349
prefix_lse, # [NUM_HEADS, NUM_TOKENS]
4450
suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
4551
suffix_lse, # [NUM_HEADS, NUM_TOKENS]
52+
prefix_head_stride,
53+
output_head_stride,
4654
HEAD_SIZE: tl.constexpr,
4755
PADDED_HEAD_SIZE: tl.constexpr,
4856
OUTPUT_LSE: tl.constexpr,
@@ -79,15 +87,15 @@ def merge_attn_states_kernel(
7987
head_mask = head_arange < HEAD_SIZE
8088
p_out = tl.load(
8189
prefix_output
82-
+ token_idx * num_heads * HEAD_SIZE
83-
+ head_idx * HEAD_SIZE
90+
+ token_idx * num_heads * prefix_head_stride
91+
+ head_idx * prefix_head_stride
8492
+ head_arange,
8593
mask=head_mask,
8694
)
8795
s_out = tl.load(
8896
suffix_output
89-
+ token_idx * num_heads * HEAD_SIZE
90-
+ head_idx * HEAD_SIZE
97+
+ token_idx * num_heads * prefix_head_stride
98+
+ head_idx * prefix_head_stride
9199
+ head_arange,
92100
mask=head_mask,
93101
)
@@ -99,7 +107,10 @@ def merge_attn_states_kernel(
99107
s_scale = s_se / out_se
100108
out = p_out * p_scale + s_out * s_scale
101109
tl.store(
102-
output + token_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + head_arange,
110+
output
111+
+ token_idx * num_heads * output_head_stride
112+
+ head_idx * output_head_stride
113+
+ head_arange,
103114
out,
104115
mask=head_mask,
105116
)

vllm/v1/attention/backends/mla/common.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1238,15 +1238,13 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
12381238
def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
12391239
# Convert from (B, N, L) to (N, B, L)
12401240
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
1241+
12411242
if self.is_aiter_triton_fp8_bmm_enabled:
1243+
out = out.view(-1, self.num_heads, self.v_head_dim)
12421244
# Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
12431245
x = rocm_aiter_ops.triton_fp8_bmm(
1244-
x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True
1246+
x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True, YQ=out
12451247
)
1246-
# Convert from (B, N, V) to (B, N * V)
1247-
x = x.reshape(-1, self.num_heads * self.v_head_dim)
1248-
# Copy result
1249-
out.copy_(x)
12501248
else:
12511249
# Convert from (B, N * V) to (N, B, V)
12521250
out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1)
@@ -1824,7 +1822,8 @@ def _forward_prefill(
18241822
kv_c_and_k_pe_cache: torch.Tensor,
18251823
attn_metadata: MLACommonMetadata,
18261824
k_scale: torch.Tensor,
1827-
) -> torch.Tensor:
1825+
output: torch.Tensor,
1826+
) -> None:
18281827
# TODO (zyongye): Prefill function here
18291828
assert attn_metadata.prefill is not None
18301829
assert self.dcp_world_size is not None
@@ -1837,7 +1836,7 @@ def _forward_prefill(
18371836

18381837
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
18391838

1840-
output = self._run_prefill_new_tokens(
1839+
output_prefill = self._run_prefill_new_tokens(
18411840
prefill=attn_metadata.prefill,
18421841
q=q,
18431842
k=k,
@@ -1846,7 +1845,7 @@ def _forward_prefill(
18461845
)
18471846

18481847
if has_context:
1849-
suffix_output, suffix_lse = output
1848+
suffix_output, suffix_lse = output_prefill
18501849
if self.dcp_world_size > 1:
18511850
context_output, context_lse = (
18521851
self._context_parallel_compute_prefill_context(
@@ -1862,20 +1861,22 @@ def _forward_prefill(
18621861
q, kv_c_and_k_pe_cache, attn_metadata, k_scale
18631862
)
18641863

1865-
output = torch.empty_like(suffix_output)
1864+
# unpad if necessary
1865+
if self._pad_v:
1866+
context_output = context_output[..., : v.shape[-1]]
1867+
suffix_output = suffix_output[..., : v.shape[-1]]
1868+
1869+
output = output.view(-1, self.num_heads, self.v_head_dim)
18661870
merge_attn_states(
18671871
output=output,
18681872
prefix_output=context_output,
18691873
prefix_lse=context_lse,
18701874
suffix_output=suffix_output,
18711875
suffix_lse=suffix_lse,
18721876
)
1873-
1874-
# unpad if necessary
1875-
if self._pad_v:
1876-
output = output[..., : v.shape[-1]]
1877-
1878-
return output.flatten(start_dim=-2)
1877+
else:
1878+
output_prefill = output_prefill[..., : v.shape[-1]].flatten(start_dim=-2)
1879+
output.copy_(output_prefill)
18791880

18801881
@abstractmethod
18811882
def _forward_decode(
@@ -1970,13 +1971,14 @@ def forward(
19701971
kv_cache = kv_cache.view(current_platform.fp8_dtype())
19711972

19721973
if has_prefill:
1973-
output[num_decode_tokens:] = self._forward_prefill(
1974+
self._forward_prefill(
19741975
prefill_q,
19751976
prefill_k_c_normed,
19761977
prefill_k_pe,
19771978
kv_cache,
19781979
attn_metadata,
19791980
layer._k_scale,
1981+
output=output[num_decode_tokens:],
19801982
)
19811983

19821984
if has_decode:

0 commit comments

Comments
 (0)