diff --git a/tests/v1/spec_decode/test_causal_conv1d_with_eagle_tree.py b/tests/v1/spec_decode/test_causal_conv1d_with_eagle_tree.py new file mode 100644 index 000000000000..df942feefded --- /dev/null +++ b/tests/v1/spec_decode/test_causal_conv1d_with_eagle_tree.py @@ -0,0 +1,212 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +import pytest +import torch +import torch.nn.functional as F + +from vllm.model_executor.layers.mamba.ops.causal_conv1d import causal_conv1d_update +from vllm.platforms import current_platform + + +def causal_conv1d_update_ref( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + activation: str | None, + conv_state_indices: torch.Tensor, + num_accepted_tokens: torch.Tensor, + query_start_loc: torch.Tensor, + retrieve_parent_token: torch.Tensor | None, +): + """ + x: (dim, seqlen) + conv_state: (chunk_size, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + out: (dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + dim, width = weight.shape + state_len = width - 1 + output = torch.zeros_like(x) + conv_state_ori = conv_state.clone() + num_reqs = query_start_loc.shape[0] - 1 + for j in range(num_reqs): + # update conv_state + con_seq_len = query_start_loc[j + 1] - query_start_loc[j] + x_new = torch.cat( + [ + conv_state[ + conv_state_indices[j], :, : num_accepted_tokens[j] + state_len - 1 + ], + x[:, query_start_loc[j] : query_start_loc[j + 1]], + ], + dim=-1, + ).to(weight.dtype) + update_state_len = state_len + con_seq_len - 1 + conv_state[conv_state_indices[j], :, :update_state_len].copy_( + x_new[:, num_accepted_tokens[j] :] + ) + # update output + for i in range(con_seq_len): + con_x = x[:, query_start_loc[j] + i : query_start_loc[j] + i + 1] + con_index = i + if retrieve_parent_token is not None: + while retrieve_parent_token[j, con_index] != -1: + con_index = retrieve_parent_token[j, con_index] + con_x = torch.cat( + [ + x[ + :, + query_start_loc[j] + con_index : query_start_loc[j] + + con_index + + 1, + ], + con_x, + ], + dim=-1, + ) + else: + con_x = x[:, query_start_loc[j] : query_start_loc[j] + i + 1] + con_x = torch.cat( + [ + conv_state_ori[ + conv_state_indices[j], + :, + : num_accepted_tokens[j] + state_len - 1, + ], + con_x, + ], + dim=-1, + ).to(weight.dtype) + + con_x = con_x[:, -width:] + + if unsqueeze: + con_x = con_x.unsqueeze(0) + out = F.conv1d(con_x, weight.unsqueeze(1), bias, padding=0, groups=dim)[ + :, :, -1: + ] + if unsqueeze: + out = out.squeeze(0) + output[:, query_start_loc[j] + i : query_start_loc[j] + i + 1] = out + + return (output if activation is None else F.silu(output)).to(dtype=dtype_in) + + +@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("silu_activation", [False, True]) +@pytest.mark.parametrize("has_bias", [False, True]) +@pytest.mark.parametrize("is_eagle_tree", [False, True]) +def test_causal_conv1d_update( + has_bias: bool, silu_activation: bool, itype: torch.dtype, is_eagle_tree: bool +): + device = "cuda" + # set seed + current_platform.seed_everything(0) + rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (1e-2, 5e-2) + # num_reqs = 4, max_spec_len = 5 + batch_size = 4 + num_speculative_tokens = 5 + chunk_size = 64 + dim = 2048 + width = 4 + # shape is [batch_size, max_spec_len + 1] + retrieve_parent_token = torch.tensor( + [ + # Tree1: + # 0 + # / \ + # 1 2 + # / + # 3 + [-1, 0, 0, 1, -1, -1], + # Tree2: + # 0 + # / + # 1 + # / + # 2 + [-1, 0, 1, -1, -1, -1], + # Tree3: + # 0 + # / \ + # 1 2 + # / \ + # 3 4 + [-1, 0, 0, 1, 1, -1], + # Tree4: + # 0 + # / \ + # 1 2 + # / \ / + # 3 4 5 + [-1, 0, 0, 1, 1, 2], + ], + device="cuda", + dtype=torch.int32, + ) + spec_query_start_loc = torch.tensor( + [0, 4, 7, 12, 18], + device="cuda", + dtype=torch.int32, + ) + spec_state_indices_tensor = torch.tensor( + [ + [1, 2, 3, 4, 5, 6], + [7, 8, 9, 10, 11, 12], + [13, 14, 15, 16, 17, 18], + [19, 20, 21, 22, 23, 24], + ], + device=device, + dtype=torch.int32, + ) + num_accepted_tokens = torch.tensor( + [1, 2, 1, 2], + device="cuda", + dtype=torch.int32, + ) + seqlen = spec_query_start_loc[-1].item() + x = torch.rand(seqlen, dim, device=device, dtype=itype) + x_ref = x.clone().transpose(0, 1) + + conv_state = torch.randn( + chunk_size, dim, width + num_speculative_tokens - 1, device=device, dtype=itype + ) + weight = torch.randn(dim, width, device=device, dtype=itype) + bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None + conv_state_ref = conv_state.detach().clone() + activation = None if not silu_activation else "silu" + + out = causal_conv1d_update( + x, + conv_state, + weight, + bias, + activation=activation, + conv_state_indices=spec_state_indices_tensor[:, 0][:batch_size], + num_accepted_tokens=num_accepted_tokens, + query_start_loc=spec_query_start_loc, + max_query_len=spec_state_indices_tensor.size(-1), + retrieve_parent_token=retrieve_parent_token if is_eagle_tree else None, + validate_data=False, + ) + out_ref = causal_conv1d_update_ref( + x_ref, + conv_state_ref, + weight, + bias, + activation=activation, + conv_state_indices=spec_state_indices_tensor[:, 0][:batch_size], + num_accepted_tokens=num_accepted_tokens, + query_start_loc=spec_query_start_loc, + retrieve_parent_token=retrieve_parent_token if is_eagle_tree else None, + ).transpose(0, 1) + assert torch.equal(conv_state, conv_state_ref) + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) diff --git a/tests/v1/spec_decode/test_recurrent_gated_delta_with_eagle_tree.py b/tests/v1/spec_decode/test_recurrent_gated_delta_with_eagle_tree.py new file mode 100644 index 000000000000..2a83437b92b9 --- /dev/null +++ b/tests/v1/spec_decode/test_recurrent_gated_delta_with_eagle_tree.py @@ -0,0 +1,164 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch +import torch.nn.functional as F +from einops import repeat + +from vllm.model_executor.layers.fla.ops import fused_recurrent_gated_delta_rule + + +def recurrent_gated_delta_rule_ref( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + g: torch.Tensor, + cu_seq_lens: torch.Tensor, + ssm_state_indices: torch.Tensor, + num_accepted_tokens: torch.Tensor, + retrieve_parent_token: torch.Tensor = None, + initial_state: torch.Tensor = None, + scale: float | None = None, +): + o = torch.zeros(*v.shape).to(v) + q, k, v, beta, g = map( + lambda x: x.transpose(1, 2).contiguous().to(torch.float32), [q, k, v, beta, g] + ) + + if scale is None: + scale = 1 / (q.shape[-1] ** 0.5) + q = q * scale + + num_reqs = cu_seq_lens.shape[0] - 1 + for j in range(num_reqs): + h = initial_state[ssm_state_indices[j][num_accepted_tokens[j] - 1]] + T = cu_seq_lens[j + 1] - cu_seq_lens[j] + for i in range(T): + if retrieve_parent_token is not None and i != 0: + h = initial_state[ssm_state_indices[j][retrieve_parent_token[j][i]]] + b_q = q[:, :, cu_seq_lens[j] + i] + b_k = k[:, :, cu_seq_lens[j] + i] + b_v = v[:, :, cu_seq_lens[j] + i].clone() + h = h.clone() * g[:, :, cu_seq_lens[j] + i].exp()[..., None, None] + b_beta = beta[:, :, cu_seq_lens[j] + i] + b_v = b_v - (h.clone() * b_k[..., None]).sum(-2) + b_v = b_v * b_beta[..., None] + h = h.clone() + b_k.unsqueeze(-1) * b_v.unsqueeze(-2) + o[:, cu_seq_lens[j] + i, :] = torch.einsum("bhd,bhdm->bhm", b_q, h) + initial_state[ssm_state_indices[j][i]] = h + return o, initial_state + + +@pytest.mark.parametrize("has_eagle_tree_state", [False, True]) +def test_fused_recurrent(has_eagle_tree_state: bool): + torch.manual_seed(42) + H = 4 + K = 128 + V = 128 + HV = 8 + + # shape is [batch_size, max_spec_len + 1] + retrieve_parent_token = torch.tensor( + [ + # Tree1: + # 0 + # / \ + # 1 2 + # / + # 3 + [-1, 0, 0, 1, -1, -1], + # Tree2: + # 0 + # / + # 1 + # / + # 2 + [-1, 0, 1, -1, -1, -1], + # Tree3: + # 0 + # / \ + # 1 2 + # / \ + # 3 4 + [-1, 0, 0, 1, 1, -1], + # Tree4: + # 0 + # / \ + # 1 2 + # / \ / + # 3 4 5 + [-1, 0, 0, 1, 1, 2], + ], + device="cuda", + dtype=torch.int32, + ) + # num_reqs = 4, max_spec_len = 5 + cu_seq_lens = torch.tensor( + [0, 4, 7, 12, 18], + device="cuda", + dtype=torch.int32, + ) + spec_state_indices_tensor = torch.tensor( + [ + [1, 2, 3, 4, 5, 6], + [7, 8, 9, 10, 11, 12], + [13, 14, 15, 16, 17, 18], + [19, 20, 21, 22, 23, 24], + ], + device="cuda", + dtype=torch.int32, + ) + num_accepted_tokens = torch.tensor( + [2, 1, 1, 2], + device="cuda", + dtype=torch.int32, + ) + # for variable-length inputs, + # the batch size `B` is expected to be 1 and `cu_seqlens` is required + B = 1 + T = cu_seq_lens.max() + chunk_size = 64 + q = torch.randn(B, T, H, K, dtype=torch.float16) + k = torch.randn(B, T, H, K, dtype=torch.float16) + v = torch.randn(B, T, HV, V, dtype=torch.float16) + beta = torch.randn(B, T, HV, dtype=torch.float16).sigmoid() + g = F.logsigmoid(torch.rand(B, T, HV, dtype=torch.float32)) + h0 = torch.randn(chunk_size, HV, K, V, dtype=torch.float32) + q, k, v, beta, g, h0 = map( + lambda x: x.to("cuda").requires_grad_(), (q, k, v, beta, g, h0) + ) + + ref, ref_ht = recurrent_gated_delta_rule_ref( + q=F.normalize( + repeat(q.clone(), "b t h d -> b t (h g) d", g=HV // H), p=2, dim=-1 + ), + k=F.normalize( + repeat(k.clone(), "b t h d -> b t (h g) d", g=HV // H), p=2, dim=-1 + ), + v=v.clone(), + beta=beta.clone(), + g=g.clone(), + cu_seq_lens=cu_seq_lens, + ssm_state_indices=spec_state_indices_tensor, + num_accepted_tokens=num_accepted_tokens, + retrieve_parent_token=retrieve_parent_token if has_eagle_tree_state else None, + initial_state=h0.clone(), + ) + tri, tri_ht = fused_recurrent_gated_delta_rule( + q=q.clone(), + k=k.clone(), + v=v.clone(), + beta=beta.clone(), + g=g.clone(), + initial_state=h0.clone(), + use_qk_l2norm_in_kernel=True, + cu_seqlens=cu_seq_lens, + ssm_state_indices=spec_state_indices_tensor, + num_accepted_tokens=num_accepted_tokens, + retrieve_parent_token=retrieve_parent_token if has_eagle_tree_state else None, + inplace_final_state=True, + ) + + assert torch.allclose(ref, tri, atol=1e-3, rtol=1e-4) + assert torch.allclose(ref_ht, tri_ht, atol=1e-3, rtol=1e-4) diff --git a/vllm/model_executor/layers/fla/ops/fused_recurrent.py b/vllm/model_executor/layers/fla/ops/fused_recurrent.py index 0f27504780ac..99c67e1df93e 100644 --- a/vllm/model_executor/layers/fla/ops/fused_recurrent.py +++ b/vllm/model_executor/layers/fla/ops/fused_recurrent.py @@ -21,6 +21,7 @@ "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, "IS_CONTINUOUS_BATCHING": lambda args: args["ssm_state_indices"] is not None, "IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None, + "IS_EAGLE_TREE": lambda args: args["retrieve_parent_token"] is not None, } ) @triton.jit(do_not_specialize=["N", "T"]) @@ -36,6 +37,7 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( cu_seqlens, ssm_state_indices, num_accepted_tokens, + retrieve_parent_token, scale, N: tl.int64, # num of sequences T: tl.int64, # num of tokens @@ -58,6 +60,7 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( IS_CONTINUOUS_BATCHING: tl.constexpr, IS_SPEC_DECODING: tl.constexpr, IS_KDA: tl.constexpr, + IS_EAGLE_TREE: tl.constexpr, ): i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_n, i_hv = i_nh // HV, i_nh % HV @@ -119,6 +122,25 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) for i_t in range(0, T): + # i_t = 0 should use the b_h from USE_INITIAL_STATE + if IS_EAGLE_TREE: # noqa: SIM102 + if i_t != 0: + # when calculating current step's attention, load the state from the parent token + parent_step_idx = tl.load( + retrieve_parent_token + + (i_n * stride_indices_seq) + + i_t * stride_indices_tok + ) + p_h0 = ( + ht + + tl.load( + ssm_state_indices + i_n * stride_indices_seq + parent_step_idx + ).to(tl.int64) + * stride_init_state_token + ) + + p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :] + b_h = tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) @@ -184,6 +206,7 @@ def fused_recurrent_gated_delta_rule_fwd( cu_seqlens: torch.LongTensor | None = None, ssm_state_indices: torch.Tensor | None = None, num_accepted_tokens: torch.Tensor | None = None, + retrieve_parent_token: torch.Tensor | None = None, use_qk_l2norm_in_kernel: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: B, T, H, K, V = *k.shape, v.shape[-1] @@ -211,6 +234,12 @@ def fused_recurrent_gated_delta_rule_fwd( else: stride_indices_seq, stride_indices_tok = ssm_state_indices.stride() + if retrieve_parent_token is not None: + assert retrieve_parent_token.stride() == ( + stride_indices_seq, + stride_indices_tok, + ), "retrieve_parent_token and ssm_state_indices must have the same stride" + grid = (NK, NV, N * HV) fused_recurrent_gated_delta_rule_fwd_kernel[grid]( q=q, @@ -224,6 +253,7 @@ def fused_recurrent_gated_delta_rule_fwd( cu_seqlens=cu_seqlens, ssm_state_indices=ssm_state_indices, num_accepted_tokens=num_accepted_tokens, + retrieve_parent_token=retrieve_parent_token, scale=scale, N=N, T=T, @@ -264,6 +294,7 @@ def forward( cu_seqlens: torch.LongTensor | None = None, ssm_state_indices: torch.Tensor | None = None, num_accepted_tokens: torch.Tensor | None = None, + retrieve_parent_token: torch.Tensor | None = None, use_qk_l2norm_in_kernel: bool = False, ): o, final_state = fused_recurrent_gated_delta_rule_fwd( @@ -278,6 +309,7 @@ def forward( cu_seqlens=cu_seqlens, ssm_state_indices=ssm_state_indices, num_accepted_tokens=num_accepted_tokens, + retrieve_parent_token=retrieve_parent_token, use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, ) @@ -296,6 +328,7 @@ def fused_recurrent_gated_delta_rule( cu_seqlens: torch.LongTensor | None = None, ssm_state_indices: torch.Tensor | None = None, num_accepted_tokens: torch.Tensor | None = None, + retrieve_parent_token: torch.Tensor | None = None, use_qk_l2norm_in_kernel: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: r""" @@ -385,6 +418,7 @@ def fused_recurrent_gated_delta_rule( cu_seqlens, ssm_state_indices, num_accepted_tokens, + retrieve_parent_token, use_qk_l2norm_in_kernel, ) return o, final_state diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index 83c2c5f11e18..ac74557cefaa 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -757,6 +757,7 @@ def _causal_conv1d_update_kernel( query_start_loc_ptr, # (batch + 1) block_idx_last_scheduled_token, # (batch,) initial_state_idx, # (batch,) + retrieve_parent_token_ptr, # (batch, max_spec_len + 1) o_ptr, # (batch, dim, seqlen) # Matrix dimensions batch: int, @@ -774,6 +775,8 @@ def _causal_conv1d_update_kernel( stride_conv_state_dim: tl.constexpr, stride_conv_state_tok: tl.constexpr, stride_state_indices: tl.constexpr, + stride_retrieve_parent_token_seq: tl.constexpr, + stride_retrieve_parent_token_token: tl.constexpr, stride_o_seq: tl.constexpr, stride_o_dim: tl.constexpr, stride_o_token: tl.constexpr, @@ -789,6 +792,7 @@ def _causal_conv1d_update_kernel( NP2_STATELEN: tl.constexpr, USE_PAD_SLOT: tl.constexpr, BLOCK_N: tl.constexpr, + IS_EAGLE_TREE: tl.constexpr, ): # ruff: noqa: E501 idx_seq = tl.program_id(0) @@ -972,87 +976,140 @@ def _causal_conv1d_update_kernel( for idx_token in tl.range(seqlen): acc = acc_preload - matrix_w = w_col0 - matrix_x = col0 - for j in tl.static_range(KERNEL_WIDTH): - if KERNEL_WIDTH == 2: - if j == 1: # KERNEL_WIDTH-1: - matrix_w = w_col1 - x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + if IS_EAGLE_TREE: + _idx_token = idx_token + x_ptrs_1d = x_base_1d + _idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + # convolution operation: itself * wcol[-1] + parent * wcol[-2] + grand-parent * wcol[-3] + ... + for j in tl.static_range(KERNEL_WIDTH): + if KERNEL_WIDTH == 2: + matrix_w = w_col1 if j == 0 else w_col0 + elif KERNEL_WIDTH == 3: + if j == 0: + matrix_w = w_col2 + elif j == 1: + matrix_w = w_col1 + else: + matrix_w = w_col0 + elif KERNEL_WIDTH == 4: + if j == 0: + matrix_w = w_col3 + elif j == 1: + matrix_w = w_col2 + elif j == 2: + matrix_w = w_col1 + else: + matrix_w = w_col0 + + acc += matrix_x * matrix_w + + # move to parent for next iteration + if _idx_token > 0: + _idx_token = tl.load( + retrieve_parent_token_ptr + + idx_seq * stride_retrieve_parent_token_seq + + _idx_token * stride_retrieve_parent_token_token, + mask=_idx_token < seqlen, + ).to(tl.int64) + x_ptrs_1d = x_base_1d + _idx_token * stride_x_token # [BLOCK_N] matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + else: + # no parent within the current chunk, load from prev conv state: col[-1] (idx 0's parent), col[-2] (idx 0's grand parent), ... + if KERNEL_WIDTH == 2: + if _idx_token == 0: + matrix_x = col0 + elif KERNEL_WIDTH == 3: + matrix_x = col1 if _idx_token == 0 else col0 + elif KERNEL_WIDTH == 4: + if _idx_token == 0: + matrix_x = col2 + elif _idx_token == -1: + matrix_x = col1 + else: + matrix_x = col0 + _idx_token = _idx_token - 1 + else: + matrix_w = w_col0 + matrix_x = col0 + for j in tl.static_range(KERNEL_WIDTH): + if KERNEL_WIDTH == 2: + if j == 1: # KERNEL_WIDTH-1: + matrix_w = w_col1 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 3: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 4: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + matrix_x = col2 + elif j == 3: + matrix_w = w_col3 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 5: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + matrix_x = col2 + elif j == 3: + matrix_w = w_col3 + matrix_x = col3 + elif j == 4: + matrix_w = w_col4 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 6: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + matrix_x = col2 + elif j == 3: + matrix_w = w_col3 + matrix_x = col3 + elif j == 4: + matrix_w = w_col4 + matrix_x = col4 + elif j == 5: + matrix_w = w_col5 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + + acc += matrix_x * matrix_w # [BLOCK_N] + + if KERNEL_WIDTH == 2: + col0 = matrix_x elif KERNEL_WIDTH == 3: - if j == 1: - matrix_w = w_col1 - matrix_x = col1 - elif j == 2: - matrix_w = w_col2 - x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] - matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + col0 = col1 + col1 = matrix_x elif KERNEL_WIDTH == 4: - if j == 1: - matrix_w = w_col1 - matrix_x = col1 - elif j == 2: - matrix_w = w_col2 - matrix_x = col2 - elif j == 3: - matrix_w = w_col3 - x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] - matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + col0 = col1 + col1 = col2 + col2 = matrix_x elif KERNEL_WIDTH == 5: - if j == 1: - matrix_w = w_col1 - matrix_x = col1 - elif j == 2: - matrix_w = w_col2 - matrix_x = col2 - elif j == 3: - matrix_w = w_col3 - matrix_x = col3 - elif j == 4: - matrix_w = w_col4 - x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] - matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + col0 = col1 + col1 = col2 + col2 = col3 + col3 = matrix_x elif KERNEL_WIDTH == 6: - if j == 1: - matrix_w = w_col1 - matrix_x = col1 - elif j == 2: - matrix_w = w_col2 - matrix_x = col2 - elif j == 3: - matrix_w = w_col3 - matrix_x = col3 - elif j == 4: - matrix_w = w_col4 - matrix_x = col4 - elif j == 5: - matrix_w = w_col5 - x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] - matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) - - acc += matrix_x * matrix_w # [BLOCK_N] - - if KERNEL_WIDTH == 2: - col0 = matrix_x - elif KERNEL_WIDTH == 3: - col0 = col1 - col1 = matrix_x - elif KERNEL_WIDTH == 4: - col0 = col1 - col1 = col2 - col2 = matrix_x - elif KERNEL_WIDTH == 5: - col0 = col1 - col1 = col2 - col2 = col3 - col3 = matrix_x - elif KERNEL_WIDTH == 6: - col0 = col1 - col1 = col2 - col2 = col3 - col3 = col4 - col4 = matrix_x + col0 = col1 + col1 = col2 + col2 = col3 + col3 = col4 + col4 = matrix_x if SILU_ACTIVATION: acc = acc / (1 + tl.exp(-acc)) @@ -1079,6 +1136,7 @@ def causal_conv1d_update( pad_slot_id: int = PAD_SLOT_ID, block_idx_last_scheduled_token: torch.Tensor | None = None, initial_state_idx: torch.Tensor | None = None, + retrieve_parent_token: torch.Tensor | None = None, validate_data=False, ): """ @@ -1160,8 +1218,8 @@ def causal_conv1d_update( assert num_cache_lines >= batch assert weight.stride(1) == 1 # Need this - # adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o' - out = x + # have to create a new tensor 'o' when retrieve_parent_token is provided. + out = torch.empty_like(x) stride_w_dim, stride_w_width = weight.stride() if query_start_loc is None: @@ -1185,6 +1243,15 @@ def causal_conv1d_update( state_len = width - 1 np2_statelen = triton.next_power_of_2(state_len) + # prepare retrieve_parent_token buffer strides if provided + if retrieve_parent_token is not None: + stride_retrieve_parent_token_seq, stride_retrieve_parent_token_token = ( + retrieve_parent_token.stride(0), + retrieve_parent_token.stride(1), + ) + else: + stride_retrieve_parent_token_seq, stride_retrieve_parent_token_token = 0, 0 + def grid(META): return ( batch, @@ -1202,6 +1269,7 @@ def grid(META): query_start_loc, block_idx_last_scheduled_token, initial_state_idx, + retrieve_parent_token, out, # Matrix dimensions batch, @@ -1219,6 +1287,8 @@ def grid(META): stride_istate_dim, stride_istate_token, stride_state_indices, + stride_retrieve_parent_token_seq, + stride_retrieve_parent_token_token, stride_o_seq, stride_o_dim, stride_o_token, @@ -1234,6 +1304,7 @@ def grid(META): NP2_STATELEN=np2_statelen, USE_PAD_SLOT=pad_slot_id is not None, BLOCK_N=256, + IS_EAGLE_TREE=retrieve_parent_token is not None, ) if unsqueeze: out = out.squeeze(-1)