Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 212 additions & 0 deletions tests/v1/spec_decode/test_causal_conv1d_with_eagle_tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project


import pytest
import torch
import torch.nn.functional as F

from vllm.model_executor.layers.mamba.ops.causal_conv1d import causal_conv1d_update
from vllm.platforms import current_platform


def causal_conv1d_update_ref(
x: torch.Tensor,
conv_state: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor | None,
activation: str | None,
conv_state_indices: torch.Tensor,
num_accepted_tokens: torch.Tensor,
query_start_loc: torch.Tensor,
retrieve_parent_token: torch.Tensor | None,
):
"""
x: (dim, seqlen)
conv_state: (chunk_size, dim, state_len), where state_len >= width - 1
weight: (dim, width)
bias: (dim,)
out: (dim, seqlen)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
dtype_in = x.dtype
unsqueeze = x.dim() == 2
dim, width = weight.shape
state_len = width - 1
output = torch.zeros_like(x)
conv_state_ori = conv_state.clone()
num_reqs = query_start_loc.shape[0] - 1
for j in range(num_reqs):
# update conv_state
con_seq_len = query_start_loc[j + 1] - query_start_loc[j]
x_new = torch.cat(
[
conv_state[
conv_state_indices[j], :, : num_accepted_tokens[j] + state_len - 1
],
x[:, query_start_loc[j] : query_start_loc[j + 1]],
],
dim=-1,
).to(weight.dtype)
update_state_len = state_len + con_seq_len - 1
conv_state[conv_state_indices[j], :, :update_state_len].copy_(
x_new[:, num_accepted_tokens[j] :]
)
# update output
for i in range(con_seq_len):
con_x = x[:, query_start_loc[j] + i : query_start_loc[j] + i + 1]
con_index = i
if retrieve_parent_token is not None:
while retrieve_parent_token[j, con_index] != -1:
con_index = retrieve_parent_token[j, con_index]
con_x = torch.cat(
[
x[
:,
query_start_loc[j] + con_index : query_start_loc[j]
+ con_index
+ 1,
],
con_x,
],
dim=-1,
)
else:
con_x = x[:, query_start_loc[j] : query_start_loc[j] + i + 1]
con_x = torch.cat(
[
conv_state_ori[
conv_state_indices[j],
:,
: num_accepted_tokens[j] + state_len - 1,
],
con_x,
],
dim=-1,
).to(weight.dtype)

con_x = con_x[:, -width:]

if unsqueeze:
con_x = con_x.unsqueeze(0)
out = F.conv1d(con_x, weight.unsqueeze(1), bias, padding=0, groups=dim)[
:, :, -1:
]
if unsqueeze:
out = out.squeeze(0)
output[:, query_start_loc[j] + i : query_start_loc[j] + i + 1] = out

return (output if activation is None else F.silu(output)).to(dtype=dtype_in)


@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float32])
@pytest.mark.parametrize("silu_activation", [False, True])
@pytest.mark.parametrize("has_bias", [False, True])
@pytest.mark.parametrize("is_eagle_tree", [False, True])
def test_causal_conv1d_update(
has_bias: bool, silu_activation: bool, itype: torch.dtype, is_eagle_tree: bool
):
device = "cuda"
# set seed
current_platform.seed_everything(0)
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (1e-2, 5e-2)
# num_reqs = 4, max_spec_len = 5
batch_size = 4
num_speculative_tokens = 5
chunk_size = 64
dim = 2048
width = 4
# shape is [batch_size, max_spec_len + 1]
retrieve_parent_token = torch.tensor(
[
# Tree1:
# 0
# / \
# 1 2
# /
# 3
[-1, 0, 0, 1, -1, -1],
# Tree2:
# 0
# /
# 1
# /
# 2
[-1, 0, 1, -1, -1, -1],
# Tree3:
# 0
# / \
# 1 2
# / \
# 3 4
[-1, 0, 0, 1, 1, -1],
# Tree4:
# 0
# / \
# 1 2
# / \ /
# 3 4 5
[-1, 0, 0, 1, 1, 2],
],
device="cuda",
dtype=torch.int32,
)
spec_query_start_loc = torch.tensor(
[0, 4, 7, 12, 18],
device="cuda",
dtype=torch.int32,
)
spec_state_indices_tensor = torch.tensor(
[
[1, 2, 3, 4, 5, 6],
[7, 8, 9, 10, 11, 12],
[13, 14, 15, 16, 17, 18],
[19, 20, 21, 22, 23, 24],
],
device=device,
dtype=torch.int32,
)
num_accepted_tokens = torch.tensor(
[1, 2, 1, 2],
device="cuda",
dtype=torch.int32,
)
seqlen = spec_query_start_loc[-1].item()
x = torch.rand(seqlen, dim, device=device, dtype=itype)
x_ref = x.clone().transpose(0, 1)

conv_state = torch.randn(
chunk_size, dim, width + num_speculative_tokens - 1, device=device, dtype=itype
)
weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
conv_state_ref = conv_state.detach().clone()
activation = None if not silu_activation else "silu"

out = causal_conv1d_update(
x,
conv_state,
weight,
bias,
activation=activation,
conv_state_indices=spec_state_indices_tensor[:, 0][:batch_size],
num_accepted_tokens=num_accepted_tokens,
query_start_loc=spec_query_start_loc,
max_query_len=spec_state_indices_tensor.size(-1),
retrieve_parent_token=retrieve_parent_token if is_eagle_tree else None,
validate_data=False,
)
out_ref = causal_conv1d_update_ref(
x_ref,
conv_state_ref,
weight,
bias,
activation=activation,
conv_state_indices=spec_state_indices_tensor[:, 0][:batch_size],
num_accepted_tokens=num_accepted_tokens,
query_start_loc=spec_query_start_loc,
retrieve_parent_token=retrieve_parent_token if is_eagle_tree else None,
).transpose(0, 1)
assert torch.equal(conv_state, conv_state_ref)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
164 changes: 164 additions & 0 deletions tests/v1/spec_decode/test_recurrent_gated_delta_with_eagle_tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
import torch.nn.functional as F
from einops import repeat

from vllm.model_executor.layers.fla.ops import fused_recurrent_gated_delta_rule


def recurrent_gated_delta_rule_ref(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
beta: torch.Tensor,
g: torch.Tensor,
cu_seq_lens: torch.Tensor,
ssm_state_indices: torch.Tensor,
num_accepted_tokens: torch.Tensor,
retrieve_parent_token: torch.Tensor = None,
initial_state: torch.Tensor = None,
scale: float | None = None,
):
o = torch.zeros(*v.shape).to(v)
q, k, v, beta, g = map(
lambda x: x.transpose(1, 2).contiguous().to(torch.float32), [q, k, v, beta, g]
)

if scale is None:
scale = 1 / (q.shape[-1] ** 0.5)
q = q * scale

num_reqs = cu_seq_lens.shape[0] - 1
for j in range(num_reqs):
h = initial_state[ssm_state_indices[j][num_accepted_tokens[j] - 1]]
T = cu_seq_lens[j + 1] - cu_seq_lens[j]
for i in range(T):
if retrieve_parent_token is not None and i != 0:
h = initial_state[ssm_state_indices[j][retrieve_parent_token[j][i]]]
b_q = q[:, :, cu_seq_lens[j] + i]
b_k = k[:, :, cu_seq_lens[j] + i]
b_v = v[:, :, cu_seq_lens[j] + i].clone()
h = h.clone() * g[:, :, cu_seq_lens[j] + i].exp()[..., None, None]
b_beta = beta[:, :, cu_seq_lens[j] + i]
b_v = b_v - (h.clone() * b_k[..., None]).sum(-2)
b_v = b_v * b_beta[..., None]
h = h.clone() + b_k.unsqueeze(-1) * b_v.unsqueeze(-2)
o[:, cu_seq_lens[j] + i, :] = torch.einsum("bhd,bhdm->bhm", b_q, h)
initial_state[ssm_state_indices[j][i]] = h
return o, initial_state


@pytest.mark.parametrize("has_eagle_tree_state", [False, True])
def test_fused_recurrent(has_eagle_tree_state: bool):
torch.manual_seed(42)
H = 4
K = 128
V = 128
HV = 8

# shape is [batch_size, max_spec_len + 1]
retrieve_parent_token = torch.tensor(
[
# Tree1:
# 0
# / \
# 1 2
# /
# 3
[-1, 0, 0, 1, -1, -1],
# Tree2:
# 0
# /
# 1
# /
# 2
[-1, 0, 1, -1, -1, -1],
# Tree3:
# 0
# / \
# 1 2
# / \
# 3 4
[-1, 0, 0, 1, 1, -1],
# Tree4:
# 0
# / \
# 1 2
# / \ /
# 3 4 5
[-1, 0, 0, 1, 1, 2],
],
device="cuda",
dtype=torch.int32,
)
# num_reqs = 4, max_spec_len = 5
cu_seq_lens = torch.tensor(
[0, 4, 7, 12, 18],
device="cuda",
dtype=torch.int32,
)
spec_state_indices_tensor = torch.tensor(
[
[1, 2, 3, 4, 5, 6],
[7, 8, 9, 10, 11, 12],
[13, 14, 15, 16, 17, 18],
[19, 20, 21, 22, 23, 24],
],
device="cuda",
dtype=torch.int32,
)
num_accepted_tokens = torch.tensor(
[2, 1, 1, 2],
device="cuda",
dtype=torch.int32,
)
# for variable-length inputs,
# the batch size `B` is expected to be 1 and `cu_seqlens` is required
B = 1
T = cu_seq_lens.max()
chunk_size = 64
q = torch.randn(B, T, H, K, dtype=torch.float16)
k = torch.randn(B, T, H, K, dtype=torch.float16)
v = torch.randn(B, T, HV, V, dtype=torch.float16)
beta = torch.randn(B, T, HV, dtype=torch.float16).sigmoid()
g = F.logsigmoid(torch.rand(B, T, HV, dtype=torch.float32))
h0 = torch.randn(chunk_size, HV, K, V, dtype=torch.float32)
q, k, v, beta, g, h0 = map(
lambda x: x.to("cuda").requires_grad_(), (q, k, v, beta, g, h0)
)

ref, ref_ht = recurrent_gated_delta_rule_ref(
q=F.normalize(
repeat(q.clone(), "b t h d -> b t (h g) d", g=HV // H), p=2, dim=-1
),
k=F.normalize(
repeat(k.clone(), "b t h d -> b t (h g) d", g=HV // H), p=2, dim=-1
),
v=v.clone(),
beta=beta.clone(),
g=g.clone(),
cu_seq_lens=cu_seq_lens,
ssm_state_indices=spec_state_indices_tensor,
num_accepted_tokens=num_accepted_tokens,
retrieve_parent_token=retrieve_parent_token if has_eagle_tree_state else None,
initial_state=h0.clone(),
)
tri, tri_ht = fused_recurrent_gated_delta_rule(
q=q.clone(),
k=k.clone(),
v=v.clone(),
beta=beta.clone(),
g=g.clone(),
initial_state=h0.clone(),
use_qk_l2norm_in_kernel=True,
cu_seqlens=cu_seq_lens,
ssm_state_indices=spec_state_indices_tensor,
num_accepted_tokens=num_accepted_tokens,
retrieve_parent_token=retrieve_parent_token if has_eagle_tree_state else None,
inplace_final_state=True,
)

assert torch.allclose(ref, tri, atol=1e-3, rtol=1e-4)
assert torch.allclose(ref_ht, tri_ht, atol=1e-3, rtol=1e-4)
Loading