@@ -20,11 +20,11 @@ template <typename FMHA> static auto run(typename FMHA::Params params) -> void {
2020
2121 int smem_size = FMHA::SharedStorageSize;
2222
23- const auto sycl_block = syclcompat ::dim3 (block.x , block.y , block.z );
24- const auto sycl_grid = syclcompat ::dim3 (grid.x , grid.y , grid.z );
23+ const auto sycl_block = compat ::dim3 (block.x , block.y , block.z );
24+ const auto sycl_grid = compat ::dim3 (grid.x , grid.y , grid.z );
2525
2626#if !defined(SYCL_EXT_ONEAPI_WORK_GROUP_SCRATCH_MEMORY)
27- using namespace syclcompat ::experimental;
27+ using namespace compat ::experimental;
2828 auto event = launch<cutlass::device_kernel<FMHA>>(
2929 launch_policy{
3030 sycl_grid, sycl_block,
@@ -33,15 +33,15 @@ template <typename FMHA> static auto run(typename FMHA::Params params) -> void {
3333 sycl_exp::sub_group_size<FMHA::DispatchPolicy::SubgroupSize>}},
3434 params);
3535#else
36- syclcompat ::experimental::launch_properties launch_props{
36+ compat ::experimental::launch_properties launch_props{
3737 sycl::ext::oneapi::experimental::work_group_scratch_size (smem_size),
3838 };
39- syclcompat ::experimental::kernel_properties kernel_props{
39+ compat ::experimental::kernel_properties kernel_props{
4040 sycl::ext::oneapi::experimental::sub_group_size<
4141 FMHA::DispatchPolicy::SubgroupSize>};
42- syclcompat ::experimental::launch_policy policy{sycl_grid, sycl_block,
42+ compat ::experimental::launch_policy policy{sycl_grid, sycl_block,
4343 launch_props, kernel_props};
44- auto event = syclcompat ::experimental::launch<cutlass::device_kernel<FMHA>>(
44+ auto event = compat ::experimental::launch<cutlass::device_kernel<FMHA>>(
4545 policy, params);
4646#endif
4747
@@ -102,7 +102,7 @@ static auto attention_run(const at::Tensor &Q, const at::Tensor &K,
102102 using CollectiveEpilogue =
103103 cutlass::flash_attention::collective::FlashPrefillEpilogue<
104104 EpilogueDispatchPolicy, MMAOperation, TileShapeOutput, SubgroupLayout,
105- ElementAccumulator, cutlass::gemm::TagToStrideC_t<LayoutO>,
105+ ElementAccumulator, ElementOutput, cutlass::gemm::TagToStrideC_t<LayoutO>,
106106 ElementOutput, GmemTiledCopyStore>;
107107
108108 // / FA ///
@@ -181,7 +181,7 @@ static auto attention_run(const at::Tensor &Q, const at::Tensor &K,
181181 FMHAPrefillKernel::to_underlying_arguments (arguments, workspace_ptr);
182182 run<FMHAPrefillKernel>(params);
183183
184- syclcompat ::wait ();
184+ compat ::wait ();
185185
186186 } catch (std::exception &e) {
187187 std::cerr << " Runtime error: " << e.what () << std::endl;
0 commit comments