Skip to content

Commit d24f73d

Browse files
committed
debug
1 parent 1cc9ae2 commit d24f73d

File tree

2 files changed

+163
-7
lines changed

2 files changed

+163
-7
lines changed

comfy/int8_kernels.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -348,22 +348,23 @@ def int8_gemm_kernel(
348348
b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None]
349349
a_s_ptrs = a_s_ptr + offs_m * k
350350

351-
# Weight scale indexing: b_s has shape (N//BLOCK_SIZE_K, K//BLOCK_SIZE_K)
352-
# For this N tile (pid_n), we need scales[pid_n, :] across K iterations
353-
b_s_base = b_s_ptr + pid_n * k
351+
# FIXED: Weight scale indexing for 2D scale array (N_blocks, K_blocks)
352+
# b_s has shape (N//BLOCK_SIZE_K, K//BLOCK_SIZE_K) stored in row-major
353+
# For N tile pid_n, we need scales[pid_n, :] across K iterations
354+
# Address calculation: scale[pid_n, i] = base + pid_n * stride + i
355+
k_blocks = k # Number of K blocks for clarity
356+
b_s_base = b_s_ptr + pid_n * k_blocks
354357

355358
# Create accumulators outside the loop for better performance
356359
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
357-
#acc_int32 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32)
358-
for i in range(k):
360+
for i in range(k_blocks):
359361
# Load int8 data - use other=0 (int) not 0.0 (float) to preserve int8 type
360362
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0)
361363
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0)
362364
a_s = tl.load(a_s_ptrs)
363-
# Load single scalar weight scale for this (N block, K block) pair
365+
# FIXED: Load single scalar weight scale for (pid_n, i) block pair
364366
b_s = tl.load(b_s_base + i)
365367
# INT8 matmul → INT32 acc, then cast to FP32 and apply per-block scaling
366-
# Use explicit int32 accumulator to ensure int8 × int8 → int32 accumulation
367368
dot_prod = tl.dot(a, b, out_dtype=tl.int32) # int8 × int8 → int32
368369
accumulator += dot_prod.to(tl.float32) * a_s[:, None] * b_s
369370
a_ptrs += BLOCK_SIZE_K

