mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Torchao floatx version guard (#12923)
* Adding torchao version guard for floatx usage Summary: TorchAO removing floatx support, added version guard in quantization_config.py * Adding torchao version guard for floatx usage Summary: TorchAO removing floatx support, added version guard in quantization_config.py Altered tests in test_torchao.py to version guard floatx Created new test to verify version guard of floatx support * Adding torchao version guard for floatx usage Summary: TorchAO removing floatx support, added version guard in quantization_config.py Altered tests in test_torchao.py to version guard floatx Created new test to verify version guard of floatx support * Adding torchao version guard for floatx usage Summary: TorchAO removing floatx support, added version guard in quantization_config.py Altered tests in test_torchao.py to version guard floatx Created new test to verify version guard of floatx support * Adding torchao version guard for floatx usage Summary: TorchAO removing floatx support, added version guard in quantization_config.py Altered tests in test_torchao.py to version guard floatx Created new test to verify version guard of floatx support * Adding torchao version guard for floatx usage Summary: TorchAO removing floatx support, added version guard in quantization_config.py Altered tests in test_torchao.py to version guard floatx Created new test to verify version guard of floatx support --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -457,7 +457,7 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
- Shorthands: `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`,
|
||||
`float8_e4m3_tensor`, `float8_e4m3_row`,
|
||||
|
||||
- **Floating point X-bit quantization:**
|
||||
- **Floating point X-bit quantization:** (in torchao <= 0.14.1, not supported in torchao >= 0.15.0)
|
||||
- Full function names: `fpx_weight_only`
|
||||
- Shorthands: `fpX_eAwB`, where `X` is the number of bits (between `1` to `7`), `A` is the number
|
||||
of exponent bits and `B` is the number of mantissa bits. The constraint of `X == A + B + 1` must
|
||||
@@ -531,12 +531,18 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
|
||||
|
||||
if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys():
|
||||
is_floating_quant_type = self.quant_type.startswith("float") or self.quant_type.startswith("fp")
|
||||
if is_floating_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9():
|
||||
is_floatx_quant_type = self.quant_type.startswith("fp")
|
||||
is_float_quant_type = self.quant_type.startswith("float") or is_floatx_quant_type
|
||||
if is_float_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9():
|
||||
raise ValueError(
|
||||
f"Requested quantization type: {self.quant_type} is not supported on GPUs with CUDA capability <= 8.9. You "
|
||||
f"can check the CUDA capability of your GPU using `torch.cuda.get_device_capability()`."
|
||||
)
|
||||
elif is_floatx_quant_type and not is_torchao_version("<=", "0.14.1"):
|
||||
raise ValueError(
|
||||
f"Requested quantization type: {self.quant_type} is only supported in torchao <= 0.14.1. "
|
||||
f"Please downgrade to torchao <= 0.14.1 to use this quantization type."
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"Requested quantization type: {self.quant_type} is not supported or is an incorrect `quant_type` name. If you think the "
|
||||
@@ -622,7 +628,6 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
float8_dynamic_activation_float8_weight,
|
||||
float8_static_activation_float8_weight,
|
||||
float8_weight_only,
|
||||
fpx_weight_only,
|
||||
int4_weight_only,
|
||||
int8_dynamic_activation_int4_weight,
|
||||
int8_dynamic_activation_int8_weight,
|
||||
@@ -630,6 +635,8 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
uintx_weight_only,
|
||||
)
|
||||
|
||||
if is_torchao_version("<=", "0.14.1"):
|
||||
from torchao.quantization import fpx_weight_only
|
||||
# TODO(aryan): Add a note on how to use PerAxis and PerGroup observers
|
||||
from torchao.quantization.observer import PerRow, PerTensor
|
||||
|
||||
@@ -650,18 +657,21 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
return types
|
||||
|
||||
def generate_fpx_quantization_types(bits: int):
|
||||
types = {}
|
||||
if is_torchao_version("<=", "0.14.1"):
|
||||
types = {}
|
||||
|
||||
for ebits in range(1, bits):
|
||||
mbits = bits - ebits - 1
|
||||
types[f"fp{bits}_e{ebits}m{mbits}"] = partial(fpx_weight_only, ebits=ebits, mbits=mbits)
|
||||
for ebits in range(1, bits):
|
||||
mbits = bits - ebits - 1
|
||||
types[f"fp{bits}_e{ebits}m{mbits}"] = partial(fpx_weight_only, ebits=ebits, mbits=mbits)
|
||||
|
||||
non_sign_bits = bits - 1
|
||||
default_ebits = (non_sign_bits + 1) // 2
|
||||
default_mbits = non_sign_bits - default_ebits
|
||||
types[f"fp{bits}"] = partial(fpx_weight_only, ebits=default_ebits, mbits=default_mbits)
|
||||
non_sign_bits = bits - 1
|
||||
default_ebits = (non_sign_bits + 1) // 2
|
||||
default_mbits = non_sign_bits - default_ebits
|
||||
types[f"fp{bits}"] = partial(fpx_weight_only, ebits=default_ebits, mbits=default_mbits)
|
||||
|
||||
return types
|
||||
return types
|
||||
else:
|
||||
raise ValueError("Floating point X-bit quantization is not supported in torchao >= 0.15.0")
|
||||
|
||||
INT4_QUANTIZATION_TYPES = {
|
||||
# int4 weight + bfloat16/float16 activation
|
||||
@@ -710,15 +720,15 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
**generate_float8dq_types(torch.float8_e4m3fn),
|
||||
# float8 weight + float8 activation (static)
|
||||
"float8_static_activation_float8_weight": float8_static_activation_float8_weight,
|
||||
# For fpx, only x <= 8 is supported by default. Other dtypes can be explored by users directly
|
||||
# fpx weight + bfloat16/float16 activation
|
||||
**generate_fpx_quantization_types(3),
|
||||
**generate_fpx_quantization_types(4),
|
||||
**generate_fpx_quantization_types(5),
|
||||
**generate_fpx_quantization_types(6),
|
||||
**generate_fpx_quantization_types(7),
|
||||
}
|
||||
|
||||
if is_torchao_version("<=", "0.14.1"):
|
||||
FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(3))
|
||||
FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(4))
|
||||
FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(5))
|
||||
FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(6))
|
||||
FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(7))
|
||||
|
||||
UINTX_QUANTIZATION_DTYPES = {
|
||||
"uintx_weight_only": uintx_weight_only,
|
||||
"uint1wo": partial(uintx_weight_only, dtype=torch.uint1),
|
||||
|
||||
@@ -256,9 +256,12 @@ class TorchAoTest(unittest.TestCase):
|
||||
# Cutlass fails to initialize for below
|
||||
# ("float8dq_e4m3_row", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
|
||||
# =====
|
||||
("fp4", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
|
||||
("fp6", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
|
||||
])
|
||||
if version.parse(importlib.metadata.version("torchao")) <= version.Version("0.14.1"):
|
||||
QUANTIZATION_TYPES_TO_TEST.extend([
|
||||
("fp4", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
|
||||
("fp6", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
|
||||
])
|
||||
# fmt: on
|
||||
|
||||
for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST:
|
||||
@@ -271,6 +274,34 @@ class TorchAoTest(unittest.TestCase):
|
||||
)
|
||||
self._test_quant_type(quantization_config, expected_slice, model_id)
|
||||
|
||||
@unittest.skip("Skipping floatx quantization tests")
|
||||
def test_floatx_quantization(self):
|
||||
for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]:
|
||||
if TorchAoConfig._is_xpu_or_cuda_capability_atleast_8_9():
|
||||
if version.parse(importlib.metadata.version("torchao")) <= version.Version("0.14.1"):
|
||||
quantization_config = TorchAoConfig(quant_type="fp4", modules_to_not_convert=["x_embedder"])
|
||||
self._test_quant_type(
|
||||
quantization_config,
|
||||
np.array(
|
||||
[
|
||||
0.4648,
|
||||
0.5195,
|
||||
0.5547,
|
||||
0.4180,
|
||||
0.4434,
|
||||
0.6445,
|
||||
0.4316,
|
||||
0.4531,
|
||||
0.5625,
|
||||
]
|
||||
),
|
||||
model_id,
|
||||
)
|
||||
else:
|
||||
# Make sure the correct error is thrown
|
||||
with self.assertRaisesRegex(ValueError, "Please downgrade"):
|
||||
quantization_config = TorchAoConfig(quant_type="fp4", modules_to_not_convert=["x_embedder"])
|
||||
|
||||
def test_int4wo_quant_bfloat16_conversion(self):
|
||||
"""
|
||||
Tests whether the dtype of model will be modified to bfloat16 for int4 weight-only quantization.
|
||||
@@ -794,8 +825,11 @@ class SlowTorchAoTests(unittest.TestCase):
|
||||
if TorchAoConfig._is_xpu_or_cuda_capability_atleast_8_9():
|
||||
QUANTIZATION_TYPES_TO_TEST.extend([
|
||||
("float8wo_e4m3", np.array([0.0546, 0.0722, 0.1328, 0.0468, 0.0585, 0.1367, 0.0605, 0.0703, 0.1328, 0.0625, 0.0703, 0.1445, 0.0585, 0.0703, 0.1406, 0.0605, 0.3496, 0.7109, 0.4843, 0.4042, 0.7226, 0.5000, 0.4160, 0.7031, 0.4824, 0.3886, 0.6757, 0.4667, 0.3710, 0.6679, 0.4902, 0.4238])),
|
||||
("fp5_e3m1", np.array([0.0527, 0.0762, 0.1309, 0.0449, 0.0645, 0.1328, 0.0566, 0.0723, 0.125, 0.0566, 0.0703, 0.1328, 0.0566, 0.0742, 0.1348, 0.0566, 0.3633, 0.7617, 0.5273, 0.4277, 0.7891, 0.5469, 0.4375, 0.8008, 0.5586, 0.4336, 0.7383, 0.5156, 0.3906, 0.6992, 0.5156, 0.4375])),
|
||||
])
|
||||
if version.parse(importlib.metadata.version("torchao")) <= version.Version("0.14.1"):
|
||||
QUANTIZATION_TYPES_TO_TEST.extend([
|
||||
("fp5_e3m1", np.array([0.0527, 0.0762, 0.1309, 0.0449, 0.0645, 0.1328, 0.0566, 0.0723, 0.125, 0.0566, 0.0703, 0.1328, 0.0566, 0.0742, 0.1348, 0.0566, 0.3633, 0.7617, 0.5273, 0.4277, 0.7891, 0.5469, 0.4375, 0.8008, 0.5586, 0.4336, 0.7383, 0.5156, 0.3906, 0.6992, 0.5156, 0.4375])),
|
||||
])
|
||||
# fmt: on
|
||||
|
||||
for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST:
|
||||
|
||||
Reference in New Issue
Block a user