1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Fix TorchAO related bugs; revert device_map changes (#10371)

* Revert "Add support for sharded models when TorchAO quantization is enabled (#10256)"

This reverts commit 41ba8c0bf6.

* update tests

* udpate

* update

* update

* update device map tests

* apply review suggestions

* update

* make style

* fix

* update docs

* update tests

* update workflow

* update

* improve tests

* allclose tolerance

* Update src/diffusers/models/modeling_utils.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update tests/quantization/torchao/test_torchao.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* improve tests

* fix

* update correct slices

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
Aryan
2024-12-25 15:37:49 +05:30
committed by GitHub
parent 825979ddc3
commit cd991d1e1a
5 changed files with 339 additions and 114 deletions

View File

@@ -359,6 +359,8 @@ jobs:
test_location: "bnb"
- backend: "gguf"
test_location: "gguf"
- backend: "torchao"
test_location: "torchao"
runs-on:
group: aws-g6e-xlarge-plus
container:

View File

@@ -25,6 +25,7 @@ Quantize a model by passing [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`]
The example below only quantizes the weights to int8.
```python
import torch
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig
model_id = "black-forest-labs/FLUX.1-dev"
@@ -44,6 +45,10 @@ pipe = FluxPipeline.from_pretrained(
)
pipe.to("cuda")
# Without quantization: ~31.447 GB
# With quantization: ~20.40 GB
print(f"Pipeline memory usage: {torch.cuda.max_memory_reserved() / 1024**3:.3f} GB")
prompt = "A cat holding a sign that says hello world"
image = pipe(
prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512
@@ -88,6 +93,63 @@ Some quantization methods are aliases (for example, `int8wo` is the commonly use
Refer to the official torchao documentation for a better understanding of the available quantization methods and the exhaustive list of configuration options available.
## Serializing and Deserializing quantized models
To serialize a quantized model in a given dtype, first load the model with the desired quantization dtype and then save it using the [`~ModelMixin.save_pretrained`] method.
```python
import torch
from diffusers import FluxTransformer2DModel, TorchAoConfig
quantization_config = TorchAoConfig("int8wo")
transformer = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/Flux.1-Dev",
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
)
transformer.save_pretrained("/path/to/flux_int8wo", safe_serialization=False)
```
To load a serialized quantized model, use the [`~ModelMixin.from_pretrained`] method.
```python
import torch
from diffusers import FluxPipeline, FluxTransformer2DModel
transformer = FluxTransformer2DModel.from_pretrained("/path/to/flux_int8wo", torch_dtype=torch.bfloat16, use_safetensors=False)
pipe = FluxPipeline.from_pretrained("black-forest-labs/Flux.1-Dev", transformer=transformer, torch_dtype=torch.bfloat16)
pipe.to("cuda")
prompt = "A cat holding a sign that says hello world"
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.0).images[0]
image.save("output.png")
```
Some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source.
```python
import torch
from accelerate import init_empty_weights
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig
# Serialize the model
transformer = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/Flux.1-Dev",
subfolder="transformer",
quantization_config=TorchAoConfig("uint4wo"),
torch_dtype=torch.bfloat16,
)
transformer.save_pretrained("/path/to/flux_uint4wo", safe_serialization=False, max_shard_size="50GB")
# ...
# Load the model
state_dict = torch.load("/path/to/flux_uint4wo/diffusion_pytorch_model.bin", weights_only=False, map_location="cpu")
with init_empty_weights():
transformer = FluxTransformer2DModel.from_config("/path/to/flux_uint4wo/config.json")
transformer.load_state_dict(state_dict, strict=True, assign=True)
```
## Resources
- [TorchAO Quantization API](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md)

View File

@@ -718,10 +718,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
hf_quantizer = None
if hf_quantizer is not None:
is_bnb_quantization_method = hf_quantizer.quantization_config.quant_method.value == "bitsandbytes"
if is_bnb_quantization_method and device_map is not None:
if device_map is not None:
raise NotImplementedError(
"Currently, `device_map` is automatically inferred for quantized bitsandbytes models. Support for providing `device_map` as an input will be added in the future."
"Currently, providing `device_map` is not supported for quantized models. Providing `device_map` as an input will be added in the future."
)
hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map)
@@ -820,7 +819,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
revision=revision,
subfolder=subfolder or "",
)
if hf_quantizer is not None and is_bnb_quantization_method:
# TODO: https://github.com/huggingface/diffusers/issues/10013
if hf_quantizer is not None:
model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata)
logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.")
is_sharded = False

View File

@@ -132,7 +132,7 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
def update_torch_dtype(self, torch_dtype):
quant_type = self.quantization_config.quant_type
if quant_type.startswith("int"):
if quant_type.startswith("int") or quant_type.startswith("uint"):
if torch_dtype is not None and torch_dtype != torch.bfloat16:
logger.warning(
f"You are trying to set torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but "

View File

@@ -131,8 +131,9 @@ class TorchAoTest(unittest.TestCase):
gc.collect()
torch.cuda.empty_cache()
def get_dummy_components(self, quantization_config: TorchAoConfig):
model_id = "hf-internal-testing/tiny-flux-pipe"
def get_dummy_components(
self, quantization_config: TorchAoConfig, model_id: str = "hf-internal-testing/tiny-flux-pipe"
):
transformer = FluxTransformer2DModel.from_pretrained(
model_id,
subfolder="transformer",
@@ -211,8 +212,8 @@ class TorchAoTest(unittest.TestCase):
"timestep": timestep,
}
def _test_quant_type(self, quantization_config: TorchAoConfig, expected_slice: List[float]):
components = self.get_dummy_components(quantization_config)
def _test_quant_type(self, quantization_config: TorchAoConfig, expected_slice: List[float], model_id: str):
components = self.get_dummy_components(quantization_config, model_id)
pipe = FluxPipeline(**components)
pipe.to(device=torch_device)
@@ -223,44 +224,45 @@ class TorchAoTest(unittest.TestCase):
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
def test_quantization(self):
# fmt: off
QUANTIZATION_TYPES_TO_TEST = [
("int4wo", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6445, 0.4336, 0.4531, 0.5625])),
("int4dq", np.array([0.4688, 0.5195, 0.5547, 0.418, 0.4414, 0.6406, 0.4336, 0.4531, 0.5625])),
("int8wo", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
("int8dq", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
("uint4wo", np.array([0.4609, 0.5234, 0.5508, 0.4199, 0.4336, 0.6406, 0.4316, 0.4531, 0.5625])),
("uint7wo", np.array([0.4648, 0.5195, 0.5547, 0.4219, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
]
for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]:
# fmt: off
QUANTIZATION_TYPES_TO_TEST = [
("int4wo", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6445, 0.4336, 0.4531, 0.5625])),
("int4dq", np.array([0.4688, 0.5195, 0.5547, 0.418, 0.4414, 0.6406, 0.4336, 0.4531, 0.5625])),
("int8wo", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
("int8dq", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
("uint4wo", np.array([0.4609, 0.5234, 0.5508, 0.4199, 0.4336, 0.6406, 0.4316, 0.4531, 0.5625])),
("uint7wo", np.array([0.4648, 0.5195, 0.5547, 0.4219, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
]
if TorchAoConfig._is_cuda_capability_atleast_8_9():
QUANTIZATION_TYPES_TO_TEST.extend([
("float8wo_e5m2", np.array([0.4590, 0.5273, 0.5547, 0.4219, 0.4375, 0.6406, 0.4316, 0.4512, 0.5625])),
("float8wo_e4m3", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6406, 0.4316, 0.4531, 0.5625])),
# =====
# The following lead to an internal torch error:
# RuntimeError: mat2 shape (32x4 must be divisible by 16
# Skip these for now; TODO(aryan): investigate later
# ("float8dq_e4m3", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
# ("float8dq_e4m3_tensor", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
# =====
# Cutlass fails to initialize for below
# ("float8dq_e4m3_row", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
# =====
("fp4", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
("fp6", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
])
# fmt: on
if TorchAoConfig._is_cuda_capability_atleast_8_9():
QUANTIZATION_TYPES_TO_TEST.extend([
("float8wo_e5m2", np.array([0.4590, 0.5273, 0.5547, 0.4219, 0.4375, 0.6406, 0.4316, 0.4512, 0.5625])),
("float8wo_e4m3", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6406, 0.4316, 0.4531, 0.5625])),
# =====
# The following lead to an internal torch error:
# RuntimeError: mat2 shape (32x4 must be divisible by 16
# Skip these for now; TODO(aryan): investigate later
# ("float8dq_e4m3", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
# ("float8dq_e4m3_tensor", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
# =====
# Cutlass fails to initialize for below
# ("float8dq_e4m3_row", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
# =====
("fp4", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
("fp6", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
])
# fmt: on
for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST:
quant_kwargs = {}
if quantization_name in ["uint4wo", "uint7wo"]:
# The dummy flux model that we use has smaller dimensions. This imposes some restrictions on group_size here
quant_kwargs.update({"group_size": 16})
quantization_config = TorchAoConfig(
quant_type=quantization_name, modules_to_not_convert=["x_embedder"], **quant_kwargs
)
self._test_quant_type(quantization_config, expected_slice)
for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST:
quant_kwargs = {}
if quantization_name in ["uint4wo", "uint7wo"]:
# The dummy flux model that we use has smaller dimensions. This imposes some restrictions on group_size here
quant_kwargs.update({"group_size": 16})
quantization_config = TorchAoConfig(
quant_type=quantization_name, modules_to_not_convert=["x_embedder"], **quant_kwargs
)
self._test_quant_type(quantization_config, expected_slice, model_id)
def test_int4wo_quant_bfloat16_conversion(self):
"""
@@ -280,12 +282,14 @@ class TorchAoTest(unittest.TestCase):
self.assertEqual(weight.quant_max, 15)
def test_device_map(self):
# Note: We were not checking if the weight tensor's were AffineQuantizedTensor's before. If we did
# it would have errored out. Now, we do. So, device_map basically never worked with or without
# sharded checkpoints. This will need to be supported in the future (TODO(aryan))
"""
Test if the quantized model int4 weight-only is working properly with "auto" and custom device maps.
The custom device map performs cpu/disk offloading as well. Also verifies that the device map is
correctly set (in the `hf_device_map` attribute of the model).
"""
custom_device_map_dict = {
"time_text_embed": torch_device,
"context_embedder": torch_device,
@@ -297,48 +301,54 @@ class TorchAoTest(unittest.TestCase):
}
device_maps = ["auto", custom_device_map_dict]
inputs = self.get_dummy_tensor_inputs(torch_device)
expected_slice = np.array([0.3457, -0.0366, 0.0105, -0.2275, -0.4941, 0.4395, -0.166, -0.6641, 0.4375])
# inputs = self.get_dummy_tensor_inputs(torch_device)
# expected_slice = np.array([0.3457, -0.0366, 0.0105, -0.2275, -0.4941, 0.4395, -0.166, -0.6641, 0.4375])
for device_map in device_maps:
device_map_to_compare = {"": 0} if device_map == "auto" else device_map
# device_map_to_compare = {"": 0} if device_map == "auto" else device_map
# Test non-sharded model
with tempfile.TemporaryDirectory() as offload_folder:
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
quantized_model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-pipe",
subfolder="transformer",
quantization_config=quantization_config,
device_map=device_map,
torch_dtype=torch.bfloat16,
offload_folder=offload_folder,
)
# Test non-sharded model - should work
with self.assertRaises(NotImplementedError):
with tempfile.TemporaryDirectory() as offload_folder:
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
_ = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-pipe",
subfolder="transformer",
quantization_config=quantization_config,
device_map=device_map,
torch_dtype=torch.bfloat16,
offload_folder=offload_folder,
)
self.assertTrue(quantized_model.hf_device_map == device_map_to_compare)
# weight = quantized_model.transformer_blocks[0].ff.net[2].weight
# self.assertTrue(quantized_model.hf_device_map == device_map_to_compare)
# self.assertTrue(isinstance(weight, AffineQuantizedTensor))
output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
# output = quantized_model(**inputs)[0]
# output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
# self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
# Test sharded model
with tempfile.TemporaryDirectory() as offload_folder:
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
quantized_model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-sharded",
subfolder="transformer",
quantization_config=quantization_config,
device_map=device_map,
torch_dtype=torch.bfloat16,
offload_folder=offload_folder,
)
# Test sharded model - should not work
with self.assertRaises(NotImplementedError):
with tempfile.TemporaryDirectory() as offload_folder:
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
_ = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-sharded",
subfolder="transformer",
quantization_config=quantization_config,
device_map=device_map,
torch_dtype=torch.bfloat16,
offload_folder=offload_folder,
)
self.assertTrue(quantized_model.hf_device_map == device_map_to_compare)
# weight = quantized_model.transformer_blocks[0].ff.net[2].weight
# self.assertTrue(quantized_model.hf_device_map == device_map_to_compare)
# self.assertTrue(isinstance(weight, AffineQuantizedTensor))
output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
# output = quantized_model(**inputs)[0]
# output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
# self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
def test_modules_to_not_convert(self):
quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"])
@@ -404,43 +414,63 @@ class TorchAoTest(unittest.TestCase):
@nightly
def test_torch_compile(self):
r"""Test that verifies if torch.compile works with torchao quantization."""
quantization_config = TorchAoConfig("int8_weight_only")
components = self.get_dummy_components(quantization_config)
pipe = FluxPipeline(**components)
pipe.to(device=torch_device, dtype=torch.bfloat16)
for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]:
quantization_config = TorchAoConfig("int8_weight_only")
components = self.get_dummy_components(quantization_config, model_id=model_id)
pipe = FluxPipeline(**components)
pipe.to(device=torch_device)
inputs = self.get_dummy_inputs(torch_device)
normal_output = pipe(**inputs)[0].flatten()[-32:]
inputs = self.get_dummy_inputs(torch_device)
normal_output = pipe(**inputs)[0].flatten()[-32:]
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True, dynamic=False)
inputs = self.get_dummy_inputs(torch_device)
compile_output = pipe(**inputs)[0].flatten()[-32:]
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True, dynamic=False)
inputs = self.get_dummy_inputs(torch_device)
compile_output = pipe(**inputs)[0].flatten()[-32:]
# Note: Seems to require higher tolerance
self.assertTrue(np.allclose(normal_output, compile_output, atol=1e-2, rtol=1e-3))
# Note: Seems to require higher tolerance
self.assertTrue(np.allclose(normal_output, compile_output, atol=1e-2, rtol=1e-3))
def test_memory_footprint(self):
r"""
A simple test to check if the model conversion has been done correctly by checking on the
memory footprint of the converted model and the class type of the linear layers of the converted models
"""
transformer_int4wo = self.get_dummy_components(TorchAoConfig("int4wo"))["transformer"]
transformer_int4wo_gs32 = self.get_dummy_components(TorchAoConfig("int4wo", group_size=32))["transformer"]
transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"))["transformer"]
transformer_bf16 = self.get_dummy_components(None)["transformer"]
for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]:
transformer_int4wo = self.get_dummy_components(TorchAoConfig("int4wo"), model_id=model_id)["transformer"]
transformer_int4wo_gs32 = self.get_dummy_components(
TorchAoConfig("int4wo", group_size=32), model_id=model_id
)["transformer"]
transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"), model_id=model_id)["transformer"]
transformer_bf16 = self.get_dummy_components(None, model_id=model_id)["transformer"]
total_int4wo = get_model_size_in_bytes(transformer_int4wo)
total_int4wo_gs32 = get_model_size_in_bytes(transformer_int4wo_gs32)
total_int8wo = get_model_size_in_bytes(transformer_int8wo)
total_bf16 = get_model_size_in_bytes(transformer_bf16)
# Will not quantized all the layers by default due to the model weights shapes not being divisible by group_size=64
for block in transformer_int4wo.transformer_blocks:
self.assertTrue(isinstance(block.ff.net[2].weight, AffineQuantizedTensor))
self.assertTrue(isinstance(block.ff_context.net[2].weight, AffineQuantizedTensor))
# Latter has smaller group size, so more groups -> more scales and zero points
self.assertTrue(total_int4wo < total_int4wo_gs32)
# int8 quantizes more layers compare to int4 with default group size
self.assertTrue(total_int8wo < total_int4wo)
# int4wo does not quantize too many layers because of default group size, but for the layers it does
# there is additional overhead of scales and zero points
self.assertTrue(total_bf16 < total_int4wo)
# Will quantize all the linear layers except x_embedder
for name, module in transformer_int4wo_gs32.named_modules():
if isinstance(module, nn.Linear) and name not in ["x_embedder"]:
self.assertTrue(isinstance(module.weight, AffineQuantizedTensor))
# Will quantize all the linear layers
for module in transformer_int8wo.modules():
if isinstance(module, nn.Linear):
self.assertTrue(isinstance(module.weight, AffineQuantizedTensor))
total_int4wo = get_model_size_in_bytes(transformer_int4wo)
total_int4wo_gs32 = get_model_size_in_bytes(transformer_int4wo_gs32)
total_int8wo = get_model_size_in_bytes(transformer_int8wo)
total_bf16 = get_model_size_in_bytes(transformer_bf16)
# TODO: refactor to align with other quantization tests
# Latter has smaller group size, so more groups -> more scales and zero points
self.assertTrue(total_int4wo < total_int4wo_gs32)
# int8 quantizes more layers compare to int4 with default group size
self.assertTrue(total_int8wo < total_int4wo)
# int4wo does not quantize too many layers because of default group size, but for the layers it does
# there is additional overhead of scales and zero points
self.assertTrue(total_bf16 < total_int4wo)
def test_wrong_config(self):
with self.assertRaises(ValueError):
@@ -500,6 +530,8 @@ class TorchAoSerializationTest(unittest.TestCase):
inputs = self.get_dummy_tensor_inputs(torch_device)
output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
weight = quantized_model.transformer_blocks[0].ff.net[2].weight
self.assertTrue(isinstance(weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor)))
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
def _check_serialization_expected_slice(self, quant_method, quant_method_kwargs, expected_slice, device):
@@ -508,8 +540,8 @@ class TorchAoSerializationTest(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmp_dir:
quantized_model.save_pretrained(tmp_dir, safe_serialization=False)
loaded_quantized_model = FluxTransformer2DModel.from_pretrained(
tmp_dir, torch_dtype=torch.bfloat16, device_map=torch_device, use_safetensors=False
)
tmp_dir, torch_dtype=torch.bfloat16, use_safetensors=False
).to(device=torch_device)
inputs = self.get_dummy_tensor_inputs(torch_device)
output = loaded_quantized_model(**inputs)[0]
@@ -563,20 +595,25 @@ class SlowTorchAoTests(unittest.TestCase):
torch.cuda.empty_cache()
def get_dummy_components(self, quantization_config: TorchAoConfig):
# This is just for convenience, so that we can modify it at one place for custom environments and locally testing
cache_dir = None
model_id = "black-forest-labs/FLUX.1-dev"
transformer = FluxTransformer2DModel.from_pretrained(
model_id,
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
cache_dir=cache_dir,
)
text_encoder = CLIPTextModel.from_pretrained(
model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16, cache_dir=cache_dir
)
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
text_encoder_2 = T5EncoderModel.from_pretrained(
model_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16
model_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16, cache_dir=cache_dir
)
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2")
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.bfloat16)
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer", cache_dir=cache_dir)
tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2", cache_dir=cache_dir)
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.bfloat16, cache_dir=cache_dir)
scheduler = FlowMatchEulerDiscreteScheduler()
return {
@@ -611,10 +648,12 @@ class SlowTorchAoTests(unittest.TestCase):
pipe = FluxPipeline(**components)
pipe.enable_model_cpu_offload()
weight = pipe.transformer.transformer_blocks[0].ff.net[2].weight
self.assertTrue(isinstance(weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor)))
inputs = self.get_dummy_inputs(torch_device)
output = pipe(**inputs)[0].flatten()
output_slice = np.concatenate((output[:16], output[-16:]))
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
def test_quantization(self):
@@ -627,7 +666,7 @@ class SlowTorchAoTests(unittest.TestCase):
if TorchAoConfig._is_cuda_capability_atleast_8_9():
QUANTIZATION_TYPES_TO_TEST.extend([
("float8wo_e4m3", np.array([0.0546, 0.0722, 0.1328, 0.0468, 0.0585, 0.1367, 0.0605, 0.0703, 0.1328, 0.0625, 0.0703, 0.1445, 0.0585, 0.0703, 0.1406, 0.0605, 0.3496, 0.7109, 0.4843, 0.4042, 0.7226, 0.5000, 0.4160, 0.7031, 0.4824, 0.3886, 0.6757, 0.4667, 0.3710, 0.6679, 0.4902, 0.4238])),
("fp5_e3m1", np.array([0.0527, 0.0742, 0.1289, 0.0449, 0.0625, 0.1308, 0.0585, 0.0742, 0.1269, 0.0585, 0.0722, 0.1328, 0.0566, 0.0742, 0.1347, 0.0585, 0.3691, 0.7578, 0.5429, 0.4355, 0.7695, 0.5546, 0.4414, 0.7578, 0.5468, 0.4179, 0.7265, 0.5273, 0.3945, 0.6992, 0.5234, 0.4316])),
("fp5_e3m1", np.array([0.0527, 0.0762, 0.1309, 0.0449, 0.0645, 0.1328, 0.0566, 0.0723, 0.125, 0.0566, 0.0703, 0.1328, 0.0566, 0.0742, 0.1348, 0.0566, 0.3633, 0.7617, 0.5273, 0.4277, 0.7891, 0.5469, 0.4375, 0.8008, 0.5586, 0.4336, 0.7383, 0.5156, 0.3906, 0.6992, 0.5156, 0.4375])),
])
# fmt: on
@@ -637,3 +676,125 @@ class SlowTorchAoTests(unittest.TestCase):
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()
def test_serialization_int8wo(self):
quantization_config = TorchAoConfig("int8wo")
components = self.get_dummy_components(quantization_config)
pipe = FluxPipeline(**components)
pipe.enable_model_cpu_offload()
weight = pipe.transformer.x_embedder.weight
self.assertTrue(isinstance(weight, AffineQuantizedTensor))
inputs = self.get_dummy_inputs(torch_device)
output = pipe(**inputs)[0].flatten()[:128]
with tempfile.TemporaryDirectory() as tmp_dir:
pipe.transformer.save_pretrained(tmp_dir, safe_serialization=False)
pipe.remove_all_hooks()
del pipe.transformer
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()
transformer = FluxTransformer2DModel.from_pretrained(
tmp_dir, torch_dtype=torch.bfloat16, use_safetensors=False
)
pipe.transformer = transformer
pipe.enable_model_cpu_offload()
weight = transformer.x_embedder.weight
self.assertTrue(isinstance(weight, AffineQuantizedTensor))
loaded_output = pipe(**inputs)[0].flatten()[:128]
# Seems to require higher tolerance depending on which machine it is being run.
# A difference of 0.06 in normalized pixel space (-1 to 1), corresponds to a difference of
# 0.06 / 2 * 255 = 7.65 in pixel space (0 to 255). On our CI runners, the difference is about 0.04,
# on DGX it is 0.06, and on audace it is 0.037. So, we are using a tolerance of 0.06 here.
self.assertTrue(np.allclose(output, loaded_output, atol=0.06))
def test_memory_footprint_int4wo(self):
# The original checkpoints are in bf16 and about 24 GB
expected_memory_in_gb = 6.0
quantization_config = TorchAoConfig("int4wo")
cache_dir = None
transformer = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
cache_dir=cache_dir,
)
int4wo_memory_in_gb = get_model_size_in_bytes(transformer) / 1024**3
self.assertTrue(int4wo_memory_in_gb < expected_memory_in_gb)
def test_memory_footprint_int8wo(self):
# The original checkpoints are in bf16 and about 24 GB
expected_memory_in_gb = 12.0
quantization_config = TorchAoConfig("int8wo")
cache_dir = None
transformer = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
cache_dir=cache_dir,
)
int8wo_memory_in_gb = get_model_size_in_bytes(transformer) / 1024**3
self.assertTrue(int8wo_memory_in_gb < expected_memory_in_gb)
@require_torch
@require_torch_gpu
@require_torchao_version_greater_or_equal("0.7.0")
@slow
@nightly
class SlowTorchAoPreserializedModelTests(unittest.TestCase):
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
def get_dummy_inputs(self, device: torch.device, seed: int = 0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator().manual_seed(seed)
inputs = {
"prompt": "an astronaut riding a horse in space",
"height": 512,
"width": 512,
"num_inference_steps": 20,
"output_type": "np",
"generator": generator,
}
return inputs
def test_transformer_int8wo(self):
# fmt: off
expected_slice = np.array([0.0566, 0.0781, 0.1426, 0.0488, 0.0684, 0.1504, 0.0625, 0.0781, 0.1445, 0.0625, 0.0781, 0.1562, 0.0547, 0.0723, 0.1484, 0.0566, 0.5703, 0.8867, 0.7266, 0.5742, 0.875, 0.7148, 0.5586, 0.875, 0.7148, 0.5547, 0.8633, 0.7109, 0.5469, 0.8398, 0.6992, 0.5703])
# fmt: on
# This is just for convenience, so that we can modify it at one place for custom environments and locally testing
cache_dir = None
transformer = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/FLUX.1-Dev-TorchAO-int8wo-transformer",
torch_dtype=torch.bfloat16,
use_safetensors=False,
cache_dir=cache_dir,
)
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16, cache_dir=cache_dir
)
pipe.enable_model_cpu_offload()
# Verify that all linear layer weights are quantized
for name, module in pipe.transformer.named_modules():
if isinstance(module, nn.Linear):
self.assertTrue(isinstance(module.weight, AffineQuantizedTensor))
# Verify outputs match expected slice
inputs = self.get_dummy_inputs(torch_device)
output = pipe(**inputs)[0].flatten()
output_slice = np.concatenate((output[:16], output[-16:]))
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))