Skip to content

Commit 3f382a4

Browse files
authored
quant ops: Dequantize weight in-place (#10935)
In flux2 these weights are huge (200MB). As plain_tensor is a throw-away deep copy, do this multiplication in-place to save VRAM.
1 parent f17251b commit 3f382a4

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

comfy/quant_ops.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,8 @@ def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_roun
425425
@staticmethod
426426
def dequantize(qdata, scale, orig_dtype, **kwargs):
427427
plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype)
428-
return plain_tensor * scale
428+
plain_tensor.mul_(scale)
429+
return plain_tensor
429430

430431
@classmethod
431432
def get_plain_tensors(cls, qtensor):

0 commit comments

Comments
 (0)