@@ -16,7 +16,8 @@ __global__ void merge_attn_states_kernel(
1616 scalar_t * output, float * output_lse, const scalar_t * prefix_output,
1717 const float * prefix_lse, const scalar_t * suffix_output,
1818 const float * suffix_lse, const uint num_tokens, const uint num_heads,
19- const uint head_size) {
19+ const uint head_size, const uint prefix_head_stride,
20+ const uint output_head_stride) {
2021 using pack_128b_t = uint4 ;
2122 const uint pack_size = 16 / sizeof (scalar_t );
2223 const uint threads_per_head = head_size / pack_size;
@@ -34,11 +35,13 @@ __global__ void merge_attn_states_kernel(
3435 const uint head_idx = token_head_idx % num_heads;
3536
3637 const uint pack_offset = pack_idx * pack_size; // (0~15)*8, etc.
37- const uint head_offset =
38- token_idx * num_heads * head_size + head_idx * head_size;
39- const scalar_t * prefix_head_ptr = prefix_output + head_offset;
40- const scalar_t * suffix_head_ptr = suffix_output + head_offset;
41- scalar_t * output_head_ptr = output + head_offset;
38+ const uint src_head_offset = token_idx * num_heads * prefix_head_stride +
39+ head_idx * prefix_head_stride;
40+ const uint dst_head_offset = token_idx * num_heads * output_head_stride +
41+ head_idx * output_head_stride;
42+ const scalar_t * prefix_head_ptr = prefix_output + src_head_offset;
43+ const scalar_t * suffix_head_ptr = suffix_output + src_head_offset;
44+ scalar_t * output_head_ptr = output + dst_head_offset;
4245
4346 float p_lse = prefix_lse[head_idx * num_tokens + token_idx];
4447 float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
@@ -140,7 +143,7 @@ __global__ void merge_attn_states_kernel(
140143 reinterpret_cast <float *>(prefix_lse.data_ptr ()), \
141144 reinterpret_cast <scalar_t *>(suffix_output.data_ptr ()), \
142145 reinterpret_cast <float *>(suffix_lse.data_ptr ()), num_tokens, \
143- num_heads, head_size); \
146+ num_heads, head_size, prefix_head_stride, output_head_stride); \
144147 }
145148
146149/* @brief Merges the attention states from prefix and suffix
@@ -166,17 +169,11 @@ void merge_attn_states_launcher(torch::Tensor& output,
166169 const uint num_tokens = output.size (0 );
167170 const uint num_heads = output.size (1 );
168171 const uint head_size = output.size (2 );
172+ const uint prefix_head_stride = prefix_output.stride (1 );
173+ const uint output_head_stride = output.stride (1 );
169174 const uint pack_size = 16 / sizeof (scalar_t );
170175 TORCH_CHECK (head_size % pack_size == 0 ,
171176 " headsize must be multiple of pack_size:" , pack_size);
172- TORCH_CHECK (output.stride (-2 ) == head_size && output.stride (-1 ) == 1 ,
173- " output heads must be contiguous in memory" );
174- TORCH_CHECK (
175- prefix_output.stride (-2 ) == head_size && prefix_output.stride (-1 ) == 1 ,
176- " prefix_output heads must be contiguous in memory" );
177- TORCH_CHECK (
178- suffix_output.stride (-2 ) == head_size && suffix_output.stride (-1 ) == 1 ,
179- " suffix_output heads must be contiguous in memory" );
180177 float * output_lse_ptr = nullptr ;
181178 if (output_lse.has_value ()) {
182179 output_lse_ptr = output_lse.value ().data_ptr <float >();
0 commit comments