Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class AutoRound:
the quantization of LLMs." arXiv:2309.05516 (2023).

Attributes:
model (torch.nn.Module): The loaded PyTorch model in eval mode.
model (torch.nn.Module | str): The loaded PyTorch model in eval mode.
tokenizer: Tokenizer used to prepare input text for calibration/tuning.
platform (str): The platform to load pretrained moded, options: ["hf", "model_scope"]
bits (int): Weight quantization bits.
Expand Down Expand Up @@ -85,6 +85,8 @@ def __new__(
enable_adam: bool = False,
# for MLLM and Diffusion
extra_config: ExtraConfig = None,
enable_alg_ext: bool = False,
disable_opt_rtn: bool = False,
low_cpu_mem_usage: bool = False,
**kwargs,
) -> BaseCompressor:
Expand Down
12 changes: 10 additions & 2 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,13 @@ def __init__(
if self.iters == 0:
self.lr = 5e-3
else:
self.lr = lr or (1.0 / self.iters) # must place after iter setting
if not lr:
# TODO need to check 3/4 bits lr setting for auto-round-best
self.lr = 2.0 / self.iters if (self.iters >= 1000 and self.bits == 2) else 1.0 / self.iters
if self.iters >= 1000 and self.bits == 2:
logger.info("set the lr to 2.0/iters for better accuracy")
else:
self.lr = lr
self.minmax_lr = minmax_lr or self.lr
self.enable_alg_ext = enable_alg_ext
self.sampler = sampler
Expand Down Expand Up @@ -510,7 +516,9 @@ def _set_device(self, device_map: Union[str, torch.device, int, dict]) -> None:
else:
raise TypeError(f"device_map should be [str, torch.device, int, dict], but got {type(device_map)}")

def _parse_and_set_scheme(self, scheme: Union[str, dict, QuantizationScheme], kwargs) -> QuantizationScheme:
def _parse_and_set_scheme(
self, scheme: Union[str, dict, QuantizationScheme], kwargs
) -> tuple[QuantizationScheme, bool]:
"""Parse and set the quantization scheme."""

def _parse_and_set(scheme, kwargs):
Expand Down
18 changes: 10 additions & 8 deletions docs/alg_202508.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@ We use **lm-eval** for evaluation. For LLaMA, we enabled `add_bos_token` and
in [modeling_llama.py](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L52C1-L52C40)
to stabilize accuracy during evaluation. All other settings follow the default configurations of AutoRound and lm-eval.

| Qwen3-8B W2G64 | Avg. | arc_challenge | hellaswag | gsm8k | lambada_openai | mmlu | mmlupro | truthfulqa_mc1 | winogrande |
|:-------------------|:--------:|:--------:|:--------:|:--------:|:--------:|:--------:|:--------:|:--------:|:--------:|
| AutoRound | 0.4373 | 0.4019 | 0.4437 | 0.4215 | 0.4826 | 0.5474 | 0.2630 | 0.3072 | 0.6314 |
| AutoRound+alg_ext | 0.4787 | 0.4275 | 0.4516 | 0.5944 | 0.5181 | 0.5773 | 0.2807 | 0.3305 | 0.6496 |
| Qwen3-8B W2G64 | Avg. | arc_challenge | hellaswag | gsm8k | lambada_openai | mmlu | mmlupro | truthfulqa_mc1 | winogrande |
|:------------------------------|:------:|:-------------:|:---------:|:------:|:--------------:|:------:|:-------:|:--------------:|:----------:|
| AutoRound | 0.4373 | 0.4019 | 0.4437 | 0.4215 | 0.4826 | 0.5474 | 0.2630 | 0.3072 | 0.6314 |
| AutoRound+alg_ext | 0.4787 | 0.4275 | 0.4516 | 0.5944 | 0.5181 | 0.5773 | 0.2807 | 0.3305 | 0.6496 |
| AutoRoundBest+alg_ext lr 2e-3 | 0.4937 | 0.4505 | 0.474 | 0.5906 | 0.5556 | 0.6028 | 0.3127 | 0.3109 | 0.6527 |

| Llama3.1-8B-Instruct W2G64 | Avg. | arc_challenge | hellaswag | gsm8k | lambada_openai | mmlu | mmlupro | truthfulqa_mc1 | winogrande |
|:---------------------------|:--------:|:--------:|:--------:|:--------:|:--------:|:--------:|:--------:|:--------:|:--------:|
| AutoRound | 0.3820 | 0.3635 | 0.4562 | 0.1622 | 0.5069 | 0.4411 | 0.1661 | 0.3207 | 0.6393 |
| AutoRound+alg_ext | 0.4166 | 0.3712 | 0.4729 | 0.2039 | 0.5946 | 0.4981 | 0.2163 | 0.3011 | 0.6748 |
| Llama3.1-8B-Instruct W2G64 | Avg. | arc_challenge | hellaswag | gsm8k | lambada_openai | mmlu | mmlupro | truthfulqa_mc1 | winogrande |
|:------------------------------|:------:|:-------------:|:---------:|:------:|:--------------:|:------:|:-------:|:--------------:|:----------:|
| AutoRound | 0.3820 | 0.3635 | 0.4562 | 0.1622 | 0.5069 | 0.4411 | 0.1661 | 0.3207 | 0.6393 |
| AutoRound+alg_ext | 0.4166 | 0.3712 | 0.4729 | 0.2039 | 0.5946 | 0.4981 | 0.2163 | 0.3011 | 0.6748 |
| AutoRoundBest+alg_ext lr 2e-3 | 0.4539 | 0.4138 | 0.4999 | 0.3071 | 0.6233 | 0.5279 | 0.2364 | 0.3231 | 0.6993 |
Loading