Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions RWKV-v7/ascend/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
cmake_minimum_required(VERSION 3.16.0)
project(Ascend_C)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

# user-defined configuration
set(SOC_VERSION "Ascend910B3" CACHE STRING "system on chip type")
set(ASCEND_CANN_PACKAGE_PATH "/usr/local/Ascend/ascend-toolkit/latest" CACHE PATH "ASCEND CANN package installation directory")
set(RUN_MODE "npu" CACHE STRING "run mode: npu")
set(CMAKE_BUILD_TYPE "Debug" CACHE STRING "Build type Release/Debug (default Debug)" FORCE)
set(CMAKE_INSTALL_PREFIX "${CMAKE_CURRENT_LIST_DIR}/out" CACHE STRING "path for install()" FORCE)

if(EXISTS ${ASCEND_CANN_PACKAGE_PATH}/tools/tikcpp/ascendc_kernel_cmake)
set(ASCENDC_CMAKE_DIR ${ASCEND_CANN_PACKAGE_PATH}/tools/tikcpp/ascendc_kernel_cmake)
elseif(EXISTS ${ASCEND_CANN_PACKAGE_PATH}/compiler/tikcpp/ascendc_kernel_cmake)
set(ASCENDC_CMAKE_DIR ${ASCEND_CANN_PACKAGE_PATH}/compiler/tikcpp/ascendc_kernel_cmake)
elseif(EXISTS ${ASCEND_CANN_PACKAGE_PATH}/ascendc_devkit/tikcpp/samples/cmake)
set(ASCENDC_CMAKE_DIR ${ASCEND_CANN_PACKAGE_PATH}/ascendc_devkit/tikcpp/samples/cmake)
else()
message(FATAL_ERROR "ascendc_kernel_cmake does not exist, please check whether the cann package is installed.")
endif()

include(${ASCENDC_CMAKE_DIR}/ascendc.cmake)

# ascendc_library use to add kernel file to generate ascendc library
ascendc_library(kernels STATIC
wkv7s.cc
)

add_library(pybind11_lib SHARED wkv7s_op.cpp)
target_link_libraries(pybind11_lib PRIVATE
kernels
torch_npu
)
execute_process(COMMAND python3 -c "import os; import torch; print(os.path.dirname(torch.__file__))"
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE TORCH_PATH
)
message("TORCH_PATH is ${TORCH_PATH}")
set(ENV{ASCEND_HOME_PATH} ${ASCEND_CANN_PACKAGE_PATH})
execute_process(COMMAND python3 -c "import os; import torch_npu; print(os.path.dirname(torch_npu.__file__))"
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE TORCH_NPU_PATH
)
message("TORCH_NPU_PATH is ${TORCH_NPU_PATH}")
target_link_directories(pybind11_lib PRIVATE
${TORCH_PATH}/lib
${TORCH_NPU_PATH}/lib
)
target_link_options(pybind11_lib PRIVATE
-Wl,-rpath,${TORCH_PATH}/lib
-Wl,-rpath,${TORCH_NPU_PATH}/lib
)
target_include_directories(pybind11_lib PRIVATE
${TORCH_NPU_PATH}/include
${TORCH_PATH}/include
${TORCH_PATH}/include/torch/csrc/api/include
)
execute_process(COMMAND python3 -m pybind11 --includes
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE PYBIND11_INC
)
string(REPLACE " " ";" PYBIND11_INC ${PYBIND11_INC})
target_compile_options(pybind11_lib PRIVATE
${PYBIND11_INC}
-D_GLIBCXX_USE_CXX11_ABI=0
)

execute_process(COMMAND python3-config --extension-suffix
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE PYBIND11_SUFFIX
)
set_target_properties(pybind11_lib PROPERTIES
OUTPUT_NAME wkv7s${PYBIND11_SUFFIX}
PREFIX "" SUFFIX ""
)
18 changes: 18 additions & 0 deletions RWKV-v7/ascend/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# 1. 编译算子
```
mkdir build
cd build
cmake ..
make
```

# 2. 测试算子
```
python test_rwkv.py
```

