Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 185 additions & 0 deletions benchmarks/benchmark_fa2_sinks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
"""
Single-file FlashAttention2 with attention sinks benchmarks.

Usage:
python3 benchmark_fa2_sinks.py

Parameters:
so_path - path to the shared library (.so)
path_new - path to save results
"""
import math
import torch
from einops import rearrange
import torch.utils.benchmark as benchmark
import csv
import vllm
import os

pkg = os.path.dirname(vllm.__file__)
so_path = os.path.join(pkg, "vllm_flash_attn", "_vllm_fa2_C.abi3.so")
path_new = os.path.join(".", "benchmark_fa2_sinks.csv")

csv_rows = []

def benchmark_forward(
fn, inputs, repeats=10, desc="", verbose=False, amp=False, amp_dtype=torch.float16, **kwinputs
):
"""Use Pytorch Benchmark on the forward pass of an arbitrary function."""
if verbose:
print(desc, "- Forward pass")

def amp_wrapper(inputs, **kwinputs):
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
fn(*inputs)

t = benchmark.Timer(
stmt="fn_amp(inputs, **kwinputs)",
globals={"fn_amp": amp_wrapper, "inputs": inputs, "kwinputs": kwinputs},
num_threads=torch.get_num_threads(),
)
m = t.timeit(repeats)
if verbose:
print(m)
return t, m

