From cd991d1e1a648cffe894405db02f34059d86809f Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 25 Dec 2024 15:37:49 +0530 Subject: [PATCH] Fix TorchAO related bugs; revert device_map changes (#10371) * Revert "Add support for sharded models when TorchAO quantization is enabled (#10256)" This reverts commit 41ba8c0bf6b3dc3ebd0fa6b96ecf671fa4171566. * update tests * udpate * update * update * update device map tests * apply review suggestions * update * make style * fix * update docs * update tests * update workflow * update * improve tests * allclose tolerance * Update src/diffusers/models/modeling_utils.py Co-authored-by: Sayak Paul * Update tests/quantization/torchao/test_torchao.py Co-authored-by: Sayak Paul * improve tests * fix * update correct slices --------- Co-authored-by: Sayak Paul --- .github/workflows/nightly_tests.yml | 2 + docs/source/en/quantization/torchao.md | 62 +++ src/diffusers/models/modeling_utils.py | 8 +- .../quantizers/torchao/torchao_quantizer.py | 2 +- tests/quantization/torchao/test_torchao.py | 379 +++++++++++++----- 5 files changed, 339 insertions(+), 114 deletions(-) diff --git a/.github/workflows/nightly_tests.yml b/.github/workflows/nightly_tests.yml index cc0abac6e4..9375f760a1 100644 --- a/.github/workflows/nightly_tests.yml +++ b/.github/workflows/nightly_tests.yml @@ -359,6 +359,8 @@ jobs: test_location: "bnb" - backend: "gguf" test_location: "gguf" + - backend: "torchao" + test_location: "torchao" runs-on: group: aws-g6e-xlarge-plus container: diff --git a/docs/source/en/quantization/torchao.md b/docs/source/en/quantization/torchao.md index 1f9f99a79a..c056876c2f 100644 --- a/docs/source/en/quantization/torchao.md +++ b/docs/source/en/quantization/torchao.md @@ -25,6 +25,7 @@ Quantize a model by passing [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`] The example below only quantizes the weights to int8. ```python +import torch from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig model_id = "black-forest-labs/FLUX.1-dev" @@ -44,6 +45,10 @@ pipe = FluxPipeline.from_pretrained( ) pipe.to("cuda") +# Without quantization: ~31.447 GB +# With quantization: ~20.40 GB +print(f"Pipeline memory usage: {torch.cuda.max_memory_reserved() / 1024**3:.3f} GB") + prompt = "A cat holding a sign that says hello world" image = pipe( prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512 @@ -88,6 +93,63 @@ Some quantization methods are aliases (for example, `int8wo` is the commonly use Refer to the official torchao documentation for a better understanding of the available quantization methods and the exhaustive list of configuration options available. +## Serializing and Deserializing quantized models + +To serialize a quantized model in a given dtype, first load the model with the desired quantization dtype and then save it using the [`~ModelMixin.save_pretrained`] method. + +```python +import torch +from diffusers import FluxTransformer2DModel, TorchAoConfig + +quantization_config = TorchAoConfig("int8wo") +transformer = FluxTransformer2DModel.from_pretrained( + "black-forest-labs/Flux.1-Dev", + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, +) +transformer.save_pretrained("/path/to/flux_int8wo", safe_serialization=False) +``` + +To load a serialized quantized model, use the [`~ModelMixin.from_pretrained`] method. + +```python +import torch +from diffusers import FluxPipeline, FluxTransformer2DModel + +transformer = FluxTransformer2DModel.from_pretrained("/path/to/flux_int8wo", torch_dtype=torch.bfloat16, use_safetensors=False) +pipe = FluxPipeline.from_pretrained("black-forest-labs/Flux.1-Dev", transformer=transformer, torch_dtype=torch.bfloat16) +pipe.to("cuda") + +prompt = "A cat holding a sign that says hello world" +image = pipe(prompt, num_inference_steps=30, guidance_scale=7.0).images[0] +image.save("output.png") +``` + +Some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source. + +```python +import torch +from accelerate import init_empty_weights +from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig + +# Serialize the model +transformer = FluxTransformer2DModel.from_pretrained( + "black-forest-labs/Flux.1-Dev", + subfolder="transformer", + quantization_config=TorchAoConfig("uint4wo"), + torch_dtype=torch.bfloat16, +) +transformer.save_pretrained("/path/to/flux_uint4wo", safe_serialization=False, max_shard_size="50GB") +# ... + +# Load the model +state_dict = torch.load("/path/to/flux_uint4wo/diffusion_pytorch_model.bin", weights_only=False, map_location="cpu") +with init_empty_weights(): + transformer = FluxTransformer2DModel.from_config("/path/to/flux_uint4wo/config.json") +transformer.load_state_dict(state_dict, strict=True, assign=True) +``` + ## Resources - [TorchAO Quantization API](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index d236ebb839..d6efcc7364 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -718,10 +718,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): hf_quantizer = None if hf_quantizer is not None: - is_bnb_quantization_method = hf_quantizer.quantization_config.quant_method.value == "bitsandbytes" - if is_bnb_quantization_method and device_map is not None: + if device_map is not None: raise NotImplementedError( - "Currently, `device_map` is automatically inferred for quantized bitsandbytes models. Support for providing `device_map` as an input will be added in the future." + "Currently, providing `device_map` is not supported for quantized models. Providing `device_map` as an input will be added in the future." ) hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map) @@ -820,7 +819,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): revision=revision, subfolder=subfolder or "", ) - if hf_quantizer is not None and is_bnb_quantization_method: + # TODO: https://github.com/huggingface/diffusers/issues/10013 + if hf_quantizer is not None: model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata) logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.") is_sharded = False diff --git a/src/diffusers/quantizers/torchao/torchao_quantizer.py b/src/diffusers/quantizers/torchao/torchao_quantizer.py index 5770e32c90..a829234afd 100644 --- a/src/diffusers/quantizers/torchao/torchao_quantizer.py +++ b/src/diffusers/quantizers/torchao/torchao_quantizer.py @@ -132,7 +132,7 @@ class TorchAoHfQuantizer(DiffusersQuantizer): def update_torch_dtype(self, torch_dtype): quant_type = self.quantization_config.quant_type - if quant_type.startswith("int"): + if quant_type.startswith("int") or quant_type.startswith("uint"): if torch_dtype is not None and torch_dtype != torch.bfloat16: logger.warning( f"You are trying to set torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but " diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index 0fa9182a33..3c3f13db9b 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -131,8 +131,9 @@ class TorchAoTest(unittest.TestCase): gc.collect() torch.cuda.empty_cache() - def get_dummy_components(self, quantization_config: TorchAoConfig): - model_id = "hf-internal-testing/tiny-flux-pipe" + def get_dummy_components( + self, quantization_config: TorchAoConfig, model_id: str = "hf-internal-testing/tiny-flux-pipe" + ): transformer = FluxTransformer2DModel.from_pretrained( model_id, subfolder="transformer", @@ -211,8 +212,8 @@ class TorchAoTest(unittest.TestCase): "timestep": timestep, } - def _test_quant_type(self, quantization_config: TorchAoConfig, expected_slice: List[float]): - components = self.get_dummy_components(quantization_config) + def _test_quant_type(self, quantization_config: TorchAoConfig, expected_slice: List[float], model_id: str): + components = self.get_dummy_components(quantization_config, model_id) pipe = FluxPipeline(**components) pipe.to(device=torch_device) @@ -223,44 +224,45 @@ class TorchAoTest(unittest.TestCase): self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) def test_quantization(self): - # fmt: off - QUANTIZATION_TYPES_TO_TEST = [ - ("int4wo", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6445, 0.4336, 0.4531, 0.5625])), - ("int4dq", np.array([0.4688, 0.5195, 0.5547, 0.418, 0.4414, 0.6406, 0.4336, 0.4531, 0.5625])), - ("int8wo", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), - ("int8dq", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), - ("uint4wo", np.array([0.4609, 0.5234, 0.5508, 0.4199, 0.4336, 0.6406, 0.4316, 0.4531, 0.5625])), - ("uint7wo", np.array([0.4648, 0.5195, 0.5547, 0.4219, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), - ] + for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]: + # fmt: off + QUANTIZATION_TYPES_TO_TEST = [ + ("int4wo", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6445, 0.4336, 0.4531, 0.5625])), + ("int4dq", np.array([0.4688, 0.5195, 0.5547, 0.418, 0.4414, 0.6406, 0.4336, 0.4531, 0.5625])), + ("int8wo", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), + ("int8dq", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), + ("uint4wo", np.array([0.4609, 0.5234, 0.5508, 0.4199, 0.4336, 0.6406, 0.4316, 0.4531, 0.5625])), + ("uint7wo", np.array([0.4648, 0.5195, 0.5547, 0.4219, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), + ] - if TorchAoConfig._is_cuda_capability_atleast_8_9(): - QUANTIZATION_TYPES_TO_TEST.extend([ - ("float8wo_e5m2", np.array([0.4590, 0.5273, 0.5547, 0.4219, 0.4375, 0.6406, 0.4316, 0.4512, 0.5625])), - ("float8wo_e4m3", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6406, 0.4316, 0.4531, 0.5625])), - # ===== - # The following lead to an internal torch error: - # RuntimeError: mat2 shape (32x4 must be divisible by 16 - # Skip these for now; TODO(aryan): investigate later - # ("float8dq_e4m3", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), - # ("float8dq_e4m3_tensor", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), - # ===== - # Cutlass fails to initialize for below - # ("float8dq_e4m3_row", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), - # ===== - ("fp4", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])), - ("fp6", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])), - ]) - # fmt: on + if TorchAoConfig._is_cuda_capability_atleast_8_9(): + QUANTIZATION_TYPES_TO_TEST.extend([ + ("float8wo_e5m2", np.array([0.4590, 0.5273, 0.5547, 0.4219, 0.4375, 0.6406, 0.4316, 0.4512, 0.5625])), + ("float8wo_e4m3", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6406, 0.4316, 0.4531, 0.5625])), + # ===== + # The following lead to an internal torch error: + # RuntimeError: mat2 shape (32x4 must be divisible by 16 + # Skip these for now; TODO(aryan): investigate later + # ("float8dq_e4m3", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), + # ("float8dq_e4m3_tensor", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), + # ===== + # Cutlass fails to initialize for below + # ("float8dq_e4m3_row", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), + # ===== + ("fp4", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])), + ("fp6", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])), + ]) + # fmt: on - for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST: - quant_kwargs = {} - if quantization_name in ["uint4wo", "uint7wo"]: - # The dummy flux model that we use has smaller dimensions. This imposes some restrictions on group_size here - quant_kwargs.update({"group_size": 16}) - quantization_config = TorchAoConfig( - quant_type=quantization_name, modules_to_not_convert=["x_embedder"], **quant_kwargs - ) - self._test_quant_type(quantization_config, expected_slice) + for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST: + quant_kwargs = {} + if quantization_name in ["uint4wo", "uint7wo"]: + # The dummy flux model that we use has smaller dimensions. This imposes some restrictions on group_size here + quant_kwargs.update({"group_size": 16}) + quantization_config = TorchAoConfig( + quant_type=quantization_name, modules_to_not_convert=["x_embedder"], **quant_kwargs + ) + self._test_quant_type(quantization_config, expected_slice, model_id) def test_int4wo_quant_bfloat16_conversion(self): """ @@ -280,12 +282,14 @@ class TorchAoTest(unittest.TestCase): self.assertEqual(weight.quant_max, 15) def test_device_map(self): + # Note: We were not checking if the weight tensor's were AffineQuantizedTensor's before. If we did + # it would have errored out. Now, we do. So, device_map basically never worked with or without + # sharded checkpoints. This will need to be supported in the future (TODO(aryan)) """ Test if the quantized model int4 weight-only is working properly with "auto" and custom device maps. The custom device map performs cpu/disk offloading as well. Also verifies that the device map is correctly set (in the `hf_device_map` attribute of the model). """ - custom_device_map_dict = { "time_text_embed": torch_device, "context_embedder": torch_device, @@ -297,48 +301,54 @@ class TorchAoTest(unittest.TestCase): } device_maps = ["auto", custom_device_map_dict] - inputs = self.get_dummy_tensor_inputs(torch_device) - expected_slice = np.array([0.3457, -0.0366, 0.0105, -0.2275, -0.4941, 0.4395, -0.166, -0.6641, 0.4375]) + # inputs = self.get_dummy_tensor_inputs(torch_device) + # expected_slice = np.array([0.3457, -0.0366, 0.0105, -0.2275, -0.4941, 0.4395, -0.166, -0.6641, 0.4375]) for device_map in device_maps: - device_map_to_compare = {"": 0} if device_map == "auto" else device_map + # device_map_to_compare = {"": 0} if device_map == "auto" else device_map - # Test non-sharded model - with tempfile.TemporaryDirectory() as offload_folder: - quantization_config = TorchAoConfig("int4_weight_only", group_size=64) - quantized_model = FluxTransformer2DModel.from_pretrained( - "hf-internal-testing/tiny-flux-pipe", - subfolder="transformer", - quantization_config=quantization_config, - device_map=device_map, - torch_dtype=torch.bfloat16, - offload_folder=offload_folder, - ) + # Test non-sharded model - should work + with self.assertRaises(NotImplementedError): + with tempfile.TemporaryDirectory() as offload_folder: + quantization_config = TorchAoConfig("int4_weight_only", group_size=64) + _ = FluxTransformer2DModel.from_pretrained( + "hf-internal-testing/tiny-flux-pipe", + subfolder="transformer", + quantization_config=quantization_config, + device_map=device_map, + torch_dtype=torch.bfloat16, + offload_folder=offload_folder, + ) - self.assertTrue(quantized_model.hf_device_map == device_map_to_compare) + # weight = quantized_model.transformer_blocks[0].ff.net[2].weight + # self.assertTrue(quantized_model.hf_device_map == device_map_to_compare) + # self.assertTrue(isinstance(weight, AffineQuantizedTensor)) - output = quantized_model(**inputs)[0] - output_slice = output.flatten()[-9:].detach().float().cpu().numpy() - self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) + # output = quantized_model(**inputs)[0] + # output_slice = output.flatten()[-9:].detach().float().cpu().numpy() + # self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) - # Test sharded model - with tempfile.TemporaryDirectory() as offload_folder: - quantization_config = TorchAoConfig("int4_weight_only", group_size=64) - quantized_model = FluxTransformer2DModel.from_pretrained( - "hf-internal-testing/tiny-flux-sharded", - subfolder="transformer", - quantization_config=quantization_config, - device_map=device_map, - torch_dtype=torch.bfloat16, - offload_folder=offload_folder, - ) + # Test sharded model - should not work + with self.assertRaises(NotImplementedError): + with tempfile.TemporaryDirectory() as offload_folder: + quantization_config = TorchAoConfig("int4_weight_only", group_size=64) + _ = FluxTransformer2DModel.from_pretrained( + "hf-internal-testing/tiny-flux-sharded", + subfolder="transformer", + quantization_config=quantization_config, + device_map=device_map, + torch_dtype=torch.bfloat16, + offload_folder=offload_folder, + ) - self.assertTrue(quantized_model.hf_device_map == device_map_to_compare) + # weight = quantized_model.transformer_blocks[0].ff.net[2].weight + # self.assertTrue(quantized_model.hf_device_map == device_map_to_compare) + # self.assertTrue(isinstance(weight, AffineQuantizedTensor)) - output = quantized_model(**inputs)[0] - output_slice = output.flatten()[-9:].detach().float().cpu().numpy() + # output = quantized_model(**inputs)[0] + # output_slice = output.flatten()[-9:].detach().float().cpu().numpy() - self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) + # self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) def test_modules_to_not_convert(self): quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"]) @@ -404,43 +414,63 @@ class TorchAoTest(unittest.TestCase): @nightly def test_torch_compile(self): r"""Test that verifies if torch.compile works with torchao quantization.""" - quantization_config = TorchAoConfig("int8_weight_only") - components = self.get_dummy_components(quantization_config) - pipe = FluxPipeline(**components) - pipe.to(device=torch_device, dtype=torch.bfloat16) + for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]: + quantization_config = TorchAoConfig("int8_weight_only") + components = self.get_dummy_components(quantization_config, model_id=model_id) + pipe = FluxPipeline(**components) + pipe.to(device=torch_device) - inputs = self.get_dummy_inputs(torch_device) - normal_output = pipe(**inputs)[0].flatten()[-32:] + inputs = self.get_dummy_inputs(torch_device) + normal_output = pipe(**inputs)[0].flatten()[-32:] - pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True, dynamic=False) - inputs = self.get_dummy_inputs(torch_device) - compile_output = pipe(**inputs)[0].flatten()[-32:] + pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True, dynamic=False) + inputs = self.get_dummy_inputs(torch_device) + compile_output = pipe(**inputs)[0].flatten()[-32:] - # Note: Seems to require higher tolerance - self.assertTrue(np.allclose(normal_output, compile_output, atol=1e-2, rtol=1e-3)) + # Note: Seems to require higher tolerance + self.assertTrue(np.allclose(normal_output, compile_output, atol=1e-2, rtol=1e-3)) def test_memory_footprint(self): r""" A simple test to check if the model conversion has been done correctly by checking on the memory footprint of the converted model and the class type of the linear layers of the converted models """ - transformer_int4wo = self.get_dummy_components(TorchAoConfig("int4wo"))["transformer"] - transformer_int4wo_gs32 = self.get_dummy_components(TorchAoConfig("int4wo", group_size=32))["transformer"] - transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"))["transformer"] - transformer_bf16 = self.get_dummy_components(None)["transformer"] + for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]: + transformer_int4wo = self.get_dummy_components(TorchAoConfig("int4wo"), model_id=model_id)["transformer"] + transformer_int4wo_gs32 = self.get_dummy_components( + TorchAoConfig("int4wo", group_size=32), model_id=model_id + )["transformer"] + transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"), model_id=model_id)["transformer"] + transformer_bf16 = self.get_dummy_components(None, model_id=model_id)["transformer"] - total_int4wo = get_model_size_in_bytes(transformer_int4wo) - total_int4wo_gs32 = get_model_size_in_bytes(transformer_int4wo_gs32) - total_int8wo = get_model_size_in_bytes(transformer_int8wo) - total_bf16 = get_model_size_in_bytes(transformer_bf16) + # Will not quantized all the layers by default due to the model weights shapes not being divisible by group_size=64 + for block in transformer_int4wo.transformer_blocks: + self.assertTrue(isinstance(block.ff.net[2].weight, AffineQuantizedTensor)) + self.assertTrue(isinstance(block.ff_context.net[2].weight, AffineQuantizedTensor)) - # Latter has smaller group size, so more groups -> more scales and zero points - self.assertTrue(total_int4wo < total_int4wo_gs32) - # int8 quantizes more layers compare to int4 with default group size - self.assertTrue(total_int8wo < total_int4wo) - # int4wo does not quantize too many layers because of default group size, but for the layers it does - # there is additional overhead of scales and zero points - self.assertTrue(total_bf16 < total_int4wo) + # Will quantize all the linear layers except x_embedder + for name, module in transformer_int4wo_gs32.named_modules(): + if isinstance(module, nn.Linear) and name not in ["x_embedder"]: + self.assertTrue(isinstance(module.weight, AffineQuantizedTensor)) + + # Will quantize all the linear layers + for module in transformer_int8wo.modules(): + if isinstance(module, nn.Linear): + self.assertTrue(isinstance(module.weight, AffineQuantizedTensor)) + + total_int4wo = get_model_size_in_bytes(transformer_int4wo) + total_int4wo_gs32 = get_model_size_in_bytes(transformer_int4wo_gs32) + total_int8wo = get_model_size_in_bytes(transformer_int8wo) + total_bf16 = get_model_size_in_bytes(transformer_bf16) + + # TODO: refactor to align with other quantization tests + # Latter has smaller group size, so more groups -> more scales and zero points + self.assertTrue(total_int4wo < total_int4wo_gs32) + # int8 quantizes more layers compare to int4 with default group size + self.assertTrue(total_int8wo < total_int4wo) + # int4wo does not quantize too many layers because of default group size, but for the layers it does + # there is additional overhead of scales and zero points + self.assertTrue(total_bf16 < total_int4wo) def test_wrong_config(self): with self.assertRaises(ValueError): @@ -500,6 +530,8 @@ class TorchAoSerializationTest(unittest.TestCase): inputs = self.get_dummy_tensor_inputs(torch_device) output = quantized_model(**inputs)[0] output_slice = output.flatten()[-9:].detach().float().cpu().numpy() + weight = quantized_model.transformer_blocks[0].ff.net[2].weight + self.assertTrue(isinstance(weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor))) self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) def _check_serialization_expected_slice(self, quant_method, quant_method_kwargs, expected_slice, device): @@ -508,8 +540,8 @@ class TorchAoSerializationTest(unittest.TestCase): with tempfile.TemporaryDirectory() as tmp_dir: quantized_model.save_pretrained(tmp_dir, safe_serialization=False) loaded_quantized_model = FluxTransformer2DModel.from_pretrained( - tmp_dir, torch_dtype=torch.bfloat16, device_map=torch_device, use_safetensors=False - ) + tmp_dir, torch_dtype=torch.bfloat16, use_safetensors=False + ).to(device=torch_device) inputs = self.get_dummy_tensor_inputs(torch_device) output = loaded_quantized_model(**inputs)[0] @@ -563,20 +595,25 @@ class SlowTorchAoTests(unittest.TestCase): torch.cuda.empty_cache() def get_dummy_components(self, quantization_config: TorchAoConfig): + # This is just for convenience, so that we can modify it at one place for custom environments and locally testing + cache_dir = None model_id = "black-forest-labs/FLUX.1-dev" transformer = FluxTransformer2DModel.from_pretrained( model_id, subfolder="transformer", quantization_config=quantization_config, torch_dtype=torch.bfloat16, + cache_dir=cache_dir, + ) + text_encoder = CLIPTextModel.from_pretrained( + model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16, cache_dir=cache_dir ) - text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16) text_encoder_2 = T5EncoderModel.from_pretrained( - model_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16 + model_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16, cache_dir=cache_dir ) - tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer") - tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2") - vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.bfloat16) + tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer", cache_dir=cache_dir) + tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2", cache_dir=cache_dir) + vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.bfloat16, cache_dir=cache_dir) scheduler = FlowMatchEulerDiscreteScheduler() return { @@ -611,10 +648,12 @@ class SlowTorchAoTests(unittest.TestCase): pipe = FluxPipeline(**components) pipe.enable_model_cpu_offload() + weight = pipe.transformer.transformer_blocks[0].ff.net[2].weight + self.assertTrue(isinstance(weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor))) + inputs = self.get_dummy_inputs(torch_device) output = pipe(**inputs)[0].flatten() output_slice = np.concatenate((output[:16], output[-16:])) - self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) def test_quantization(self): @@ -627,7 +666,7 @@ class SlowTorchAoTests(unittest.TestCase): if TorchAoConfig._is_cuda_capability_atleast_8_9(): QUANTIZATION_TYPES_TO_TEST.extend([ ("float8wo_e4m3", np.array([0.0546, 0.0722, 0.1328, 0.0468, 0.0585, 0.1367, 0.0605, 0.0703, 0.1328, 0.0625, 0.0703, 0.1445, 0.0585, 0.0703, 0.1406, 0.0605, 0.3496, 0.7109, 0.4843, 0.4042, 0.7226, 0.5000, 0.4160, 0.7031, 0.4824, 0.3886, 0.6757, 0.4667, 0.3710, 0.6679, 0.4902, 0.4238])), - ("fp5_e3m1", np.array([0.0527, 0.0742, 0.1289, 0.0449, 0.0625, 0.1308, 0.0585, 0.0742, 0.1269, 0.0585, 0.0722, 0.1328, 0.0566, 0.0742, 0.1347, 0.0585, 0.3691, 0.7578, 0.5429, 0.4355, 0.7695, 0.5546, 0.4414, 0.7578, 0.5468, 0.4179, 0.7265, 0.5273, 0.3945, 0.6992, 0.5234, 0.4316])), + ("fp5_e3m1", np.array([0.0527, 0.0762, 0.1309, 0.0449, 0.0645, 0.1328, 0.0566, 0.0723, 0.125, 0.0566, 0.0703, 0.1328, 0.0566, 0.0742, 0.1348, 0.0566, 0.3633, 0.7617, 0.5273, 0.4277, 0.7891, 0.5469, 0.4375, 0.8008, 0.5586, 0.4336, 0.7383, 0.5156, 0.3906, 0.6992, 0.5156, 0.4375])), ]) # fmt: on @@ -637,3 +676,125 @@ class SlowTorchAoTests(unittest.TestCase): gc.collect() torch.cuda.empty_cache() torch.cuda.synchronize() + + def test_serialization_int8wo(self): + quantization_config = TorchAoConfig("int8wo") + components = self.get_dummy_components(quantization_config) + pipe = FluxPipeline(**components) + pipe.enable_model_cpu_offload() + + weight = pipe.transformer.x_embedder.weight + self.assertTrue(isinstance(weight, AffineQuantizedTensor)) + + inputs = self.get_dummy_inputs(torch_device) + output = pipe(**inputs)[0].flatten()[:128] + + with tempfile.TemporaryDirectory() as tmp_dir: + pipe.transformer.save_pretrained(tmp_dir, safe_serialization=False) + pipe.remove_all_hooks() + del pipe.transformer + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + transformer = FluxTransformer2DModel.from_pretrained( + tmp_dir, torch_dtype=torch.bfloat16, use_safetensors=False + ) + pipe.transformer = transformer + pipe.enable_model_cpu_offload() + + weight = transformer.x_embedder.weight + self.assertTrue(isinstance(weight, AffineQuantizedTensor)) + + loaded_output = pipe(**inputs)[0].flatten()[:128] + # Seems to require higher tolerance depending on which machine it is being run. + # A difference of 0.06 in normalized pixel space (-1 to 1), corresponds to a difference of + # 0.06 / 2 * 255 = 7.65 in pixel space (0 to 255). On our CI runners, the difference is about 0.04, + # on DGX it is 0.06, and on audace it is 0.037. So, we are using a tolerance of 0.06 here. + self.assertTrue(np.allclose(output, loaded_output, atol=0.06)) + + def test_memory_footprint_int4wo(self): + # The original checkpoints are in bf16 and about 24 GB + expected_memory_in_gb = 6.0 + quantization_config = TorchAoConfig("int4wo") + cache_dir = None + transformer = FluxTransformer2DModel.from_pretrained( + "black-forest-labs/FLUX.1-dev", + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, + cache_dir=cache_dir, + ) + int4wo_memory_in_gb = get_model_size_in_bytes(transformer) / 1024**3 + self.assertTrue(int4wo_memory_in_gb < expected_memory_in_gb) + + def test_memory_footprint_int8wo(self): + # The original checkpoints are in bf16 and about 24 GB + expected_memory_in_gb = 12.0 + quantization_config = TorchAoConfig("int8wo") + cache_dir = None + transformer = FluxTransformer2DModel.from_pretrained( + "black-forest-labs/FLUX.1-dev", + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, + cache_dir=cache_dir, + ) + int8wo_memory_in_gb = get_model_size_in_bytes(transformer) / 1024**3 + self.assertTrue(int8wo_memory_in_gb < expected_memory_in_gb) + + +@require_torch +@require_torch_gpu +@require_torchao_version_greater_or_equal("0.7.0") +@slow +@nightly +class SlowTorchAoPreserializedModelTests(unittest.TestCase): + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + + def get_dummy_inputs(self, device: torch.device, seed: int = 0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator().manual_seed(seed) + + inputs = { + "prompt": "an astronaut riding a horse in space", + "height": 512, + "width": 512, + "num_inference_steps": 20, + "output_type": "np", + "generator": generator, + } + + return inputs + + def test_transformer_int8wo(self): + # fmt: off + expected_slice = np.array([0.0566, 0.0781, 0.1426, 0.0488, 0.0684, 0.1504, 0.0625, 0.0781, 0.1445, 0.0625, 0.0781, 0.1562, 0.0547, 0.0723, 0.1484, 0.0566, 0.5703, 0.8867, 0.7266, 0.5742, 0.875, 0.7148, 0.5586, 0.875, 0.7148, 0.5547, 0.8633, 0.7109, 0.5469, 0.8398, 0.6992, 0.5703]) + # fmt: on + + # This is just for convenience, so that we can modify it at one place for custom environments and locally testing + cache_dir = None + transformer = FluxTransformer2DModel.from_pretrained( + "hf-internal-testing/FLUX.1-Dev-TorchAO-int8wo-transformer", + torch_dtype=torch.bfloat16, + use_safetensors=False, + cache_dir=cache_dir, + ) + pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16, cache_dir=cache_dir + ) + pipe.enable_model_cpu_offload() + + # Verify that all linear layer weights are quantized + for name, module in pipe.transformer.named_modules(): + if isinstance(module, nn.Linear): + self.assertTrue(isinstance(module.weight, AffineQuantizedTensor)) + + # Verify outputs match expected slice + inputs = self.get_dummy_inputs(torch_device) + output = pipe(**inputs)[0].flatten() + output_slice = np.concatenate((output[:16], output[-16:])) + self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))