[Kernel] add attention sinks for flash attention2 #103
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
I have successfully implemented support for attention sinks in FlashAttention-2 and validated it on attention-sink models such as GPT-OSS-20B. Below is a side-by-side speed comparison between the attention-sink–enabled FlashAttention-2 and the original FA-2:
causal,headdim,batch_size,seqlen,TFLOPs/s_old,TFLOPs/s_new,TFLOPs/s_old/TFLOPs/s_new
False,128,1,16384,80.97,81.31,0.9958
False,128,16,1024,72.67,72.86,0.9974
False,128,2,8192,80.34,80.70,0.9955
False,128,32,512,67.15,66.18,1.0147
False,128,4,4096,79.16,79.61,0.9943
False,128,8,2048,76.68,76.92,0.9969
False,64,1,16384,86.70,87.16,0.9947
False,64,16,1024,80.04,79.68,1.0045
False,64,2,8192,86.30,85.53,1.0090
False,64,32,512,76.18,73.69,1.0338
False,64,4,4096,85.14,83.17,1.0237
False,64,8,2048,83.69,81.77,1.0235
True,128,1,16384,78.55,77.32,1.0159
True,128,16,1024,57.83,56.75,1.0190
True,128,2,8192,76.22,76.31,0.9988
True,128,32,512,45.38,45.03,1.0078
True,128,4,4096,72.41,72.09,1.0044
True,128,8,2048,65.28,66.34,0.9840
True,64,1,16384,83.39,83.30,1.0011
True,64,16,1024,63.20,63.68,0.9925
True,64,2,8192,81.37,81.66,0.9964
True,64,32,512,48.91,49.13,0.9955
True,64,4,4096,77.93,77.89,1.0005
True,64,8,2048,72.07,72.16,0.9988