Skip to content

Commit 0ff951c

Browse files
add attention sinks for flash attention2
Signed-off-by: root <[email protected]>
1 parent a893712 commit 0ff951c

File tree

9 files changed

+455
-17
lines changed

9 files changed

+455
-17
lines changed

benchmarks/benchmark_fa2_sinks.py

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
"""
2+
Single-file FlashAttention2 with attention sinks benchmarks.
3+
4+
Usage:
5+
python3 benchmark_fa2_sinks.py
6+
7+
Parameters:
8+
so_path - path to the shared library (.so)
9+
path_new - path to save results
10+
"""
11+
import math
12+
import torch
13+
from einops import rearrange
14+
import torch.utils.benchmark as benchmark
15+
import csv
16+
import vllm
17+
import os
18+
19+
pkg = os.path.dirname(vllm.__file__)
20+
so_path = os.path.join(pkg, "vllm_flash_attn", "_vllm_fa2_C.abi3.so")
21+
path_new = os.path.join(".", "benchmark_fa2_sinks.csv")
22+
23+
csv_rows = []
24+
25+
def benchmark_forward(
26+
fn, inputs, repeats=10, desc="", verbose=False, amp=False, amp_dtype=torch.float16, **kwinputs
27+
):
28+
"""Use Pytorch Benchmark on the forward pass of an arbitrary function."""
29+
if verbose:
30+
print(desc, "- Forward pass")
31+
32+
def amp_wrapper(inputs, **kwinputs):
33+
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
34+
fn(*inputs)
35+
36+
t = benchmark.Timer(
37+
stmt="fn_amp(inputs, **kwinputs)",
38+
globals={"fn_amp": amp_wrapper, "inputs": inputs, "kwinputs": kwinputs},
39+
num_threads=torch.get_num_threads(),
40+
)
41+
m = t.timeit(repeats)
42+
if verbose:
43+
print(m)
44+
return t, m
45+
46+
def benchmark_fwd_bwd(
47+
fn,
48+
inputs,
49+
grad=None,
50+
repeats=10,
51+
desc="",
52+
verbose=True,
53+
amp=False,
54+
amp_dtype=torch.float16,
55+
**kwinputs,
56+
):
57+
"""Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
58+
return benchmark_forward(
59+
fn,
60+
inputs,
61+
repeats=repeats,
62+
desc=desc,
63+
verbose=verbose,
64+
amp=amp,
65+
amp_dtype=amp_dtype,
66+
**kwinputs,
67+
)
68+
69+
attention_triton = None
70+
xops = None
71+
72+
def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"):
73+
assert mode in ["fwd", "bwd", "fwd_bwd"]
74+
f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)
75+
return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)
76+
77+
def efficiency(flop, time):
78+
return (flop / time / 10**12) if not math.isnan(time) else 0.0
79+
80+
81+
def attention_pytorch(qkv, dropout_p=0.0, causal=True):
82+
"""
83+
Arguments:
84+
qkv: (batch_size, seqlen, 3, nheads, head_dim)
85+
dropout_p: float
86+
Output:
87+
output: (batch_size, seqlen, nheads, head_dim)
88+
"""
89+
batch_size, seqlen, _, nheads, d = qkv.shape
90+
q, k, v = qkv.unbind(dim=2)
91+
q = rearrange(q, 'b t h d -> (b h) t d')
92+
k = rearrange(k, 'b s h d -> (b h) d s')
93+
softmax_scale = 1.0 / math.sqrt(d)
94+
scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device)
95+
scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale),
96+
'(b h) t s -> b h t s', h=nheads)
97+
if causal:
98+
causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
99+
scores = scores + causal_mask.to(dtype=scores.dtype)
100+
attention = torch.softmax(scores, dim=-1)
101+
attention_drop = F.dropout(attention, dropout_p)
102+
output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
103+
return output.to(dtype=qkv.dtype)
104+
105+
106+
def time_fwd_bwd(func, *args, **kwargs):
107+
time_f = benchmark_fwd_bwd(func, *args, **kwargs)
108+
return time_f[1].mean
109+
110+
111+
repeats = 30
112+
device = 'cuda'
113+
dtype = torch.float16
114+
115+
bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)]
116+
causal_vals = [False, True]
117+
headdim_vals = [64, 128]
118+
dim = 2048
119+
dropout_p = 0.0
120+
121+
time_f = {}
122+
time_b = {}
123+
time_f_b = {}
124+
speed_f = {}
125+
speed_b = {}
126+
speed_f_b = {}
127+
def test_time(path, func_name, old_or_new):
128+
for causal in causal_vals:
129+
for headdim in headdim_vals:
130+
for batch_size, seqlen in bs_seqlen_vals:
131+
config = (causal, headdim, batch_size, seqlen)
132+
nheads = dim // headdim
133+
134+
qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
135+
requires_grad=True)
136+
q = qkv[:, :, 0] # (B, S, H, D)
137+
k = qkv[:, :, 1] # (B, S, H, D)
138+
v = qkv[:, :, 2] # (B, S, H, D)
139+
140+
q = q.reshape(-1, nheads, headdim)
141+
k = k.reshape(-1, nheads, headdim)
142+
v = v.reshape(-1, nheads, headdim)
143+
s_aux = torch.randn(nheads,device=device, dtype=dtype,requires_grad=True)
144+
out_buf = torch.empty_like(q)
145+
fa2_fwd_closure = [q, k, v,
146+
out_buf,
147+
torch.tensor([(seqlen)*i for i in range(batch_size+1)], device=device, dtype = torch.int32),
148+
torch.tensor([(seqlen)*i for i in range(batch_size+1)], device=device, dtype = torch.int32),
149+
None,
150+
None,
151+
None,
152+
None,
153+
seqlen,
154+
seqlen,
155+
dropout_p,
156+
torch.tensor(1.0 / (headdim ** 0.5), device=device),
157+
False,
158+
causal,
159+
-1,
160+
-1,
161+
0.0,
162+
dropout_p > 0,
163+
None,
164+
s_aux
165+
]
166+
f = time_fwd_bwd(func_name, fa2_fwd_closure, repeats=repeats, verbose=False)
167+
time_f[config, "Flash2"] = f
168+
print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###")
169+
speed_f[config, "Flash2"] = efficiency(
170+
flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"),
171+
time_f[config, "Flash2"])
172+
print(
173+
f"{"Flash2"} fwd: {speed_f[config, "Flash2"]:.2f} TFLOPs/s, "
174+
)
175+
csv_rows.append([causal, headdim, batch_size, seqlen, f"{speed_f[config, "Flash2"]:.2f}"])
176+
with open(path, "a", newline="") as fp:
177+
writer = csv.writer(fp)
178+
writer.writerow(["causal", "headdim", "batch_size", "seqlen", "TFLOPs/s"])
179+
writer.writerows(csv_rows)
180+
181+
print(f"已写入{path}")
182+
183+
torch.ops.load_library(so_path)
184+
func_name = torch.ops._vllm_fa2_C.varlen_fwd
185+
test_time(path_new, func_name, "new")

csrc/flash_attn/flash_api.cpp

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,11 @@ void set_params_fprop(Flash_fwd_params &params,
4949
int window_size_right,
5050
const float softcap,
5151
bool seqlenq_ngroups_swapped=false,
52-
const bool unpadded_lse=false) {
52+
const bool unpadded_lse=false,
53+
//TODO(dudugong-gitch): sinks
54+
const std::optional<at::Tensor> s_aux_ = std::nullopt,
55+
//TODO(dudugong-gitch):q heads per k heads
56+
int q_heads_per_k_heads = 1) {
5357

5458
// Reset the parameters
5559
params = {};
@@ -81,6 +85,23 @@ void set_params_fprop(Flash_fwd_params &params,
8185
params.o_batch_stride *= seqlen_q;
8286
}
8387
}
88+
// TODO(dudugong-gitch): sink-tokens support – integrate into params
89+
if (s_aux_.has_value()) {
90+
TORCH_CHECK(q.size(-1) == v.size(-1),
91+
"We don't support S_aux with hdim != hdim_v");
92+
auto s_aux = s_aux_.value();
93+
TORCH_CHECK(s_aux.device() == q.device(),
94+
"s_aux must be on the same device as q");
95+
TORCH_CHECK(s_aux.ndimension() == 1,
96+
c10::str("s_aux must be 1-D, but got ", s_aux.ndimension(), "-D."));
97+
params.s_aux_ptr = s_aux.data_ptr();
98+
TORCH_CHECK(s_aux.scalar_type() == q.scalar_type(),
99+
"s_aux and q must have the same dtype");
100+
} else {
101+
params.s_aux_ptr = nullptr;
102+
}
103+
//TODO(dudugong-gitch): q_heads_per_k_heads
104+
params.q_heads_per_k_heads = q_heads_per_k_heads;
84105

85106
params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
86107
params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
@@ -533,7 +554,9 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
533554
int window_size_right,
534555
const float softcap,
535556
const bool return_softmax,
536-
std::optional<at::Generator> gen_) {
557+
std::optional<at::Generator> gen_,
558+
//TODO(dudugong-gitch): sinks
559+
const std::optional<at::Tensor> &s_aux = std::nullopt) {
537560

538561
// Otherwise the kernel will be launched from cuda:0 device
539562
at::cuda::CUDAGuard device_guard{q.device()};
@@ -588,6 +611,11 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
588611

589612
void *cu_seqlens_q_d = cu_seqlens_q.data_ptr();
590613

614+
//TODO(dudugong-gitch): validate shape
615+
TORCH_CHECK(!s_aux.has_value() || s_aux.value().size(0) == q.size(-2),
616+
c10::str("s_aux.size(0) must equal the number of heads of q (",
617+
q.size(-2), "), but got ", s_aux.value().size(0), "."));
618+
591619
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
592620
// H/t Daniel Haziza
593621
const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size % 8 == 0 && !alibi_slopes_.has_value();
@@ -689,7 +717,9 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
689717
window_size_right,
690718
softcap,
691719
seqlenq_ngroups_swapped,
692-
/*unpadded_lse*/true);
720+
/*unpadded_lse*/true,
721+
/*s aux*/s_aux,
722+
/*q_heads_per_k_heads*/ngroups);
693723
params.total_q = total_q;
694724

695725
if (paged_KV) {
@@ -1251,7 +1281,9 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
12511281
int window_size_right,
12521282
const float softcap,
12531283
bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
1254-
int num_splits
1284+
int num_splits,
1285+
//TODO(dudugong-gitch): sinks
1286+
const std::optional<at::Tensor> &s_aux = std::nullopt
12551287
) {
12561288

12571289
// Otherwise the kernel will be launched from cuda:0 device
@@ -1309,6 +1341,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
13091341

13101342
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
13111343
// H/t Daniel Haziza
1344+
//TODO(dudugong-gitch): q_heads_per_k_heads
1345+
const int q_heads_per_k_heads = num_heads / num_heads_k;
13121346
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
13131347
if (seqlenq_ngroups_swapped) {
13141348
const int ngroups = num_heads / num_heads_k;
@@ -1384,8 +1418,11 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
13841418
softmax_scale,
13851419
window_size_left,
13861420
window_size_right,
1387-
softcap
1388-
);
1421+
softcap,
1422+
/*seqlenq_ngroups_swapped default*/false,
1423+
/*unpadded_lse default*/false,
1424+
/*s aux*/s_aux,
1425+
/*q_heads_per_k_heads*/q_heads_per_k_heads);
13891426

13901427
at::Tensor k, v, k_padded, v_padded;
13911428
if (k_.has_value()) {

csrc/flash_attn/flash_api_sparse.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,11 @@ void set_params_fprop(Flash_fwd_params &params,
5555
int window_size_right,
5656
const float softcap,
5757
bool seqlenq_ngroups_swapped=false,
58-
const bool unpadded_lse=false);
58+
const bool unpadded_lse=false,
59+
//TODO(dudugong-gitch): sinks
60+
std::optional<at::Tensor> s_aux_ = std::nullopt,
61+
//TODO(dudugong-gitch):q_heads_per_k_heads
62+
int q_heads_per_k_heads = 1);
5963

6064
std::tuple<at::Tensor, at::Tensor> set_params_splitkv(Flash_fwd_params &params, const int batch_size,
6165
const int num_heads, const int head_size, const int max_seqlen_k, const int max_seqlen_q,

csrc/flash_attn/flash_api_torch_lib.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
3535
int window_size_right,
3636
const float softcap,
3737
const bool return_softmax,
38-
std::optional<at::Generator> gen_);
38+
std::optional<at::Generator> gen_,
39+
//TODO(dudugong-gitch): sinks
40+
const std::optional<at::Tensor> &s_aux);
3941

4042
std::vector<at::Tensor>
4143
mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
@@ -57,7 +59,9 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
5759
int window_size_right,
5860
const float softcap,
5961
bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
60-
int num_splits);
62+
int num_splits,
63+
//TODO(dudugong-gitch): sinks
64+
const std::optional<at::Tensor> &s_aux_);
6165

6266
/////////////////////////// From flash_api_sparse.cpp //////////////////////////
6367

@@ -104,18 +108,20 @@ mha_varlen_fwd_sparse(at::Tensor &q, // total_q x num_heads x head_size, total_
104108
/**
105109
* Torch Library Registration
106110
*/
111+
// TODO(dudugong-gitch): sync pybind once TORCH_LIBRARY supports optional<Tensor> for sink tokens
107112
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
108113
ops.def("varlen_fwd(Tensor! q, Tensor k, Tensor v, Tensor!? out, Tensor cu_seqlens_q, "
109114
"Tensor cu_seqlens_k, Tensor? seqused_k, Tensor? leftpad_k, Tensor? block_table, Tensor? alibi_slopes, "
110115
"int max_seqlen_q, int max_seqlen_k, float p_dropout, float softmax_scale, bool zero_tensors, "
111116
"bool is_causal, int window_size_left, int window_size_right, float softcap, bool return_softmax, "
112-
"Generator? gen) -> Tensor[]");
117+
"Generator? gen, "
118+
"Tensor? s_aux = None) -> Tensor[]");
113119
ops.impl("varlen_fwd", torch::kCUDA, make_pytorch_shim(&mha_varlen_fwd));
114120

115121
ops.def("fwd_kvcache(Tensor! q, Tensor kcache, Tensor vcache, Tensor? k, Tensor? v, Tensor? seqlens_k, "
116122
"Tensor? rotary_cos, Tensor? rotary_sin, Tensor? cache_batch_idx, Tensor? leftpad_k, Tensor? block_table, "
117123
"Tensor? alibi_slopes, Tensor!? out, float softmax_scale, bool is_causal, int window_size_left, "
118-
"int window_size_right, float softcap, bool is_rotary_interleaved, int num_splits) -> Tensor[]");
124+
"int window_size_right, float softcap, bool is_rotary_interleaved, int num_splits, Tensor? s_aux_) -> Tensor[]");
119125
ops.impl("fwd_kvcache", torch::kCUDA, make_pytorch_shim(&mha_fwd_kvcache));
120126

121127
ops.def("fwd_sparse(Tensor! q, Tensor k, Tensor v, "

csrc/flash_attn/src/flash.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ struct Qkv_params {
4141
// In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be
4242
// different from nheads (query).
4343
int h_h_k_ratio; // precompute h / h_k,
44+
//TODO(dudugong-gitch):q_heads_per_k_heads
45+
int q_heads_per_k_heads = 1;
4446
};
4547

4648
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -140,6 +142,8 @@ struct Flash_fwd_params : public Qkv_params {
140142

141143
bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q].
142144
bool seqlenq_ngroups_swapped; // q has been transposed from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d).
145+
// TODO(dudugong-gitch): sink-token vector pointer
146+
void *__restrict__ s_aux_ptr = nullptr;
143147
};
144148

145149
////////////////////////////////////////////////////////////////////////////////////////////////////

0 commit comments

Comments
 (0)