Skip to content

Conversation

@LuFinch
Copy link
Contributor

@LuFinch LuFinch commented Nov 12, 2025

This PR moves the sycltla kernels in pytorch/pytorch#167056 into torch-xpu-ops.

This PR is based on #2030. When the build PR merge, I will rebase this PR.

Copy link
Contributor

@EikanWang EikanWang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBH, I cannot quite understand the detailed implementation. I need to take more time to understand the logic.


file(GLOB xpu_cpp "xpu/*.cpp")
file(GLOB xpu_native_cpp "native/xpu/*.cpp" "native/sparse/*.cpp" "native/sparse/xpu/*.cpp" "native/nested/*.cpp" "native/nested/xpu/*.cpp" "native/transformers/*.cpp" "native/quantized/*.cpp")
file(GLOB xpu_native_cpp "native/xpu/*.cpp" "native/sparse/*.cpp" "native/sparse/xpu/*.cpp" "native/nested/*.cpp" "native/nested/xpu/*.cpp" "native/transformers/*.cpp" "native/quantized/*.cpp" "native/transformers/xpu/flash_attn/*.cpp")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I think we should install the header file under flash_attn into PyTorch such as line 42

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May I know what is the purpose of installing header file?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Give a chance to use them in cpp extension.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@guangyey , I think PyTorch does not expose flash_attn because it is the underlying logic of sdpa, which is exposed as a backend. Meanwhile, I don't believe users invoke the flash_atten of PyTorch because dao/flash_atten is a better choice.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Meanwhile, the namespace of these functions is sycltla. It is weird to let users invoke sycl-tla-specific functions.

Copilot AI review requested due to automatic review settings November 13, 2025 05:52

This comment was marked as outdated.

@LuFinch LuFinch force-pushed the lfq/flash_attention branch from 770035a to 442c445 Compare November 13, 2025 05:55
out = at::empty({batch_size, numhead_qo, seqlen_qo, headsize_vo}, opts);
} else if (layout == ATTN_TENSOR_LAYOUT::BSHD) {
out = at::empty({batch_size, seqlen_qo, numhead_qo, headsize_vo}, opts)
.permute({0, 2, 1, 3});

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why need to permute here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

output is inited as BSHD contiguous but the shape should be BHSD in for SDPA. Hence it needs to permute the seqlen and numhead dimension.

@LuFinch LuFinch force-pushed the lfq/flash_attention branch from 2eb4cd9 to 95f9c65 Compare November 17, 2025 03:04
@EikanWang
Copy link
Contributor

@LuFinch , should we land this PR now?

@LuFinch
Copy link
Contributor Author

LuFinch commented Nov 17, 2025

@EikanWang No. CI failed at build. Checking whether it is a driver issue...

InvalidModule: Invalid SPIR-V module: input SPIR-V module uses unknown extension 'SPV_INTEL_2d_block_io'
 Undefined function _Z45intel_sub_group_2d_block_prefetch_16b_4r16x2cPU3AS1viiiDv2_i found in ... This may result in runtime errors.

@LuFinch
Copy link
Contributor Author

LuFinch commented Nov 17, 2025

The CD docker's driver from rhe-l8.8 is too old which can't find intel 2d load symbol. Need to upgrade driver to rhel-8.10.

#pragma once

#include "cutlass/cutlass.h"
#include "cutlass/fp8_to_fp16.h"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose the fp8 related feature is not part of this PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Just copy the kernel file from sycltla directly. No do code clean yet.

for (int k_tile = 0; k_tile < k_tile_count; ++k_tile) {
copy(params.gmem_tiled_copy_q, tQgQ(_, _, _, k_tile), tQrQ);
copy(params.gmem_tiled_copy_k, tKgK(_, _, _, k_tile), tKrK);
if constexpr (is_fp8_v<ElementQ> && is_fp8_v<ElementK>) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is so large, fp8 should covered by this PR?

auto kv_head_coord = q_head_coord / q_group_size;
int offset_q = 0, offset_k = 0, offset_v = 0;

if constexpr (is_var_len) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that var length is not validated by the UT now? if it is not part of the goal of this PR, suggest to enable it in another PR and add the related UT

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently I copy these collective and kernel file from sycltla to torch-xpu-ops directly because they update the codes frequently and copy entire file make rebase more easy. I think we could merge the code to torch-xpu-ops at first. Then we can do code clean when the sycltla code is stable.

@LuFinch LuFinch force-pushed the lfq/flash_attention branch from 95f9c65 to 89c6a49 Compare November 18, 2025 08:28
@github-actions
Copy link

Performance outliers, please check!

  • 🔴 [-1, 80%), should be regression
Category Model Target vs. Baseline [Eager] Target vs. Baseline [Inductor]
torchbench_bfloat16_training pytorch_unet 1.040893 0.705059

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants