From 447726280153607f934b8b2337b202d048b07d0d Mon Sep 17 00:00:00 2001 From: "fengqing.lu" Date: Wed, 26 Nov 2025 22:18:05 -0800 Subject: [PATCH 1/3] fix perf --- .../sycltla/kernel/xe_sdpa_fwd_bshd.h | 4 ++-- .../xpu/flash_attn/sycltla/mha_fwd.cpp | 24 +++++++++---------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/ATen/native/transformers/xpu/flash_attn/sycltla/kernel/xe_sdpa_fwd_bshd.h b/src/ATen/native/transformers/xpu/flash_attn/sycltla/kernel/xe_sdpa_fwd_bshd.h index 7b903f4fc9..2fe2ee3eeb 100644 --- a/src/ATen/native/transformers/xpu/flash_attn/sycltla/kernel/xe_sdpa_fwd_bshd.h +++ b/src/ATen/native/transformers/xpu/flash_attn/sycltla/kernel/xe_sdpa_fwd_bshd.h @@ -191,7 +191,7 @@ class FMHAPrefill { } // Find the length of the longest non masked sequence within that subgroup - int calculate_longest_non_masked_length( + CUTLASS_DEVICE int calculate_longest_non_masked_length( const int& seq_len_kv, const int& seq_len_qo, const int& last_seq_coord, @@ -222,7 +222,7 @@ class FMHAPrefill { } template - void handle_corner_cases( + CUTLASS_DEVICE void handle_corner_cases( Tensor& tSr, const int& thread_idx, const int& SubgroupSize, diff --git a/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_fwd.cpp b/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_fwd.cpp index 2ac153ad99..5fae07eb81 100644 --- a/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_fwd.cpp +++ b/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_fwd.cpp @@ -333,19 +333,19 @@ void run_mha_fwd_( TileShapeOutPut, SubgroupLayout, PipelineStages); + } else { + constexpr int PipelineStages = 2; + using TileShapeQK = Shape<_256, _32, _64>; + using TileShapePV = Shape<_256, _32, _32>; + using TileShapeOutPut = Shape<_256, _128, _32>; + using SubgroupLayout = Layout, Stride<_1, _1, _1>>; + run_mha_fwd_specialized( + TileShapeQK, + TileShapePV, + TileShapeOutPut, + SubgroupLayout, + PipelineStages); } - - constexpr int PipelineStages = 2; - using TileShapeQK = Shape<_256, _32, _64>; - using TileShapePV = Shape<_256, _32, _32>; - using TileShapeOutPut = Shape<_256, _128, _32>; - using SubgroupLayout = Layout, Stride<_1, _1, _1>>; - run_mha_fwd_specialized( - TileShapeQK, - TileShapePV, - TileShapeOutPut, - SubgroupLayout, - PipelineStages); } else if (headdim == 192) { constexpr int PipelineStages = 2; using TileShapeQK = Shape<_256, _64, _64>; From 4c8e6b0d8e3f34649a1e279e8d625b7a12ac9be7 Mon Sep 17 00:00:00 2001 From: "fengqing.lu" Date: Thu, 27 Nov 2025 22:01:12 -0800 Subject: [PATCH 2/3] remove intel_gpu_bmg_g31 due to stock pytorch CI is 2025.1 --- .../native/transformers/xpu/flash_attn/sycltla/mha_bwd.cpp | 5 ++--- .../native/transformers/xpu/flash_attn/sycltla/mha_fwd.cpp | 5 ++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_bwd.cpp b/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_bwd.cpp index aa467b4ea9..73b3ec2440 100644 --- a/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_bwd.cpp +++ b/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_bwd.cpp @@ -1464,15 +1464,14 @@ std::tuple flash_attention_backward_sycltla( std::array{ sycl::ext::oneapi::experimental::architecture::intel_gpu_pvc, sycl::ext::oneapi::experimental::architecture::intel_gpu_pvc_vg, - sycl::ext::oneapi::experimental::architecture::intel_gpu_bmg_g21, - sycl::ext::oneapi::experimental::architecture::intel_gpu_bmg_g31}; + sycl::ext::oneapi::experimental::architecture::intel_gpu_bmg_g21}; if (std::find( supported_architectures.begin(), supported_architectures.end(), device_architecture) == supported_architectures.end()) { TORCH_CHECK( false, - "XPU device architecture does not support flash attention backward. Supported architectures are: intel_gpu_pvc, intel_gpu_pvc_vg, intel_gpu_bmg_g21, intel_gpu_bmg_g31."); + "XPU device architecture does not support flash attention backward. Supported architectures are: intel_gpu_pvc, intel_gpu_pvc_vg, intel_gpu_bmg_g21."); } auto grad_query = at::empty_like(query); diff --git a/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_fwd.cpp b/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_fwd.cpp index 5fae07eb81..b7caf6066d 100644 --- a/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_fwd.cpp +++ b/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_fwd.cpp @@ -540,15 +540,14 @@ flash_attention_forward_sycltla( std::array{ sycl::ext::oneapi::experimental::architecture::intel_gpu_pvc, sycl::ext::oneapi::experimental::architecture::intel_gpu_pvc_vg, - sycl::ext::oneapi::experimental::architecture::intel_gpu_bmg_g21, - sycl::ext::oneapi::experimental::architecture::intel_gpu_bmg_g31}; + sycl::ext::oneapi::experimental::architecture::intel_gpu_bmg_g21}; if (std::find( supported_architectures.begin(), supported_architectures.end(), device_architecture) == supported_architectures.end()) { TORCH_CHECK( false, - "XPU device architecture does not support flash attention. Supported architectures are: intel_gpu_pvc, intel_gpu_pvc_vg, intel_gpu_bmg_g21, intel_gpu_bmg_g31."); + "XPU device architecture does not support flash attention. Supported architectures are: intel_gpu_pvc, intel_gpu_pvc_vg, intel_gpu_bmg_g21."); } auto problem_shape = ProblemShapeRegular( From 57c6114954469608aa34c71eca6d23977d169e7d Mon Sep 17 00:00:00 2001 From: "fengqing.lu" Date: Thu, 27 Nov 2025 22:32:17 -0800 Subject: [PATCH 3/3] refine --- src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_bwd.cpp | 2 +- src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_fwd.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_bwd.cpp b/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_bwd.cpp index 73b3ec2440..5605292c77 100644 --- a/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_bwd.cpp +++ b/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_bwd.cpp @@ -1461,7 +1461,7 @@ std::tuple flash_attention_backward_sycltla( .get_info< sycl::ext::oneapi::experimental::info::device::architecture>(); constexpr auto supported_architectures = - std::array{ + std::array{ sycl::ext::oneapi::experimental::architecture::intel_gpu_pvc, sycl::ext::oneapi::experimental::architecture::intel_gpu_pvc_vg, sycl::ext::oneapi::experimental::architecture::intel_gpu_bmg_g21}; diff --git a/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_fwd.cpp b/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_fwd.cpp index b7caf6066d..cda198b8e8 100644 --- a/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_fwd.cpp +++ b/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_fwd.cpp @@ -537,7 +537,7 @@ flash_attention_forward_sycltla( .get_info< sycl::ext::oneapi::experimental::info::device::architecture>(); constexpr auto supported_architectures = - std::array{ + std::array{ sycl::ext::oneapi::experimental::architecture::intel_gpu_pvc, sycl::ext::oneapi::experimental::architecture::intel_gpu_pvc_vg, sycl::ext::oneapi::experimental::architecture::intel_gpu_bmg_g21};