Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ matrixmultiply = { version = "0.3.2", default-features = false, features=["cgemm
serde = { version = "1.0", optional = true, default-features = false, features = ["alloc"] }
rawpointer = { version = "0.2" }

half = {version = "2.7.1", default-features = false, features = ["num-traits"], optional = true}

[dev-dependencies]
defmac = "0.2"
quickcheck = { workspace = true }
Expand All @@ -58,15 +60,17 @@ default = ["std"]
# See README for more instructions
blas = ["dep:cblas-sys", "dep:libc"]

serde = ["dep:serde"]
serde = ["dep:serde", "half?/serde"]

std = ["num-traits/std", "matrixmultiply/std"]
std = ["num-traits/std", "matrixmultiply/std", "half?/std"]
rayon = ["dep:rayon", "std"]

matrixmultiply-threading = ["matrixmultiply/threading"]

portable-atomic-critical-section = ["portable-atomic/critical-section"]

half = ["dep:half"]


[target.'cfg(not(target_has_atomic = "ptr"))'.dependencies]
portable-atomic = { version = "1.6.0" }
Expand Down Expand Up @@ -115,6 +119,6 @@ tag-name = "{{version}}"

# Config specific to docs.rs
[package.metadata.docs.rs]
features = ["approx", "serde", "rayon"]
features = ["approx", "serde", "rayon", "half"]
# Define the configuration attribute `docsrs`
rustdoc-args = ["--cfg", "docsrs"]
6 changes: 5 additions & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ your `Cargo.toml`.

- Whether ``portable-atomic`` should use ``critical-section``

- ``half``

- Enable support for the ``half::f16`` and ``half::bf16`` types.


How to use with cargo
---------------------

Expand Down Expand Up @@ -179,4 +184,3 @@ http://www.apache.org/licenses/LICENSE-2.0 or the MIT license
http://opensource.org/licenses/MIT, at your
option. This file may not be copied, modified, or distributed
except according to those terms.

32 changes: 31 additions & 1 deletion benches/bench1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,7 @@ fn bench_col_iter(bench: &mut test::Bencher)
}

macro_rules! mat_mul {
($modname:ident, $ty:ident, $(($name:ident, $m:expr, $n:expr, $k:expr))+) => {
($modname:ident, $ty:ty, $(($name:ident, $m:expr, $n:expr, $k:expr))+) => {
mod $modname {
use test::{black_box, Bencher};
use ndarray::Array;
Expand All @@ -814,6 +814,36 @@ macro_rules! mat_mul {
};
}

#[cfg(feature = "half")]
mat_mul! {mat_mul_f16, half::f16,
(m004, 4, 4, 4)
(m007, 7, 7, 7)
(m008, 8, 8, 8)
(m012, 12, 12, 12)
(m016, 16, 16, 16)
(m032, 32, 32, 32)
(m064, 64, 64, 64)
(m127, 127, 127, 127) // ~128x slower than f32
(mix16x4, 32, 4, 32)
(mix32x2, 32, 2, 32)
// (mix10000, 128, 10000, 128) // too slow
}

#[cfg(feature = "half")]
mat_mul! {mat_mul_bf16, half::bf16,
(m004, 4, 4, 4)
(m007, 7, 7, 7)
(m008, 8, 8, 8)
(m012, 12, 12, 12)
(m016, 16, 16, 16)
(m032, 32, 32, 32)
(m064, 64, 64, 64)
(m127, 127, 127, 127) // 84x slower than f32
(mix16x4, 32, 4, 32)
(mix32x2, 32, 2, 32)
// (mix10000, 128, 10000, 128) // too slow
}

mat_mul! {mat_mul_f32, f32,
(m004, 4, 4, 4)
(m007, 7, 7, 7)
Expand Down
4 changes: 4 additions & 0 deletions crates/blas-tests/tests/oper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,10 @@ fn mat_mut_zero_len()
}
}
});
#[cfg(feature = "half")]
mat_mul_zero_len!(range_mat::<half::f16>);
#[cfg(feature = "half")]
mat_mul_zero_len!(range_mat::<half::bf16>);
mat_mul_zero_len!(range_mat::<f32>);
mat_mul_zero_len!(range_mat::<f64>);
mat_mul_zero_len!(range_i32);
Expand Down
3 changes: 3 additions & 0 deletions crates/numeric-tests/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,15 @@ rand_distr = { workspace = true }
blas-src = { optional = true, version = "0.10", default-features = false, features = ["openblas"] }
openblas-src = { optional = true, version = ">=0.10.11", default-features = false, features = ["cblas", "system"] }

half = { optional = true, version = "2.7.1", default-features = false, features = ["num-traits", "rand_distr"] }

[dev-dependencies]
num-traits = { workspace = true }
num-complex = { workspace = true }

[features]
test_blas = ["ndarray/blas", "blas-src", "openblas-src"]
half = ["dep:half", "ndarray/half"]

# Config for cargo-release
[package.metadata.release]
Expand Down
28 changes: 28 additions & 0 deletions crates/numeric-tests/tests/accuracy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,20 @@ fn accurate_eye_f64()
}
}

