1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[quantization] feat: support aobaseconfig classes in TorchAOConfig (#12275)

* feat: support aobaseconfig classes.

* [docs] AOBaseConfig (#12302)

init

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* up

* replace with is_torchao_version

* up

* up

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
This commit is contained in:
Sayak Paul
2025-09-29 18:04:18 +05:30
committed by GitHub
parent 0a151115bb
commit 64a5187d96
4 changed files with 297 additions and 91 deletions

View File

@@ -11,69 +11,96 @@ specific language governing permissions and limitations under the License. -->
# torchao
[TorchAO](https://github.com/pytorch/ao) is an architecture optimization library for PyTorch. It provides high-performance dtypes, optimization techniques, and kernels for inference and training, featuring composability with native PyTorch features like [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html), FullyShardedDataParallel (FSDP), and more.
[torchao](https://github.com/pytorch/ao) provides high-performance dtypes and optimizations based on quantization and sparsity for inference and training PyTorch models. It is supported for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers.
Before you begin, make sure you have Pytorch 2.5+ and TorchAO installed.
Make sure Pytorch 2.5+ and torchao are installed with the command below.
```bash
pip install -U torch torchao
uv pip install -U torch torchao
```
Each quantization dtype is available as a separate instance of a [AOBaseConfig](https://docs.pytorch.org/ao/main/api_ref_quantization.html#inference-apis-for-quantize) class. This provides more flexible configuration options by exposing more available arguments.
Quantize a model by passing [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`] (you can also load pre-quantized models). This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers.
Pass the `AOBaseConfig` of a quantization dtype, like [Int4WeightOnlyConfig](https://docs.pytorch.org/ao/main/generated/torchao.quantization.Int4WeightOnlyConfig) to [`TorchAoConfig`] in [`~ModelMixin.from_pretrained`].
The example below only quantizes the weights to int8.
```py
import torch
from diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConfig
from torchao.quantization import Int8WeightOnlyConfig
pipeline_quant_config = PipelineQuantizationConfig(
quant_mapping={"transformer": TorchAoConfig(Int8WeightOnlyConfig(group_size=128)))}
)
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
quantzation_config=pipeline_quant_config,
torch_dtype=torch.bfloat16,
device_map="cuda"
)
```
For simple use cases, you could also provide a string identifier in [`TorchAo`] as shown below.
```py
import torch
from diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConfig
pipeline_quant_config = PipelineQuantizationConfig(
quant_mapping={"transformer": TorchAoConfig("int8wo")}
)
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
quantzation_config=pipeline_quant_config,
torch_dtype=torch.bfloat16,
device_map="cuda"
)
```
## torch.compile
torchao supports [torch.compile](../optimization/fp16#torchcompile) which can speed up inference with one line of code.
```python
import torch
from diffusers import FluxPipeline, AutoModel, TorchAoConfig
from diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConfig
from torchao.quantization import Int4WeightOnlyConfig
model_id = "black-forest-labs/FLUX.1-dev"
dtype = torch.bfloat16
quantization_config = TorchAoConfig("int8wo")
transformer = AutoModel.from_pretrained(
model_id,
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=dtype,
pipeline_quant_config = PipelineQuantizationConfig(
quant_mapping={"transformer": TorchAoConfig(Int4WeightOnlyConfig(group_size=128)))}
)
pipe = FluxPipeline.from_pretrained(
model_id,
transformer=transformer,
torch_dtype=dtype,
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
quantzation_config=pipeline_quant_config,
torch_dtype=torch.bfloat16,
device_map="cuda"
)
pipe.to("cuda")
# Without quantization: ~31.447 GB
# With quantization: ~20.40 GB
print(f"Pipeline memory usage: {torch.cuda.max_memory_reserved() / 1024**3:.3f} GB")
prompt = "A cat holding a sign that says hello world"
image = pipe(
prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512
).images[0]
image.save("output.png")
pipeline.transformer.compile(transformer, mode="max-autotune", fullgraph=True)
```
TorchAO is fully compatible with [torch.compile](../optimization/fp16#torchcompile), setting it apart from other quantization methods. This makes it easy to speed up inference with just one line of code.
```python
# In the above code, add the following after initializing the transformer
transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)
```
For speed and memory benchmarks on Flux and CogVideoX, please refer to the table [here](https://github.com/huggingface/diffusers/pull/10009#issue-2688781450). You can also find some torchao [benchmarks](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks) numbers for various hardware.
Refer to this [table](https://github.com/huggingface/diffusers/pull/10009#issue-2688781450) for inference speed and memory usage benchmarks with Flux and CogVideoX. More benchmarks on various hardware are also available in the torchao [repository](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks).
> [!TIP]
> The FP8 post-training quantization schemes in torchao are effective for GPUs with compute capability of at least 8.9 (RTX-4090, Hopper, etc.). FP8 often provides the best speed, memory, and quality trade-off when generating images and videos. We recommend combining FP8 and torch.compile if your GPU is compatible.
torchao also supports an automatic quantization API through [autoquant](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md#autoquantization). Autoquantization determines the best quantization strategy applicable to a model by comparing the performance of each technique on chosen input types and shapes. Currently, this can be used directly on the underlying modeling components. Diffusers will also expose an autoquant configuration option in the future.
## autoquant
The `TorchAoConfig` class accepts three parameters:
- `quant_type`: A string value mentioning one of the quantization types below.
- `modules_to_not_convert`: A list of module full/partial module names for which quantization should not be performed. For example, to not perform any quantization of the [`FluxTransformer2DModel`]'s first block, one would specify: `modules_to_not_convert=["single_transformer_blocks.0"]`.
- `kwargs`: A dict of keyword arguments to pass to the underlying quantization method which will be invoked based on `quant_type`.
torchao provides [autoquant](https://docs.pytorch.org/ao/stable/generated/torchao.quantization.autoquant.html#torchao.quantization.autoquant) an automatic quantization API. Autoquantization chooses the best quantization strategy by comparing the performance of each strategy on chosen input types and shapes. This is only supported in Diffusers for individual models at the moment.
```py
import torch
from diffusers import DiffusionPipeline
from torchao.quantization import autoquant
# Load the pipeline
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=torch.bfloat16,
device_map="cuda"
)
transformer = autoquant(pipeline.transformer)
```
## Supported quantization types

View File

@@ -21,19 +21,20 @@ https://github.com/huggingface/transformers/blob/52cb4034ada381fe1ffe8d428a1076e
"""
import copy
import dataclasses
import importlib.metadata
import inspect
import json
import os
import warnings
from dataclasses import dataclass
from dataclasses import dataclass, is_dataclass
from enum import Enum
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Union
from packaging import version
from ..utils import is_torch_available, is_torchao_available, logging
from ..utils import is_torch_available, is_torchao_available, is_torchao_version, logging
if is_torch_available():
@@ -443,7 +444,7 @@ class TorchAoConfig(QuantizationConfigMixin):
"""This is a config class for torchao quantization/sparsity techniques.
Args:
quant_type (`str`):
quant_type (Union[`str`, AOBaseConfig]):
The type of quantization we want to use, currently supporting:
- **Integer quantization:**
- Full function names: `int4_weight_only`, `int8_dynamic_activation_int4_weight`,
@@ -465,6 +466,7 @@ class TorchAoConfig(QuantizationConfigMixin):
- **Unsigned Integer quantization:**
- Full function names: `uintx_weight_only`
- Shorthands: `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo`
- An AOBaseConfig instance: for more advanced configuration options.
modules_to_not_convert (`List[str]`, *optional*, default to `None`):
The list of modules to not quantize, useful for quantizing models that explicitly require to have some
modules left in their original precision.
@@ -478,6 +480,12 @@ class TorchAoConfig(QuantizationConfigMixin):
```python
from diffusers import FluxTransformer2DModel, TorchAoConfig
# AOBaseConfig-based configuration
from torchao.quantization import Int8WeightOnlyConfig
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
# String-based config
quantization_config = TorchAoConfig("int8wo")
transformer = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/Flux.1-Dev",
@@ -490,7 +498,7 @@ class TorchAoConfig(QuantizationConfigMixin):
def __init__(
self,
quant_type: str,
quant_type: Union[str, "AOBaseConfig"], # noqa: F821
modules_to_not_convert: Optional[List[str]] = None,
**kwargs,
) -> None:
@@ -504,34 +512,103 @@ class TorchAoConfig(QuantizationConfigMixin):
else:
self.quant_type_kwargs = kwargs
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys():
is_floating_quant_type = self.quant_type.startswith("float") or self.quant_type.startswith("fp")
if is_floating_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9():
self.post_init()
def post_init(self):
if not isinstance(self.quant_type, str):
if is_torchao_version("<=", "0.9.0"):
raise ValueError(
f"Requested quantization type: {self.quant_type} is not supported on GPUs with CUDA capability <= 8.9. You "
f"can check the CUDA capability of your GPU using `torch.cuda.get_device_capability()`."
f"torchao <= 0.9.0 only supports string quant_type, got {type(self.quant_type).__name__}. "
f"Upgrade to torchao > 0.9.0 to use AOBaseConfig."
)
raise ValueError(
f"Requested quantization type: {self.quant_type} is not supported or is an incorrect `quant_type` name. If you think the "
f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues."
)
from torchao.quantization.quant_api import AOBaseConfig
method = TORCHAO_QUANT_TYPE_METHODS[self.quant_type]
signature = inspect.signature(method)
all_kwargs = {
param.name
for param in signature.parameters.values()
if param.kind in [inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD]
}
unsupported_kwargs = list(self.quant_type_kwargs.keys() - all_kwargs)
if not isinstance(self.quant_type, AOBaseConfig):
raise TypeError(f"quant_type must be a AOBaseConfig instance, got {type(self.quant_type).__name__}")
if len(unsupported_kwargs) > 0:
raise ValueError(
f'The quantization method "{quant_type}" does not support the following keyword arguments: '
f"{unsupported_kwargs}. The following keywords arguments are supported: {all_kwargs}."
)
elif isinstance(self.quant_type, str):
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys():
is_floating_quant_type = self.quant_type.startswith("float") or self.quant_type.startswith("fp")
if is_floating_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9():
raise ValueError(
f"Requested quantization type: {self.quant_type} is not supported on GPUs with CUDA capability <= 8.9. You "
f"can check the CUDA capability of your GPU using `torch.cuda.get_device_capability()`."
)
raise ValueError(
f"Requested quantization type: {self.quant_type} is not supported or is an incorrect `quant_type` name. If you think the "
f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues."
)
method = TORCHAO_QUANT_TYPE_METHODS[self.quant_type]
signature = inspect.signature(method)
all_kwargs = {
param.name
for param in signature.parameters.values()
if param.kind in [inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD]
}
unsupported_kwargs = list(self.quant_type_kwargs.keys() - all_kwargs)
if len(unsupported_kwargs) > 0:
raise ValueError(
f'The quantization method "{self.quant_type}" does not support the following keyword arguments: '
f"{unsupported_kwargs}. The following keywords arguments are supported: {all_kwargs}."
)
def to_dict(self):
"""Convert configuration to a dictionary."""
d = super().to_dict()
if isinstance(self.quant_type, str):
# Handle layout serialization if present
if "quant_type_kwargs" in d and "layout" in d["quant_type_kwargs"]:
if is_dataclass(d["quant_type_kwargs"]["layout"]):
d["quant_type_kwargs"]["layout"] = [
d["quant_type_kwargs"]["layout"].__class__.__name__,
dataclasses.asdict(d["quant_type_kwargs"]["layout"]),
]
if isinstance(d["quant_type_kwargs"]["layout"], list):
assert len(d["quant_type_kwargs"]["layout"]) == 2, "layout saves layout name and layout kwargs"
assert isinstance(d["quant_type_kwargs"]["layout"][0], str), "layout name must be a string"
assert isinstance(d["quant_type_kwargs"]["layout"][1], dict), "layout kwargs must be a dict"
else:
raise ValueError("layout must be a list")
else:
# Handle AOBaseConfig serialization
from torchao.core.config import config_to_dict
# For now we assume there is 1 config per Transformer, however in the future
# We may want to support a config per fqn.
d["quant_type"] = {"default": config_to_dict(self.quant_type)}
return d
@classmethod
def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
"""Create configuration from a dictionary."""
if not is_torchao_version(">", "0.9.0"):
raise NotImplementedError("TorchAoConfig requires torchao > 0.9.0 for construction from dict")
config_dict = config_dict.copy()
quant_type = config_dict.pop("quant_type")
if isinstance(quant_type, str):
return cls(quant_type=quant_type, **config_dict)
# Check if we only have one key which is "default"
# In the future we may update this
assert len(quant_type) == 1 and "default" in quant_type, (
"Expected only one key 'default' in quant_type dictionary"
)
quant_type = quant_type["default"]
# Deserialize quant_type if needed
from torchao.core.config import config_from_dict
quant_type = config_from_dict(quant_type)
return cls(quant_type=quant_type, **config_dict)
@classmethod
def _get_torchao_quant_type_to_method(cls):
@@ -681,8 +758,38 @@ class TorchAoConfig(QuantizationConfigMixin):
raise RuntimeError("TorchAO requires a CUDA compatible GPU or Intel XPU and installation of PyTorch.")
def get_apply_tensor_subclass(self):
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
return TORCHAO_QUANT_TYPE_METHODS[self.quant_type](**self.quant_type_kwargs)
"""Create the appropriate quantization method based on configuration."""
if not isinstance(self.quant_type, str):
return self.quant_type
else:
methods = self._get_torchao_quant_type_to_method()
quant_type_kwargs = self.quant_type_kwargs.copy()
if (
not torch.cuda.is_available()
and is_torchao_available()
and self.quant_type == "int4_weight_only"
and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0")
and quant_type_kwargs.get("layout", None) is None
):
if torch.xpu.is_available():
if version.parse(importlib.metadata.version("torchao")) >= version.parse(
"0.11.0"
) and version.parse(importlib.metadata.version("torch")) > version.parse("2.7.9"):
from torchao.dtypes import Int4XPULayout
from torchao.quantization.quant_primitives import ZeroPointDomain
quant_type_kwargs["layout"] = Int4XPULayout()
quant_type_kwargs["zero_point_domain"] = ZeroPointDomain.INT
else:
raise ValueError(
"TorchAoConfig requires torchao >= 0.11.0 and torch >= 2.8.0 for XPU support. Please upgrade the version or use run on CPU with the cpu version pytorch."
)
else:
from torchao.dtypes import Int4CPULayout
quant_type_kwargs["layout"] = Int4CPULayout()
return methods[self.quant_type](**quant_type_kwargs)
def __repr__(self):
r"""

View File

@@ -18,9 +18,10 @@ https://github.com/huggingface/transformers/blob/3a8eb74668e9c2cc563b2f5c62fac17
"""
import importlib
import re
import types
from fnmatch import fnmatch
from typing import TYPE_CHECKING, Any, Dict, List, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from packaging import version
@@ -107,6 +108,21 @@ if (
_update_torch_safe_globals()
def fuzzy_match_size(config_name: str) -> Optional[str]:
"""
Extract the size digit from strings like "4weight", "8weight". Returns the digit as an integer if found, otherwise
None.
"""
config_name = config_name.lower()
str_match = re.search(r"(\d)weight", config_name)
if str_match:
return str_match.group(1)
return None
logger = logging.get_logger(__name__)
@@ -176,8 +192,7 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
def update_torch_dtype(self, torch_dtype):
quant_type = self.quantization_config.quant_type
if quant_type.startswith("int") or quant_type.startswith("uint"):
if isinstance(quant_type, str) and (quant_type.startswith("int") or quant_type.startswith("uint")):
if torch_dtype is not None and torch_dtype != torch.bfloat16:
logger.warning(
f"You are trying to set torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but "
@@ -197,24 +212,44 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
quant_type = self.quantization_config.quant_type
from accelerate.utils import CustomDtype
if quant_type.startswith("int8") or quant_type.startswith("int4"):
# Note that int4 weights are created by packing into torch.int8, but since there is no torch.int4, we use torch.int8
return torch.int8
elif quant_type == "uintx_weight_only":
return self.quantization_config.quant_type_kwargs.get("dtype", torch.uint8)
elif quant_type.startswith("uint"):
return {
1: torch.uint1,
2: torch.uint2,
3: torch.uint3,
4: torch.uint4,
5: torch.uint5,
6: torch.uint6,
7: torch.uint7,
}[int(quant_type[4])]
elif quant_type.startswith("float") or quant_type.startswith("fp"):
return torch.bfloat16
if isinstance(quant_type, str):
if quant_type.startswith("int8"):
# Note that int4 weights are created by packing into torch.int8, but since there is no torch.int4, we use torch.int8
return torch.int8
elif quant_type.startswith("int4"):
return CustomDtype.INT4
elif quant_type == "uintx_weight_only":
return self.quantization_config.quant_type_kwargs.get("dtype", torch.uint8)
elif quant_type.startswith("uint"):
return {
1: torch.uint1,
2: torch.uint2,
3: torch.uint3,
4: torch.uint4,
5: torch.uint5,
6: torch.uint6,
7: torch.uint7,
}[int(quant_type[4])]
elif quant_type.startswith("float") or quant_type.startswith("fp"):
return torch.bfloat16
elif is_torchao_version(">", "0.9.0"):
from torchao.core.config import AOBaseConfig
quant_type = self.quantization_config.quant_type
if isinstance(quant_type, AOBaseConfig):
# Extract size digit using fuzzy match on the class name
config_name = quant_type.__class__.__name__
size_digit = fuzzy_match_size(config_name)
# Map the extracted digit to appropriate dtype
if size_digit == "4":
return CustomDtype.INT4
else:
# Default to int8
return torch.int8
if isinstance(target_dtype, SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION):
return target_dtype
@@ -297,6 +332,21 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
# Original mapping for non-AOBaseConfig types
# For the uint types, this is a best guess. Once these types become more used
# we can look into their nuances.
if is_torchao_version(">", "0.9.0"):
from torchao.core.config import AOBaseConfig
quant_type = self.quantization_config.quant_type
# For autoquant case, it will be treated in the string implementation below in map_to_target_dtype
if isinstance(quant_type, AOBaseConfig):
# Extract size digit using fuzzy match on the class name
config_name = quant_type.__class__.__name__
size_digit = fuzzy_match_size(config_name)
if size_digit == "4":
return 8
else:
return 4
map_to_target_dtype = {"int4_*": 8, "int8_*": 4, "uint*": 8, "float8*": 4}
quant_type = self.quantization_config.quant_type
for pattern, target_dtype in map_to_target_dtype.items():

View File

@@ -14,11 +14,13 @@
# limitations under the License.
import gc
import importlib.metadata
import tempfile
import unittest
from typing import List
import numpy as np
from packaging import version
from parameterized import parameterized
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
@@ -65,6 +67,9 @@ if is_torchao_available():
from torchao.quantization.quant_primitives import MappingType
from torchao.utils import get_model_size_in_bytes
if version.parse(importlib.metadata.version("torchao")) >= version.Version("0.9.0"):
from torchao.quantization import Int8WeightOnlyConfig
@require_torch
@require_torch_accelerator
@@ -522,6 +527,15 @@ class TorchAoTest(unittest.TestCase):
inputs = self.get_dummy_inputs(torch_device)
_ = pipe(**inputs)
@require_torchao_version_greater_or_equal("0.9.0")
def test_aobase_config(self):
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
components = self.get_dummy_components(quantization_config)
pipe = FluxPipeline(**components).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
_ = pipe(**inputs)
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch
@@ -628,6 +642,14 @@ class TorchAoSerializationTest(unittest.TestCase):
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
@require_torchao_version_greater_or_equal("0.9.0")
def test_aobase_config(self):
quant_method, quant_method_kwargs = Int8WeightOnlyConfig(), {}
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
device = torch_device
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
@require_torchao_version_greater_or_equal("0.7.0")
class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):