Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion csrc/flash_attn/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
int window_size_right,
const float softcap,
const bool return_softmax,
int num_splits,
std::optional<at::Generator> gen_) {

// Otherwise the kernel will be launched from cuda:0 device
Expand Down Expand Up @@ -706,7 +707,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
std::tie(softmax_lse_accum, out_accum) =
set_params_splitkv(params, batch_size, num_heads, head_size,
max_seqlen_k, max_seqlen_q, head_size_rounded,
p_dropout, /*num_splits*/ 0, get_num_sm(get_current_device()), opts);
p_dropout, num_splits, get_num_sm(get_current_device()), opts);
}

if (leftpad_k_.has_value()) {
Expand Down
3 changes: 2 additions & 1 deletion csrc/flash_attn/flash_api_torch_lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
int window_size_right,
const float softcap,
const bool return_softmax,
int num_splits,
std::optional<at::Generator> gen_);

std::vector<at::Tensor>
Expand Down Expand Up @@ -109,7 +110,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor cu_seqlens_k, Tensor? seqused_k, Tensor? leftpad_k, Tensor? block_table, Tensor? alibi_slopes, "
"int max_seqlen_q, int max_seqlen_k, float p_dropout, float softmax_scale, bool zero_tensors, "
"bool is_causal, int window_size_left, int window_size_right, float softcap, bool return_softmax, "
"Generator? gen) -> Tensor[]");
"int num_splits, Generator? gen) -> Tensor[]");
ops.impl("varlen_fwd", torch::kCUDA, make_pytorch_shim(&mha_varlen_fwd));

ops.def("fwd_kvcache(Tensor! q, Tensor kcache, Tensor vcache, Tensor? k, Tensor? v, Tensor? seqlens_k, "
Expand Down
1 change: 1 addition & 0 deletions vllm_flash_attn/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ def flash_attn_varlen_func(
real_window_size[1],
softcap,
return_softmax_lse and dropout_p > 0,
num_splits,
None,
)
elif fa_version == 3:
Expand Down