@@ -1020,7 +1020,7 @@ def int8_linear(func, args, kwargs):
10201020
10211021 orig_dtype = input_tensor ._layout_params ['orig_dtype' ]
10221022 out_dtype = kwargs .get ('out_dtype' , orig_dtype )
1023- out_quant = kwargs .get ('out_quant' , True ) # Whether to return quantized output
1023+ out_quant = kwargs .get ('out_quant' , False ) # Whether to return quantized output
10241024
10251025 # Weight is already in (N, K) format (standard PyTorch weight format)
10261026 # Pass out_quant to _int8_gemm_triton_or_fallback for fused matmul+quant
@@ -1080,7 +1080,7 @@ def int8_mm(func, args, kwargs):
10801080
10811081 orig_dtype = input_tensor ._layout_params ['orig_dtype' ]
10821082 out_dtype = kwargs .get ('out_dtype' , orig_dtype )
1083- out_quant = kwargs .get ('out_quant' , True ) # Whether to return quantized output (default: True)
1083+ out_quant = kwargs .get ('out_quant' , False ) # Whether to return quantized output (default: True)
10841084
10851085 # Check if weight needs to be transposed to (N, K) format
10861086 # For mm: input is (M, K), weight should be (N, K) for the kernel
@@ -1154,7 +1154,7 @@ def int8_addmm(func, args, kwargs):
11541154
11551155 orig_dtype = input_tensor ._layout_params ['orig_dtype' ]
11561156 out_dtype = kwargs .get ('out_dtype' , orig_dtype )
1157- out_quant = kwargs .get ('out_quant' , True ) # Whether to return quantized output
1157+ out_quant = kwargs .get ('out_quant' , False ) # Whether to return quantized output
11581158
11591159 # PyTorch's F.linear internally calls addmm(bias, input, weight.t())
11601160 # So weight arrives in (K, N) format (transposed), need to transpose back to (N, K)
0 commit comments