@@ -148,12 +148,108 @@ def test_scale_loading():
148148 print ("=" * 80 )
149149
150150
151+ def test_exact_failing_case ():
152+ """Reproduce the EXACT test case that's failing"""
153+ device = torch .device ('cuda' )
154+ block_size = 128
155+ torch .manual_seed (123 )
156+
157+ batch_size = 2
158+ seq_len = 8
159+ in_features = 256
160+ out_features = 512
161+
162+ print ("=" * 80 )
163+ print ("EXACT FAILING TEST CASE: Reproduce test_triton_linear_from_raw_int8_and_scales" )
164+ print ("=" * 80 )
165+
166+ # Manually create int8 data and scales for input (activation)
167+ # Input shape: (batch_size, seq_len, in_features)
168+ input_int8 = torch .randint (- 127 , 127 , (batch_size , seq_len , in_features ),
169+ dtype = torch .int8 , device = device )
170+ input_scale = torch .rand (batch_size , seq_len , in_features // block_size ,
171+ dtype = torch .float32 , device = device ) * 0.1
172+
173+ # Manually create int8 data and scales for weight
174+ # Weight shape: (out_features, in_features)
175+ weight_int8 = torch .randint (- 127 , 127 , (out_features , in_features ),
176+ dtype = torch .int8 , device = device )
177+ weight_scale = torch .rand (out_features // block_size , in_features // block_size ,
178+ dtype = torch .float32 , device = device ) * 0.1
179+
180+ # Bias
181+ bias = torch .randn (out_features , dtype = torch .float32 , device = device )
182+
183+ print (f"Input shape: { input_int8 .shape } " )
184+ print (f"Input scale shape: { input_scale .shape } " )
185+ print (f"Weight shape: { weight_int8 .shape } " )
186+ print (f"Weight scale shape: { weight_scale .shape } " )
187+ print (f"Bias shape: { bias .shape } " )
188+ print ()
189+
190+ # Method 1: Call INT8 GEMM via Triton/fallback
191+ print ("Calling Triton/fallback..." )
192+ output_triton = _int8_gemm_triton_or_fallback (
193+ input_int8 , input_scale , weight_int8 , weight_scale , block_size , bias = bias , out_quant = False
194+ )
195+
196+ # Method 2: Call PyTorch INT8 GEMM fallback directly
197+ print ("Calling PyTorch fallback..." )
198+ output_pytorch = _int8_gemm_pytorch_fallback (
199+ input_int8 , input_scale , weight_int8 , weight_scale , block_size , bias = bias
200+ )
201+
202+ # Convert all to float32 for fair comparison
203+ output_triton_fp32 = output_triton .to (torch .float32 )
204+ output_pytorch_fp32 = output_pytorch .to (torch .float32 )
205+
206+ # Compare Method 1 vs Method 2: Triton vs PyTorch INT8 GEMM
207+ abs_diff = (output_triton_fp32 - output_pytorch_fp32 ).abs ()
208+ mean_abs_diff = abs_diff .mean ().item ()
209+ max_abs_diff = abs_diff .max ().item ()
210+
211+ print (f"\n Comparison:" )
212+ print (f" Output shape: { output_triton .shape } " )
213+ print (f" Triton sample [0,0,:5]: { output_triton [0 ,0 ,:5 ].cpu ()} " )
214+ print (f" PyTorch sample [0,0,:5]: { output_pytorch [0 ,0 ,:5 ].cpu ()} " )
215+ print (f" Difference [0,0,:5]: { abs_diff [0 ,0 ,:5 ].cpu ()} " )
216+ print (f"\n Mean absolute difference: { mean_abs_diff :.6f} " )
217+ print (f" Max absolute difference: { max_abs_diff :.6f} " )
218+ print (f" Test threshold: 0.001000" )
219+ print (f" PASS: { mean_abs_diff < 1e-3 } " )
220+
221+ # Show where the largest differences are
222+ max_idx = abs_diff .argmax ()
223+ max_idx_flat = max_idx .item ()
224+ shape = abs_diff .shape
225+ # Convert flat index to multi-dimensional index
226+ idx_0 = max_idx_flat // (shape [1 ] * shape [2 ])
227+ idx_1 = (max_idx_flat // shape [2 ]) % shape [1 ]
228+ idx_2 = max_idx_flat % shape [2 ]
229+
230+ print (f"\n Max difference location: [{ idx_0 } , { idx_1 } , { idx_2 } ]" )
231+ print (f" Triton value: { output_triton_fp32 [idx_0 , idx_1 , idx_2 ].item ():.6f} " )
232+ print (f" PyTorch value: { output_pytorch_fp32 [idx_0 , idx_1 , idx_2 ].item ():.6f} " )
233+
234+ # Check if there's a pattern by N-block
235+ print (f"\n Difference by N-block:" )
236+ for n_block in range (out_features // block_size ):
237+ start = n_block * block_size
238+ end = (n_block + 1 ) * block_size
239+ block_diff = abs_diff [:, :, start :end ].mean ().item ()
240+ print (f" N-block { n_block } (outputs { start } :{ end } ): mean diff = { block_diff :.6f} " )
241+
242+ print ("=" * 80 )
243+
244+
151245if __name__ == "__main__" :
152246 if torch .cuda .is_available ():
153247 print ("CUDA available, running tests...\n " )
154248 test_simple_case ()
155249 print ("\n " )
156250 test_scale_loading ()
251+ print ("\n " )
252+ test_exact_failing_case ()
157253 else :
158254 print ("CUDA not available, skipping tests" )
159255
0 commit comments