Skip to content

Commit 001e0e0

Browse files
committed
fix
1 parent d24f73d commit 001e0e0

File tree

1 file changed

+22
-17
lines changed

1 file changed

+22
-17
lines changed

comfy/int8_kernels.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -429,23 +429,24 @@ def int8_gemm_addmm_kernel(
429429
b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None]
430430
a_s_ptrs = a_s_ptr + offs_m * k
431431

432-
# Weight scale indexing: b_s has shape (N//BLOCK_SIZE_K, K//BLOCK_SIZE_K)
433-
# For this N tile (pid_n), we need scales[pid_n, :] across K iterations
434-
b_s_base = b_s_ptr + pid_n * k
432+
# FIXED: Weight scale indexing for 2D scale array (N_blocks, K_blocks)
433+
# b_s has shape (N//BLOCK_SIZE_K, K//BLOCK_SIZE_K) stored in row-major
434+
# For N tile pid_n, we need scales[pid_n, :] across K iterations
435+
# Address calculation: scale[pid_n, i] = base + pid_n * stride + i
436+
k_blocks = k # Number of K blocks for clarity
437+
b_s_base = b_s_ptr + pid_n * k_blocks
435438

436439
# Accumulate matmul result
437440
# Create accumulators outside the loop for better performance
438441
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
439-
#acc_int32 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32)
440-
for i in range(k):
442+
for i in range(k_blocks):
441443
# Load int8 data - use other=0 (int) not 0.0 (float) to preserve int8 type
442444
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0)
443445
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0)
444446
a_s = tl.load(a_s_ptrs)
445-
# Load single scalar weight scale for this (N block, K block) pair
447+
# FIXED: Load single scalar weight scale for (pid_n, i) block pair
446448
b_s = tl.load(b_s_base + i)
447449
# INT8 matmul → INT32 acc, then cast to FP32 and apply per-block scaling
448-
# Use explicit int32 accumulator to ensure int8 × int8 → int32 accumulation
449450
dot_prod = tl.dot(a, b, out_dtype=tl.int32) # int8 × int8 → int32
450451
accumulator += dot_prod.to(tl.float32) * a_s[:, None] * b_s
451452
a_ptrs += BLOCK_SIZE_K
@@ -670,17 +671,19 @@ def int8_gemm_quant_kernel(
670671
b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None]
671672
a_s_ptrs = a_s_ptr + offs_m * k
672673

673-
# Weight scale indexing: b_s has shape (N//BLOCK_SIZE_K, K//BLOCK_SIZE_K)
674-
# For this N tile (pid_n), we need scales[pid_n, :] across K iterations
675-
b_s_base = b_s_ptr + pid_n * k
674+
# FIXED: Weight scale indexing for 2D scale array (N_blocks, K_blocks)
675+
# b_s has shape (N//BLOCK_SIZE_K, K//BLOCK_SIZE_K) stored in row-major
676+
# For N tile pid_n, we need scales[pid_n, :] across K iterations
677+
k_blocks = k # Number of K blocks for clarity
678+
b_s_base = b_s_ptr + pid_n * k_blocks
676679

677680
# Accumulate matmul result
678681
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
679-
for i in range(k):
682+
for i in range(k_blocks):
680683
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0)
681684
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0)
682685
a_s = tl.load(a_s_ptrs)
683-
# Load single scalar weight scale for this (N block, K block) pair
686+
# FIXED: Load single scalar weight scale for (pid_n, i) block pair
684687
b_s = tl.load(b_s_base + i)
685688
dot_prod = tl.dot(a, b, out_dtype=tl.int32)
686689
accumulator += dot_prod.to(tl.float32) * a_s[:, None] * b_s
@@ -783,17 +786,19 @@ def int8_gemm_addmm_quant_kernel(
783786
b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None]
784787
a_s_ptrs = a_s_ptr + offs_m * k
785788

786-
# Weight scale indexing: b_s has shape (N//BLOCK_SIZE_K, K//BLOCK_SIZE_K)
787-
# For this N tile (pid_n), we need scales[pid_n, :] across K iterations
788-
b_s_base = b_s_ptr + pid_n * k
789+
# FIXED: Weight scale indexing for 2D scale array (N_blocks, K_blocks)
790+
# b_s has shape (N//BLOCK_SIZE_K, K//BLOCK_SIZE_K) stored in row-major
791+
# For N tile pid_n, we need scales[pid_n, :] across K iterations
792+
k_blocks = k # Number of K blocks for clarity
793+
b_s_base = b_s_ptr + pid_n * k_blocks
789794

790795
# Accumulate matmul result
791796
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
792-
for i in range(k):
797+
for i in range(k_blocks):
793798
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0)
794799
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0)
795800
a_s = tl.load(a_s_ptrs)
796-
# Load single scalar weight scale for this (N block, K block) pair
801+
# FIXED: Load single scalar weight scale for (pid_n, i) block pair
797802
b_s = tl.load(b_s_base + i)
798803
dot_prod = tl.dot(a, b, out_dtype=tl.int32)
799804
accumulator += dot_prod.to(tl.float32) * a_s[:, None] * b_s

0 commit comments

Comments
 (0)