#[test]
#[cfg(feature = "half")]
fn accurate_mul_f16_dot()
{
accurate_mul_float_general::<half::f16>(1e-2, false);
}

#[test]
#[cfg(feature = "half")]
fn accurate_mul_bf16_dot()
{
accurate_mul_float_general::<half::bf16>(1e-1, false);
}

#[test]
fn accurate_mul_f32_dot()
{
Expand Down Expand Up @@ -222,6 +236,20 @@ where
}
}

#[test]
#[cfg(feature = "half")]
fn accurate_mul_complex16()
{
accurate_mul_complex_general::<half::f16>(1e-2);
}

#[test]
#[cfg(feature = "half")]
fn accurate_mul_complexb16()
{
accurate_mul_complex_general::<half::bf16>(1e-1);
}

#[test]
fn accurate_mul_complex32()
{
Expand Down
6 changes: 5 additions & 1 deletion ndarray-rand/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,14 @@ rand = { workspace = true }
rand_distr = { workspace = true }
quickcheck = { workspace = true, optional = true }

half = { optional = true, version = "2.7.1", default-features = false, features = ["num-traits"] }

[dev-dependencies]
rand_isaac = "0.4.0"
quickcheck = { workspace = true }

[features]
half = ["dep:half", "ndarray/half"]

[package.metadata.release]
tag-name = "ndarray-rand-{{version}}"

16 changes: 16 additions & 0 deletions ndarray-rand/benches/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,22 @@ fn uniform_f32(b: &mut Bencher)
b.iter(|| Array::random((m, m), Uniform::new(-1f32, 1.).unwrap()));
}

#[bench]
#[cfg(feature = "half")]
fn norm_f16(b: &mut Bencher)
{
let m = 100;
b.iter(|| Array::random((m, m), Normal::new(half::f16::ZERO, half::f16::ONE).unwrap()));
}

#[bench]
#[cfg(feature = "half")]
fn norm_bf16(b: &mut Bencher)
{
let m = 100;
b.iter(|| Array::random((m, m), Normal::new(half::bf16::ZERO, half::bf16::ONE).unwrap()));
}

