1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Torchao floatx version guard (#12923)

* Adding torchao version guard for floatx usage

Summary: TorchAO removing floatx support, added version guard in quantization_config.py

* Adding torchao version guard for floatx usage

Summary: TorchAO removing floatx support, added version guard in quantization_config.py
Altered tests in test_torchao.py to version guard floatx
Created new test to verify version guard of floatx support

* Adding torchao version guard for floatx usage

Summary: TorchAO removing floatx support, added version guard in quantization_config.py
Altered tests in test_torchao.py to version guard floatx
Created new test to verify version guard of floatx support

* Adding torchao version guard for floatx usage

Summary: TorchAO removing floatx support, added version guard in quantization_config.py
Altered tests in test_torchao.py to version guard floatx
Created new test to verify version guard of floatx support

* Adding torchao version guard for floatx usage

Summary: TorchAO removing floatx support, added version guard in quantization_config.py
Altered tests in test_torchao.py to version guard floatx
Created new test to verify version guard of floatx support

* Adding torchao version guard for floatx usage

Summary: TorchAO removing floatx support, added version guard in quantization_config.py
Altered tests in test_torchao.py to version guard floatx
Created new test to verify version guard of floatx support

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
Howard Zhang
2026-01-08 21:21:53 -08:00
committed by GitHub
parent be38f41f9f
commit 2f66edc880
2 changed files with 67 additions and 23 deletions

View File

@@ -457,7 +457,7 @@ class TorchAoConfig(QuantizationConfigMixin):
- Shorthands: `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`,
`float8_e4m3_tensor`, `float8_e4m3_row`,
- **Floating point X-bit quantization:**
- **Floating point X-bit quantization:** (in torchao <= 0.14.1, not supported in torchao >= 0.15.0)
- Full function names: `fpx_weight_only`
- Shorthands: `fpX_eAwB`, where `X` is the number of bits (between `1` to `7`), `A` is the number
of exponent bits and `B` is the number of mantissa bits. The constraint of `X == A + B + 1` must
@@ -531,12 +531,18 @@ class TorchAoConfig(QuantizationConfigMixin):
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys():
is_floating_quant_type = self.quant_type.startswith("float") or self.quant_type.startswith("fp")
if is_floating_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9():
is_floatx_quant_type = self.quant_type.startswith("fp")
is_float_quant_type = self.quant_type.startswith("float") or is_floatx_quant_type
if is_float_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9():
raise ValueError(
f"Requested quantization type: {self.quant_type} is not supported on GPUs with CUDA capability <= 8.9. You "
f"can check the CUDA capability of your GPU using `torch.cuda.get_device_capability()`."
)
elif is_floatx_quant_type and not is_torchao_version("<=", "0.14.1"):
raise ValueError(
f"Requested quantization type: {self.quant_type} is only supported in torchao <= 0.14.1. "
f"Please downgrade to torchao <= 0.14.1 to use this quantization type."
)
raise ValueError(
f"Requested quantization type: {self.quant_type} is not supported or is an incorrect `quant_type` name. If you think the "
@@ -622,7 +628,6 @@ class TorchAoConfig(QuantizationConfigMixin):
float8_dynamic_activation_float8_weight,
float8_static_activation_float8_weight,
float8_weight_only,
fpx_weight_only,
int4_weight_only,
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int8_weight,
@@ -630,6 +635,8 @@ class TorchAoConfig(QuantizationConfigMixin):
uintx_weight_only,
)
if is_torchao_version("<=", "0.14.1"):
from torchao.quantization import fpx_weight_only
# TODO(aryan): Add a note on how to use PerAxis and PerGroup observers
from torchao.quantization.observer import PerRow, PerTensor
@@ -650,18 +657,21 @@ class TorchAoConfig(QuantizationConfigMixin):
return types
def generate_fpx_quantization_types(bits: int):
types = {}
if is_torchao_version("<=", "0.14.1"):
types = {}
for ebits in range(1, bits):
mbits = bits - ebits - 1
types[f"fp{bits}_e{ebits}m{mbits}"] = partial(fpx_weight_only, ebits=ebits, mbits=mbits)
for ebits in range(1, bits):
mbits = bits - ebits - 1
types[f"fp{bits}_e{ebits}m{mbits}"] = partial(fpx_weight_only, ebits=ebits, mbits=mbits)
non_sign_bits = bits - 1
default_ebits = (non_sign_bits + 1) // 2
default_mbits = non_sign_bits - default_ebits
types[f"fp{bits}"] = partial(fpx_weight_only, ebits=default_ebits, mbits=default_mbits)
non_sign_bits = bits - 1
default_ebits = (non_sign_bits + 1) // 2
default_mbits = non_sign_bits - default_ebits
types[f"fp{bits}"] = partial(fpx_weight_only, ebits=default_ebits, mbits=default_mbits)
return types
return types
else:
raise ValueError("Floating point X-bit quantization is not supported in torchao >= 0.15.0")
INT4_QUANTIZATION_TYPES = {
# int4 weight + bfloat16/float16 activation
@@ -710,15 +720,15 @@ class TorchAoConfig(QuantizationConfigMixin):
**generate_float8dq_types(torch.float8_e4m3fn),
# float8 weight + float8 activation (static)
"float8_static_activation_float8_weight": float8_static_activation_float8_weight,
# For fpx, only x <= 8 is supported by default. Other dtypes can be explored by users directly
# fpx weight + bfloat16/float16 activation
**generate_fpx_quantization_types(3),
**generate_fpx_quantization_types(4),
**generate_fpx_quantization_types(5),
**generate_fpx_quantization_types(6),
**generate_fpx_quantization_types(7),
}
if is_torchao_version("<=", "0.14.1"):
FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(3))
FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(4))
FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(5))
FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(6))
FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(7))
UINTX_QUANTIZATION_DTYPES = {
"uintx_weight_only": uintx_weight_only,
"uint1wo": partial(uintx_weight_only, dtype=torch.uint1),

View File

@@ -256,9 +256,12 @@ class TorchAoTest(unittest.TestCase):
# Cutlass fails to initialize for below
# ("float8dq_e4m3_row", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
# =====
("fp4", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
("fp6", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
])
if version.parse(importlib.metadata.version("torchao")) <= version.Version("0.14.1"):
QUANTIZATION_TYPES_TO_TEST.extend([
("fp4", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
("fp6", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
])
# fmt: on
for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST:
@@ -271,6 +274,34 @@ class TorchAoTest(unittest.TestCase):
)
self._test_quant_type(quantization_config, expected_slice, model_id)
@unittest.skip("Skipping floatx quantization tests")
def test_floatx_quantization(self):
for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]:
if TorchAoConfig._is_xpu_or_cuda_capability_atleast_8_9():
if version.parse(importlib.metadata.version("torchao")) <= version.Version("0.14.1"):
quantization_config = TorchAoConfig(quant_type="fp4", modules_to_not_convert=["x_embedder"])
self._test_quant_type(
quantization_config,
np.array(
[
0.4648,
0.5195,
0.5547,
0.4180,
0.4434,
0.6445,
0.4316,
0.4531,
0.5625,
]
),
model_id,
)
else:
# Make sure the correct error is thrown
with self.assertRaisesRegex(ValueError, "Please downgrade"):
quantization_config = TorchAoConfig(quant_type="fp4", modules_to_not_convert=["x_embedder"])
def test_int4wo_quant_bfloat16_conversion(self):
"""
Tests whether the dtype of model will be modified to bfloat16 for int4 weight-only quantization.
@@ -794,8 +825,11 @@ class SlowTorchAoTests(unittest.TestCase):
if TorchAoConfig._is_xpu_or_cuda_capability_atleast_8_9():
QUANTIZATION_TYPES_TO_TEST.extend([
("float8wo_e4m3", np.array([0.0546, 0.0722, 0.1328, 0.0468, 0.0585, 0.1367, 0.0605, 0.0703, 0.1328, 0.0625, 0.0703, 0.1445, 0.0585, 0.0703, 0.1406, 0.0605, 0.3496, 0.7109, 0.4843, 0.4042, 0.7226, 0.5000, 0.4160, 0.7031, 0.4824, 0.3886, 0.6757, 0.4667, 0.3710, 0.6679, 0.4902, 0.4238])),
("fp5_e3m1", np.array([0.0527, 0.0762, 0.1309, 0.0449, 0.0645, 0.1328, 0.0566, 0.0723, 0.125, 0.0566, 0.0703, 0.1328, 0.0566, 0.0742, 0.1348, 0.0566, 0.3633, 0.7617, 0.5273, 0.4277, 0.7891, 0.5469, 0.4375, 0.8008, 0.5586, 0.4336, 0.7383, 0.5156, 0.3906, 0.6992, 0.5156, 0.4375])),
])
if version.parse(importlib.metadata.version("torchao")) <= version.Version("0.14.1"):
QUANTIZATION_TYPES_TO_TEST.extend([
("fp5_e3m1", np.array([0.0527, 0.0762, 0.1309, 0.0449, 0.0645, 0.1328, 0.0566, 0.0723, 0.125, 0.0566, 0.0703, 0.1328, 0.0566, 0.0742, 0.1348, 0.0566, 0.3633, 0.7617, 0.5273, 0.4277, 0.7891, 0.5469, 0.4375, 0.8008, 0.5586, 0.4336, 0.7383, 0.5156, 0.3906, 0.6992, 0.5156, 0.4375])),
])
# fmt: on
for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST: