Skip to content

Commit 48bd2f9

Browse files
committed
fix
1 parent 1083c1a commit 48bd2f9

File tree

1 file changed

+96
-0
lines changed

1 file changed

+96
-0
lines changed

debug_int8_gemm.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,12 +148,108 @@ def test_scale_loading():
148148
print("=" * 80)
149149

150150

151+
def test_exact_failing_case():
152+
"""Reproduce the EXACT test case that's failing"""
153+
device = torch.device('cuda')
154+
block_size = 128
155+
torch.manual_seed(123)
156+
157+
batch_size = 2
158+
seq_len = 8
159+
in_features = 256
160+
out_features = 512
161+
162+
print("=" * 80)
163+
print("EXACT FAILING TEST CASE: Reproduce test_triton_linear_from_raw_int8_and_scales")
164+
print("=" * 80)
165+
166+
# Manually create int8 data and scales for input (activation)
167+
# Input shape: (batch_size, seq_len, in_features)
168+
input_int8 = torch.randint(-127, 127, (batch_size, seq_len, in_features),
169+
dtype=torch.int8, device=device)
170+
input_scale = torch.rand(batch_size, seq_len, in_features // block_size,
171+
dtype=torch.float32, device=device) * 0.1
172+
173+
# Manually create int8 data and scales for weight
174+
# Weight shape: (out_features, in_features)
175+
weight_int8 = torch.randint(-127, 127, (out_features, in_features),
176+
dtype=torch.int8, device=device)
177+
weight_scale = torch.rand(out_features // block_size, in_features // block_size,
178+
dtype=torch.float32, device=device) * 0.1
179+
180+
# Bias
181+
bias = torch.randn(out_features, dtype=torch.float32, device=device)
182+
183+
print(f"Input shape: {input_int8.shape}")
184+
print(f"Input scale shape: {input_scale.shape}")
185+
print(f"Weight shape: {weight_int8.shape}")
186+
print(f"Weight scale shape: {weight_scale.shape}")
187+
print(f"Bias shape: {bias.shape}")
188+
print()
189+
190+
# Method 1: Call INT8 GEMM via Triton/fallback
191+
print("Calling Triton/fallback...")
192+
output_triton = _int8_gemm_triton_or_fallback(
193+
input_int8, input_scale, weight_int8, weight_scale, block_size, bias=bias, out_quant=False
194+
)
195+
196+
# Method 2: Call PyTorch INT8 GEMM fallback directly
197+
print("Calling PyTorch fallback...")
198+
output_pytorch = _int8_gemm_pytorch_fallback(
199+
input_int8, input_scale, weight_int8, weight_scale, block_size, bias=bias
200+
)
201+
202+
# Convert all to float32 for fair comparison
203+
output_triton_fp32 = output_triton.to(torch.float32)
204+
output_pytorch_fp32 = output_pytorch.to(torch.float32)
205+
206+
# Compare Method 1 vs Method 2: Triton vs PyTorch INT8 GEMM
207+
abs_diff = (output_triton_fp32 - output_pytorch_fp32).abs()
208+
mean_abs_diff = abs_diff.mean().item()
209+
max_abs_diff = abs_diff.max().item()
210+
211+
print(f"\nComparison:")
212+
print(f" Output shape: {output_triton.shape}")
213+
print(f" Triton sample [0,0,:5]: {output_triton[0,0,:5].cpu()}")
214+
print(f" PyTorch sample [0,0,:5]: {output_pytorch[0,0,:5].cpu()}")
215+
print(f" Difference [0,0,:5]: {abs_diff[0,0,:5].cpu()}")
216+
print(f"\n Mean absolute difference: {mean_abs_diff:.6f}")
217+
print(f" Max absolute difference: {max_abs_diff:.6f}")
218+
print(f" Test threshold: 0.001000")
219+
print(f" PASS: {mean_abs_diff < 1e-3}")
220+
221+
# Show where the largest differences are
222+
max_idx = abs_diff.argmax()
223+
max_idx_flat = max_idx.item()
224+
shape = abs_diff.shape
225+
# Convert flat index to multi-dimensional index
226+
idx_0 = max_idx_flat // (shape[1] * shape[2])
227+
idx_1 = (max_idx_flat // shape[2]) % shape[1]
228+
idx_2 = max_idx_flat % shape[2]
229+
230+
print(f"\n Max difference location: [{idx_0}, {idx_1}, {idx_2}]")
231+
print(f" Triton value: {output_triton_fp32[idx_0, idx_1, idx_2].item():.6f}")
232+
print(f" PyTorch value: {output_pytorch_fp32[idx_0, idx_1, idx_2].item():.6f}")
233+
234+
# Check if there's a pattern by N-block
235+
print(f"\n Difference by N-block:")
236+
for n_block in range(out_features // block_size):
237+
start = n_block * block_size
238+
end = (n_block + 1) * block_size
239+
block_diff = abs_diff[:, :, start:end].mean().item()
240+
print(f" N-block {n_block} (outputs {start}:{end}): mean diff = {block_diff:.6f}")
241+
242+
print("=" * 80)
243+
244+
151245
if __name__ == "__main__":
152246
if torch.cuda.is_available():
153247
print("CUDA available, running tests...\n")
154248
test_simple_case()
155249
print("\n")
156250
test_scale_loading()
251+
print("\n")
252+
test_exact_failing_case()
157253
else:
158254
print("CUDA not available, skipping tests")
159255

0 commit comments

Comments
 (0)