Skip to content

Commit b47f50b

Browse files
committed
support tree attention in gdn
Signed-off-by: liumengge1205 <[email protected]>
1 parent 0037b57 commit b47f50b

File tree

4 files changed

+577
-79
lines changed

4 files changed

+577
-79
lines changed
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
5+
import pytest
6+
import torch
7+
import torch.nn.functional as F
8+
9+
from vllm.model_executor.layers.mamba.ops.causal_conv1d import causal_conv1d_update
10+
from vllm.platforms import current_platform
11+
12+
13+
def causal_conv1d_update_ref(
14+
x: torch.Tensor,
15+
conv_state: torch.Tensor,
16+
weight: torch.Tensor,
17+
bias: torch.Tensor | None,
18+
activation: str | None,
19+
conv_state_indices: torch.Tensor,
20+
num_accepted_tokens: torch.Tensor,
21+
query_start_loc: torch.Tensor,
22+
retrieve_parent_token: torch.Tensor | None,
23+
):
24+
"""
25+
x: (dim, seqlen)
26+
conv_state: (chunk_size, dim, state_len), where state_len >= width - 1
27+
weight: (dim, width)
28+
bias: (dim,)
29+
out: (dim, seqlen)
30+
"""
31+
if activation not in [None, "silu", "swish"]:
32+
raise NotImplementedError("activation must be None, silu, or swish")
33+
dtype_in = x.dtype
34+
unsqueeze = x.dim() == 2
35+
dim, width = weight.shape
36+
state_len = width - 1
37+
output = torch.zeros_like(x)
38+
conv_state_ori = conv_state.clone()
39+
num_reqs = query_start_loc.shape[0] - 1
40+
for j in range(num_reqs):
41+
# update conv_state
42+
con_seq_len = query_start_loc[j + 1] - query_start_loc[j]
43+
x_new = torch.cat(
44+
[
45+
conv_state[
46+
conv_state_indices[j], :, : num_accepted_tokens[j] + state_len - 1
47+
],
48+
x[:, query_start_loc[j] : query_start_loc[j + 1]],
49+
],
50+
dim=-1,
51+
).to(weight.dtype)
52+
update_state_len = state_len + con_seq_len - 1
53+
conv_state[conv_state_indices[j], :, :update_state_len].copy_(
54+
x_new[:, num_accepted_tokens[j] :]
55+
)
56+
# update output
57+
for i in range(con_seq_len):
58+
con_x = x[:, query_start_loc[j] + i : query_start_loc[j] + i + 1]
59+
con_index = i
60+
if retrieve_parent_token is not None:
61+
while retrieve_parent_token[j, con_index] != -1:
62+
con_index = retrieve_parent_token[j, con_index]
63+
con_x = torch.cat(
64+
[
65+
x[
66+
:,
67+
query_start_loc[j] + con_index : query_start_loc[j]
68+
+ con_index
69+
+ 1,
70+
],
71+
con_x,
72+
],
73+
dim=-1,
74+
)
75+
else:
76+
con_x = x[:, query_start_loc[j] : query_start_loc[j] + i + 1]
77+
con_x = torch.cat(
78+
[
79+
conv_state_ori[
80+
conv_state_indices[j],
81+
:,
82+
: num_accepted_tokens[j] + state_len - 1,
83+
],
84+
con_x,
85+
],
86+
dim=-1,
87+
).to(weight.dtype)
88+
89+
con_x = con_x[:, -width:]
90+
91+
if unsqueeze:
92+
con_x = con_x.unsqueeze(0)
93+
out = F.conv1d(con_x, weight.unsqueeze(1), bias, padding=0, groups=dim)[
94+
:, :, -1:
95+
]
96+
if unsqueeze:
97+
out = out.squeeze(0)
98+
output[:, query_start_loc[j] + i : query_start_loc[j] + i + 1] = out
99+
100+
return (output if activation is None else F.silu(output)).to(dtype=dtype_in)
101+
102+
103+
@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float32])
104+
@pytest.mark.parametrize("silu_activation", [False, True])
105+
@pytest.mark.parametrize("has_bias", [False, True])
106+
@pytest.mark.parametrize("is_eagle_tree", [False, True])
107+
def test_causal_conv1d_update(
108+
has_bias: bool, silu_activation: bool, itype: torch.dtype, is_eagle_tree: bool
109+
):
110+
device = "cuda"
111+
# set seed
112+
current_platform.seed_everything(0)
113+
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (1e-2, 5e-2)
114+
# num_reqs = 4, max_spec_len = 5
115+
batch_size = 4
116+
num_speculative_tokens = 5
117+
chunk_size = 64
118+
dim = 2048
119+
width = 4
120+
# shape is [batch_size, max_spec_len + 1]
121+
retrieve_parent_token = torch.tensor(
122+
[
123+
# Tree1:
124+
# 0
125+
# / \
126+
# 1 2
127+
# /
128+
# 3
129+
[-1, 0, 0, 1, -1, -1],
130+
# Tree2:
131+
# 0
132+
# /
133+
# 1
134+
# /
135+
# 2
136+
[-1, 0, 1, -1, -1, -1],
137+
# Tree3:
138+
# 0
139+
# / \
140+
# 1 2
141+
# / \
142+
# 3 4
143+
[-1, 0, 0, 1, 1, -1],
144+
# Tree4:
145+
# 0
146+
# / \
147+
# 1 2
148+
# / \ /
149+
# 3 4 5
150+
[-1, 0, 0, 1, 1, 2],
151+
],
152+
device="cuda",
153+
dtype=torch.int32,
154+
)
155+
spec_query_start_loc = torch.tensor(
156+
[0, 4, 7, 12, 18],
157+
device="cuda",
158+
dtype=torch.int32,
159+
)
160+
spec_state_indices_tensor = torch.tensor(
161+
[
162+
[1, 2, 3, 4, 5, 6],
163+
[7, 8, 9, 10, 11, 12],
164+
[13, 14, 15, 16, 17, 18],
165+
[19, 20, 21, 22, 23, 24],
166+
],
167+
device=device,
168+
dtype=torch.int32,
169+
)
170+
num_accepted_tokens = torch.tensor(
171+
[1, 2, 1, 2],
172+
device="cuda",
173+
dtype=torch.int32,
174+
)
175+
seqlen = spec_query_start_loc[-1].item()
176+
x = torch.rand(seqlen, dim, device=device, dtype=itype)
177+
x_ref = x.clone().transpose(0, 1)
178+
179+
conv_state = torch.randn(
180+
chunk_size, dim, width + num_speculative_tokens - 1, device=device, dtype=itype
181+
)
182+
weight = torch.randn(dim, width, device=device, dtype=itype)
183+
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
184+
conv_state_ref = conv_state.detach().clone()
185+
activation = None if not silu_activation else "silu"
186+
187+
out = causal_conv1d_update(
188+
x,
189+
conv_state,
190+
weight,
191+
bias,
192+
activation=activation,
193+
conv_state_indices=spec_state_indices_tensor[:, 0][:batch_size],
194+
num_accepted_tokens=num_accepted_tokens,
195+
query_start_loc=spec_query_start_loc,
196+
max_query_len=spec_state_indices_tensor.size(-1),
197+
retrieve_parent_token=retrieve_parent_token if is_eagle_tree else None,
198+
validate_data=False,
199+
)
200+
out_ref = causal_conv1d_update_ref(
201+
x_ref,
202+
conv_state_ref,
203+
weight,
204+
bias,
205+
activation=activation,
206+
conv_state_indices=spec_state_indices_tensor[:, 0][:batch_size],
207+
num_accepted_tokens=num_accepted_tokens,
208+
query_start_loc=spec_query_start_loc,
209+
retrieve_parent_token=retrieve_parent_token if is_eagle_tree else None,
210+
).transpose(0, 1)
211+
assert torch.equal(conv_state, conv_state_ref)
212+
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import pytest
4+
import torch
5+
import torch.nn.functional as F
6+
from einops import repeat
7+
8+
from vllm.model_executor.layers.fla.ops import fused_recurrent_gated_delta_rule
9+
10+
11+
def recurrent_gated_delta_rule_ref(
12+
q: torch.Tensor,
13+
k: torch.Tensor,
14+
v: torch.Tensor,
15+
beta: torch.Tensor,
16+
g: torch.Tensor,
17+
cu_seq_lens: torch.Tensor,
18+
ssm_state_indices: torch.Tensor,
19+
num_accepted_tokens: torch.Tensor,
20+
retrieve_parent_token: torch.Tensor = None,
21+
initial_state: torch.Tensor = None,
22+
scale: float | None = None,
23+
):
24+
o = torch.zeros(*v.shape).to(v)
25+
q, k, v, beta, g = map(
26+
lambda x: x.transpose(1, 2).contiguous().to(torch.float32), [q, k, v, beta, g]
27+
)
28+
29+
if scale is None:
30+
scale = 1 / (q.shape[-1] ** 0.5)
31+
q = q * scale
32+
33+
num_reqs = cu_seq_lens.shape[0] - 1
34+
for j in range(num_reqs):
35+
h = initial_state[ssm_state_indices[j][num_accepted_tokens[j] - 1]]
36+
T = cu_seq_lens[j + 1] - cu_seq_lens[j]
37+
for i in range(T):
38+
if retrieve_parent_token is not None and i != 0:
39+
h = initial_state[ssm_state_indices[j][retrieve_parent_token[j][i]]]
40+
b_q = q[:, :, cu_seq_lens[j] + i]
41+
b_k = k[:, :, cu_seq_lens[j] + i]
42+
b_v = v[:, :, cu_seq_lens[j] + i].clone()
43+
h = h.clone() * g[:, :, cu_seq_lens[j] + i].exp()[..., None, None]
44+
b_beta = beta[:, :, cu_seq_lens[j] + i]
45+
b_v = b_v - (h.clone() * b_k[..., None]).sum(-2)
46+
b_v = b_v * b_beta[..., None]
47+
h = h.clone() + b_k.unsqueeze(-1) * b_v.unsqueeze(-2)
48+
o[:, cu_seq_lens[j] + i, :] = torch.einsum("bhd,bhdm->bhm", b_q, h)
49+
initial_state[ssm_state_indices[j][i]] = h
50+
return o, initial_state
51+
52+
53+
@pytest.mark.parametrize("has_eagle_tree_state", [False, True])
54+
def test_fused_recurrent(has_eagle_tree_state: bool):
55+
torch.manual_seed(42)
56+
H = 4
57+
K = 128
58+
V = 128
59+
HV = 8
60+
61+
# shape is [batch_size, max_spec_len + 1]
62+
retrieve_parent_token = torch.tensor(
63+
[
64+
# Tree1:
65+
# 0
66+
# / \
67+
# 1 2
68+
# /
69+
# 3
70+
[-1, 0, 0, 1, -1, -1],
71+
# Tree2:
72+
# 0
73+
# /
74+
# 1
75+
# /
76+
# 2
77+
[-1, 0, 1, -1, -1, -1],
78+
# Tree3:
79+
# 0
80+
# / \
81+
# 1 2
82+
# / \
83+
# 3 4
84+
[-1, 0, 0, 1, 1, -1],
85+
# Tree4:
86+
# 0
87+
# / \
88+
# 1 2
89+
# / \ /
90+
# 3 4 5
91+
[-1, 0, 0, 1, 1, 2],
92+
],
93+
device="cuda",
94+
dtype=torch.int32,
95+
)
96+
# num_reqs = 4, max_spec_len = 5
97+
cu_seq_lens = torch.tensor(
98+
[0, 4, 7, 12, 18],
99+
device="cuda",
100+
dtype=torch.int32,
101+
)
102+
spec_state_indices_tensor = torch.tensor(
103+
[
104+
[1, 2, 3, 4, 5, 6],
105+
[7, 8, 9, 10, 11, 12],
106+
[13, 14, 15, 16, 17, 18],
107+
[19, 20, 21, 22, 23, 24],
108+
],
109+
device="cuda",
110+
dtype=torch.int32,
111+
)
112+
num_accepted_tokens = torch.tensor(
113+
[2, 1, 1, 2],
114+
device="cuda",
115+
dtype=torch.int32,
116+
)
117+
# for variable-length inputs,
118+
# the batch size `B` is expected to be 1 and `cu_seqlens` is required
119+
B = 1
120+
T = cu_seq_lens.max()
121+
chunk_size = 64
122+
q = torch.randn(B, T, H, K, dtype=torch.float16)
123+
k = torch.randn(B, T, H, K, dtype=torch.float16)
124+
v = torch.randn(B, T, HV, V, dtype=torch.float16)
125+
beta = torch.randn(B, T, HV, dtype=torch.float16).sigmoid()
126+
g = F.logsigmoid(torch.rand(B, T, HV, dtype=torch.float32))
127+
h0 = torch.randn(chunk_size, HV, K, V, dtype=torch.float32)
128+
q, k, v, beta, g, h0 = map(
129+
lambda x: x.to("cuda").requires_grad_(), (q, k, v, beta, g, h0)
130+
)
131+
132+
ref, ref_ht = recurrent_gated_delta_rule_ref(
133+
q=F.normalize(
134+
repeat(q.clone(), "b t h d -> b t (h g) d", g=HV // H), p=2, dim=-1
135+
),
136+
k=F.normalize(
137+
repeat(k.clone(), "b t h d -> b t (h g) d", g=HV // H), p=2, dim=-1
138+
),
139+
v=v.clone(),
140+
beta=beta.clone(),
141+
g=g.clone(),
142+
cu_seq_lens=cu_seq_lens,
143+
ssm_state_indices=spec_state_indices_tensor,
144+
num_accepted_tokens=num_accepted_tokens,
145+
retrieve_parent_token=retrieve_parent_token if has_eagle_tree_state else None,
146+
initial_state=h0.clone(),
147+
)
148+
tri, tri_ht = fused_recurrent_gated_delta_rule(
149+
q=q.clone(),
150+
k=k.clone(),
151+
v=v.clone(),
152+
beta=beta.clone(),
153+
g=g.clone(),
154+
initial_state=h0.clone(),
155+
use_qk_l2norm_in_kernel=True,
156+
cu_seqlens=cu_seq_lens,
157+
ssm_state_indices=spec_state_indices_tensor,
158+
num_accepted_tokens=num_accepted_tokens,
159+
retrieve_parent_token=retrieve_parent_token if has_eagle_tree_state else None,
160+
inplace_final_state=True,
161+
)
162+
163+
assert torch.allclose(ref, tri, atol=1e-3, rtol=1e-4)
164+
assert torch.allclose(ref_ht, tri_ht, atol=1e-3, rtol=1e-4)

0 commit comments

Comments
 (0)