debug_int8_gemm.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Debug script to test INT8 GEMM with simple known values.
4+
This will help us understand what's going wrong.
5+
"""
6+
import torch
7+
import sys
8+
9+
# Add comfy to path
10+
sys.path.insert(0, '/Users/l_y_o/Work/ComfyUI')
11+
12+
from comfy.quant_ops import _int8_gemm_pytorch_fallback, _int8_gemm_triton_or_fallback
13+
14+
def test_simple_case():
15+
"""Test with very simple values to see the difference"""
16+
device = torch.device('cuda')
17+
block_size = 128
18+
19+
# Very simple case: 1 batch, small dimensions
20+
M, K, N = 128, 256, 256
21+
22+
# Create simple int8 data: all ones
23+
input_int8 = torch.ones((M, K), dtype=torch.int8, device=device)
24+
weight_int8 = torch.ones((N, K), dtype=torch.int8, device=device)
25+
26+
# Create simple scales: all 0.01
27+
input_scale = torch.full((M, K // block_size), 0.01, dtype=torch.float32, device=device)
28+
weight_scale = torch.full((N // block_size, K // block_size), 0.01, dtype=torch.float32, device=device)
29+
30+
# No bias for simplicity
31+
bias = None
32+
33+
print("=" * 80)
34+
print("SIMPLE TEST CASE: all ones, scales=0.01")
35+
print("=" * 80)
36+
print(f"Input shape: {input_int8.shape}, scales: {input_scale.shape}")
37+
print(f"Weight shape: {weight_int8.shape}, scales: {weight_scale.shape}")
38+
print(f"Expected: Each output element = sum(1*0.01 * 1*0.01 for k in range(K))")
39+
print(f" = K * (0.01 * 0.01) = {K} * 0.0001 = {K * 0.0001}")
40+
print()
41+
42+
# Method 1: Triton
43+
try:
44+
output_triton = _int8_gemm_triton_or_fallback(
45+
input_int8, input_scale, weight_int8, weight_scale, block_size, bias=bias, out_quant=False
46+
)
47+
print(f"Triton output sample (first 5): {output_triton[0, :5].cpu()}")
48+
print(f"Triton output mean: {output_triton.mean().item():.6f}")
49+
print(f"Triton output [0,0]: {output_triton[0, 0].item():.6f}")
50+
except Exception as e:
51+
print(f"Triton failed: {e}")
52+
output_triton = None
53+
54+
# Method 2: PyTorch
55+
output_pytorch = _int8_gemm_pytorch_fallback(
56+
input_int8, input_scale, weight_int8, weight_scale, block_size, bias=bias
57+
)
58+
print(f"\nPyTorch output sample (first 5): {output_pytorch[0, :5].cpu()}")
59+
print(f"PyTorch output mean: {output_pytorch.mean().item():.6f}")
60+
print(f"PyTorch output [0,0]: {output_pytorch[0, 0].item():.6f}")
61+
62+
if output_triton is not None:
63+
diff = (output_triton.float() - output_pytorch.float()).abs()
64+
print(f"\nDifference mean: {diff.mean().item():.6f}")
65+
print(f"Difference max: {diff.max().item():.6f}")
66+
print(f"Difference [0,0]: {diff[0, 0].item():.6f}")
67+
68+
print("\n" + "=" * 80)
69+
70+
71+
def test_scale_loading():
72+
"""Test to see which scales are being used"""
73+
device = torch.device('cuda')
74+
block_size = 128
75+
76+
M, K, N = 128, 256, 256
77+
78+
# Create int8 data: all ones
79+
input_int8 = torch.ones((M, K), dtype=torch.int8, device=device)
80+
weight_int8 = torch.ones((N, K), dtype=torch.int8, device=device)
81+
82+
# Create UNIQUE scales to trace which ones are being used
83+
# Input scales: [0.01, 0.02] for the two K blocks
84+
input_scale = torch.tensor([[0.01, 0.02]] * M, dtype=torch.float32, device=device)
85+
86+
# Weight scales: unique value for each position
87+
# Shape: (N//block_size, K//block_size) = (2, 2)
88+
weight_scale = torch.tensor([
89+
[0.10, 0.20], # N-block 0: K-block 0=0.10, K-block 1=0.20
90+
[0.30, 0.40], # N-block 1: K-block 0=0.30, K-block 1=0.40
91+
], dtype=torch.float32, device=device)
92+
93+
print("=" * 80)
94+
print("SCALE LOADING TEST: unique scales to trace usage")
95+
print("=" * 80)
96+
print(f"Input scales shape: {input_scale.shape}")
97+
print(f" Values: [0.01, 0.02] for K-blocks [0, 1]")
98+
print(f"\nWeight scales shape: {weight_scale.shape}")
99+
print(f" N-block 0: K-blocks [0.10, 0.20]")
100+
print(f" N-block 1: K-blocks [0.30, 0.40]")
101+
print()
102+
print("For output[i, j], we should get:")
103+
print(" j in [0:128] (N-block 0): sum of [block0: 128*1*0.01*1*0.10, block1: 128*1*0.02*1*0.20]")
104+
print(" = 128*0.001 + 128*0.004 = 0.128 + 0.512 = 0.640")
105+
print(" j in [128:256] (N-block 1): sum of [block0: 128*1*0.01*1*0.30, block1: 128*1*0.02*1*0.40]")
106+
print(" = 128*0.003 + 128*0.008 = 0.384 + 1.024 = 1.408")
107+
print()
108+
109+
# PyTorch reference
110+
output_pytorch = _int8_gemm_pytorch_fallback(
111+
input_int8, input_scale, weight_int8, weight_scale, block_size, bias=None
112+
)
113+
114+
print("PyTorch output:")
115+
print(f" output[0, 0] (N-block 0): {output_pytorch[0, 0].item():.6f} (expected: 0.640)")
116+
print(f" output[0, 128] (N-block 1): {output_pytorch[0, 128].item():.6f} (expected: 1.408)")
117+
print(f" Mean of N-block 0: {output_pytorch[0, :128].mean().item():.6f}")
118+
print(f" Mean of N-block 1: {output_pytorch[0, 128:].mean().item():.6f}")
119+
120+
# Triton
121+
try:
122+
output_triton = _int8_gemm_triton_or_fallback(
123+
input_int8, input_scale, weight_int8, weight_scale, block_size, bias=None, out_quant=False
124+
)
125+
126+
print("\nTriton output:")
127+
print(f" output[0, 0] (N-block 0): {output_triton[0, 0].item():.6f} (expected: 0.640)")
128+
print(f" output[0, 128] (N-block 1): {output_triton[0, 128].item():.6f} (expected: 1.408)")
129+
print(f" Mean of N-block 0: {output_triton[0, :128].mean().item():.6f}")
130+
print(f" Mean of N-block 1: {output_triton[0, 128:].mean().item():.6f}")
131+
132+
# Compare
133+
diff = (output_triton.float() - output_pytorch.float()).abs()
134+
print(f"\nDifference:")
135+
print(f" [0, 0]: {diff[0, 0].item():.6f}")
136+
print(f" [0, 128]: {diff[0, 128].item():.6f}")
137+
print(f" Mean: {diff.mean().item():.6f}, Max: {diff.max().item():.6f}")
138+
139+
except Exception as e:
140+
print(f"\nTriton failed: {e}")
141+
import traceback
142+
traceback.print_exc()
143+
144+
print("=" * 80)
145+
146+
147+
if __name__ == "__main__":
148+
if torch.cuda.is_available():
149+
print("CUDA available, running tests...\n")
150+
test_simple_case()
151+
print("\n")
152+
test_scale_loading()
153+
else:
154+
print("CUDA not available, skipping tests")
155+

0 commit comments

Comments
 (0)