@@ -552,7 +552,6 @@ SIMSIMD_PUBLIC void simsimd_intersect_u16_turin( //
552552
553553 // Broadcast index for last element (hoisted outside loop)
554554 __m256i const last_idx = _mm256_set1_epi16 (15 );
555-
556555 while (a + 16 <= a_end && b + 16 <= b_end) {
557556 a_vec.ymm = _mm256_lddqu_si256 ((__m256i const *)a);
558557 b_vec.ymm = _mm256_lddqu_si256 ((__m256i const *)b);
@@ -648,6 +647,8 @@ SIMSIMD_PUBLIC void simsimd_spdot_weights_u16_turin( //
648647 } a_vec, b_vec, product_vec;
649648 product_vec.ymmps = _mm256_setzero_ps ();
650649
650+ // Broadcast index for last element (hoisted outside loop)
651+ __m256i const last_idx = _mm256_set1_epi16 (15 );
651652 while (a + 16 <= a_end && b + 16 <= b_end) {
652653 a_vec.ymm = _mm256_lddqu_si256 ((__m256i const *)a);
653654 b_vec.ymm = _mm256_lddqu_si256 ((__m256i const *)b);
@@ -694,12 +695,12 @@ SIMSIMD_PUBLIC void simsimd_spdot_weights_u16_turin( //
694695 product_vec.ymmps = _mm256_dpbf16_ps (product_vec.ymmps , (__m256bh)a_weights_vec, (__m256bh)b_weights_vec);
695696 }
696697
697- __m256i a_last_broadcasted = _mm256_set1_epi16 (*( short const *)&a_max );
698- __m256i b_last_broadcasted = _mm256_set1_epi16 (*( short const *)&b_max );
698+ __m256i a_last_broadcasted = _mm256_permutexvar_epi16 (last_idx, a_vec. ymm );
699+ __m256i b_last_broadcasted = _mm256_permutexvar_epi16 (last_idx, b_vec. ymm );
699700 __mmask16 a_step_mask = _mm256_cmple_epu16_mask (a_vec.ymm , b_last_broadcasted);
700701 __mmask16 b_step_mask = _mm256_cmple_epu16_mask (b_vec.ymm , a_last_broadcasted);
701- int a_step = 32 - _lzcnt_u32 ( (simsimd_u32_t )a_step_mask); // ? Is this correct? Needs testing!
702- int b_step = 32 - _lzcnt_u32 ( (simsimd_u32_t )b_step_mask);
702+ simsimd_size_t a_step = _tzcnt_u32 (~ (simsimd_u32_t )a_step_mask | 0x10000 );
703+ simsimd_size_t b_step = _tzcnt_u32 (~ (simsimd_u32_t )b_step_mask | 0x10000 );
703704 a += a_step, a_weights += a_step;
704705 b += b_step, b_weights += b_step;
705706 }
@@ -733,6 +734,8 @@ SIMSIMD_PUBLIC void simsimd_spdot_counts_u16_turin( //
733734 } a_vec, b_vec, product_vec;
734735 product_vec.ymm = _mm256_setzero_si256 ();
735736
737+ // Broadcast index for last element (hoisted outside loop)
738+ __m256i const last_idx = _mm256_set1_epi16 (15 );
736739 while (a + 16 <= a_end && b + 16 <= b_end) {
737740 a_vec.ymm = _mm256_lddqu_si256 ((__m256i const *)a);
738741 b_vec.ymm = _mm256_lddqu_si256 ((__m256i const *)b);
@@ -779,12 +782,12 @@ SIMSIMD_PUBLIC void simsimd_spdot_counts_u16_turin( //
779782 product_vec.ymm = _mm256_dpwssds_epi32 (product_vec.ymm , a_weights_vec, b_weights_vec);
780783 }
781784
782- __m256i a_last_broadcasted = _mm256_set1_epi16 (*( short const *)&a_max );
783- __m256i b_last_broadcasted = _mm256_set1_epi16 (*( short const *)&b_max );
785+ __m256i a_last_broadcasted = _mm256_permutexvar_epi16 (last_idx, a_vec. ymm );
786+ __m256i b_last_broadcasted = _mm256_permutexvar_epi16 (last_idx, b_vec. ymm );
784787 __mmask16 a_step_mask = _mm256_cmple_epu16_mask (a_vec.ymm , b_last_broadcasted);
785788 __mmask16 b_step_mask = _mm256_cmple_epu16_mask (b_vec.ymm , a_last_broadcasted);
786- int a_step = 32 - _lzcnt_u32 ( (simsimd_u32_t )a_step_mask); // ? Is this correct? Needs testing!
787- int b_step = 32 - _lzcnt_u32 ( (simsimd_u32_t )b_step_mask);
789+ simsimd_size_t a_step = _tzcnt_u32 (~ (simsimd_u32_t )a_step_mask | 0x10000 );
790+ simsimd_size_t b_step = _tzcnt_u32 (~ (simsimd_u32_t )b_step_mask | 0x10000 );
788791 a += a_step, a_weights += a_step;
789792 b += b_step, b_weights += b_step;
790793 }
0 commit comments