Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions include/infinicore/nn/rope.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ namespace infinicore::nn {
class RoPE : public Module {
public:
/**
* @brief RoPE algorithm type
* @brief RoPE rotation algorithm type (kernel-level and frequency generation)
*/
enum class Algo {
GPT_J = 0, // GPT-J style RoPE algorithm (Interleave even and odd dimensions)
GPT_NEOX = 1, // GPT-NeoX style RoPE algorithm (First half dimensions for sin, second half for cos)
GPT_J = 0, // GPT-J style: pairs dimensions as (2j, 2j+1), frequency for cache entry j is theta^(-2j/head_dim)
GPT_NEOX = 1, // GPT-NeoX style: pairs dimensions as (j, j+head_dim/2), frequency for cache entry j is theta^(-j/head_dim)
};

/**
Expand All @@ -23,13 +23,15 @@ class RoPE : public Module {
* @param head_dim Dimension of each attention head (must be even)
* @param max_seq_len Maximum sequence length for pre-computed cache
* @param theta Base frequency for rotary embeddings (default: 10000.0)
* @param algo RoPE algorithm type (default: Algo::GPT_J)
* @param freq_gen Frequency generation method (default: Algo::GPT_J)
* @param algo Rotation algorithm type for kernel (default: Algo::GPT_J)
* @param dtype Data type for sin/cos cache (default: DataType::F32)
* @param device Device to create the cache on
*/
RoPE(size_t head_dim,
size_t max_seq_len,
double theta = 10000.0,
Algo freq_gen = Algo::GPT_J,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

所以这个地方又是为什么要两个呢?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rope如果有问题的话需要解决一下

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

问了gpt,hf确实是这样算的

Algo algo = Algo::GPT_J,
const DataType &dtype = DataType::F32,
const Device &device = Device());
Expand All @@ -55,6 +57,7 @@ class RoPE : public Module {
size_t head_dim() const { return head_dim_; }
size_t max_seq_len() const { return max_seq_len_; }
double theta() const { return theta_; }
Algo freq_gen() const { return freq_gen_; }
Algo algo() const { return algo_; }
DataType dtype() const { return dtype_; }

Expand All @@ -72,7 +75,8 @@ class RoPE : public Module {
size_t head_dim_; // Dimension of each attention head
size_t max_seq_len_; // Maximum sequence length
double theta_; // Base frequency for rotary embeddings
Algo algo_; // RoPE algorithm type
Algo freq_gen_; // Frequency generation method
Algo algo_; // Rotation algorithm type (kernel-level)
DataType dtype_; // Data type for cache tables
};

Expand Down
2 changes: 2 additions & 0 deletions python/infinicore/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import contextlib

import infinicore.context as context
import infinicore.nn as nn

# Import context functions
Expand Down Expand Up @@ -60,6 +61,7 @@

__all__ = [
# Modules.
"context",
"nn",
# Classes.
"device",
Expand Down
4 changes: 2 additions & 2 deletions python/infinicore/nn/functional/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
class RopeAlgo:
r"""Different types of RoPE algorithms."""

GPT_J = _infinicore.Algo.GPT_J
GPT_NEOX = _infinicore.Algo.GPT_NEOX
GPT_J = _infinicore.RoPEAlgo.GPT_J
GPT_NEOX = _infinicore.RoPEAlgo.GPT_NEOX


def rope(
Expand Down
51 changes: 51 additions & 0 deletions python/infinicore/nn/modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from __future__ import annotations

from typing import Any

from infinicore.lib import _infinicore
from infinicore.tensor import Tensor as TensorWrapper

__all__ = ["RoPE"]


def _unwrap_tensor(tensor: Any):
if isinstance(tensor, TensorWrapper):
return tensor._underlying
return tensor


def _wrap_tensor(tensor: Any) -> TensorWrapper:
if isinstance(tensor, TensorWrapper):
return tensor
return TensorWrapper(tensor)


class RoPE:
"""Python-friendly wrapper for ``_infinicore.RoPE``."""

def __init__(
self,
head_dim: int,
max_seq_len: int,
theta: float = 10000.0,
freq_gen: _infinicore.RoPEAlgo = _infinicore.RoPEAlgo.GPT_J,
algo: _infinicore.RoPEAlgo = _infinicore.RoPEAlgo.GPT_J,
dtype: _infinicore.DataType = _infinicore.DataType.F32,
device=None,
) -> None:
self._module = _infinicore.RoPE(
head_dim,
max_seq_len,
theta,
freq_gen,
algo,
dtype,
getattr(device, "_underlying", device),
)

def forward(self, x, pos):
output = self._module.forward(_unwrap_tensor(x), _unwrap_tensor(pos))
return _wrap_tensor(output)

def __call__(self, x, pos):
return self.forward(x, pos)
42 changes: 30 additions & 12 deletions src/infinicore-test/main.cc
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#include "memory_test.h"
#include "tensor/test_d2h_issue.h"
#include "tensor/test_tensor_destructor.h"
#include "tensor/test_tensor_usage.h"
#include "test_nn_module.h"
#include "test_runner.h"
#include "test_tensor_destructor.h"
#include <iostream>
#include <memory>
#include <spdlog/spdlog.h>
Expand All @@ -10,6 +12,7 @@
struct ParsedArgs {
infiniDevice_t device_type = INFINI_DEVICE_CPU;
bool run_basic = true;
bool run_tensor = true;
bool run_concurrency = true;
bool run_exception_safety = true;
bool run_memory_leak = true;
Expand All @@ -26,7 +29,7 @@ void printUsage() {
<< std::endl
<< "Options:" << std::endl
<< " --<device> Specify the device type (default: cpu)" << std::endl
<< " --test <name> Run specific test (basic|concurrency|exception|leak|performance|stress|module|all)" << std::endl
<< " --test <name> Run specific test (basic|tensor|concurrency|exception|leak|performance|stress|module|all)" << std::endl
<< " --threads <num> Number of threads for concurrency tests (default: 4)" << std::endl
<< " --iterations <num> Number of iterations for stress tests (default: 1000)" << std::endl
<< " --help Show this help message" << std::endl
Expand All @@ -45,6 +48,7 @@ void printUsage() {
<< std::endl
<< "Available tests:" << std::endl
<< " basic - Basic memory allocation and deallocation tests" << std::endl
<< " tensor - Tensor-related tests (destructor, usage)" << std::endl
<< " concurrency - Thread safety and concurrent access tests" << std::endl
<< " exception - Exception safety tests" << std::endl
<< " leak - Memory leak detection tests" << std::endl
Expand Down Expand Up @@ -91,10 +95,12 @@ ParsedArgs parseArgs(int argc, char *argv[]) {
}

std::string test_name = argv[++i];
args.run_basic = args.run_concurrency = args.run_exception_safety = args.run_memory_leak = args.run_performance = args.run_stress = args.run_module = false;
args.run_basic = args.run_tensor = args.run_concurrency = args.run_exception_safety = args.run_memory_leak = args.run_performance = args.run_stress = args.run_module = false;

if (test_name == "basic") {
args.run_basic = true;
} else if (test_name == "tensor") {
args.run_tensor = true;
} else if (test_name == "concurrency") {
args.run_concurrency = true;
} else if (test_name == "exception") {
Expand All @@ -108,7 +114,7 @@ ParsedArgs parseArgs(int argc, char *argv[]) {
} else if (test_name == "module") {
args.run_module = true;
} else if (test_name == "all") {
args.run_basic = args.run_concurrency = args.run_exception_safety = args.run_memory_leak = args.run_performance = args.run_stress = args.run_module = true;
args.run_basic = args.run_tensor = args.run_concurrency = args.run_exception_safety = args.run_memory_leak = args.run_performance = args.run_module = true;
} else {
std::cerr << "Error: Unknown test name: " << test_name << std::endl;
exit(EXIT_FAILURE);
Expand Down Expand Up @@ -145,7 +151,7 @@ ParsedArgs parseArgs(int argc, char *argv[]) {
int main(int argc, char *argv[]) {
try {
ParsedArgs args = parseArgs(argc, argv);
spdlog::info("Arguments parsed successfully");
SPDLOG_INFO("Arguments parsed successfully");

std::cout << "==============================================\n"
<< "InfiniCore Memory Management Test Suite\n"
Expand All @@ -155,21 +161,32 @@ int main(int argc, char *argv[]) {
<< "Iterations: " << args.iterations << "\n"
<< "==============================================" << std::endl;

spdlog::info("About to initialize InfiniCore context");
SPDLOG_INFO("About to initialize InfiniCore context");
// Initialize InfiniCore context
infinicore::context::setDevice(infinicore::Device(static_cast<infinicore::Device::Type>(args.device_type), 0));
spdlog::info("InfiniCore context initialized successfully");
SPDLOG_INFO("InfiniCore context initialized successfully");

spdlog::info("Creating test runner");
SPDLOG_INFO("Creating test runner");
// Create test runner
infinicore::test::InfiniCoreTestRunner runner;
spdlog::info("Test runner created successfully");
SPDLOG_INFO("Test runner created successfully");

// Add tests based on arguments
if (args.run_basic) {
runner.addTest(std::make_unique<infinicore::test::BasicMemoryTest>());

// Add device switch test to basic tests (critical regression test)
runner.addTest(std::make_unique<infinicore::test::DeviceSwitchTest>());
}

if (args.run_tensor) {
runner.addTest(std::make_unique<infinicore::test::TensorDestructorTest>());

// Add tensor usage test (tests operations used in InfiniLM weight loading)
runner.addTest(std::make_unique<infinicore::test::TensorUsageTest>());

// Add D2H issue test (reproduces the segfault issue)
runner.addTest(std::make_unique<infinicore::test::D2HIssueTest>());
}

if (args.run_module) {
Expand All @@ -189,17 +206,18 @@ int main(int argc, char *argv[]) {
}

if (args.run_performance) {
runner.addTest(std::make_unique<infinicore::test::PerformanceTest>());
// TODO: Segmentation fault when LOG LEVEL is set to INFO, passed when set to DEBUG
// runner.addTest(std::make_unique<infinicore::test::PerformanceTest>());
}

if (args.run_stress) {
runner.addTest(std::make_unique<infinicore::test::StressTest>());
}

spdlog::info("About to run all tests");
SPDLOG_INFO("About to run all tests");
// Run all tests
auto results = runner.runAllTests();
spdlog::info("All tests completed");
SPDLOG_INFO("All tests completed");

// Count results and collect failed tests
size_t passed = 0, failed = 0;
Expand Down
Loading
Loading