Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/core_simd/examples/dot_product.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
// Add these imports to use the stdsimd library
#![feature(portable_simd)]
use core_simd::simd::prelude::*;
use std_float::StdFloat;

// This is your barebones dot product implementation:
// Take 2 vectors, multiply them element wise and *then*
Expand Down Expand Up @@ -71,7 +72,6 @@ pub fn dot_prod_simd_1(a: &[f32], b: &[f32]) -> f32 {

// A lot of knowledgeable use of SIMD comes from knowing specific instructions that are
// available - let's try to use the `mul_add` instruction, which is the fused-multiply-add we were looking for.
use std_float::StdFloat;
pub fn dot_prod_simd_2(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len());
// TODO handle remainder when a.len() % 4 != 0
Expand Down
6 changes: 3 additions & 3 deletions crates/core_simd/examples/spectral_norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ fn a(i: usize, j: usize) -> f64 {

fn mult_av(v: &[f64], out: &mut [f64]) {
assert!(v.len() == out.len());
assert!(v.len() % 2 == 0);
assert!(v.len().is_multiple_of(2));

for (i, out) in out.iter_mut().enumerate() {
let mut sum = f64x2::splat(0.0);
Expand All @@ -26,7 +26,7 @@ fn mult_av(v: &[f64], out: &mut [f64]) {

fn mult_atv(v: &[f64], out: &mut [f64]) {
assert!(v.len() == out.len());
assert!(v.len() % 2 == 0);
assert!(v.len().is_multiple_of(2));

for (i, out) in out.iter_mut().enumerate() {
let mut sum = f64x2::splat(0.0);
Expand All @@ -48,7 +48,7 @@ fn mult_atav(v: &[f64], out: &mut [f64], tmp: &mut [f64]) {
}

pub fn spectral_norm(n: usize) -> f64 {
assert!(n % 2 == 0, "only even lengths are accepted");
assert!(n.is_multiple_of(2), "only even lengths are accepted");

let mut u = vec![1.0; n];
let mut v = u.clone();
Expand Down
54 changes: 54 additions & 0 deletions crates/std_float/examples/fma.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
//! Demonstrates fused multiply-add (FMA) operations.

#![feature(portable_simd)]
use core_simd::simd::prelude::*;
use std_float::StdFloat;

fn main() {
let a = f32x4::from_array([1.0, 2.0, 3.0, 4.0]);
let b = f32x4::from_array([2.0, 3.0, 4.0, 5.0]);
let c = f32x4::from_array([10.0, 10.0, 10.0, 10.0]);

println!("FMA: a*b + c");
println!("a = {:?}", a.to_array());
println!("b = {:?}", b.to_array());
println!("c = {:?}", c.to_array());
println!("result = {:?}", a.mul_add(b, c).to_array());
println!();

// Polynomial: p(x) = 2x³ + 3x² + 4x + 5
// Horner form: ((2x + 3)x + 4)x + 5
let x = f32x8::from_array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]);
let result = f32x8::splat(2.0)
.mul_add(x, f32x8::splat(3.0))
.mul_add(x, f32x8::splat(4.0))
.mul_add(x, f32x8::splat(5.0));

println!("Polynomial p(x) = 2x³ + 3x² + 4x + 5");
println!("x = {:?}", x.to_array());
println!("p(x) = {:?}", result.to_array());
println!();

let v1 = f32x4::from_array([1.0, 2.0, 3.0, 4.0]);
let v2 = f32x4::from_array([5.0, 6.0, 7.0, 8.0]);

let mut acc = 0.0;
for i in 0..4 {
acc = v1[i].mul_add(v2[i], acc);
}

println!("Dot product using FMA:");
println!("v1 · v2 = {}", acc);
println!();

let large = f32x4::splat(1e10);
let small = f32x4::splat(1.0);

let fma_result = large.mul_add(f32x4::splat(1.0), small);
let separate_result = large * f32x4::splat(1.0) + small;

println!("Accuracy comparison (1e10 * 1.0 + 1.0):");
println!("FMA result: {:?}", fma_result.to_array());
println!("Separate ops: {:?}", separate_result.to_array());
println!("Both preserve precision in this case");
}
13 changes: 13 additions & 0 deletions crates/std_float/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,19 @@ pub trait StdFloat: Sealed + Sized {
unsafe { intrinsics::simd_fma(self, a, b) }
}

/// Elementwise fused multiply-subtract. Computes `(self * a) - b` with only one rounding error,
/// yielding a more accurate result than an unfused multiply-subtract.
///
/// Using `mul_sub` *may* be more performant than an unfused multiply-subtract if the target
/// architecture has a dedicated `fma` CPU instruction. However, this is not always
/// true, and will be heavily dependent on designing algorithms with specific target
/// hardware in mind.
#[inline]
#[must_use = "method returns a new vector and does not mutate the original value"]
fn mul_sub(self, a: Self, b: Self) -> Self {
unsafe { intrinsics::simd_fma(self, a, intrinsics::simd_neg(b)) }
}

/// Produces a vector where every element has the square root value
/// of the equivalently-indexed element in `self`
#[inline]
Expand Down
176 changes: 176 additions & 0 deletions crates/std_float/tests/fma.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
#![feature(portable_simd)]

use core_simd::simd::prelude::*;
use std_float::StdFloat;

#[test]
fn test_mul_add_basic() {
let a = f32x4::from_array([2.0, 3.0, 4.0, 5.0]);
let b = f32x4::from_array([10.0, 10.0, 10.0, 10.0]);
let c = f32x4::from_array([1.0, 2.0, 3.0, 4.0]);
assert_eq!(a.mul_add(b, c), f32x4::from_array([21.0, 32.0, 43.0, 54.0]));
}

#[test]
fn test_mul_add_f64() {
let a = f64x4::from_array([2.0, 3.0, 4.0, 5.0]);
let b = f64x4::from_array([10.0, 10.0, 10.0, 10.0]);
let c = f64x4::from_array([1.0, 2.0, 3.0, 4.0]);
assert_eq!(a.mul_add(b, c), f64x4::from_array([21.0, 32.0, 43.0, 54.0]));
}

#[test]
fn test_mul_sub_basic() {
let a = f32x4::from_array([2.0, 3.0, 4.0, 5.0]);
let b = f32x4::from_array([10.0, 10.0, 10.0, 10.0]);
let c = f32x4::from_array([1.0, 2.0, 3.0, 4.0]);
assert_eq!(a.mul_sub(b, c), f32x4::from_array([19.0, 28.0, 37.0, 46.0]));
}

#[test]
fn test_mul_sub_f64() {
let a = f64x4::from_array([2.0, 3.0, 4.0, 5.0]);
let b = f64x4::from_array([10.0, 10.0, 10.0, 10.0]);
let c = f64x4::from_array([1.0, 2.0, 3.0, 4.0]);
assert_eq!(a.mul_sub(b, c), f64x4::from_array([19.0, 28.0, 37.0, 46.0]));
}

#[test]
fn test_fma_accuracy_catastrophic_cancellation() {
let epsilon = 1e-4_f32;
let x = 1.0 + epsilon;
let y = 1.0 - epsilon;

let a = f32x4::splat(x);
let b = f32x4::splat(y);
let c = f32x4::splat(-1.0);

let fma_result = a.mul_add(b, c);
let separate_result = a * b + c;

let expected = -epsilon * epsilon;

let fma_error = (fma_result[0] - expected).abs();
let sep_error = (separate_result[0] - expected).abs();

assert!(fma_error <= sep_error);
}

#[test]
fn test_fma_accuracy_discriminant() {
let b = f64x2::splat(1e8);
let four_ac = f64x2::splat(1.0);

let fma_discriminant = b.mul_add(b, -four_ac);
let sep_discriminant = b * b - four_ac;

let expected = 1e16 - 1.0;

let fma_error = ((fma_discriminant[0] - expected) / expected).abs();
let sep_error = ((sep_discriminant[0] - expected) / expected).abs();

assert!(fma_error <= sep_error);
}

#[test]
fn test_fma_accuracy_polynomial() {
let x = f64x2::splat(1.00001);
let a = f64x2::splat(1.0);
let b = f64x2::splat(-2.0);
let c = f64x2::splat(1.0);

let fma_result = a.mul_add(x, b).mul_add(x, c);
let sep_result = (a * x + b) * x + c;

let expected = (x[0] - 1.0) * (x[0] - 1.0);

let fma_error = (fma_result[0] - expected).abs();
let sep_error = (sep_result[0] - expected).abs();

assert!(fma_error < sep_error || (fma_error - sep_error).abs() < 1e-15);
}

#[test]
fn test_negative_values() {
let a = f32x4::from_array([-2.0, -3.0, -4.0, -5.0]);
let b = f32x4::splat(2.0);
let c = f32x4::splat(1.0);
assert_eq!(a.mul_add(b, c), f32x4::from_array([-3.0, -5.0, -7.0, -9.0]));
assert_eq!(
a.mul_sub(b, c),
f32x4::from_array([-5.0, -7.0, -9.0, -11.0])
);
}

#[test]
fn test_infinity() {
let a = f32x4::from_array([f32::INFINITY, 1.0, 2.0, 3.0]);
let b = f32x4::splat(2.0);
let c = f32x4::splat(1.0);
let result = a.mul_add(b, c);
assert_eq!(result[0], f32::INFINITY);
assert_eq!(result[1], 3.0);
}

#[test]
fn test_nan_propagation() {
let a = f32x4::from_array([f32::NAN, 2.0, 3.0, 4.0]);
let b = f32x4::splat(2.0);
let c = f32x4::splat(1.0);
let result = a.mul_add(b, c);
assert!(result[0].is_nan());
assert_eq!(result[1], 5.0);
}

#[test]
fn test_different_sizes() {
let a2 = f32x2::from_array([3.0, 4.0]);
let b2 = f32x2::from_array([2.0, 2.0]);
let c2 = f32x2::from_array([1.0, 1.0]);
assert_eq!(a2.mul_add(b2, c2), f32x2::from_array([7.0, 9.0]));

let a8 = f32x8::splat(2.0);
let b8 = f32x8::splat(3.0);
let c8 = f32x8::splat(4.0);
assert_eq!(a8.mul_add(b8, c8), f32x8::splat(10.0));
}

#[test]
fn test_polynomial_evaluation() {
let x = f32x4::from_array([1.0, 2.0, 3.0, 4.0]);
let result = f32x4::splat(2.0)
.mul_add(x, f32x4::splat(3.0))
.mul_add(x, f32x4::splat(5.0));
assert_eq!(result, f32x4::from_array([10.0, 19.0, 32.0, 49.0]));
}

#[test]
fn test_max_min_values() {
let a = f32x4::from_array([f32::MAX, f32::MIN, 1.0, -1.0]);
let b = f32x4::splat(1.0);
let c = f32x4::splat(0.0);
let result = a.mul_add(b, c);
assert_eq!(result[0], f32::MAX);
assert_eq!(result[1], f32::MIN);
}

#[test]
fn test_subnormal_values() {
let subnormal = f32::MIN_POSITIVE / 2.0;
let a = f32x4::splat(subnormal);
let b = f32x4::splat(2.0);
let c = f32x4::splat(0.0);
let result = a.mul_add(b, c);
assert!(result[0].is_finite());

// On platforms with flush-to-zero (FTZ) mode (e.g., ARM NEON), subnormal
// values in SIMD operations may be flushed to zero for performance.
// We accept either the mathematically correct result or zero.
let expected = subnormal * 2.0;
assert!(
result[0] == expected || result[0] == 0.0,
"Expected {} (or 0.0 due to FTZ), got {}",
expected,
result[0]
);
}