@@ -429,23 +429,24 @@ def int8_gemm_addmm_kernel(
429429 b_ptrs = b_ptr + offs_n [None , :] * K + offs_k [:, None ]
430430 a_s_ptrs = a_s_ptr + offs_m * k
431431
432- # Weight scale indexing: b_s has shape (N//BLOCK_SIZE_K, K//BLOCK_SIZE_K)
433- # For this N tile (pid_n), we need scales[pid_n, :] across K iterations
434- b_s_base = b_s_ptr + pid_n * k
432+ # FIXED: Weight scale indexing for 2D scale array (N_blocks, K_blocks)
433+ # b_s has shape (N//BLOCK_SIZE_K, K//BLOCK_SIZE_K) stored in row-major
434+ # For N tile pid_n, we need scales[pid_n, :] across K iterations
435+ # Address calculation: scale[pid_n, i] = base + pid_n * stride + i
436+ k_blocks = k # Number of K blocks for clarity
437+ b_s_base = b_s_ptr + pid_n * k_blocks
435438
436439 # Accumulate matmul result
437440 # Create accumulators outside the loop for better performance
438441 accumulator = tl .zeros ((BLOCK_SIZE_M , BLOCK_SIZE_N ), dtype = tl .float32 )
439- #acc_int32 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32)
440- for i in range (k ):
442+ for i in range (k_blocks ):
441443 # Load int8 data - use other=0 (int) not 0.0 (float) to preserve int8 type
442444 a = tl .load (a_ptrs , mask = offs_k [None , :] < K - i * BLOCK_SIZE_K , other = 0 )
443445 b = tl .load (b_ptrs , mask = offs_k [:, None ] < K - i * BLOCK_SIZE_K , other = 0 )
444446 a_s = tl .load (a_s_ptrs )
445- # Load single scalar weight scale for this (N block, K block) pair
447+ # FIXED: Load single scalar weight scale for (pid_n, i) block pair
446448 b_s = tl .load (b_s_base + i )
447449 # INT8 matmul → INT32 acc, then cast to FP32 and apply per-block scaling
448- # Use explicit int32 accumulator to ensure int8 × int8 → int32 accumulation
449450 dot_prod = tl .dot (a , b , out_dtype = tl .int32 ) # int8 × int8 → int32
450451 accumulator += dot_prod .to (tl .float32 ) * a_s [:, None ] * b_s
451452 a_ptrs += BLOCK_SIZE_K
@@ -670,17 +671,19 @@ def int8_gemm_quant_kernel(
670671 b_ptrs = b_ptr + offs_n [None , :] * K + offs_k [:, None ]
671672 a_s_ptrs = a_s_ptr + offs_m * k
672673
673- # Weight scale indexing: b_s has shape (N//BLOCK_SIZE_K, K//BLOCK_SIZE_K)
674- # For this N tile (pid_n), we need scales[pid_n, :] across K iterations
675- b_s_base = b_s_ptr + pid_n * k
674+ # FIXED: Weight scale indexing for 2D scale array (N_blocks, K_blocks)
675+ # b_s has shape (N//BLOCK_SIZE_K, K//BLOCK_SIZE_K) stored in row-major
676+ # For N tile pid_n, we need scales[pid_n, :] across K iterations
677+ k_blocks = k # Number of K blocks for clarity
678+ b_s_base = b_s_ptr + pid_n * k_blocks
676679
677680 # Accumulate matmul result
678681 accumulator = tl .zeros ((BLOCK_SIZE_M , BLOCK_SIZE_N ), dtype = tl .float32 )
679- for i in range (k ):
682+ for i in range (k_blocks ):
680683 a = tl .load (a_ptrs , mask = offs_k [None , :] < K - i * BLOCK_SIZE_K , other = 0 )
681684 b = tl .load (b_ptrs , mask = offs_k [:, None ] < K - i * BLOCK_SIZE_K , other = 0 )
682685 a_s = tl .load (a_s_ptrs )
683- # Load single scalar weight scale for this (N block, K block) pair
686+ # FIXED: Load single scalar weight scale for (pid_n, i) block pair
684687 b_s = tl .load (b_s_base + i )
685688 dot_prod = tl .dot (a , b , out_dtype = tl .int32 )
686689 accumulator += dot_prod .to (tl .float32 ) * a_s [:, None ] * b_s
@@ -783,17 +786,19 @@ def int8_gemm_addmm_quant_kernel(
783786 b_ptrs = b_ptr + offs_n [None , :] * K + offs_k [:, None ]
784787 a_s_ptrs = a_s_ptr + offs_m * k
785788
786- # Weight scale indexing: b_s has shape (N//BLOCK_SIZE_K, K//BLOCK_SIZE_K)
787- # For this N tile (pid_n), we need scales[pid_n, :] across K iterations
788- b_s_base = b_s_ptr + pid_n * k
789+ # FIXED: Weight scale indexing for 2D scale array (N_blocks, K_blocks)
790+ # b_s has shape (N//BLOCK_SIZE_K, K//BLOCK_SIZE_K) stored in row-major
791+ # For N tile pid_n, we need scales[pid_n, :] across K iterations
792+ k_blocks = k # Number of K blocks for clarity
793+ b_s_base = b_s_ptr + pid_n * k_blocks
789794
790795 # Accumulate matmul result
791796 accumulator = tl .zeros ((BLOCK_SIZE_M , BLOCK_SIZE_N ), dtype = tl .float32 )
792- for i in range (k ):
797+ for i in range (k_blocks ):
793798 a = tl .load (a_ptrs , mask = offs_k [None , :] < K - i * BLOCK_SIZE_K , other = 0 )
794799 b = tl .load (b_ptrs , mask = offs_k [:, None ] < K - i * BLOCK_SIZE_K , other = 0 )
795800 a_s = tl .load (a_s_ptrs )
796- # Load single scalar weight scale for this (N block, K block) pair
801+ # FIXED: Load single scalar weight scale for (pid_n, i) block pair
797802 b_s = tl .load (b_s_base + i )
798803 dot_prod = tl .dot (a , b , out_dtype = tl .int32 )
799804 accumulator += dot_prod .to (tl .float32 ) * a_s [:, None ] * b_s
0 commit comments