Skip to content

Commit 730bd35

Browse files
authored
[perf][cpu] Accelerate paged attention GEMMs (QK, PV) on Arm CPUs with NEON (#29193)
Signed-off-by: Fadi Arafeh <[email protected]>
1 parent f55c76c commit 730bd35

File tree

5 files changed

+416
-5
lines changed

5 files changed

+416
-5
lines changed

csrc/cpu/cpu_attn.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,18 @@
1313
#define AMX_DISPATCH(...) case cpu_attention::ISA::AMX:
1414
#endif
1515

16+
#ifdef __aarch64__
17+
#include "cpu_attn_neon.hpp"
18+
#define NEON_DISPATCH(...) \
19+
case cpu_attention::ISA::NEON: { \
20+
using attn_impl = cpu_attention::AttentionImpl<cpu_attention::ISA::NEON, \
21+
scalar_t, head_dim>; \
22+
return __VA_ARGS__(); \
23+
}
24+
#else
25+
#define NEON_DISPATCH(...) case cpu_attention::ISA::NEON:
26+
#endif // #ifdef __aarch64__
27+
1628
#define CPU_ATTN_DISPATCH_CASE(HEAD_DIM, ...) \
1729
case HEAD_DIM: { \
1830
constexpr size_t head_dim = HEAD_DIM; \
@@ -41,6 +53,7 @@
4153
[&] { \
4254
switch (ISA_TYPE) { \
4355
AMX_DISPATCH(__VA_ARGS__) \
56+
NEON_DISPATCH(__VA_ARGS__) \
4457
case cpu_attention::ISA::VEC: { \
4558
using attn_impl = \
4659
cpu_attention::AttentionImpl<cpu_attention::ISA::VEC, scalar_t, \
@@ -73,6 +86,8 @@ torch::Tensor get_scheduler_metadata(
7386
isa = cpu_attention::ISA::VEC;
7487
} else if (isa_hint == "vec16") {
7588
isa = cpu_attention::ISA::VEC16;
89+
} else if (isa_hint == "neon") {
90+
isa = cpu_attention::ISA::NEON;
7691
} else {
7792
TORCH_CHECK(false, "Unsupported CPU attention ISA hint: " + isa_hint);
7893
}
@@ -158,6 +173,8 @@ void cpu_attn_reshape_and_cache(
158173
return cpu_attention::ISA::VEC;
159174
} else if (isa == "vec16") {
160175
return cpu_attention::ISA::VEC16;
176+
} else if (isa == "neon") {
177+
return cpu_attention::ISA::NEON;
161178
} else {
162179
TORCH_CHECK(false, "Invalid ISA type: " + isa);
163180
}

csrc/cpu/cpu_attn_impl.hpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
#include "utils.hpp"
1515

1616
namespace cpu_attention {
17-
enum class ISA { AMX, VEC, VEC16 };
17+
enum class ISA { AMX, VEC, VEC16, NEON };
1818

1919
template <ISA isa, typename scalar_t, int64_t head_dim>
2020
class AttentionImpl {};
@@ -143,6 +143,12 @@ struct AttentionMetadata {
143143
case ISA::VEC:
144144
ss << "VEC, ";
145145
break;
146+
case ISA::VEC16:
147+
ss << "VEC16, ";
148+
break;
149+
case ISA::NEON:
150+
ss << "NEON, ";
151+
break;
146152
}
147153
ss << "workitem_group_num: " << workitem_group_num
148154
<< ", reduction_item_num: " << reduction_item_num

0 commit comments

Comments
 (0)