Skip to content

Commit e811cb5

Browse files
committed
fix some comments
Signed-off-by: liumengge1205 <[email protected]>
1 parent af52c00 commit e811cb5

File tree

2 files changed

+17
-34
lines changed

2 files changed

+17
-34
lines changed

vllm/model_executor/layers/fla/ops/fused_recurrent.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
4141
scale,
4242
N: tl.int64, # num of sequences
4343
T: tl.int64, # num of tokens
44-
NP2_T: tl.constexpr,
4544
B: tl.constexpr,
4645
H: tl.constexpr,
4746
HV: tl.constexpr,
@@ -122,23 +121,15 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
122121
p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
123122
b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
124123

125-
if IS_EAGLE_TREE:
126-
token_indices = tl.arange(0, NP2_T)
127-
mask_retrieve = token_indices < T
128-
retrieve_parent_token_base = (
129-
retrieve_parent_token
130-
+ (i_n * stride_indices_seq)
131-
+ token_indices * stride_indices_tok
132-
)
133-
parent_idx_tokens = tl.load(retrieve_parent_token_base, mask_retrieve)
134-
135124
for i_t in range(0, T):
136125
# i_t = 0 should use the b_h from USE_INITIAL_STATE
137126
if IS_EAGLE_TREE: # noqa: SIM102
138127
if i_t != 0:
139128
# when calculating current step's attention, load the state from the parent token
140-
parent_step_idx = tl.sum(
141-
tl.where(token_indices == i_t, parent_idx_tokens, 0)
129+
parent_step_idx = tl.load(
130+
retrieve_parent_token
131+
+ (i_n * stride_indices_seq)
132+
+ i_t * stride_indices_tok
142133
)
143134
p_h0 = (
144135
ht
@@ -242,7 +233,13 @@ def fused_recurrent_gated_delta_rule_fwd(
242233
stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1
243234
else:
244235
stride_indices_seq, stride_indices_tok = ssm_state_indices.stride()
245-
NP2_T = triton.next_power_of_2(stride_indices_seq)
236+
237+
if retrieve_parent_token is not None:
238+
assert retrieve_parent_token.stride() == (
239+
stride_indices_seq,
240+
stride_indices_tok,
241+
), "retrieve_parent_token and ssm_state_indices must have the same stride"
242+
246243
grid = (NK, NV, N * HV)
247244
fused_recurrent_gated_delta_rule_fwd_kernel[grid](
248245
q=q,
@@ -260,7 +257,6 @@ def fused_recurrent_gated_delta_rule_fwd(
260257
scale=scale,
261258
N=N,
262259
T=T,
263-
NP2_T=NP2_T,
264260
B=B,
265261
H=H,
266262
HV=HV,

vllm/model_executor/layers/mamba/ops/causal_conv1d.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -790,7 +790,6 @@ def _causal_conv1d_update_kernel(
790790
IS_APC_ENABLED: tl.constexpr,
791791
IS_SPEC_DECODING: tl.constexpr,
792792
NP2_STATELEN: tl.constexpr,
793-
NP2_SEQLEN: tl.constexpr,
794793
USE_PAD_SLOT: tl.constexpr,
795794
BLOCK_N: tl.constexpr,
796795
IS_EAGLE_TREE: tl.constexpr,
@@ -973,16 +972,6 @@ def _causal_conv1d_update_kernel(
973972
x_base_1d = x_base # starting of chunk [BLOCK_N]
974973
mask_x_1d = idx_feats < dim
975974

976-
if IS_EAGLE_TREE:
977-
token_indices = tl.arange(0, NP2_SEQLEN)
978-
mask_retrieve = token_indices < seqlen
979-
retrieve_parent_token_base = (
980-
retrieve_parent_token_ptr
981-
+ (idx_seq * stride_retrieve_parent_token_seq)
982-
+ token_indices * stride_retrieve_parent_token_token
983-
)
984-
parent_idx_tokens = tl.load(retrieve_parent_token_base, mask_retrieve)
985-
986975
# STEP 5: compute each token
987976
for idx_token in tl.range(seqlen):
988977
acc = acc_preload
@@ -995,7 +984,6 @@ def _causal_conv1d_update_kernel(
995984
for j in tl.static_range(KERNEL_WIDTH):
996985
if KERNEL_WIDTH == 2:
997986
matrix_w = w_col1 if j == 0 else w_col0
998-
999987
elif KERNEL_WIDTH == 3:
1000988
if j == 0:
1001989
matrix_w = w_col2
@@ -1017,11 +1005,12 @@ def _causal_conv1d_update_kernel(
10171005

10181006
# move to parent for next iteration
10191007
if _idx_token > 0:
1020-
_idx_token = tl.sum(
1021-
tl.where(idx_tokens == _idx_token, parent_idx_tokens, 0).to(
1022-
tl.int64
1023-
)
1024-
)
1008+
_idx_token = tl.load(
1009+
retrieve_parent_token_ptr
1010+
+ idx_seq * stride_retrieve_parent_token_seq
1011+
+ _idx_token * stride_retrieve_parent_token_token,
1012+
mask=_idx_token < seqlen,
1013+
).to(tl.int64)
10251014
x_ptrs_1d = x_base_1d + _idx_token * stride_x_token # [BLOCK_N]
10261015
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
10271016
else:
@@ -1253,7 +1242,6 @@ def causal_conv1d_update(
12531242
else:
12541243
state_len = width - 1
12551244
np2_statelen = triton.next_power_of_2(state_len)
1256-
np2_seqlen = triton.next_power_of_2(seqlen)
12571245

12581246
# prepare retrieve_parent_token buffer strides if provided
12591247
if retrieve_parent_token is not None:
@@ -1314,7 +1302,6 @@ def grid(META):
13141302
IS_APC_ENABLED=block_idx_last_scheduled_token is not None,
13151303
IS_SPEC_DECODING=num_accepted_tokens is not None,
13161304
NP2_STATELEN=np2_statelen,
1317-
NP2_SEQLEN=np2_seqlen,
13181305
USE_PAD_SLOT=pad_slot_id is not None,
13191306
BLOCK_N=256,
13201307
IS_EAGLE_TREE=retrieve_parent_token is not None,

0 commit comments

Comments
 (0)