Skip to content

Commit 3b3ef9a

Browse files
authored
Quantized Ops fixes (#10715)
* offload support, bug fixes, remove mixins * add readme
1 parent 8b0b93d commit 3b3ef9a

File tree

3 files changed

+219
-25
lines changed

3 files changed

+219
-25
lines changed

QUANTIZATION.md

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
# The Comfy guide to Quantization
2+
3+
4+
## How does quantization work?
5+
6+
Quantization aims to map a high-precision value x_f to a lower precision format with minimal loss in accuracy. These smaller formats then serve to reduce the models memory footprint and increase throughput by using specialized hardware.
7+
8+
When simply converting a value from FP16 to FP8 using the round-nearest method we might hit two issues:
9+
- The dynamic range of FP16 (-65,504, 65,504) far exceeds FP8 formats like E4M3 (-448, 448) or E5M2 (-57,344, 57,344), potentially resulting in clipped values
10+
- The original values are concentrated in a small range (e.g. -1,1) leaving many FP8-bits "unused"
11+
12+
By using a scaling factor, we aim to map these values into the quantized-dtype range, making use of the full spectrum. One of the easiest approaches, and common, is using per-tensor absolute-maximum scaling.
13+
14+
```
15+
absmax = max(abs(tensor))
16+
scale = amax / max_dynamic_range_low_precision
17+
18+
# Quantization
19+
tensor_q = (tensor / scale).to(low_precision_dtype)
20+
21+
# De-Quantization
22+
tensor_dq = tensor_q.to(fp16) * scale
23+
24+
tensor_dq ~ tensor
25+
```
26+
27+
Given that additional information (scaling factor) is needed to "interpret" the quantized values, we describe those as derived datatypes.
28+
29+
30+
## Quantization in Comfy
31+
32+
```
33+
QuantizedTensor (torch.Tensor subclass)
34+
↓ __torch_dispatch__
35+
Two-Level Registry (generic + layout handlers)
36+
37+
MixedPrecisionOps + Metadata Detection
38+
```
39+
40+
### Representation
41+
42+
To represent these derived datatypes, ComfyUI uses a subclass of torch.Tensor to implements these using the `QuantizedTensor` class found in `comfy/quant_ops.py`
43+
44+
A `Layout` class defines how a specific quantization format behaves:
45+
- Required parameters
46+
- Quantize method
47+
- De-Quantize method
48+
49+
```python
50+
from comfy.quant_ops import QuantizedLayout
51+
52+
class MyLayout(QuantizedLayout):
53+
@classmethod
54+
def quantize(cls, tensor, **kwargs):
55+
# Convert to quantized format
56+
qdata = ...
57+
params = {'scale': ..., 'orig_dtype': tensor.dtype}
58+
return qdata, params
59+
60+
@staticmethod
61+
def dequantize(qdata, scale, orig_dtype, **kwargs):
62+
return qdata.to(orig_dtype) * scale
63+
```
64+
65+
To then run operations using these QuantizedTensors we use two registry systems to define supported operations.
66+
The first is a **generic registry** that handles operations common to all quantized formats (e.g., `.to()`, `.clone()`, `.reshape()`).
67+
68+
The second registry is layout-specific and allows to implement fast-paths like nn.Linear.
69+
```python
70+
from comfy.quant_ops import register_layout_op
71+
72+
@register_layout_op(torch.ops.aten.linear.default, MyLayout)
73+
def my_linear(func, args, kwargs):
74+
# Extract tensors, call optimized kernel
75+
...
76+
```
77+
When `torch.nn.functional.linear()` is called with QuantizedTensor arguments, `__torch_dispatch__` automatically routes to the registered implementation.
78+
For any unsupported operation, QuantizedTensor will fallback to call `dequantize` and dispatch using the high-precision implementation.
79+
80+
81+
### Mixed Precision
82+
83+
The `MixedPrecisionOps` class (lines 542-648 in `comfy/ops.py`) enables per-layer quantization decisions, allowing different layers in a model to use different precisions. This is activated when a model config contains a `layer_quant_config` dictionary that specifies which layers should be quantized and how.
84+
85+
**Architecture:**
86+
87+
```python
88+
class MixedPrecisionOps(disable_weight_init):
89+
_layer_quant_config = {} # Maps layer names to quantization configs
90+
_compute_dtype = torch.bfloat16 # Default compute / dequantize precision
91+
```
92+
93+
**Key mechanism:**
94+
95+
The custom `Linear._load_from_state_dict()` method inspects each layer during model loading:
96+
- If the layer name is **not** in `_layer_quant_config`: load weight as regular tensor in `_compute_dtype`
97+
- If the layer name **is** in `_layer_quant_config`:
98+
- Load weight as `QuantizedTensor` with the specified layout (e.g., `TensorCoreFP8Layout`)
99+
- Load associated quantization parameters (scales, block_size, etc.)
100+
101+
**Why it's needed:**
102+
103+
Not all layers tolerate quantization equally. Sensitive operations like final projections can be kept in higher precision, while compute-heavy matmuls are quantized. This provides most of the performance benefits while maintaining quality.
104+
105+
The system is selected in `pick_operations()` when `model_config.layer_quant_config` is present, making it the highest-priority operation mode.
106+
107+
108+
## Checkpoint Format
109+
110+
Quantized checkpoints are stored as standard safetensors files with quantized weight tensors and associated scaling parameters, plus a `_quantization_metadata` JSON entry describing the quantization scheme.
111+
112+
The quantized checkpoint will contain the same layers as the original checkpoint but:
113+
- The weights are stored as quantized values, sometimes using a different storage datatype. E.g. uint8 container for fp8.
114+
- For each quantized weight a number of additional scaling parameters are stored alongside depending on the recipe.
115+
- We store a metadata.json in the metadata of the final safetensor containing the `_quantization_metadata` describing which layers are quantized and what layout has been used.
116+
117+
### Scaling Parameters details
118+
We define 4 possible scaling parameters that should cover most recipes in the near-future:
119+
- **weight_scale**: quantization scalers for the weights
120+
- **weight_scale_2**: global scalers in the context of double scaling
121+
- **pre_quant_scale**: scalers used for smoothing salient weights
122+
- **input_scale**: quantization scalers for the activations
123+
124+
| Format | Storage dtype | weight_scale | weight_scale_2 | pre_quant_scale | input_scale |
125+
|--------|---------------|--------------|----------------|-----------------|-------------|
126+
| float8_e4m3fn | float32 | float32 (scalar) | - | - | float32 (scalar) |
127+
128+
You can find the defined formats in `comfy/quant_ops.py` (QUANT_ALGOS).
129+
130+
### Quantization Metadata
131+
132+
The metadata stored alongside the checkpoint contains:
133+
- **format_version**: String to define a version of the standard
134+
- **layers**: A dictionary mapping layer names to their quantization format. The format string maps to the definitions found in `QUANT_ALGOS`.
135+
136+
Example:
137+
```json
138+
{
139+
"_quantization_metadata": {
140+
"format_version": "1.0",
141+
"layers": {
142+
"model.layers.0.mlp.up_proj": "float8_e4m3fn",
143+
"model.layers.0.mlp.down_proj": "float8_e4m3fn",
144+
"model.layers.1.mlp.up_proj": "float8_e4m3fn"
145+
}
146+
}
147+
}
148+
```
149+
150+
151+
## Creating Quantized Checkpoints
152+
153+
To create compatible checkpoints, use any quantization tool provided the output follows the checkpoint format described above and uses a layout defined in `QUANT_ALGOS`.
154+
155+
### Weight Quantization
156+
157+
Weight quantization is straightforward - compute the scaling factor directly from the weight tensor using the absolute maximum method described earlier. Each layer's weights are quantized independently and stored with their corresponding `weight_scale` parameter.
158+
159+
### Calibration (for Activation Quantization)
160+
161+
Activation quantization (e.g., for FP8 Tensor Core operations) requires `input_scale` parameters that cannot be determined from static weights alone. Since activation values depend on actual inputs, we use **post-training calibration (PTQ)**:
162+
163+
1. **Collect statistics**: Run inference on N representative samples
164+
2. **Track activations**: Record the absolute maximum (`amax`) of inputs to each quantized layer
165+
3. **Compute scales**: Derive `input_scale` from collected statistics
166+
4. **Store in checkpoint**: Save `input_scale` parameters alongside weights
167+
168+
The calibration dataset should be representative of your target use case. For diffusion models, this typically means a diverse set of prompts and generation parameters.

comfy/ops.py

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,10 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
7777
# will add async-offload support to your cast and improve performance.
7878
if input is not None:
7979
if dtype is None:
80-
dtype = input.dtype
80+
if isinstance(input, QuantizedTensor):
81+
dtype = input._layout_params["orig_dtype"]
82+
else:
83+
dtype = input.dtype
8184
if bias_dtype is None:
8285
bias_dtype = dtype
8386
if device is None:
@@ -534,18 +537,7 @@ def forward(self, *args, **kwargs):
534537
# ==============================================================================
535538
# Mixed Precision Operations
536539
# ==============================================================================
537-
from .quant_ops import QuantizedTensor
538-
539-
QUANT_FORMAT_MIXINS = {
540-
"float8_e4m3fn": {
541-
"dtype": torch.float8_e4m3fn,
542-
"layout_type": "TensorCoreFP8Layout",
543-
"parameters": {
544-
"weight_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
545-
"input_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
546-
}
547-
}
548-
}
540+
from .quant_ops import QuantizedTensor, QUANT_ALGOS
549541

550542
class MixedPrecisionOps(disable_weight_init):
551543
_layer_quant_config = {}
@@ -596,23 +588,24 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
596588
if quant_format is None:
597589
raise ValueError(f"Unknown quantization format for layer {layer_name}")
598590

599-
mixin = QUANT_FORMAT_MIXINS[quant_format]
600-
self.layout_type = mixin["layout_type"]
591+
qconfig = QUANT_ALGOS[quant_format]
592+
self.layout_type = qconfig["comfy_tensor_layout"]
601593

602-
scale_key = f"{prefix}weight_scale"
594+
weight_scale_key = f"{prefix}weight_scale"
603595
layout_params = {
604-
'scale': state_dict.pop(scale_key, None),
605-
'orig_dtype': MixedPrecisionOps._compute_dtype
596+
'scale': state_dict.pop(weight_scale_key, None),
597+
'orig_dtype': MixedPrecisionOps._compute_dtype,
598+
'block_size': qconfig.get("group_size", None),
606599
}
607600
if layout_params['scale'] is not None:
608-
manually_loaded_keys.append(scale_key)
601+
manually_loaded_keys.append(weight_scale_key)
609602

610603
self.weight = torch.nn.Parameter(
611-
QuantizedTensor(weight.to(device=device, dtype=mixin["dtype"]), self.layout_type, layout_params),
604+
QuantizedTensor(weight.to(device=device), self.layout_type, layout_params),
612605
requires_grad=False
613606
)
614607

615-
for param_name, param_value in mixin["parameters"].items():
608+
for param_name in qconfig["parameters"]:
616609
param_key = f"{prefix}{param_name}"
617610
_v = state_dict.pop(param_key, None)
618611
if _v is None:
@@ -643,7 +636,7 @@ def forward(self, input, *args, **kwargs):
643636
if (getattr(self, 'layout_type', None) is not None and
644637
getattr(self, 'input_scale', None) is not None and
645638
not isinstance(input, QuantizedTensor)):
646-
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, fp8_dtype=self.weight.dtype)
639+
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype)
647640
return self._forward(input, self.weight, self.bias)
648641