def benchmark_fwd_bwd(
fn,
inputs,
grad=None,
repeats=10,
desc="",
verbose=True,
amp=False,
amp_dtype=torch.float16,
**kwinputs,
):
"""Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
return benchmark_forward(
fn,
inputs,
repeats=repeats,
desc=desc,
verbose=verbose,
amp=amp,
amp_dtype=amp_dtype,
**kwinputs,
)

attention_triton = None
xops = None

def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"):
assert mode in ["fwd", "bwd", "fwd_bwd"]
f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)
return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)

def efficiency(flop, time):
return (flop / time / 10**12) if not math.isnan(time) else 0.0


def attention_pytorch(qkv, dropout_p=0.0, causal=True):
"""
Arguments:
qkv: (batch_size, seqlen, 3, nheads, head_dim)
dropout_p: float
Output:
output: (batch_size, seqlen, nheads, head_dim)
"""
batch_size, seqlen, _, nheads, d = qkv.shape
q, k, v = qkv.unbind(dim=2)
q = rearrange(q, 'b t h d -> (b h) t d')
k = rearrange(k, 'b s h d -> (b h) d s')
softmax_scale = 1.0 / math.sqrt(d)
scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device)
scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale),
'(b h) t s -> b h t s', h=nheads)
if causal:
causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
scores = scores + causal_mask.to(dtype=scores.dtype)
attention = torch.softmax(scores, dim=-1)
attention_drop = F.dropout(attention, dropout_p)
output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
return output.to(dtype=qkv.dtype)


def time_fwd_bwd(func, *args, **kwargs):
time_f = benchmark_fwd_bwd(func, *args, **kwargs)
return time_f[1].mean


repeats = 30
device = 'cuda'
dtype = torch.float16

bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)]
causal_vals = [False, True]
headdim_vals = [64, 128]
dim = 2048
dropout_p = 0.0

time_f = {}
time_b = {}
time_f_b = {}
speed_f = {}
speed_b = {}
speed_f_b = {}
def test_time(path, func_name, old_or_new):
for causal in causal_vals:
for headdim in headdim_vals:
for batch_size, seqlen in bs_seqlen_vals:
config = (causal, headdim, batch_size, seqlen)
nheads = dim // headdim

qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
requires_grad=True)
q = qkv[:, :, 0] # (B, S, H, D)
k = qkv[:, :, 1] # (B, S, H, D)
v = qkv[:, :, 2] # (B, S, H, D)

q = q.reshape(-1, nheads, headdim)
k = k.reshape(-1, nheads, headdim)
v = v.reshape(-1, nheads, headdim)
s_aux = torch.randn(nheads,device=device, dtype=dtype,requires_grad=True)
out_buf = torch.empty_like(q)
fa2_fwd_closure = [q, k, v,
out_buf,
torch.tensor([(seqlen)*i for i in range(batch_size+1)], device=device, dtype = torch.int32),
torch.tensor([(seqlen)*i for i in range(batch_size+1)], device=device, dtype = torch.int32),
None,
None,
None,
None,
seqlen,
seqlen,
dropout_p,
torch.tensor(1.0 / (headdim ** 0.5), device=device),
False,
causal,
-1,
-1,
0.0,
dropout_p > 0,
None,
s_aux
]
f = time_fwd_bwd(func_name, fa2_fwd_closure, repeats=repeats, verbose=False)
time_f[config, "Flash2"] = f
print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###")
speed_f[config, "Flash2"] = efficiency(
flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"),
time_f[config, "Flash2"])
print(
f"{"Flash2"} fwd: {speed_f[config, "Flash2"]:.2f} TFLOPs/s, "
)
csv_rows.append([causal, headdim, batch_size, seqlen, f"{speed_f[config, "Flash2"]:.2f}"])
with open(path, "a", newline="") as fp:
writer = csv.writer(fp)
writer.writerow(["causal", "headdim", "batch_size", "seqlen", "TFLOPs/s"])
writer.writerows(csv_rows)

print(f"已写入{path}")

torch.ops.load_library(so_path)
func_name = torch.ops._vllm_fa2_C.varlen_fwd
test_time(path_new, func_name, "new")
49 changes: 43 additions & 6 deletions csrc/flash_attn/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,11 @@ void set_params_fprop(Flash_fwd_params &params,
int window_size_right,
const float softcap,
bool seqlenq_ngroups_swapped=false,
const bool unpadded_lse=false) {
const bool unpadded_lse=false,
//TODO(dudugong-gitch): sinks
const std::optional<at::Tensor> s_aux_ = std::nullopt,
//TODO(dudugong-gitch):q heads per k heads
int q_heads_per_k_heads = 1) {

// Reset the parameters
params = {};
Expand Down Expand Up @@ -81,6 +85,23 @@ void set_params_fprop(Flash_fwd_params &params,
params.o_batch_stride *= seqlen_q;
}
}
// TODO(dudugong-gitch): sink-tokens support – integrate into params
if (s_aux_.has_value()) {
TORCH_CHECK(q.size(-1) == v.size(-1),
"We don't support S_aux with hdim != hdim_v");
auto s_aux = s_aux_.value();
TORCH_CHECK(s_aux.device() == q.device(),
"s_aux must be on the same device as q");
TORCH_CHECK(s_aux.ndimension() == 1,
c10::str("s_aux must be 1-D, but got ", s_aux.ndimension(), "-D."));
params.s_aux_ptr = s_aux.data_ptr();
TORCH_CHECK(s_aux.scalar_type() == q.scalar_type(),
"s_aux and q must have the same dtype");
} else {
params.s_aux_ptr = nullptr;
}
//TODO(dudugong-gitch): q_heads_per_k_heads
params.q_heads_per_k_heads = q_heads_per_k_heads;

params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
Expand Down Expand Up @@ -533,7 +554,9 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
int window_size_right,
const float softcap,
const bool return_softmax,
std::optional<at::Generator> gen_) {
std::optional<at::Generator> gen_,
//TODO(dudugong-gitch): sinks
const std::optional<at::Tensor> &s_aux = std::nullopt) {

// Otherwise the kernel will be launched from cuda:0 device
at::cuda::CUDAGuard device_guard{q.device()};
Expand Down Expand Up @@ -588,6 +611,11 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s

void *cu_seqlens_q_d = cu_seqlens_q.data_ptr();

//TODO(dudugong-gitch): validate shape
TORCH_CHECK(!s_aux.has_value() || s_aux.value().size(0) == q.size(-2),
c10::str("s_aux.size(0) must equal the number of heads of q (",
q.size(-2), "), but got ", s_aux.value().size(0), "."));

// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza
const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size % 8 == 0 && !alibi_slopes_.has_value();
Expand Down Expand Up @@ -689,7 +717,9 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
window_size_right,
softcap,
seqlenq_ngroups_swapped,
/*unpadded_lse*/true);
/*unpadded_lse*/true,
/*s aux*/s_aux,
/*q_heads_per_k_heads*/ngroups);
params.total_q = total_q;

if (paged_KV) {
Expand Down Expand Up @@ -1251,7 +1281,9 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
int window_size_right,
const float softcap,
bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
int num_splits
int num_splits,
//TODO(dudugong-gitch): sinks
const std::optional<at::Tensor> &s_aux = std::nullopt
) {

// Otherwise the kernel will be launched from cuda:0 device
Expand Down Expand Up @@ -1309,6 +1341,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he

// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza
//TODO(dudugong-gitch): q_heads_per_k_heads
const int q_heads_per_k_heads = num_heads / num_heads_k;
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
if (seqlenq_ngroups_swapped) {
const int ngroups = num_heads / num_heads_k;
Expand Down Expand Up @@ -1384,8 +1418,11 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
softmax_scale,
window_size_left,
window_size_right,
softcap
);
softcap,
/*seqlenq_ngroups_swapped default*/false,
/*unpadded_lse default*/false,
/*s aux*/s_aux,
/*q_heads_per_k_heads*/q_heads_per_k_heads);

at::Tensor k, v, k_padded, v_padded;
if (k_.has_value()) {
Expand Down
6 changes: 5 additions & 1 deletion csrc/flash_attn/flash_api_sparse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,11 @@ void set_params_fprop(Flash_fwd_params &params,
int window_size_right,
const float softcap,
bool seqlenq_ngroups_swapped=false,
const bool unpadded_lse=false);
const bool unpadded_lse=false,
//TODO(dudugong-gitch): sinks
std::optional<at::Tensor> s_aux_ = std::nullopt,
//TODO(dudugong-gitch):q_heads_per_k_heads
int q_heads_per_k_heads = 1);

std::tuple<at::Tensor, at::Tensor> set_params_splitkv(Flash_fwd_params &params, const int batch_size,
const int num_heads, const int head_size, const int max_seqlen_k, const int max_seqlen_q,
Expand Down
14 changes: 10 additions & 4 deletions csrc/flash_attn/flash_api_torch_lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
int window_size_right,
const float softcap,
const bool return_softmax,
std::optional<at::Generator> gen_);
std::optional<at::Generator> gen_,
//TODO(dudugong-gitch): sinks
const std::optional<at::Tensor> &s_aux);

std::vector<at::Tensor>
mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
Expand All @@ -57,7 +59,9 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
int window_size_right,
const float softcap,
bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
int num_splits);
int num_splits,
//TODO(dudugong-gitch): sinks
const std::optional<at::Tensor> &s_aux_);

/////////////////////////// From flash_api_sparse.cpp //////////////////////////

Expand Down Expand Up @@ -104,18 +108,20 @@ mha_varlen_fwd_sparse(at::Tensor &q, // total_q x num_heads x head_size, total_
/**
* Torch Library Registration
*/
// TODO(dudugong-gitch): sync pybind once TORCH_LIBRARY supports optional<Tensor> for sink tokens
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("varlen_fwd(Tensor! q, Tensor k, Tensor v, Tensor!? out, Tensor cu_seqlens_q, "
"Tensor cu_seqlens_k, Tensor? seqused_k, Tensor? leftpad_k, Tensor? block_table, Tensor? alibi_slopes, "
"int max_seqlen_q, int max_seqlen_k, float p_dropout, float softmax_scale, bool zero_tensors, "
"bool is_causal, int window_size_left, int window_size_right, float softcap, bool return_softmax, "
"Generator? gen) -> Tensor[]");
"Generator? gen, "
"Tensor? s_aux = None) -> Tensor[]");
ops.impl("varlen_fwd", torch::kCUDA, make_pytorch_shim(&mha_varlen_fwd));

ops.def("fwd_kvcache(Tensor! q, Tensor kcache, Tensor vcache, Tensor? k, Tensor? v, Tensor? seqlens_k, "
"Tensor? rotary_cos, Tensor? rotary_sin, Tensor? cache_batch_idx, Tensor? leftpad_k, Tensor? block_table, "
"Tensor? alibi_slopes, Tensor!? out, float softmax_scale, bool is_causal, int window_size_left, "
"int window_size_right, float softcap, bool is_rotary_interleaved, int num_splits) -> Tensor[]");
"int window_size_right, float softcap, bool is_rotary_interleaved, int num_splits, Tensor? s_aux_) -> Tensor[]");
ops.impl("fwd_kvcache", torch::kCUDA, make_pytorch_shim(&mha_fwd_kvcache));

ops.def("fwd_sparse(Tensor! q, Tensor k, Tensor v, "
Expand Down
4 changes: 4 additions & 0 deletions csrc/flash_attn/src/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ struct Qkv_params {
// In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be
// different from nheads (query).
int h_h_k_ratio; // precompute h / h_k,
//TODO(dudugong-gitch):q_heads_per_k_heads
int q_heads_per_k_heads = 1;
};

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -140,6 +142,8 @@ struct Flash_fwd_params : public Qkv_params {

bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q].
bool seqlenq_ngroups_swapped; // q has been transposed from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d).
// TODO(dudugong-gitch): sink-token vector pointer
void *__restrict__ s_aux_ptr = nullptr;
};

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
Loading