diff --git a/Cargo.toml b/Cargo.toml index 64f18ae34..9e4fc0abc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,6 +44,8 @@ matrixmultiply = { version = "0.3.2", default-features = false, features=["cgemm serde = { version = "1.0", optional = true, default-features = false, features = ["alloc"] } rawpointer = { version = "0.2" } +half = {version = "2.7.1", default-features = false, features = ["num-traits"], optional = true} + [dev-dependencies] defmac = "0.2" quickcheck = { workspace = true } @@ -58,15 +60,17 @@ default = ["std"] # See README for more instructions blas = ["dep:cblas-sys", "dep:libc"] -serde = ["dep:serde"] +serde = ["dep:serde", "half?/serde"] -std = ["num-traits/std", "matrixmultiply/std"] +std = ["num-traits/std", "matrixmultiply/std", "half?/std"] rayon = ["dep:rayon", "std"] matrixmultiply-threading = ["matrixmultiply/threading"] portable-atomic-critical-section = ["portable-atomic/critical-section"] +half = ["dep:half"] + [target.'cfg(not(target_has_atomic = "ptr"))'.dependencies] portable-atomic = { version = "1.6.0" } @@ -115,6 +119,6 @@ tag-name = "{{version}}" # Config specific to docs.rs [package.metadata.docs.rs] -features = ["approx", "serde", "rayon"] +features = ["approx", "serde", "rayon", "half"] # Define the configuration attribute `docsrs` rustdoc-args = ["--cfg", "docsrs"] diff --git a/README.rst b/README.rst index 49558b1c1..5a9e1cdd3 100644 --- a/README.rst +++ b/README.rst @@ -101,6 +101,11 @@ your `Cargo.toml`. - Whether ``portable-atomic`` should use ``critical-section`` +- ``half`` + + - Enable support for the ``half::f16`` and ``half::bf16`` types. + + How to use with cargo --------------------- @@ -179,4 +184,3 @@ http://www.apache.org/licenses/LICENSE-2.0 or the MIT license http://opensource.org/licenses/MIT, at your option. This file may not be copied, modified, or distributed except according to those terms. - diff --git a/benches/bench1.rs b/benches/bench1.rs index c07b8e3d9..9b4ae9e66 100644 --- a/benches/bench1.rs +++ b/benches/bench1.rs @@ -795,7 +795,7 @@ fn bench_col_iter(bench: &mut test::Bencher) } macro_rules! mat_mul { - ($modname:ident, $ty:ident, $(($name:ident, $m:expr, $n:expr, $k:expr))+) => { + ($modname:ident, $ty:ty, $(($name:ident, $m:expr, $n:expr, $k:expr))+) => { mod $modname { use test::{black_box, Bencher}; use ndarray::Array; @@ -814,6 +814,36 @@ macro_rules! mat_mul { }; } +#[cfg(feature = "half")] +mat_mul! {mat_mul_f16, half::f16, + (m004, 4, 4, 4) + (m007, 7, 7, 7) + (m008, 8, 8, 8) + (m012, 12, 12, 12) + (m016, 16, 16, 16) + (m032, 32, 32, 32) + (m064, 64, 64, 64) + (m127, 127, 127, 127) // ~128x slower than f32 + (mix16x4, 32, 4, 32) + (mix32x2, 32, 2, 32) + // (mix10000, 128, 10000, 128) // too slow +} + +#[cfg(feature = "half")] +mat_mul! {mat_mul_bf16, half::bf16, + (m004, 4, 4, 4) + (m007, 7, 7, 7) + (m008, 8, 8, 8) + (m012, 12, 12, 12) + (m016, 16, 16, 16) + (m032, 32, 32, 32) + (m064, 64, 64, 64) + (m127, 127, 127, 127) // 84x slower than f32 + (mix16x4, 32, 4, 32) + (mix32x2, 32, 2, 32) + // (mix10000, 128, 10000, 128) // too slow +} + mat_mul! {mat_mul_f32, f32, (m004, 4, 4, 4) (m007, 7, 7, 7) diff --git a/crates/blas-tests/tests/oper.rs b/crates/blas-tests/tests/oper.rs index f604ae091..399703bb4 100644 --- a/crates/blas-tests/tests/oper.rs +++ b/crates/blas-tests/tests/oper.rs @@ -218,6 +218,10 @@ fn mat_mut_zero_len() } } }); + #[cfg(feature = "half")] + mat_mul_zero_len!(range_mat::); + #[cfg(feature = "half")] + mat_mul_zero_len!(range_mat::); mat_mul_zero_len!(range_mat::); mat_mul_zero_len!(range_mat::); mat_mul_zero_len!(range_i32); diff --git a/crates/numeric-tests/Cargo.toml b/crates/numeric-tests/Cargo.toml index 3e4014d25..d3549d7c8 100644 --- a/crates/numeric-tests/Cargo.toml +++ b/crates/numeric-tests/Cargo.toml @@ -21,12 +21,15 @@ rand_distr = { workspace = true } blas-src = { optional = true, version = "0.10", default-features = false, features = ["openblas"] } openblas-src = { optional = true, version = ">=0.10.11", default-features = false, features = ["cblas", "system"] } +half = { optional = true, version = "2.7.1", default-features = false, features = ["num-traits", "rand_distr"] } + [dev-dependencies] num-traits = { workspace = true } num-complex = { workspace = true } [features] test_blas = ["ndarray/blas", "blas-src", "openblas-src"] +half = ["dep:half", "ndarray/half"] # Config for cargo-release [package.metadata.release] diff --git a/crates/numeric-tests/tests/accuracy.rs b/crates/numeric-tests/tests/accuracy.rs index db10d57cd..f059e9a2a 100644 --- a/crates/numeric-tests/tests/accuracy.rs +++ b/crates/numeric-tests/tests/accuracy.rs @@ -140,6 +140,20 @@ fn accurate_eye_f64() } } +#[test] +#[cfg(feature = "half")] +fn accurate_mul_f16_dot() +{ + accurate_mul_float_general::(1e-2, false); +} + +#[test] +#[cfg(feature = "half")] +fn accurate_mul_bf16_dot() +{ + accurate_mul_float_general::(1e-1, false); +} + #[test] fn accurate_mul_f32_dot() { @@ -222,6 +236,20 @@ where } } +#[test] +#[cfg(feature = "half")] +fn accurate_mul_complex16() +{ + accurate_mul_complex_general::(1e-2); +} + +#[test] +#[cfg(feature = "half")] +fn accurate_mul_complexb16() +{ + accurate_mul_complex_general::(1e-1); +} + #[test] fn accurate_mul_complex32() { diff --git a/ndarray-rand/Cargo.toml b/ndarray-rand/Cargo.toml index 223dcdfc9..64e9d7451 100644 --- a/ndarray-rand/Cargo.toml +++ b/ndarray-rand/Cargo.toml @@ -20,10 +20,14 @@ rand = { workspace = true } rand_distr = { workspace = true } quickcheck = { workspace = true, optional = true } +half = { optional = true, version = "2.7.1", default-features = false, features = ["num-traits"] } + [dev-dependencies] rand_isaac = "0.4.0" quickcheck = { workspace = true } +[features] +half = ["dep:half", "ndarray/half"] + [package.metadata.release] tag-name = "ndarray-rand-{{version}}" - diff --git a/ndarray-rand/benches/bench.rs b/ndarray-rand/benches/bench.rs index 364eca9f4..227d0dd85 100644 --- a/ndarray-rand/benches/bench.rs +++ b/ndarray-rand/benches/bench.rs @@ -16,6 +16,22 @@ fn uniform_f32(b: &mut Bencher) b.iter(|| Array::random((m, m), Uniform::new(-1f32, 1.).unwrap())); } +#[bench] +#[cfg(feature = "half")] +fn norm_f16(b: &mut Bencher) +{ + let m = 100; + b.iter(|| Array::random((m, m), Normal::new(half::f16::ZERO, half::f16::ONE).unwrap())); +} + +#[bench] +#[cfg(feature = "half")] +fn norm_bf16(b: &mut Bencher) +{ + let m = 100; + b.iter(|| Array::random((m, m), Normal::new(half::bf16::ZERO, half::bf16::ONE).unwrap())); +} + #[bench] fn norm_f32(b: &mut Bencher) { diff --git a/src/impl_ops.rs b/src/impl_ops.rs index 53f49cc43..c347076e7 100644 --- a/src/impl_ops.rs +++ b/src/impl_ops.rs @@ -45,8 +45,16 @@ impl ScalarOperand for i128 {} impl ScalarOperand for u128 {} impl ScalarOperand for isize {} impl ScalarOperand for usize {} +#[cfg(feature = "half")] +impl ScalarOperand for half::f16 {} +#[cfg(feature = "half")] +impl ScalarOperand for half::bf16 {} impl ScalarOperand for f32 {} impl ScalarOperand for f64 {} +#[cfg(feature = "half")] +impl ScalarOperand for Complex {} +#[cfg(feature = "half")] +impl ScalarOperand for Complex {} impl ScalarOperand for Complex {} impl ScalarOperand for Complex {} @@ -468,6 +476,26 @@ mod arithmetic_ops impl_scalar_lhs_op!(bool, Commute, |, BitOr, bitor, "bit or"); impl_scalar_lhs_op!(bool, Commute, ^, BitXor, bitxor, "bit xor"); + #[cfg(feature = "half")] + mod ops_f16 { + use super::*; + impl_scalar_lhs_op!(half::f16, Commute, +, Add, add, "addition"); + impl_scalar_lhs_op!(half::f16, Ordered, -, Sub, sub, "subtraction"); + impl_scalar_lhs_op!(half::f16, Commute, *, Mul, mul, "multiplication"); + impl_scalar_lhs_op!(half::f16, Ordered, /, Div, div, "division"); + impl_scalar_lhs_op!(half::f16, Ordered, %, Rem, rem, "remainder"); + } + + #[cfg(feature = "half")] + mod ops_bf16 { + use super::*; + impl_scalar_lhs_op!(half::bf16, Commute, +, Add, add, "addition"); + impl_scalar_lhs_op!(half::bf16, Ordered, -, Sub, sub, "subtraction"); + impl_scalar_lhs_op!(half::bf16, Commute, *, Mul, mul, "multiplication"); + impl_scalar_lhs_op!(half::bf16, Ordered, /, Div, div, "division"); + impl_scalar_lhs_op!(half::bf16, Ordered, %, Rem, rem, "remainder"); + } + impl_scalar_lhs_op!(f32, Commute, +, Add, add, "addition"); impl_scalar_lhs_op!(f32, Ordered, -, Sub, sub, "subtraction"); impl_scalar_lhs_op!(f32, Commute, *, Mul, mul, "multiplication"); @@ -480,6 +508,24 @@ mod arithmetic_ops impl_scalar_lhs_op!(f64, Ordered, /, Div, div, "division"); impl_scalar_lhs_op!(f64, Ordered, %, Rem, rem, "remainder"); + #[cfg(feature = "half")] + mod ops_complex_f16 { + use super::*; + impl_scalar_lhs_op!(Complex, Commute, +, Add, add, "addition"); + impl_scalar_lhs_op!(Complex, Ordered, -, Sub, sub, "subtraction"); + impl_scalar_lhs_op!(Complex, Commute, *, Mul, mul, "multiplication"); + impl_scalar_lhs_op!(Complex, Ordered, /, Div, div, "division"); + } + + #[cfg(feature = "half")] + mod ops_complex_bf16 { + use super::*; + impl_scalar_lhs_op!(Complex, Commute, +, Add, add, "addition"); + impl_scalar_lhs_op!(Complex, Ordered, -, Sub, sub, "subtraction"); + impl_scalar_lhs_op!(Complex, Commute, *, Mul, mul, "multiplication"); + impl_scalar_lhs_op!(Complex, Ordered, /, Div, div, "division"); + } + impl_scalar_lhs_op!(Complex, Commute, +, Add, add, "addition"); impl_scalar_lhs_op!(Complex, Ordered, -, Sub, sub, "subtraction"); impl_scalar_lhs_op!(Complex, Commute, *, Mul, mul, "multiplication"); diff --git a/tests/oper.rs b/tests/oper.rs index 0751c0c13..925785102 100644 --- a/tests/oper.rs +++ b/tests/oper.rs @@ -481,6 +481,10 @@ fn mat_mut_zero_len() } } }); + #[cfg(feature = "half")] + mat_mul_zero_len!(range_mat::); + #[cfg(feature = "half")] + mat_mul_zero_len!(range_mat::); mat_mul_zero_len!(range_mat::); mat_mul_zero_len!(range_mat::); mat_mul_zero_len!(range_i32);