Skip to content

Commit d22708f

Browse files
committed
add block-wise scaled int8 quantization based on QuantizedLayout mechanism
add more tests by comparing with manual torch implementation add perf benchmarks fix errors caused by merging
1 parent 0c18842 commit d22708f

File tree

12 files changed

+4697
-36
lines changed

12 files changed

+4697
-36
lines changed

QUANTIZATION.md

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,24 +124,30 @@ We define 4 possible scaling parameters that should cover most recipes in the ne
124124
| Format | Storage dtype | weight_scale | weight_scale_2 | pre_quant_scale | input_scale |
125125
|--------|---------------|--------------|----------------|-----------------|-------------|
126126
| float8_e4m3fn | float32 | float32 (scalar) | - | - | float32 (scalar) |
127+
| int8_blockwise | int8 | float32 (per-block) | - | - | - |
128+
129+
For int8_blockwise with block_size=128 and weight shape (N, K):
130+
- weight_scale shape: (N//128, K//128)
127131

128132
You can find the defined formats in `comfy/quant_ops.py` (QUANT_ALGOS).
129133

130134
### Quantization Metadata
131135

132136
The metadata stored alongside the checkpoint contains:
133137
- **format_version**: String to define a version of the standard
134-
- **layers**: A dictionary mapping layer names to their quantization format. The format string maps to the definitions found in `QUANT_ALGOS`.
138+
- **layers**: A dictionary mapping layer names to their quantization configuration. Each layer's config is a dictionary with:
139+
- **format**: Quantization format string that maps to the definitions found in `QUANT_ALGOS`
140+
- **group_size** (optional): Block size for block-wise quantization schemes (e.g., int8_blockwise)
135141

136142
Example:
137143
```json
138144
{
139145
"_quantization_metadata": {
140146
"format_version": "1.0",
141147
"layers": {
142-
"model.layers.0.mlp.up_proj": "float8_e4m3fn",
143-
"model.layers.0.mlp.down_proj": "float8_e4m3fn",
144-
"model.layers.1.mlp.up_proj": "float8_e4m3fn"
148+
"model.layers.0.mlp.up_proj": {"format": "float8_e4m3fn"},
149+
"model.layers.0.mlp.down_proj": {"format": "int8_blockwise", "group_size": 128},
150+
"model.layers.1.mlp.up_proj": {"format": "int8_blockwise", "group_size": 256}
145151
}
146152
}
147153
}

comfy/float.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ def stochastic_rounding(value, dtype, seed=0):
5454
return value.to(dtype=torch.float16)
5555
if dtype == torch.bfloat16:
5656
return value.to(dtype=torch.bfloat16)
57+
if dtype == torch.int8:
58+
return value.to(dtype=torch.int8)
5759
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
5860
generator = torch.Generator(device=value.device)
5961
generator.manual_seed(seed)

0 commit comments

Comments
 (0)