#[bench]
fn norm_f32(b: &mut Bencher)
{
Expand Down
46 changes: 46 additions & 0 deletions src/impl_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,16 @@ impl ScalarOperand for i128 {}
impl ScalarOperand for u128 {}
impl ScalarOperand for isize {}
impl ScalarOperand for usize {}
#[cfg(feature = "half")]
impl ScalarOperand for half::f16 {}
#[cfg(feature = "half")]
impl ScalarOperand for half::bf16 {}
impl ScalarOperand for f32 {}
impl ScalarOperand for f64 {}
#[cfg(feature = "half")]
impl ScalarOperand for Complex<half::f16> {}
#[cfg(feature = "half")]
impl ScalarOperand for Complex<half::bf16> {}
impl ScalarOperand for Complex<f32> {}
impl ScalarOperand for Complex<f64> {}

Expand Down Expand Up @@ -468,6 +476,26 @@ mod arithmetic_ops
impl_scalar_lhs_op!(bool, Commute, |, BitOr, bitor, "bit or");
impl_scalar_lhs_op!(bool, Commute, ^, BitXor, bitxor, "bit xor");

#[cfg(feature = "half")]
mod ops_f16 {
use super::*;
impl_scalar_lhs_op!(half::f16, Commute, +, Add, add, "addition");
impl_scalar_lhs_op!(half::f16, Ordered, -, Sub, sub, "subtraction");
impl_scalar_lhs_op!(half::f16, Commute, *, Mul, mul, "multiplication");
impl_scalar_lhs_op!(half::f16, Ordered, /, Div, div, "division");
impl_scalar_lhs_op!(half::f16, Ordered, %, Rem, rem, "remainder");
}

#[cfg(feature = "half")]
mod ops_bf16 {
use super::*;
impl_scalar_lhs_op!(half::bf16, Commute, +, Add, add, "addition");
impl_scalar_lhs_op!(half::bf16, Ordered, -, Sub, sub, "subtraction");
impl_scalar_lhs_op!(half::bf16, Commute, *, Mul, mul, "multiplication");
impl_scalar_lhs_op!(half::bf16, Ordered, /, Div, div, "division");
impl_scalar_lhs_op!(half::bf16, Ordered, %, Rem, rem, "remainder");
}

impl_scalar_lhs_op!(f32, Commute, +, Add, add, "addition");
impl_scalar_lhs_op!(f32, Ordered, -, Sub, sub, "subtraction");
impl_scalar_lhs_op!(f32, Commute, *, Mul, mul, "multiplication");
Expand All @@ -480,6 +508,24 @@ mod arithmetic_ops
impl_scalar_lhs_op!(f64, Ordered, /, Div, div, "division");
impl_scalar_lhs_op!(f64, Ordered, %, Rem, rem, "remainder");

#[cfg(feature = "half")]
mod ops_complex_f16 {
use super::*;
impl_scalar_lhs_op!(Complex<half::f16>, Commute, +, Add, add, "addition");
impl_scalar_lhs_op!(Complex<half::f16>, Ordered, -, Sub, sub, "subtraction");
impl_scalar_lhs_op!(Complex<half::f16>, Commute, *, Mul, mul, "multiplication");
impl_scalar_lhs_op!(Complex<half::f16>, Ordered, /, Div, div, "division");
}

#[cfg(feature = "half")]
mod ops_complex_bf16 {
use super::*;
impl_scalar_lhs_op!(Complex<half::bf16>, Commute, +, Add, add, "addition");
impl_scalar_lhs_op!(Complex<half::bf16>, Ordered, -, Sub, sub, "subtraction");
impl_scalar_lhs_op!(Complex<half::bf16>, Commute, *, Mul, mul, "multiplication");
impl_scalar_lhs_op!(Complex<half::bf16>, Ordered, /, Div, div, "division");
}

impl_scalar_lhs_op!(Complex<f32>, Commute, +, Add, add, "addition");
impl_scalar_lhs_op!(Complex<f32>, Ordered, -, Sub, sub, "subtraction");
impl_scalar_lhs_op!(Complex<f32>, Commute, *, Mul, mul, "multiplication");
Expand Down
4 changes: 4 additions & 0 deletions tests/oper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,10 @@ fn mat_mut_zero_len()
}
}
});
#[cfg(feature = "half")]
mat_mul_zero_len!(range_mat::<half::f16>);
#[cfg(feature = "half")]
mat_mul_zero_len!(range_mat::<half::bf16>);
mat_mul_zero_len!(range_mat::<f32>);
mat_mul_zero_len!(range_mat::<f64>);
mat_mul_zero_len!(range_i32);
Expand Down