Skip to content

Commit 7911533

Browse files
author
pengcheng888
committed
issue/586 - 添加python的self_attention的实现和测试
1 parent 74934cd commit 7911533

File tree

6 files changed

+317
-2
lines changed

6 files changed

+317
-2
lines changed

include/infinicore/ops/attention.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "../device.hpp"
44
#include "common/op.hpp"
5+
#include <optional>
56

67
namespace infinicore::op {
78
class Attention {
@@ -13,4 +14,15 @@ class Attention {
1314

1415
Tensor attention(Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, size_t pos);
1516
void attention_(Tensor out, Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, size_t pos);
17+
18+
Tensor self_attention(Tensor query,
19+
Tensor key,
20+
Tensor value,
21+
std::optional<float> scale);
22+
23+
void self_attention_(Tensor out,
24+
Tensor query,
25+
Tensor key,
26+
Tensor value,
27+
std::optional<float> scale);
1628
} // namespace infinicore::op

python/infinicore/nn/functional/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .random_sample import random_sample
55
from .rms_norm import rms_norm
66
from .rope import RopeAlgo, rope
7+
from .self_attention import self_attention
78
from .silu import silu
89
from .swiglu import swiglu
910

@@ -17,4 +18,5 @@
1718
"embedding",
1819
"rope",
1920
"RopeAlgo",
21+
"self_attention",
2022
]
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from typing import Optional
2+
3+
from infinicore.lib import _infinicore
4+
from infinicore.tensor import Tensor
5+
6+
7+
def self_attention(
8+
query: Tensor,
9+
key: Tensor,
10+
value: Tensor,
11+
scale: Optional[float] = None,
12+
*,
13+
out=None,
14+
) -> Tensor:
15+
r"""Computes scaled dot product attention on query, key and value tensors."""
16+
17+
seq_len = query.shape[-2]
18+
total_seq_len = key.shape[-2]
19+
20+
assert (1 == seq_len and total_seq_len > 1) or (seq_len == total_seq_len), (
21+
"Incorrect parameter value."
22+
)
23+
24+
if out is None:
25+
return Tensor(
26+
_infinicore.self_attention(
27+
query._underlying, key._underlying, value._underlying, scale
28+
)
29+
)
30+
31+
_infinicore.self_attention_(
32+
out._underlying,
33+
query._underlying,
34+
key._underlying,
35+
value._underlying,
36+
scale,
37+
)
38+
39+
return out

src/infinicore/ops/attention/attention.cc

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include "infinicore/ops/attention.hpp"
2-
2+
#include "infinicore/ops/causal_softmax.hpp"
3+
#include "infinicore/ops/gemm.hpp"
4+
#include <cmath>
35
namespace infinicore::op {
46

57
common::OpDispatcher<Attention::schema> &Attention::dispatcher() {
@@ -25,4 +27,88 @@ void attention_(Tensor out, Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor
2527
Attention::execute(out, q, k, v, k_cache, v_cache, pos);
2628
}
2729

30+
Tensor self_attention(Tensor query_states, // [bs, num_attention_heads, ntoken, head_dim]
31+
Tensor key_states, // [bs, num_key_value_heads, total_token, head_dim]
32+
Tensor value_states, // [bs, num_key_value_heads, total_token, head_dim]
33+
std::optional<float> scale) {
34+
35+
auto query_shape = query_states->shape();
36+
auto key_shape = key_states->shape();
37+
38+
Size batch_size = query_shape[0];
39+
Size num_attention_heads = query_shape[1];
40+
Size ntoken = query_shape[2];
41+
Size head_dim = key_shape[3];
42+
43+
Tensor output_values = Tensor::empty({batch_size, num_attention_heads, ntoken, head_dim}, query_states->dtype(), query_states->device());
44+
45+
self_attention_(output_values, query_states, key_states, value_states, scale);
46+
47+
return output_values;
48+
}
49+
50+
void self_attention_(Tensor out,
51+
Tensor query_states,
52+
Tensor key_states,
53+
Tensor value_states,
54+
std::optional<float> scale) {
55+
56+
auto query_shape = query_states->shape();
57+
auto key_shape = key_states->shape();
58+
59+
Size batch_size = query_shape[0];
60+
Size num_attention_heads = query_shape[1];
61+
Size ntoken = query_shape[2];
62+
63+
Size num_key_value_heads = key_shape[1];
64+
Size total_token = key_shape[2];
65+
Size head_dim = key_shape[3];
66+
67+
assert(0 == (num_attention_heads % num_key_value_heads));
68+
Size ngroup = num_attention_heads / num_key_value_heads;
69+
70+
float attention_scale{0.0f};
71+
if (scale.has_value()) {
72+
attention_scale = scale.value();
73+
} else {
74+
attention_scale = 1.f / float(sqrt(head_dim));
75+
}
76+
77+
Tensor out_view = out->view({batch_size, num_key_value_heads, ngroup * ntoken, head_dim});
78+
for (Size ib = 0; ib < batch_size; ++ib) {
79+
Tensor q = query_states->narrow({{0, ib, 1}})->view({num_attention_heads, ntoken, head_dim}); // [ num_attention_heads, ntoken, head_dim]
80+
Tensor k = key_states->narrow({{0, ib, 1}})->view({num_key_value_heads, total_token, head_dim}); // [ num_key_value_heads, total_token, head_dim]
81+
Tensor v = value_states->narrow({{0, ib, 1}})->view({num_key_value_heads, total_token, head_dim}); // [ num_key_value_heads, total_token, head_dim]
82+
Tensor output_v = out_view->narrow({{0, ib, 1}})->view({num_key_value_heads, ngroup * ntoken, head_dim});
83+
{
84+
/*
85+
输入:
86+
q, [ num_attention_heads, ntoken, head_dim]
87+
k, [ num_key_value_heads, total_token, head_dim]
88+
v, [ num_key_value_heads, total_token, head_dim]
89+
输出:
90+
att_val : {num_key_value_heads, ngroup * ntok, head_dim}
91+
*/
92+
93+
auto q_gemm = q->view({num_key_value_heads, ngroup * ntoken, head_dim}); // => {nkvh, ngroup * seq_len, dh}
94+
auto k_gemm = k->permute({0, 2, 1}); // => { nkvh, dh, total_token}
95+
auto v_gemm = v; // => { nkvh, total_token, dh}
96+
97+
// qk_score : => {nkvh, ngroup * ntoken, total_token}
98+
Tensor qk_score = gemm(q_gemm, // {nkvh, ngroup * ntoken, dh}
99+
k_gemm, // {nkvh, dh, total_token}
100+
attention_scale, 0.f);
101+
102+
// softmax
103+
auto qk_softmax = qk_score->view({num_attention_heads, ntoken, total_token});
104+
causal_softmax_(qk_softmax, qk_softmax);
105+
106+
// values
107+
gemm_(output_v, // {nkvh, ngroup * ntoken, dh}
108+
qk_score, // {nkvh, ngroup * ntoken, total_token}
109+
v_gemm, // { nkvh, total_token, dh}
110+
1.0f, 0.0f);
111+
}
112+
}
113+
}
28114
} // namespace infinicore::op

src/infinicore/pybind11/ops/attention.hpp

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,29 @@ namespace py = pybind11;
88

99
namespace infinicore::ops {
1010

11+
Tensor py_self_attention(Tensor query,
12+
Tensor key,
13+
Tensor value,
14+
pybind11::object scale) {
15+
std::optional<float> scale_float = std::nullopt;
16+
if (!scale.is_none()) {
17+
scale_float = scale.cast<float>();
18+
}
19+
return op::self_attention(query, key, value, scale_float);
20+
}
21+
22+
void py_self_attention_(Tensor out,
23+
Tensor query,
24+
Tensor key,
25+
Tensor value,
26+
pybind11::object scale) {
27+
std::optional<float> scale_float = std::nullopt;
28+
if (!scale.is_none()) {
29+
scale_float = scale.cast<float>();
30+
}
31+
op::self_attention_(out, query, key, value, scale_float);
32+
}
33+
1134
inline void bind_attention(py::module &m) {
1235
m.def("attention",
1336
&op::attention,
@@ -21,7 +44,7 @@ inline void bind_attention(py::module &m) {
2144
2245
Args:
2346
q: Query tensor
24-
k: Key tensor
47+
k: Key tensor
2548
v: Value tensor
2649
k_cache: Key cache tensor
2750
v_cache: Value cache tensor
@@ -51,6 +74,23 @@ inline void bind_attention(py::module &m) {
5174
v_cache: Value cache tensor
5275
pos: Current position in the sequence
5376
)doc");
77+
78+
m.def("self_attention",
79+
&ops::py_self_attention,
80+
py::arg("query"),
81+
py::arg("key"),
82+
py::arg("value"),
83+
py::arg("scale") = py::none(),
84+
R"doc(Computes scaled dot product attention on query, key and value tensors)doc");
85+
86+
m.def("self_attention_",
87+
&ops::py_self_attention_,
88+
py::arg("out"),
89+
py::arg("query"),
90+
py::arg("key"),
91+
py::arg("value"),
92+
py::arg("scale") = py::none(),
93+
R"doc(In-place, Computes scaled dot product attention on query, key and value tensors)doc");
5494
}
5595

5696
} // namespace infinicore::ops
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import sys
2+
import os
3+
4+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
5+
6+
import torch
7+
import infinicore
8+
from framework.base import BaseOperatorTest, TensorSpec, TestCase
9+
from framework.runner import GenericTestRunner
10+
from framework.utils import is_broadcast
11+
12+
13+
# ==============================================================================
14+
# Operator-specific configuration
15+
# ==============================================================================
16+
_TEST_CASES_DATA = [
17+
# bs, ntoken, total_token, num_attention_heads, num_key_value_heads, head_dim
18+
(1, 4, 4, 8, 8, 64),
19+
(1, 1, 4, 8, 8, 64),
20+
(4, 16, 16, 32, 8, 64),
21+
(4, 1, 128, 32, 8, 64),
22+
]
23+
24+
25+
# Tolerance configuration
26+
_TOLERANCE_MAP = {
27+
infinicore.float16: {"atol": 1e-2, "rtol": 1e-2},
28+
infinicore.float32: {"atol": 1e-2, "rtol": 1e-2},
29+
infinicore.bfloat16: {"atol": 5e-2, "rtol": 5e-2},
30+
}
31+
32+
33+
# Data types to test
34+
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
35+
# _TENSOR_DTYPES = [infinicore.bfloat16]
36+
37+
38+
def parse_test_cases():
39+
"""
40+
Parse test case data and return list of TestCase objects for sdpa operation.
41+
Each test case contains all necessary information for execution and validation.
42+
"""
43+
test_cases = []
44+
45+
for data in _TEST_CASES_DATA:
46+
bs = data[0]
47+
ntoken, total_token = data[1], data[2]
48+
num_attention_heads, num_key_value_heads = data[3], data[4]
49+
head_dim = data[5]
50+
51+
# Determine shapes based on batch dimension
52+
query_shape = (bs, num_attention_heads, ntoken, head_dim)
53+
key_shape = (bs, num_key_value_heads, total_token, head_dim)
54+
value_shape = (bs, num_key_value_heads, total_token, head_dim)
55+
out_shape = (bs, num_attention_heads, ntoken, head_dim)
56+
57+
# Check if tensors support in-place operations
58+
c_supports_inplace = not is_broadcast(out_shape)
59+
60+
# Generate test cases for all data types
61+
for dtype in _TENSOR_DTYPES:
62+
tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 0, "rtol": 1e-3})
63+
64+
# Create typed tensor specs
65+
query_spec = TensorSpec.from_tensor(query_shape, None, dtype)
66+
key_spec = TensorSpec.from_tensor(key_shape, None, dtype)
67+
value_spec = TensorSpec.from_tensor(value_shape, None, dtype)
68+
out_spec = TensorSpec.from_tensor(out_shape, None, dtype)
69+
70+
# Test Case 1: Out-of-place (return value)
71+
test_cases.append(
72+
TestCase(
73+
inputs=[query_spec, key_spec, value_spec],
74+
kwargs={},
75+
output_spec=None,
76+
comparison_target=None,
77+
tolerance=tolerance,
78+
description=f"sdpa - OUT_OF_PLACE",
79+
)
80+
)
81+
82+
# Test Case 2: In-place with explicit output tensor
83+
if c_supports_inplace:
84+
test_cases.append(
85+
TestCase(
86+
inputs=[query_spec, key_spec, value_spec],
87+
kwargs=None,
88+
output_spec=out_spec, # Specify the output tensor spec
89+
comparison_target="out",
90+
tolerance=tolerance,
91+
description=f"sdpa - INPLACE(out)",
92+
)
93+
)
94+
95+
return test_cases
96+
97+
98+
class OpTest(BaseOperatorTest):
99+
"""sdpa operator test with simplified implementation"""
100+
101+
def __init__(self):
102+
super().__init__("sdpa")
103+
104+
def get_test_cases(self):
105+
return parse_test_cases()
106+
107+
def torch_operator(self, query, key, value, out=None, **kwargs):
108+
"""PyTorch sdpa implementation"""
109+
ntoken = query.shape[-2]
110+
total_token = key.shape[-2]
111+
112+
is_causal = True
113+
if 1 == ntoken and total_token > 1:
114+
is_causal = False
115+
116+
result = torch.nn.functional.scaled_dot_product_attention(
117+
query, key, value, is_causal=is_causal, enable_gqa=True
118+
)
119+
if out is not None:
120+
out.copy_(result)
121+
return out
122+
return result
123+
124+
def infinicore_operator(self, query, key, value, out=None, **kwargs):
125+
"""InfiniCore sdpa implementation"""
126+
return infinicore.nn.functional.self_attention(query, key, value, out=out)
127+
128+
129+
def main():
130+
"""Main entry point"""
131+
runner = GenericTestRunner(OpTest)
132+
runner.run_and_exit()
133+
134+
135+
if __name__ == "__main__":
136+
main()

0 commit comments

Comments
 (0)