Skip to content

Commit 108a8b5

Browse files
committed
Fix: Avoid sqrt(0) in probability.h
1 parent 31195e9 commit 108a8b5

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

include/simsimd/probability.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,8 @@ SIMSIMD_PUBLIC void simsimd_js_f16_sapphire(simsimd_f16_t const* a, simsimd_f16_
108108
d += ai * SIMSIMD_LOG((ai + epsilon) / (mi + epsilon)); \
109109
d += bi * SIMSIMD_LOG((bi + epsilon) / (mi + epsilon)); \
110110
} \
111-
*result = SIMSIMD_SQRT(((simsimd_distance_t)d / 2)); \
111+
simsimd_distance_t d_half = ((simsimd_distance_t)d / 2); \
112+
*result = d_half > 0 ? SIMSIMD_SQRT(d_half) : 0; \
112113
}
113114

114115
SIMSIMD_MAKE_KL(serial, f64, f64, SIMSIMD_DEREFERENCE, SIMSIMD_F32_DIVISION_EPSILON) // simsimd_kl_f64_serial
@@ -225,7 +226,7 @@ SIMSIMD_PUBLIC void simsimd_js_f32_neon(simsimd_f32_t const *a, simsimd_f32_t co
225226

226227
simsimd_f32_t log2_normalizer = 0.693147181f;
227228
simsimd_f32_t sum = vaddvq_f32(sum_vec) * log2_normalizer / 2;
228-
*result = _simsimd_sqrt_f32_neon(sum);
229+
*result = sum > 0 ? _simsimd_sqrt_f32_neon(sum) : 0;
229230
}
230231

231232
#pragma clang attribute pop
@@ -298,7 +299,7 @@ SIMSIMD_PUBLIC void simsimd_js_f16_neon(simsimd_f16_t const *a, simsimd_f16_t co
298299

299300
simsimd_f32_t log2_normalizer = 0.693147181f;
300301
simsimd_f32_t sum = vaddvq_f32(sum_vec) * log2_normalizer / 2;
301-
*result = _simsimd_sqrt_f32_neon(sum);
302+
*result = sum > 0 ? _simsimd_sqrt_f32_neon(sum) : 0;
302303
}
303304

304305
#pragma clang attribute pop
@@ -403,7 +404,7 @@ SIMSIMD_PUBLIC void simsimd_js_f16_haswell(simsimd_f16_t const *a, simsimd_f16_t
403404
simsimd_f32_t log2_normalizer = 0.693147181f;
404405
simsimd_f32_t sum = _simsimd_reduce_f32x8_haswell(sum_vec);
405406
sum *= log2_normalizer / 2;
406-
*result = _simsimd_sqrt_f32_haswell(sum);
407+
*result = sum > 0 ? _simsimd_sqrt_f32_haswell(sum) : 0;
407408
}
408409

409410
#pragma clang attribute pop
@@ -498,7 +499,7 @@ SIMSIMD_PUBLIC void simsimd_js_f32_skylake(simsimd_f32_t const *a, simsimd_f32_t
498499
simsimd_f32_t log2_normalizer = 0.693147181f;
499500
simsimd_f32_t sum = _mm512_reduce_add_ps(_mm512_add_ps(sum_a_vec, sum_b_vec));
500501
sum *= log2_normalizer / 2;
501-
*result = _simsimd_sqrt_f32_haswell(sum);
502+
*result = sum > 0 ? _simsimd_sqrt_f32_haswell(sum) : 0;
502503
}
503504

504505
#pragma clang attribute pop
@@ -591,7 +592,7 @@ SIMSIMD_PUBLIC void simsimd_js_f16_sapphire(simsimd_f16_t const *a, simsimd_f16_
591592
simsimd_f32_t log2_normalizer = 0.693147181f;
592593
simsimd_f32_t sum = _mm512_reduce_add_ph(_mm512_add_ph(sum_a_vec, sum_b_vec));
593594
sum *= log2_normalizer / 2;
594-
*result = _simsimd_sqrt_f32_haswell(sum);
595+
*result = sum > 0 ? _simsimd_sqrt_f32_haswell(sum) : 0;
595596
}
596597

597598
#pragma clang attribute pop

0 commit comments

Comments
 (0)