# 3. 运行模型
```
cd ..
python rwkv_v7_demo_fast_npu.py
```
101 changes: 101 additions & 0 deletions RWKV-v7/ascend/test_rwkv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import torch
import torch_npu
import sys
sys.path.append("./build")
import wkv7s

HEAD_SIZE = 64
DTYPE = torch.float16

# load(name="wkv7s", sources=["wkv7s_op.cpp", f"wkv7s.cu"], is_python_module=False,
# verbose=True, extra_cuda_cflags=["-res-usage", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-D_N_={HEAD_SIZE}"])
class WKV_7(torch.autograd.Function):
@staticmethod
def forward(ctx, state, r, w, k, v, a, b):
with torch.no_grad():
B, T, C = r.size()
H = C // HEAD_SIZE
N = HEAD_SIZE
assert HEAD_SIZE == C // H
assert r.dtype == DTYPE
assert w.dtype == DTYPE
assert k.dtype == DTYPE
assert v.dtype == DTYPE
assert a.dtype == DTYPE
assert b.dtype == DTYPE
assert r.is_contiguous()
assert w.is_contiguous()
assert k.is_contiguous()
assert v.is_contiguous()
assert a.is_contiguous()
assert b.is_contiguous()
y = torch.empty((B, T, C), device=k.device, dtype=DTYPE, memory_format=torch.contiguous_format)
wkv7s.forward(B, T, C, H, state, r, w, k, v, a, b, y)
return y

def RWKV7_OP_KERNEL(state, r, w, k, v, a, b):
return WKV_7.apply(state, r, w, k, v, a, b)


def RWKV7_OP_TORCH(state, r, w, k, v, a, b):
B, T, C = r.size()
H = C // HEAD_SIZE
N = HEAD_SIZE
r = r.view(B, T, H, N).float()
k = k.view(B, T, H, N).float()
v = v.view(B, T, H, N).float()
a = a.view(B, T, H, N).float()
b = b.view(B, T, H, N).float()
w = torch.exp(-torch.exp(w.view(B, T, H, N).float()))
out = torch.zeros((B, T, H, N), device=r.device, dtype=torch.float)

for t in range(T):
kk = k[:, t, :].view(B, H, 1, N)
rr = r[:, t, :].view(B, H, N, 1)
vv = v[:, t, :].view(B, H, N, 1)
aa = a[:, t, :].view(B, H, N, 1)
bb = b[:, t, :].view(B, H, 1, N)
state = state * w[: , t, :, None, :] + state @ aa @ bb + vv @ kk
out[:, t, :] = (state @ rr).view(B, H, N)

return out.view(B, T, C).to(dtype=DTYPE), state


if __name__ == "__main__":
device = "npu"
B = 1
T = 1
C = 1024

torch.manual_seed(42)
torch.set_printoptions(precision=4, sci_mode=False)

r = torch.randn(B, T, C, dtype=DTYPE, device=device).contiguous()
w = torch.randn(B, T, C, dtype=DTYPE, device=device).contiguous()
k = torch.randn(B, T, C, dtype=DTYPE, device=device).contiguous()
v = torch.randn(B, T, C, dtype=DTYPE, device=device).contiguous()
a = torch.randn(B, T, C, dtype=DTYPE, device=device).contiguous()
b = torch.randn(B, T, C, dtype=DTYPE, device=device).contiguous()
state = torch.randn(B, C // HEAD_SIZE, HEAD_SIZE, HEAD_SIZE, dtype=torch.float, device=device)

with torch.no_grad():
y_torch, state_torch = RWKV7_OP_TORCH(state.clone(), r, w, k, v, a, b)
torch.npu.synchronize()
state_kernel = state.clone()
y_kernel = RWKV7_OP_KERNEL(state_kernel, r, w, k, v, a, b)

print(r[0][0][:64])
print(state[0][0][0][:64])

print(state_torch[0][0][0][:64])
print(state_kernel[0][0][0][:64])

print(y_torch[0][0][:64])
print(y_kernel[0][0][:64])


# === 比较结果 ===
abs_diff = (y_kernel - y_torch).abs().float()
max_diff = abs_diff.max().item()
print("Max absolute difference:", max_diff)
print("All close (atol=1e-3):", torch.allclose(y_kernel, y_torch, atol=1e-3, rtol=1e-3))
Loading