Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions include/infinicore/ops/attention.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "../device.hpp"
#include "common/op.hpp"
#include <optional>

namespace infinicore::op {
class Attention {
Expand All @@ -13,4 +14,15 @@ class Attention {

Tensor attention(Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, size_t pos);
void attention_(Tensor out, Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, size_t pos);

Tensor self_attention(Tensor query,
Tensor key,
Tensor value,
std::optional<float> scale);

void self_attention_(Tensor out,
Tensor query,
Tensor key,
Tensor value,
std::optional<float> scale);
} // namespace infinicore::op
2 changes: 2 additions & 0 deletions python/infinicore/nn/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .random_sample import random_sample
from .rms_norm import rms_norm
from .rope import RopeAlgo, rope
from .self_attention import self_attention
from .silu import silu
from .swiglu import swiglu

Expand All @@ -17,4 +18,5 @@
"embedding",
"rope",
"RopeAlgo",
"self_attention",
]
39 changes: 39 additions & 0 deletions python/infinicore/nn/functional/self_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import Optional

from infinicore.lib import _infinicore
from infinicore.tensor import Tensor


def self_attention(
query: Tensor,
key: Tensor,
value: Tensor,
scale: Optional[float] = None,
*,
out=None,
) -> Tensor:
r"""Computes scaled dot product attention on query, key and value tensors."""

seq_len = query.shape[-2]
total_seq_len = key.shape[-2]

assert (1 == seq_len and total_seq_len > 1) or (seq_len == total_seq_len), (
"Incorrect parameter value."
)

if out is None:
return Tensor(
_infinicore.self_attention(
query._underlying, key._underlying, value._underlying, scale
)
)

_infinicore.self_attention_(
out._underlying,
query._underlying,
key._underlying,
value._underlying,
scale,
)

return out
88 changes: 87 additions & 1 deletion src/infinicore/ops/attention/attention.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "infinicore/ops/attention.hpp"

#include "infinicore/ops/causal_softmax.hpp"
#include "infinicore/ops/gemm.hpp"
#include <cmath>
namespace infinicore::op {

common::OpDispatcher<Attention::schema> &Attention::dispatcher() {
Expand All @@ -25,4 +27,88 @@ void attention_(Tensor out, Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor
Attention::execute(out, q, k, v, k_cache, v_cache, pos);
}

Tensor self_attention(Tensor query_states, // [bs, num_attention_heads, ntoken, head_dim]
Tensor key_states, // [bs, num_key_value_heads, total_token, head_dim]
Tensor value_states, // [bs, num_key_value_heads, total_token, head_dim]
std::optional<float> scale) {

auto query_shape = query_states->shape();
auto key_shape = key_states->shape();

Size batch_size = query_shape[0];
Size num_attention_heads = query_shape[1];
Size ntoken = query_shape[2];
Size head_dim = key_shape[3];

Tensor output_values = Tensor::empty({batch_size, num_attention_heads, ntoken, head_dim}, query_states->dtype(), query_states->device());

self_attention_(output_values, query_states, key_states, value_states, scale);

return output_values;
}

void self_attention_(Tensor out,
Tensor query_states,
Tensor key_states,
Tensor value_states,
std::optional<float> scale) {

auto query_shape = query_states->shape();
auto key_shape = key_states->shape();

Size batch_size = query_shape[0];
Size num_attention_heads = query_shape[1];
Size ntoken = query_shape[2];

Size num_key_value_heads = key_shape[1];
Size total_token = key_shape[2];
Size head_dim = key_shape[3];

assert(0 == (num_attention_heads % num_key_value_heads));
Size ngroup = num_attention_heads / num_key_value_heads;

float attention_scale{0.0f};
if (scale.has_value()) {
attention_scale = scale.value();
} else {
attention_scale = 1.f / float(sqrt(head_dim));
}

Tensor out_view = out->view({batch_size, num_key_value_heads, ngroup * ntoken, head_dim});
for (Size ib = 0; ib < batch_size; ++ib) {
Tensor q = query_states->narrow({{0, ib, 1}})->view({num_attention_heads, ntoken, head_dim}); // [ num_attention_heads, ntoken, head_dim]
Tensor k = key_states->narrow({{0, ib, 1}})->view({num_key_value_heads, total_token, head_dim}); // [ num_key_value_heads, total_token, head_dim]
Tensor v = value_states->narrow({{0, ib, 1}})->view({num_key_value_heads, total_token, head_dim}); // [ num_key_value_heads, total_token, head_dim]
Tensor output_v = out_view->narrow({{0, ib, 1}})->view({num_key_value_heads, ngroup * ntoken, head_dim});
{
/*
输入:
q, [ num_attention_heads, ntoken, head_dim]
k, [ num_key_value_heads, total_token, head_dim]
v, [ num_key_value_heads, total_token, head_dim]
输出:
att_val : {num_key_value_heads, ngroup * ntok, head_dim}
*/

auto q_gemm = q->view({num_key_value_heads, ngroup * ntoken, head_dim}); // => {nkvh, ngroup * seq_len, dh}
auto k_gemm = k->permute({0, 2, 1}); // => { nkvh, dh, total_token}
auto v_gemm = v; // => { nkvh, total_token, dh}

// qk_score : => {nkvh, ngroup * ntoken, total_token}
Tensor qk_score = gemm(q_gemm, // {nkvh, ngroup * ntoken, dh}
k_gemm, // {nkvh, dh, total_token}
attention_scale, 0.f);

// softmax
auto qk_softmax = qk_score->view({num_attention_heads, ntoken, total_token});
causal_softmax_(qk_softmax, qk_softmax);

// values
gemm_(output_v, // {nkvh, ngroup * ntoken, dh}
qk_score, // {nkvh, ngroup * ntoken, total_token}
v_gemm, // { nkvh, total_token, dh}
1.0f, 0.0f);
}
}
}
} // namespace infinicore::op
42 changes: 41 additions & 1 deletion src/infinicore/pybind11/ops/attention.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,29 @@ namespace py = pybind11;

namespace infinicore::ops {

Tensor py_self_attention(Tensor query,
Tensor key,
Tensor value,
pybind11::object scale) {
std::optional<float> scale_float = std::nullopt;
if (!scale.is_none()) {
scale_float = scale.cast<float>();
}
return op::self_attention(query, key, value, scale_float);
}

void py_self_attention_(Tensor out,
Tensor query,
Tensor key,
Tensor value,
pybind11::object scale) {
std::optional<float> scale_float = std::nullopt;
if (!scale.is_none()) {
scale_float = scale.cast<float>();
}
op::self_attention_(out, query, key, value, scale_float);
}

inline void bind_attention(py::module &m) {
m.def("attention",
&op::attention,
Expand All @@ -21,7 +44,7 @@ inline void bind_attention(py::module &m) {

Args:
q: Query tensor
k: Key tensor
k: Key tensor
v: Value tensor
k_cache: Key cache tensor
v_cache: Value cache tensor
Expand Down Expand Up @@ -51,6 +74,23 @@ inline void bind_attention(py::module &m) {
v_cache: Value cache tensor
pos: Current position in the sequence
)doc");

m.def("self_attention",
&ops::py_self_attention,
py::arg("query"),
py::arg("key"),
py::arg("value"),
py::arg("scale") = py::none(),
R"doc(Computes scaled dot product attention on query, key and value tensors)doc");

m.def("self_attention_",
&ops::py_self_attention_,
py::arg("out"),
py::arg("query"),
py::arg("key"),
py::arg("value"),
py::arg("scale") = py::none(),
R"doc(In-place, Computes scaled dot product attention on query, key and value tensors)doc");
}

} // namespace infinicore::ops
136 changes: 136 additions & 0 deletions test/infinicore/ops/self_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import sys
import os

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

import torch
import infinicore
from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner
from framework.utils import is_broadcast


# ==============================================================================
# Operator-specific configuration
# ==============================================================================
_TEST_CASES_DATA = [
# bs, ntoken, total_token, num_attention_heads, num_key_value_heads, head_dim
(1, 4, 4, 8, 8, 64),
(1, 1, 4, 8, 8, 64),
(4, 16, 16, 32, 8, 64),
(4, 1, 128, 32, 8, 64),
]


# Tolerance configuration
_TOLERANCE_MAP = {
infinicore.float16: {"atol": 1e-2, "rtol": 1e-2},
infinicore.float32: {"atol": 1e-2, "rtol": 1e-2},
infinicore.bfloat16: {"atol": 5e-2, "rtol": 5e-2},
}


# Data types to test
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
# _TENSOR_DTYPES = [infinicore.bfloat16]


def parse_test_cases():
"""
Parse test case data and return list of TestCase objects for sdpa operation.
Each test case contains all necessary information for execution and validation.
"""
test_cases = []

for data in _TEST_CASES_DATA:
bs = data[0]
ntoken, total_token = data[1], data[2]
num_attention_heads, num_key_value_heads = data[3], data[4]
head_dim = data[5]

# Determine shapes based on batch dimension
query_shape = (bs, num_attention_heads, ntoken, head_dim)
key_shape = (bs, num_key_value_heads, total_token, head_dim)
value_shape = (bs, num_key_value_heads, total_token, head_dim)
out_shape = (bs, num_attention_heads, ntoken, head_dim)

# Check if tensors support in-place operations
c_supports_inplace = not is_broadcast(out_shape)

# Generate test cases for all data types
for dtype in _TENSOR_DTYPES:
tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 0, "rtol": 1e-3})

# Create typed tensor specs
query_spec = TensorSpec.from_tensor(query_shape, None, dtype)
key_spec = TensorSpec.from_tensor(key_shape, None, dtype)
value_spec = TensorSpec.from_tensor(value_shape, None, dtype)
out_spec = TensorSpec.from_tensor(out_shape, None, dtype)

# Test Case 1: Out-of-place (return value)
test_cases.append(
TestCase(
inputs=[query_spec, key_spec, value_spec],
kwargs={},
output_spec=None,
comparison_target=None,
tolerance=tolerance,
description=f"sdpa - OUT_OF_PLACE",
)
)

# Test Case 2: In-place with explicit output tensor
if c_supports_inplace:
test_cases.append(
TestCase(
inputs=[query_spec, key_spec, value_spec],
kwargs=None,
output_spec=out_spec, # Specify the output tensor spec
comparison_target="out",
tolerance=tolerance,
description=f"sdpa - INPLACE(out)",
)
)

return test_cases


class OpTest(BaseOperatorTest):
"""sdpa operator test with simplified implementation"""

def __init__(self):
super().__init__("sdpa")

def get_test_cases(self):
return parse_test_cases()

def torch_operator(self, query, key, value, out=None, **kwargs):
"""PyTorch sdpa implementation"""
ntoken = query.shape[-2]
total_token = key.shape[-2]

is_causal = True
if 1 == ntoken and total_token > 1:
is_causal = False

result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, is_causal=is_causal, enable_gqa=True
)
if out is not None:
out.copy_(result)
return out
return result

def infinicore_operator(self, query, key, value, out=None, **kwargs):
"""InfiniCore sdpa implementation"""
return infinicore.nn.functional.self_attention(query, key, value, out=out)


def main():
"""Main entry point"""
runner = GenericTestRunner(OpTest)
runner.run_and_exit()


if __name__ == "__main__":
main()