From 0ff951ce9faa60570f41c00717d216a146d9ada8 Mon Sep 17 00:00:00 2001 From: root Date: Sun, 19 Oct 2025 07:08:47 -0700 Subject: [PATCH] add attention sinks for flash attention2 Signed-off-by: root --- benchmarks/benchmark_fa2_sinks.py | 185 ++++++++++++++++++++++++ csrc/flash_attn/flash_api.cpp | 49 ++++++- csrc/flash_attn/flash_api_sparse.cpp | 6 +- csrc/flash_attn/flash_api_torch_lib.cpp | 14 +- csrc/flash_attn/src/flash.h | 4 + csrc/flash_attn/src/flash_fwd_kernel.h | 60 +++++++- csrc/flash_attn/src/softmax.h | 19 +++ tests/test_fa2_sinks.py | 132 +++++++++++++++++ vllm_flash_attn/flash_attn_interface.py | 3 +- 9 files changed, 455 insertions(+), 17 deletions(-) create mode 100644 benchmarks/benchmark_fa2_sinks.py create mode 100644 tests/test_fa2_sinks.py diff --git a/benchmarks/benchmark_fa2_sinks.py b/benchmarks/benchmark_fa2_sinks.py new file mode 100644 index 0000000000..43434a84c1 --- /dev/null +++ b/benchmarks/benchmark_fa2_sinks.py @@ -0,0 +1,185 @@ +""" +Single-file FlashAttention2 with attention sinks benchmarks. + +Usage: + python3 benchmark_fa2_sinks.py + +Parameters: + so_path - path to the shared library (.so) + path_new - path to save results +""" +import math +import torch +from einops import rearrange +import torch.utils.benchmark as benchmark +import csv +import vllm +import os + +pkg = os.path.dirname(vllm.__file__) +so_path = os.path.join(pkg, "vllm_flash_attn", "_vllm_fa2_C.abi3.so") +path_new = os.path.join(".", "benchmark_fa2_sinks.csv") + +csv_rows = [] + +def benchmark_forward( + fn, inputs, repeats=10, desc="", verbose=False, amp=False, amp_dtype=torch.float16, **kwinputs +): + """Use Pytorch Benchmark on the forward pass of an arbitrary function.""" + if verbose: + print(desc, "- Forward pass") + + def amp_wrapper(inputs, **kwinputs): + with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): + fn(*inputs) + + t = benchmark.Timer( + stmt="fn_amp(inputs, **kwinputs)", + globals={"fn_amp": amp_wrapper, "inputs": inputs, "kwinputs": kwinputs}, + num_threads=torch.get_num_threads(), + ) + m = t.timeit(repeats) + if verbose: + print(m) + return t, m + +def benchmark_fwd_bwd( + fn, + inputs, + grad=None, + repeats=10, + desc="", + verbose=True, + amp=False, + amp_dtype=torch.float16, + **kwinputs, +): + """Use Pytorch Benchmark on the forward+backward pass of an arbitrary function.""" + return benchmark_forward( + fn, + inputs, + repeats=repeats, + desc=desc, + verbose=verbose, + amp=amp, + amp_dtype=amp_dtype, + **kwinputs, + ) + +attention_triton = None +xops = None + +def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"): + assert mode in ["fwd", "bwd", "fwd_bwd"] + f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1) + return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f) + +def efficiency(flop, time): + return (flop / time / 10**12) if not math.isnan(time) else 0.0 + + +def attention_pytorch(qkv, dropout_p=0.0, causal=True): + """ + Arguments: + qkv: (batch_size, seqlen, 3, nheads, head_dim) + dropout_p: float + Output: + output: (batch_size, seqlen, nheads, head_dim) + """ + batch_size, seqlen, _, nheads, d = qkv.shape + q, k, v = qkv.unbind(dim=2) + q = rearrange(q, 'b t h d -> (b h) t d') + k = rearrange(k, 'b s h d -> (b h) d s') + softmax_scale = 1.0 / math.sqrt(d) + scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device) + scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale), + '(b h) t s -> b h t s', h=nheads) + if causal: + causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1) + scores = scores + causal_mask.to(dtype=scores.dtype) + attention = torch.softmax(scores, dim=-1) + attention_drop = F.dropout(attention, dropout_p) + output = torch.einsum('bhts,bshd->bthd', attention_drop , v) + return output.to(dtype=qkv.dtype) + + +def time_fwd_bwd(func, *args, **kwargs): + time_f = benchmark_fwd_bwd(func, *args, **kwargs) + return time_f[1].mean + + +repeats = 30 +device = 'cuda' +dtype = torch.float16 + +bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)] +causal_vals = [False, True] +headdim_vals = [64, 128] +dim = 2048 +dropout_p = 0.0 + +time_f = {} +time_b = {} +time_f_b = {} +speed_f = {} +speed_b = {} +speed_f_b = {} +def test_time(path, func_name, old_or_new): + for causal in causal_vals: + for headdim in headdim_vals: + for batch_size, seqlen in bs_seqlen_vals: + config = (causal, headdim, batch_size, seqlen) + nheads = dim // headdim + + qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype, + requires_grad=True) + q = qkv[:, :, 0] # (B, S, H, D) + k = qkv[:, :, 1] # (B, S, H, D) + v = qkv[:, :, 2] # (B, S, H, D) + + q = q.reshape(-1, nheads, headdim) + k = k.reshape(-1, nheads, headdim) + v = v.reshape(-1, nheads, headdim) + s_aux = torch.randn(nheads,device=device, dtype=dtype,requires_grad=True) + out_buf = torch.empty_like(q) + fa2_fwd_closure = [q, k, v, + out_buf, + torch.tensor([(seqlen)*i for i in range(batch_size+1)], device=device, dtype = torch.int32), + torch.tensor([(seqlen)*i for i in range(batch_size+1)], device=device, dtype = torch.int32), + None, + None, + None, + None, + seqlen, + seqlen, + dropout_p, + torch.tensor(1.0 / (headdim ** 0.5), device=device), + False, + causal, + -1, + -1, + 0.0, + dropout_p > 0, + None, + s_aux + ] + f = time_fwd_bwd(func_name, fa2_fwd_closure, repeats=repeats, verbose=False) + time_f[config, "Flash2"] = f + print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###") + speed_f[config, "Flash2"] = efficiency( + flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"), + time_f[config, "Flash2"]) + print( + f"{"Flash2"} fwd: {speed_f[config, "Flash2"]:.2f} TFLOPs/s, " + ) + csv_rows.append([causal, headdim, batch_size, seqlen, f"{speed_f[config, "Flash2"]:.2f}"]) + with open(path, "a", newline="") as fp: + writer = csv.writer(fp) + writer.writerow(["causal", "headdim", "batch_size", "seqlen", "TFLOPs/s"]) + writer.writerows(csv_rows) + + print(f"已写入{path}") + +torch.ops.load_library(so_path) +func_name = torch.ops._vllm_fa2_C.varlen_fwd +test_time(path_new, func_name, "new") diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index 2c890d47a0..be23ccccf8 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -49,7 +49,11 @@ void set_params_fprop(Flash_fwd_params ¶ms, int window_size_right, const float softcap, bool seqlenq_ngroups_swapped=false, - const bool unpadded_lse=false) { + const bool unpadded_lse=false, + //TODO(dudugong-gitch): sinks + const std::optional s_aux_ = std::nullopt, + //TODO(dudugong-gitch):q heads per k heads + int q_heads_per_k_heads = 1) { // Reset the parameters params = {}; @@ -81,6 +85,23 @@ void set_params_fprop(Flash_fwd_params ¶ms, params.o_batch_stride *= seqlen_q; } } + // TODO(dudugong-gitch): sink-tokens support – integrate into params + if (s_aux_.has_value()) { + TORCH_CHECK(q.size(-1) == v.size(-1), + "We don't support S_aux with hdim != hdim_v"); + auto s_aux = s_aux_.value(); + TORCH_CHECK(s_aux.device() == q.device(), + "s_aux must be on the same device as q"); + TORCH_CHECK(s_aux.ndimension() == 1, + c10::str("s_aux must be 1-D, but got ", s_aux.ndimension(), "-D.")); + params.s_aux_ptr = s_aux.data_ptr(); + TORCH_CHECK(s_aux.scalar_type() == q.scalar_type(), + "s_aux and q must have the same dtype"); + } else { + params.s_aux_ptr = nullptr; + } + //TODO(dudugong-gitch): q_heads_per_k_heads + params.q_heads_per_k_heads = q_heads_per_k_heads; params.cu_seqlens_q = static_cast(cu_seqlens_q_d); params.cu_seqlens_k = static_cast(cu_seqlens_k_d); @@ -533,7 +554,9 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s int window_size_right, const float softcap, const bool return_softmax, - std::optional gen_) { + std::optional gen_, + //TODO(dudugong-gitch): sinks + const std::optional &s_aux = std::nullopt) { // Otherwise the kernel will be launched from cuda:0 device at::cuda::CUDAGuard device_guard{q.device()}; @@ -588,6 +611,11 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s void *cu_seqlens_q_d = cu_seqlens_q.data_ptr(); + //TODO(dudugong-gitch): validate shape + TORCH_CHECK(!s_aux.has_value() || s_aux.value().size(0) == q.size(-2), + c10::str("s_aux.size(0) must equal the number of heads of q (", + q.size(-2), "), but got ", s_aux.value().size(0), ".")); + // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // H/t Daniel Haziza const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size % 8 == 0 && !alibi_slopes_.has_value(); @@ -689,7 +717,9 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s window_size_right, softcap, seqlenq_ngroups_swapped, - /*unpadded_lse*/true); + /*unpadded_lse*/true, + /*s aux*/s_aux, + /*q_heads_per_k_heads*/ngroups); params.total_q = total_q; if (paged_KV) { @@ -1251,7 +1281,9 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he int window_size_right, const float softcap, bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 - int num_splits + int num_splits, + //TODO(dudugong-gitch): sinks + const std::optional &s_aux = std::nullopt ) { // Otherwise the kernel will be launched from cuda:0 device @@ -1309,6 +1341,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // H/t Daniel Haziza + //TODO(dudugong-gitch): q_heads_per_k_heads + const int q_heads_per_k_heads = num_heads / num_heads_k; const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0 && !alibi_slopes_.has_value(); if (seqlenq_ngroups_swapped) { const int ngroups = num_heads / num_heads_k; @@ -1384,8 +1418,11 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he softmax_scale, window_size_left, window_size_right, - softcap - ); + softcap, + /*seqlenq_ngroups_swapped default*/false, + /*unpadded_lse default*/false, + /*s aux*/s_aux, + /*q_heads_per_k_heads*/q_heads_per_k_heads); at::Tensor k, v, k_padded, v_padded; if (k_.has_value()) { diff --git a/csrc/flash_attn/flash_api_sparse.cpp b/csrc/flash_attn/flash_api_sparse.cpp index 62a92d8f78..17878673ed 100644 --- a/csrc/flash_attn/flash_api_sparse.cpp +++ b/csrc/flash_attn/flash_api_sparse.cpp @@ -55,7 +55,11 @@ void set_params_fprop(Flash_fwd_params ¶ms, int window_size_right, const float softcap, bool seqlenq_ngroups_swapped=false, - const bool unpadded_lse=false); + const bool unpadded_lse=false, + //TODO(dudugong-gitch): sinks + std::optional s_aux_ = std::nullopt, + //TODO(dudugong-gitch):q_heads_per_k_heads + int q_heads_per_k_heads = 1); std::tuple set_params_splitkv(Flash_fwd_params ¶ms, const int batch_size, const int num_heads, const int head_size, const int max_seqlen_k, const int max_seqlen_q, diff --git a/csrc/flash_attn/flash_api_torch_lib.cpp b/csrc/flash_attn/flash_api_torch_lib.cpp index d1299c54cd..a89adbcbb1 100644 --- a/csrc/flash_attn/flash_api_torch_lib.cpp +++ b/csrc/flash_attn/flash_api_torch_lib.cpp @@ -35,7 +35,9 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s int window_size_right, const float softcap, const bool return_softmax, - std::optional gen_); + std::optional gen_, + //TODO(dudugong-gitch): sinks + const std::optional &s_aux); std::vector mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size @@ -57,7 +59,9 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he int window_size_right, const float softcap, bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 - int num_splits); + int num_splits, + //TODO(dudugong-gitch): sinks + const std::optional &s_aux_); /////////////////////////// From flash_api_sparse.cpp ////////////////////////// @@ -104,18 +108,20 @@ mha_varlen_fwd_sparse(at::Tensor &q, // total_q x num_heads x head_size, total_ /** * Torch Library Registration */ +// TODO(dudugong-gitch): sync pybind once TORCH_LIBRARY supports optional for sink tokens TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("varlen_fwd(Tensor! q, Tensor k, Tensor v, Tensor!? out, Tensor cu_seqlens_q, " "Tensor cu_seqlens_k, Tensor? seqused_k, Tensor? leftpad_k, Tensor? block_table, Tensor? alibi_slopes, " "int max_seqlen_q, int max_seqlen_k, float p_dropout, float softmax_scale, bool zero_tensors, " "bool is_causal, int window_size_left, int window_size_right, float softcap, bool return_softmax, " - "Generator? gen) -> Tensor[]"); + "Generator? gen, " + "Tensor? s_aux = None) -> Tensor[]"); ops.impl("varlen_fwd", torch::kCUDA, make_pytorch_shim(&mha_varlen_fwd)); ops.def("fwd_kvcache(Tensor! q, Tensor kcache, Tensor vcache, Tensor? k, Tensor? v, Tensor? seqlens_k, " "Tensor? rotary_cos, Tensor? rotary_sin, Tensor? cache_batch_idx, Tensor? leftpad_k, Tensor? block_table, " "Tensor? alibi_slopes, Tensor!? out, float softmax_scale, bool is_causal, int window_size_left, " - "int window_size_right, float softcap, bool is_rotary_interleaved, int num_splits) -> Tensor[]"); + "int window_size_right, float softcap, bool is_rotary_interleaved, int num_splits, Tensor? s_aux_) -> Tensor[]"); ops.impl("fwd_kvcache", torch::kCUDA, make_pytorch_shim(&mha_fwd_kvcache)); ops.def("fwd_sparse(Tensor! q, Tensor k, Tensor v, " diff --git a/csrc/flash_attn/src/flash.h b/csrc/flash_attn/src/flash.h index 8ffbb62d66..be0e18302d 100644 --- a/csrc/flash_attn/src/flash.h +++ b/csrc/flash_attn/src/flash.h @@ -41,6 +41,8 @@ struct Qkv_params { // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be // different from nheads (query). int h_h_k_ratio; // precompute h / h_k, + //TODO(dudugong-gitch):q_heads_per_k_heads + int q_heads_per_k_heads = 1; }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -140,6 +142,8 @@ struct Flash_fwd_params : public Qkv_params { bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q]. bool seqlenq_ngroups_swapped; // q has been transposed from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d). + // TODO(dudugong-gitch): sink-token vector pointer + void *__restrict__ s_aux_ptr = nullptr; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index 7512828154..e7b6d6a356 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -430,8 +430,34 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Epilogue - Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax, params.rp_dropout); - +// TODO(dudugong-gitch): fetch current sink values and inject into normalize_softmax_lse + Tensor lse = [&]{ + if (params.s_aux_ptr && params.seqlenq_ngroups_swapped){ + Tensor s_aux_cur = make_tensor(Shape(acc_o)>>{}); + Tensor cS = make_identity_tensor(Shape, Int>{}); + Tensor tScS = thr_mma.partition_C(cS); + Tensor tScS_rowcol = make_tensor(tScS.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tScS.layout())); + static_assert(size<0>(tScS_rowcol) == size(s_aux_cur)); + #pragma unroll + for(int mi = 0; mi < size(s_aux_cur); ++mi) { + int row = m_block * kBlockM + get<0>(tScS_rowcol(mi, _0{})); + int bidh_mi = (row % params.q_heads_per_k_heads) + bidh * params.q_heads_per_k_heads; + s_aux_cur(mi) = static_cast(reinterpret_cast(params.s_aux_ptr)[bidh_mi]); + } + return softmax.template normalize_softmax_lse(acc_o, params.scale_softmax,/*s_aux_cur=*/s_aux_cur, /*scale_softmax_log2=*/params.scale_softmax_log2); + } + else if(params.s_aux_ptr && !params.seqlenq_ngroups_swapped){ + Tensor s_aux_cur = make_tensor(Shape(acc_o)>>{}); + #pragma unroll + for(int mi = 0; mi < size(s_aux_cur); ++mi) { + s_aux_cur(mi) = static_cast(reinterpret_cast(params.s_aux_ptr)[bidh]); + } + return softmax.template normalize_softmax_lse(acc_o, params.scale_softmax,/*s_aux_cur=*/s_aux_cur, /*scale_softmax_log2=*/params.scale_softmax_log2); + } + else{ + return softmax.template normalize_softmax_lse(acc_o, params.scale_softmax, /*scale_softmax_log2=*/params.scale_softmax_log2); + } + }(); // Convert acc_o from fp32 to fp16/bf16 Tensor rO = FLASH_NAMESPACE::convert_type(acc_o); Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) @@ -1008,8 +1034,34 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons } // Epilogue - - Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax); +// TODO(dudugong-gitch): fetch current sink values and inject into normalize_softmax_lse + Tensor lse = [&]{ + if (params.s_aux_ptr && params.seqlenq_ngroups_swapped){ + Tensor s_aux_cur = make_tensor(Shape(acc_o)>>{}); + Tensor cS = make_identity_tensor(Shape, Int>{}); + Tensor tScS = thr_mma.partition_C(cS); + Tensor tScS_rowcol = make_tensor(tScS.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tScS.layout())); + static_assert(size<0>(tScS_rowcol) == size(s_aux_cur)); + #pragma unroll + for(int mi = 0; mi < size(s_aux_cur); ++mi) { + int row = m_block * kBlockM + get<0>(tScS_rowcol(mi, _0{})); + int bidh_mi = (row % params.q_heads_per_k_heads) + bidh * params.q_heads_per_k_heads; + s_aux_cur(mi) = static_cast(reinterpret_cast(params.s_aux_ptr)[bidh_mi]); + } + return softmax.template normalize_softmax_lse(acc_o, params.scale_softmax,/*s_aux_cur=*/s_aux_cur, /*scale_softmax_log2=*/params.scale_softmax_log2); + } + else if(params.s_aux_ptr && !params.seqlenq_ngroups_swapped){ + Tensor s_aux_cur = make_tensor(Shape(acc_o)>>{}); + #pragma unroll + for(int mi = 0; mi < size(s_aux_cur); ++mi) { + s_aux_cur(mi) = static_cast(reinterpret_cast(params.s_aux_ptr)[bidh]); + } + return softmax.template normalize_softmax_lse(acc_o, params.scale_softmax,/*s_aux_cur=*/s_aux_cur, /*scale_softmax_log2=*/params.scale_softmax_log2); + } + else{ + return softmax.template normalize_softmax_lse(acc_o, params.scale_softmax, /*scale_softmax_log2=*/params.scale_softmax_log2); + } + }(); // if (cute::thread0()) { print(lse); } Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) diff --git a/csrc/flash_attn/src/softmax.h b/csrc/flash_attn/src/softmax.h index 01589adedb..2ae8dc7a78 100644 --- a/csrc/flash_attn/src/softmax.h +++ b/csrc/flash_attn/src/softmax.h @@ -165,6 +165,25 @@ struct Softmax { FLASH_NAMESPACE::reduce_sum(scores, row_sum); } }; + // TODO(dudugong-gitch): inject sink values into normalize_softmax_lse + template + __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, TensorT s_aux_cur, float softmax_scale_log2, float rp_dropout=1.0) { + SumOp sum_op; + quad_allreduce_(row_sum, row_sum, sum_op); + TensorT lse = make_fragment_like(row_sum); + Tensor acc_o_rowcol = make_tensor(acc_o.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_o.layout())); + static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); + #pragma unroll + for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { + float sum = row_sum(mi); + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / (sum + exp2f(float(M_LOG2E) * s_aux_cur(mi) - row_max(mi) * softmax_scale_log2)); + lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum); + float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout; + #pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; } + } + return lse; + }; template __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) { diff --git a/tests/test_fa2_sinks.py b/tests/test_fa2_sinks.py new file mode 100644 index 0000000000..566168d12e --- /dev/null +++ b/tests/test_fa2_sinks.py @@ -0,0 +1,132 @@ +""" +Single-file FlashAttention2 with attention sinks accuracy test. + +Usage: + pytest test_fa2_accuracy_single.py + +Parameters: + so_path - path to the shared library (.so) + atol - absolute tolerance + rtol - relative tolerance + csv_path - path to save accuracy +""" +import pytest +import torch +import csv +import os +import vllm +from typing import List, Tuple + +pkg = os.path.dirname(vllm.__file__) +so_path = os.path.join(pkg, "vllm_flash_attn", "_vllm_fa2_C.abi3.so") +torch.ops.load_library(so_path) +fa2_op = torch.ops._vllm_fa2_C.varlen_fwd + +atol=1e-3 +rtol=1e-3 +csv_path = os.path.join(".", "test_fa2_sinks_accuracy.csv") + +_csv_file = open(csv_path, "a", newline="") +_writer = csv.writer(_csv_file) +if os.stat(csv_path).st_size == 0: # 空文件才写表头 + _writer.writerow(["causal", "headdim", "batch_size", "seqlen", + "nheads_q", "q_heads_per_k_heads", + "rtol", "atol", "bad_ratio"]) + +bs_seqlen_vals = [(32, 1), (32, 512), (16, 1024), (8, 2048), + (4, 4096), (2, 8192), (1, 16384)] +causal_vals = [False, True] +headdim_vals = [64, 128] +nheads_q_vals = [4, 8, 16, 32, 64] +q_heads_per_k_heads_vals = [4, 2, 1] + + +def _parametrize() -> List[Tuple]: + cases = [] + for causal in causal_vals: + for headdim in headdim_vals: + for (bs, sq) in bs_seqlen_vals: + for nhq in nheads_q_vals: + for qhk in q_heads_per_k_heads_vals: + cases.append((causal, headdim, bs, sq, nhq, qhk)) + return cases + + +@pytest.mark.parametrize("causal,headdim,batch_size,seqlen,nheads_q,q_heads_per_k_heads", + _parametrize()) +def test_accuracy(causal, headdim, batch_size, seqlen, nheads_q, + q_heads_per_k_heads): + device, dtype = "cuda", torch.float16 + nheads_k = nheads_q // q_heads_per_k_heads + + + q_o = torch.randn(batch_size, seqlen, nheads_q, headdim, + device=device, dtype=dtype, requires_grad=False) + k_o = torch.randn(batch_size, seqlen, nheads_k, headdim, + device=device, dtype=dtype, requires_grad=False) + v_o = torch.randn(batch_size, seqlen, nheads_k, headdim, + device=device, dtype=dtype, requires_grad=False) + + q = q_o.reshape(-1, nheads_q, headdim) + k = k_o.reshape(-1, nheads_k, headdim) + v = v_o.reshape(-1, nheads_k, headdim) + out_buf = torch.empty_like(q) + cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, seqlen, + device=device, dtype=torch.int32) + s_aux = torch.randn(nheads_q, device=device, dtype=dtype) + + fa2_op(q, k, v, out_buf, + cu_seqlens, cu_seqlens, + None, None, None, None, + seqlen, seqlen, + 0.0, + torch.tensor(headdim ** -0.5, device=device), + False, causal, + -1, -1, 0.0, + False, None, s_aux) + + out_buf = out_buf.reshape(batch_size, seqlen, nheads_q, headdim) + + bad_count = 0 + total_count = 0 + for b in range(batch_size): + for h in range(nheads_q): + q1 = q_o[b, :, h, :].float() + k1 = k_o[b, :, h // q_heads_per_k_heads, :].float() + v1 = v_o[b, :, h // q_heads_per_k_heads, :].float() + + scores = (q1 @ k1.T) * (headdim ** -0.5) + if causal: + causal_mask = torch.triu( + torch.full((seqlen, seqlen), float("-inf"), + device=device, dtype=torch.float32), diagonal=1) + scores = scores + causal_mask + + sink_col = s_aux[h].float().view(1, 1).expand(seqlen, 1) + scores_ext = torch.cat([sink_col, scores], dim=1) + attn = torch.softmax(scores_ext, dim=1) + v_ext = torch.cat([torch.zeros(1, v1.shape[1], device=device, dtype=torch.float32), + v1], dim=0) + ref = (attn @ v_ext).to(dtype) + + bad = ~torch.isclose(ref, out_buf[b, :, h, :], atol=atol, rtol=rtol) + bad_count += bad.sum().item() + total_count += ref.numel() + + ratio = bad_count / total_count * 100 + + _writer.writerow([causal, headdim, batch_size, seqlen, nheads_q, + q_heads_per_k_heads, rtol, atol, f"{ratio:.2f}%"]) + _csv_file.flush() + assert ratio == 0, f"Bad ratio {ratio:.2f}% ({bad_count}/{total_count}) > 0" + + +def pytest_sessionfinish(session, exitstatus): + write_header = not os.path.exists(csv_path) + with open(csv_path, "a", newline="") as f: + writer = csv.writer(f) + if write_header: + writer.writerow(["causal", "headdim", "batch_size", "seqlen", + "nheads_q", "q_heads_per_k_heads", + "rtol", "atol", "bad_ratio"]) + writer.writerows(_ROWS) \ No newline at end of file diff --git a/vllm_flash_attn/flash_attn_interface.py b/vllm_flash_attn/flash_attn_interface.py index 27ef088cca..246c20d9fe 100644 --- a/vllm_flash_attn/flash_attn_interface.py +++ b/vllm_flash_attn/flash_attn_interface.py @@ -229,8 +229,6 @@ def flash_attn_varlen_func( "FA2 does not support scheduler_metadata, q_descale, " "k_descale, v_descale" ) - if s_aux is not None: - raise NotImplementedError("FA2 does not support s_aux") if num_splits > 1: raise NotImplementedError("FA2 does not support num_splits > 1") out, softmax_lse = torch.ops._vllm_fa2_C.varlen_fwd( @@ -255,6 +253,7 @@ def flash_attn_varlen_func( softcap, return_softmax_lse and dropout_p > 0, None, + s_aux ) elif fa_version == 3: assert alibi_slopes is None, "Alibi is not supported in FA3"