From 5d49b3e83b97b45d8745ed6fc9f06c32d5ef9286 Mon Sep 17 00:00:00 2001 From: hlky Date: Tue, 8 Apr 2025 16:47:03 +0100 Subject: [PATCH] Flux quantized with lora (#10990) * Flux quantized with lora * fix * changes * Apply suggestions from code review Co-authored-by: Sayak Paul * Apply style fixes * enable model cpu offload() * Update src/diffusers/loaders/lora_pipeline.py Co-authored-by: hlky * update * Apply suggestions from code review * update * add peft as an additional dependency for gguf --------- Co-authored-by: Sayak Paul Co-authored-by: github-actions[bot] Co-authored-by: Dhruv Nair --- .github/workflows/nightly_tests.yml | 2 +- src/diffusers/loaders/lora_pipeline.py | 63 ++++++++++++++++++++++++-- src/diffusers/quantizers/gguf/utils.py | 2 + tests/quantization/bnb/test_4bit.py | 45 +++++++++++++++++- tests/quantization/gguf/test_gguf.py | 46 +++++++++++++++++++ 5 files changed, 151 insertions(+), 7 deletions(-) diff --git a/.github/workflows/nightly_tests.yml b/.github/workflows/nightly_tests.yml index 2b39eea2fe..88343a128b 100644 --- a/.github/workflows/nightly_tests.yml +++ b/.github/workflows/nightly_tests.yml @@ -417,7 +417,7 @@ jobs: additional_deps: ["peft"] - backend: "gguf" test_location: "gguf" - additional_deps: [] + additional_deps: ["peft"] - backend: "torchao" test_location: "torchao" additional_deps: [] diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 513c15ac5d..a29b77acce 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -22,6 +22,8 @@ from ..utils import ( USE_PEFT_BACKEND, deprecate, get_submodule_by_name, + is_bitsandbytes_available, + is_gguf_available, is_peft_available, is_peft_version, is_torch_version, @@ -68,6 +70,49 @@ TRANSFORMER_NAME = "transformer" _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX = {"x_embedder": "in_channels"} +def _maybe_dequantize_weight_for_expanded_lora(model, module): + if is_bitsandbytes_available(): + from ..quantizers.bitsandbytes import dequantize_bnb_weight + + if is_gguf_available(): + from ..quantizers.gguf.utils import dequantize_gguf_tensor + + is_bnb_4bit_quantized = module.weight.__class__.__name__ == "Params4bit" + is_gguf_quantized = module.weight.__class__.__name__ == "GGUFParameter" + + if is_bnb_4bit_quantized and not is_bitsandbytes_available(): + raise ValueError( + "The checkpoint seems to have been quantized with `bitsandbytes` (4bits). Install `bitsandbytes` to load quantized checkpoints." + ) + if is_gguf_quantized and not is_gguf_available(): + raise ValueError( + "The checkpoint seems to have been quantized with `gguf`. Install `gguf` to load quantized checkpoints." + ) + + weight_on_cpu = False + if not module.weight.is_cuda: + weight_on_cpu = True + + if is_bnb_4bit_quantized: + module_weight = dequantize_bnb_weight( + module.weight.cuda() if weight_on_cpu else module.weight, + state=module.weight.quant_state, + dtype=model.dtype, + ).data + elif is_gguf_quantized: + module_weight = dequantize_gguf_tensor( + module.weight.cuda() if weight_on_cpu else module.weight, + ) + module_weight = module_weight.to(model.dtype) + else: + module_weight = module.weight.data + + if weight_on_cpu: + module_weight = module_weight.cpu() + + return module_weight + + class StableDiffusionLoraLoaderMixin(LoraBaseMixin): r""" Load LoRA layers into Stable Diffusion [`UNet2DConditionModel`] and @@ -2267,6 +2312,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin): overwritten_params = {} is_peft_loaded = getattr(transformer, "peft_config", None) is not None + is_quantized = hasattr(transformer, "hf_quantizer") for name, module in transformer.named_modules(): if isinstance(module, torch.nn.Linear): module_weight = module.weight.data @@ -2291,9 +2337,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin): if tuple(module_weight_shape) == (out_features, in_features): continue - # TODO (sayakpaul): We still need to consider if the module we're expanding is - # quantized and handle it accordingly if that is the case. - module_out_features, module_in_features = module_weight.shape + module_out_features, module_in_features = module_weight_shape debug_message = "" if in_features > module_in_features: debug_message += ( @@ -2316,6 +2360,10 @@ class FluxLoraLoaderMixin(LoraBaseMixin): parent_module_name, _, current_module_name = name.rpartition(".") parent_module = transformer.get_submodule(parent_module_name) + if is_quantized: + module_weight = _maybe_dequantize_weight_for_expanded_lora(transformer, module) + + # TODO: consider if this layer needs to be a quantized layer as well if `is_quantized` is True. with torch.device("meta"): expanded_module = torch.nn.Linear( in_features, out_features, bias=bias, dtype=module_weight.dtype @@ -2327,7 +2375,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin): new_weight = torch.zeros_like( expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype ) - slices = tuple(slice(0, dim) for dim in module_weight.shape) + slices = tuple(slice(0, dim) for dim in module_weight_shape) new_weight[slices] = module_weight tmp_state_dict = {"weight": new_weight} if module_bias is not None: @@ -2416,7 +2464,12 @@ class FluxLoraLoaderMixin(LoraBaseMixin): base_weight_param_name: str = None, ) -> "torch.Size": def _get_weight_shape(weight: torch.Tensor): - return weight.quant_state.shape if weight.__class__.__name__ == "Params4bit" else weight.shape + if weight.__class__.__name__ == "Params4bit": + return weight.quant_state.shape + elif weight.__class__.__name__ == "GGUFParameter": + return weight.quant_shape + else: + return weight.shape if base_module is not None: return _get_weight_shape(base_module.weight) diff --git a/src/diffusers/quantizers/gguf/utils.py b/src/diffusers/quantizers/gguf/utils.py index effc39d8fe..de82dcab07 100644 --- a/src/diffusers/quantizers/gguf/utils.py +++ b/src/diffusers/quantizers/gguf/utils.py @@ -400,6 +400,8 @@ class GGUFParameter(torch.nn.Parameter): data = data if data is not None else torch.empty(0) self = torch.Tensor._make_subclass(cls, data, requires_grad) self.quant_type = quant_type + block_size, type_size = GGML_QUANT_SIZES[quant_type] + self.quant_shape = _quant_shape_from_byte_shape(self.shape, type_size, block_size) return self diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index 29a3e212c4..fdcc5314d2 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -21,8 +21,15 @@ import numpy as np import pytest import safetensors.torch from huggingface_hub import hf_hub_download +from PIL import Image -from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel +from diffusers import ( + BitsAndBytesConfig, + DiffusionPipeline, + FluxControlPipeline, + FluxTransformer2DModel, + SD3Transformer2DModel, +) from diffusers.utils import is_accelerate_version, logging from diffusers.utils.testing_utils import ( CaptureLogger, @@ -696,6 +703,42 @@ class SlowBnb4BitFluxTests(Base4bitTests): self.assertTrue(max_diff < 1e-3) +@require_transformers_version_greater("4.44.0") +@require_peft_backend +class SlowBnb4BitFluxControlWithLoraTests(Base4bitTests): + def setUp(self) -> None: + gc.collect() + torch.cuda.empty_cache() + + self.pipeline_4bit = FluxControlPipeline.from_pretrained("eramth/flux-4bit", torch_dtype=torch.float16) + self.pipeline_4bit.enable_model_cpu_offload() + + def tearDown(self): + del self.pipeline_4bit + + gc.collect() + torch.cuda.empty_cache() + + def test_lora_loading(self): + self.pipeline_4bit.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora") + + output = self.pipeline_4bit( + prompt=self.prompt, + control_image=Image.new(mode="RGB", size=(256, 256)), + height=256, + width=256, + max_sequence_length=64, + output_type="np", + num_inference_steps=8, + generator=torch.Generator().manual_seed(42), + ).images + out_slice = output[0, -3:, -3:, -1].flatten() + expected_slice = np.array([0.1636, 0.1675, 0.1982, 0.1743, 0.1809, 0.1936, 0.1743, 0.2095, 0.2139]) + + max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice) + self.assertTrue(max_diff < 1e-3, msg=f"{out_slice=} != {expected_slice=}") + + @slow class BaseBnb4BitSerializationTests(Base4bitTests): def tearDown(self): diff --git a/tests/quantization/gguf/test_gguf.py b/tests/quantization/gguf/test_gguf.py index 5e3875c7c9..e4cf1dfee1 100644 --- a/tests/quantization/gguf/test_gguf.py +++ b/tests/quantization/gguf/test_gguf.py @@ -8,12 +8,14 @@ import torch.nn as nn from diffusers import ( AuraFlowPipeline, AuraFlowTransformer2DModel, + FluxControlPipeline, FluxPipeline, FluxTransformer2DModel, GGUFQuantizationConfig, SD3Transformer2DModel, StableDiffusion3Pipeline, ) +from diffusers.utils import load_image from diffusers.utils.testing_utils import ( is_gguf_available, nightly, @@ -21,6 +23,7 @@ from diffusers.utils.testing_utils import ( require_accelerate, require_big_gpu_with_torch_cuda, require_gguf_version_greater_or_equal, + require_peft_backend, torch_device, ) @@ -456,3 +459,46 @@ class AuraFlowGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase): ) max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice) assert max_diff < 1e-4 + + +@require_peft_backend +@nightly +@require_big_gpu_with_torch_cuda +@require_accelerate +@require_gguf_version_greater_or_equal("0.10.0") +class FluxControlLoRAGGUFTests(unittest.TestCase): + def test_lora_loading(self): + ckpt_path = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf" + transformer = FluxTransformer2DModel.from_single_file( + ckpt_path, + quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16), + torch_dtype=torch.bfloat16, + ) + pipe = FluxControlPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + transformer=transformer, + torch_dtype=torch.bfloat16, + ).to("cuda") + pipe.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora") + + prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts." + control_image = load_image( + "https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/control_image_robot_canny.png" + ) + + output = pipe( + prompt=prompt, + control_image=control_image, + height=256, + width=256, + num_inference_steps=10, + guidance_scale=30.0, + output_type="np", + generator=torch.manual_seed(0), + ).images + + out_slice = output[0, -3:, -3:, -1].flatten() + expected_slice = np.array([0.8047, 0.8359, 0.8711, 0.6875, 0.7070, 0.7383, 0.5469, 0.5820, 0.6641]) + + max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice) + self.assertTrue(max_diff < 1e-3)