-
Notifications
You must be signed in to change notification settings - Fork 61
[SYCL-TLA] Integrate FlashAttention fwd/bwd kernels #2341
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
EikanWang
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TBH, I cannot quite understand the detailed implementation. I need to take more time to understand the logic.
|
|
||
| file(GLOB xpu_cpp "xpu/*.cpp") | ||
| file(GLOB xpu_native_cpp "native/xpu/*.cpp" "native/sparse/*.cpp" "native/sparse/xpu/*.cpp" "native/nested/*.cpp" "native/nested/xpu/*.cpp" "native/transformers/*.cpp" "native/quantized/*.cpp") | ||
| file(GLOB xpu_native_cpp "native/xpu/*.cpp" "native/sparse/*.cpp" "native/sparse/xpu/*.cpp" "native/nested/*.cpp" "native/nested/xpu/*.cpp" "native/transformers/*.cpp" "native/quantized/*.cpp" "native/transformers/xpu/flash_attn/*.cpp") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I think we should install the header file under flash_attn into PyTorch such as line 42
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May I know what is the purpose of installing header file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Give a chance to use them in cpp extension.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@guangyey , I think PyTorch does not expose flash_attn because it is the underlying logic of sdpa, which is exposed as a backend. Meanwhile, I don't believe users invoke the flash_atten of PyTorch because dao/flash_atten is a better choice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Meanwhile, the namespace of these functions is sycltla. It is weird to let users invoke sycl-tla-specific functions.
770035a to
442c445
Compare
| out = at::empty({batch_size, numhead_qo, seqlen_qo, headsize_vo}, opts); | ||
| } else if (layout == ATTN_TENSOR_LAYOUT::BSHD) { | ||
| out = at::empty({batch_size, seqlen_qo, numhead_qo, headsize_vo}, opts) | ||
| .permute({0, 2, 1, 3}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why need to permute here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
output is inited as BSHD contiguous but the shape should be BHSD in for SDPA. Hence it needs to permute the seqlen and numhead dimension.
2eb4cd9 to
95f9c65
Compare
|
@LuFinch , should we land this PR now? |
|
@EikanWang No. CI failed at build. Checking whether it is a driver issue... |
|
The CD docker's driver from rhe-l8.8 is too old which can't find intel 2d load symbol. Need to upgrade driver to rhel-8.10. |
| #pragma once | ||
|
|
||
| #include "cutlass/cutlass.h" | ||
| #include "cutlass/fp8_to_fp16.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose the fp8 related feature is not part of this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Just copy the kernel file from sycltla directly. No do code clean yet.
| for (int k_tile = 0; k_tile < k_tile_count; ++k_tile) { | ||
| copy(params.gmem_tiled_copy_q, tQgQ(_, _, _, k_tile), tQrQ); | ||
| copy(params.gmem_tiled_copy_k, tKgK(_, _, _, k_tile), tKrK); | ||
| if constexpr (is_fp8_v<ElementQ> && is_fp8_v<ElementK>) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR is so large, fp8 should covered by this PR?
| auto kv_head_coord = q_head_coord / q_group_size; | ||
| int offset_q = 0, offset_k = 0, offset_v = 0; | ||
|
|
||
| if constexpr (is_var_len) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that var length is not validated by the UT now? if it is not part of the goal of this PR, suggest to enable it in another PR and add the related UT
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently I copy these collective and kernel file from sycltla to torch-xpu-ops directly because they update the codes frequently and copy entire file make rebase more easy. I think we could merge the code to torch-xpu-ops at first. Then we can do code clean when the sycltla code is stable.
95f9c65 to
89c6a49
Compare
Performance outliers, please check!
|
This PR moves the sycltla kernels in pytorch/pytorch#167056 into torch-xpu-ops.
This PR is based on #2030. When the build PR merge, I will rebase this PR.