649642

comfy/quant_ops.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,12 @@ def _copy_layout_params(params):
7474
new_params[k] = v
7575
return new_params
7676

77+
def _copy_layout_params_inplace(src, dst, non_blocking=False):
78+
for k, v in src.items():
79+
if isinstance(v, torch.Tensor):
80+
dst[k].copy_(v, non_blocking=non_blocking)
81+
else:
82+
dst[k] = v
7783

7884
class QuantizedLayout:
7985
"""
@@ -318,13 +324,13 @@ def generic_to_dtype_layout(func, args, kwargs):
318324
def generic_copy_(func, args, kwargs):
319325
qt_dest = args[0]
320326
src = args[1]
321-
327+
non_blocking = args[2] if len(args) > 2 else False
322328
if isinstance(qt_dest, QuantizedTensor):
323329
if isinstance(src, QuantizedTensor):
324330
# Copy from another quantized tensor
325-
qt_dest._qdata.copy_(src._qdata)
331+
qt_dest._qdata.copy_(src._qdata, non_blocking=non_blocking)
326332
qt_dest._layout_type = src._layout_type
327-
qt_dest._layout_params = _copy_layout_params(src._layout_params)
333+
_copy_layout_params_inplace(src._layout_params, qt_dest._layout_params, non_blocking=non_blocking)
328334
else:
329335
# Copy from regular tensor - just copy raw data
330336
qt_dest._qdata.copy_(src)
@@ -336,6 +342,26 @@ def generic_copy_(func, args, kwargs):
336342
def generic_has_compatible_shallow_copy_type(func, args, kwargs):
337343
return True
338344

345+
346+
@register_generic_util(torch.ops.aten.empty_like.default)
347+
def generic_empty_like(func, args, kwargs):
348+
"""Empty_like operation - creates an empty tensor with the same quantized structure."""
349+
qt = args[0]
350+
if isinstance(qt, QuantizedTensor):
351+
# Create empty tensor with same shape and dtype as the quantized data
352+
hp_dtype = kwargs.pop('dtype', qt._layout_params["orig_dtype"])
353+
new_qdata = torch.empty_like(qt._qdata, **kwargs)
354+
355+
# Handle device transfer for layout params
356+
target_device = kwargs.get('device', new_qdata.device)
357+
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
358+
359+
# Update orig_dtype if dtype is specified
360+
new_params['orig_dtype'] = hp_dtype
361+
362+
return QuantizedTensor(new_qdata, qt._layout_type, new_params)
363+
return func(*args, **kwargs)
364+
339365
# ==============================================================================
340366
# FP8 Layout + Operation Handlers
341367
# ==============================================================================
@@ -378,6 +404,13 @@ def dequantize(qdata, scale, orig_dtype, **kwargs):
378404
def get_plain_tensors(cls, qtensor):
379405
return qtensor._qdata, qtensor._layout_params['scale']
380406

407+
QUANT_ALGOS = {
408+
"float8_e4m3fn": {
409+
"storage_t": torch.float8_e4m3fn,
410+
"parameters": {"weight_scale", "input_scale"},
411+
"comfy_tensor_layout": "TensorCoreFP8Layout",
412+
},
413+
}
381414

382415
LAYOUTS = {
383416
"TensorCoreFP8Layout": TensorCoreFP8Layout,

0 commit comments

Comments
 (0)