Skip to content

Commit e5dad6c

Browse files
committed
Improve: Faster sparse dot product
1 parent d6e17b1 commit e5dad6c

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

include/simsimd/sparse.h

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,6 @@ SIMSIMD_PUBLIC void simsimd_intersect_u16_turin( //
552552

553553
// Broadcast index for last element (hoisted outside loop)
554554
__m256i const last_idx = _mm256_set1_epi16(15);
555-
556555
while (a + 16 <= a_end && b + 16 <= b_end) {
557556
a_vec.ymm = _mm256_lddqu_si256((__m256i const *)a);
558557
b_vec.ymm = _mm256_lddqu_si256((__m256i const *)b);
@@ -648,6 +647,8 @@ SIMSIMD_PUBLIC void simsimd_spdot_weights_u16_turin( //
648647
} a_vec, b_vec, product_vec;
649648
product_vec.ymmps = _mm256_setzero_ps();
650649

650+
// Broadcast index for last element (hoisted outside loop)
651+
__m256i const last_idx = _mm256_set1_epi16(15);
651652
while (a + 16 <= a_end && b + 16 <= b_end) {
652653
a_vec.ymm = _mm256_lddqu_si256((__m256i const *)a);
653654
b_vec.ymm = _mm256_lddqu_si256((__m256i const *)b);
@@ -694,12 +695,12 @@ SIMSIMD_PUBLIC void simsimd_spdot_weights_u16_turin( //
694695
product_vec.ymmps = _mm256_dpbf16_ps(product_vec.ymmps, (__m256bh)a_weights_vec, (__m256bh)b_weights_vec);
695696
}
696697

697-
__m256i a_last_broadcasted = _mm256_set1_epi16(*(short const *)&a_max);
698-
__m256i b_last_broadcasted = _mm256_set1_epi16(*(short const *)&b_max);
698+
__m256i a_last_broadcasted = _mm256_permutexvar_epi16(last_idx, a_vec.ymm);
699+
__m256i b_last_broadcasted = _mm256_permutexvar_epi16(last_idx, b_vec.ymm);
699700
__mmask16 a_step_mask = _mm256_cmple_epu16_mask(a_vec.ymm, b_last_broadcasted);
700701
__mmask16 b_step_mask = _mm256_cmple_epu16_mask(b_vec.ymm, a_last_broadcasted);
701-
int a_step = 32 - _lzcnt_u32((simsimd_u32_t)a_step_mask); //? Is this correct? Needs testing!
702-
int b_step = 32 - _lzcnt_u32((simsimd_u32_t)b_step_mask);
702+
simsimd_size_t a_step = _tzcnt_u32(~(simsimd_u32_t)a_step_mask | 0x10000);
703+
simsimd_size_t b_step = _tzcnt_u32(~(simsimd_u32_t)b_step_mask | 0x10000);
703704
a += a_step, a_weights += a_step;
704705
b += b_step, b_weights += b_step;
705706
}
@@ -733,6 +734,8 @@ SIMSIMD_PUBLIC void simsimd_spdot_counts_u16_turin( //
733734
} a_vec, b_vec, product_vec;
734735
product_vec.ymm = _mm256_setzero_si256();
735736

737+
// Broadcast index for last element (hoisted outside loop)
738+
__m256i const last_idx = _mm256_set1_epi16(15);
736739
while (a + 16 <= a_end && b + 16 <= b_end) {
737740
a_vec.ymm = _mm256_lddqu_si256((__m256i const *)a);
738741
b_vec.ymm = _mm256_lddqu_si256((__m256i const *)b);
@@ -779,12 +782,12 @@ SIMSIMD_PUBLIC void simsimd_spdot_counts_u16_turin( //
779782
product_vec.ymm = _mm256_dpwssds_epi32(product_vec.ymm, a_weights_vec, b_weights_vec);
780783
}
781784

782-
__m256i a_last_broadcasted = _mm256_set1_epi16(*(short const *)&a_max);
783-
__m256i b_last_broadcasted = _mm256_set1_epi16(*(short const *)&b_max);
785+
__m256i a_last_broadcasted = _mm256_permutexvar_epi16(last_idx, a_vec.ymm);
786+
__m256i b_last_broadcasted = _mm256_permutexvar_epi16(last_idx, b_vec.ymm);
784787
__mmask16 a_step_mask = _mm256_cmple_epu16_mask(a_vec.ymm, b_last_broadcasted);
785788
__mmask16 b_step_mask = _mm256_cmple_epu16_mask(b_vec.ymm, a_last_broadcasted);
786-
int a_step = 32 - _lzcnt_u32((simsimd_u32_t)a_step_mask); //? Is this correct? Needs testing!
787-
int b_step = 32 - _lzcnt_u32((simsimd_u32_t)b_step_mask);
789+
simsimd_size_t a_step = _tzcnt_u32(~(simsimd_u32_t)a_step_mask | 0x10000);
790+
simsimd_size_t b_step = _tzcnt_u32(~(simsimd_u32_t)b_step_mask | 0x10000);
788791
a += a_step, a_weights += a_step;
789792
b += b_step, b_weights += b_step;
790793
}

0 commit comments

Comments
 (0)