1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Flux quantized with lora (#10990)

* Flux quantized with lora

* fix

* changes

* Apply suggestions from code review

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Apply style fixes

* enable model cpu offload()

* Update src/diffusers/loaders/lora_pipeline.py

Co-authored-by: hlky <hlky@hlky.ac>

* update

* Apply suggestions from code review

* update

* add peft as an additional dependency for gguf

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
hlky
2025-04-08 16:47:03 +01:00
committed by GitHub
parent 71f34fc5a4
commit 5d49b3e83b
5 changed files with 151 additions and 7 deletions

View File

@@ -417,7 +417,7 @@ jobs:
additional_deps: ["peft"]
- backend: "gguf"
test_location: "gguf"
additional_deps: []
additional_deps: ["peft"]
- backend: "torchao"
test_location: "torchao"
additional_deps: []

View File

@@ -22,6 +22,8 @@ from ..utils import (
USE_PEFT_BACKEND,
deprecate,
get_submodule_by_name,
is_bitsandbytes_available,
is_gguf_available,
is_peft_available,
is_peft_version,
is_torch_version,
@@ -68,6 +70,49 @@ TRANSFORMER_NAME = "transformer"
_MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX = {"x_embedder": "in_channels"}
def _maybe_dequantize_weight_for_expanded_lora(model, module):
if is_bitsandbytes_available():
from ..quantizers.bitsandbytes import dequantize_bnb_weight
if is_gguf_available():
from ..quantizers.gguf.utils import dequantize_gguf_tensor
is_bnb_4bit_quantized = module.weight.__class__.__name__ == "Params4bit"
is_gguf_quantized = module.weight.__class__.__name__ == "GGUFParameter"
if is_bnb_4bit_quantized and not is_bitsandbytes_available():
raise ValueError(
"The checkpoint seems to have been quantized with `bitsandbytes` (4bits). Install `bitsandbytes` to load quantized checkpoints."
)
if is_gguf_quantized and not is_gguf_available():
raise ValueError(
"The checkpoint seems to have been quantized with `gguf`. Install `gguf` to load quantized checkpoints."
)
weight_on_cpu = False
if not module.weight.is_cuda:
weight_on_cpu = True
if is_bnb_4bit_quantized:
module_weight = dequantize_bnb_weight(
module.weight.cuda() if weight_on_cpu else module.weight,
state=module.weight.quant_state,
dtype=model.dtype,
).data
elif is_gguf_quantized:
module_weight = dequantize_gguf_tensor(
module.weight.cuda() if weight_on_cpu else module.weight,
)
module_weight = module_weight.to(model.dtype)
else:
module_weight = module.weight.data
if weight_on_cpu:
module_weight = module_weight.cpu()
return module_weight
class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into Stable Diffusion [`UNet2DConditionModel`] and
@@ -2267,6 +2312,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
overwritten_params = {}
is_peft_loaded = getattr(transformer, "peft_config", None) is not None
is_quantized = hasattr(transformer, "hf_quantizer")
for name, module in transformer.named_modules():
if isinstance(module, torch.nn.Linear):
module_weight = module.weight.data
@@ -2291,9 +2337,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
if tuple(module_weight_shape) == (out_features, in_features):
continue
# TODO (sayakpaul): We still need to consider if the module we're expanding is
# quantized and handle it accordingly if that is the case.
module_out_features, module_in_features = module_weight.shape
module_out_features, module_in_features = module_weight_shape
debug_message = ""
if in_features > module_in_features:
debug_message += (
@@ -2316,6 +2360,10 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
parent_module_name, _, current_module_name = name.rpartition(".")
parent_module = transformer.get_submodule(parent_module_name)
if is_quantized:
module_weight = _maybe_dequantize_weight_for_expanded_lora(transformer, module)
# TODO: consider if this layer needs to be a quantized layer as well if `is_quantized` is True.
with torch.device("meta"):
expanded_module = torch.nn.Linear(
in_features, out_features, bias=bias, dtype=module_weight.dtype
@@ -2327,7 +2375,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
new_weight = torch.zeros_like(
expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype
)
slices = tuple(slice(0, dim) for dim in module_weight.shape)
slices = tuple(slice(0, dim) for dim in module_weight_shape)
new_weight[slices] = module_weight
tmp_state_dict = {"weight": new_weight}
if module_bias is not None:
@@ -2416,7 +2464,12 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
base_weight_param_name: str = None,
) -> "torch.Size":
def _get_weight_shape(weight: torch.Tensor):
return weight.quant_state.shape if weight.__class__.__name__ == "Params4bit" else weight.shape
if weight.__class__.__name__ == "Params4bit":
return weight.quant_state.shape
elif weight.__class__.__name__ == "GGUFParameter":
return weight.quant_shape
else:
return weight.shape
if base_module is not None:
return _get_weight_shape(base_module.weight)

View File

@@ -400,6 +400,8 @@ class GGUFParameter(torch.nn.Parameter):
data = data if data is not None else torch.empty(0)
self = torch.Tensor._make_subclass(cls, data, requires_grad)
self.quant_type = quant_type
block_size, type_size = GGML_QUANT_SIZES[quant_type]
self.quant_shape = _quant_shape_from_byte_shape(self.shape, type_size, block_size)
return self

View File

@@ -21,8 +21,15 @@ import numpy as np
import pytest
import safetensors.torch
from huggingface_hub import hf_hub_download
from PIL import Image
from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel
from diffusers import (
BitsAndBytesConfig,
DiffusionPipeline,
FluxControlPipeline,
FluxTransformer2DModel,
SD3Transformer2DModel,
)
from diffusers.utils import is_accelerate_version, logging
from diffusers.utils.testing_utils import (
CaptureLogger,
@@ -696,6 +703,42 @@ class SlowBnb4BitFluxTests(Base4bitTests):
self.assertTrue(max_diff < 1e-3)
@require_transformers_version_greater("4.44.0")
@require_peft_backend
class SlowBnb4BitFluxControlWithLoraTests(Base4bitTests):
def setUp(self) -> None:
gc.collect()
torch.cuda.empty_cache()
self.pipeline_4bit = FluxControlPipeline.from_pretrained("eramth/flux-4bit", torch_dtype=torch.float16)
self.pipeline_4bit.enable_model_cpu_offload()
def tearDown(self):
del self.pipeline_4bit
gc.collect()
torch.cuda.empty_cache()
def test_lora_loading(self):
self.pipeline_4bit.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora")
output = self.pipeline_4bit(
prompt=self.prompt,
control_image=Image.new(mode="RGB", size=(256, 256)),
height=256,
width=256,
max_sequence_length=64,
output_type="np",
num_inference_steps=8,
generator=torch.Generator().manual_seed(42),
).images
out_slice = output[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.1636, 0.1675, 0.1982, 0.1743, 0.1809, 0.1936, 0.1743, 0.2095, 0.2139])
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
self.assertTrue(max_diff < 1e-3, msg=f"{out_slice=} != {expected_slice=}")
@slow
class BaseBnb4BitSerializationTests(Base4bitTests):
def tearDown(self):

View File

@@ -8,12 +8,14 @@ import torch.nn as nn
from diffusers import (
AuraFlowPipeline,
AuraFlowTransformer2DModel,
FluxControlPipeline,
FluxPipeline,
FluxTransformer2DModel,
GGUFQuantizationConfig,
SD3Transformer2DModel,
StableDiffusion3Pipeline,
)
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
is_gguf_available,
nightly,
@@ -21,6 +23,7 @@ from diffusers.utils.testing_utils import (
require_accelerate,
require_big_gpu_with_torch_cuda,
require_gguf_version_greater_or_equal,
require_peft_backend,
torch_device,
)
@@ -456,3 +459,46 @@ class AuraFlowGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
)
max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice)
assert max_diff < 1e-4
@require_peft_backend
@nightly
@require_big_gpu_with_torch_cuda
@require_accelerate
@require_gguf_version_greater_or_equal("0.10.0")
class FluxControlLoRAGGUFTests(unittest.TestCase):
def test_lora_loading(self):
ckpt_path = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
transformer = FluxTransformer2DModel.from_single_file(
ckpt_path,
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
torch_dtype=torch.bfloat16,
)
pipe = FluxControlPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
transformer=transformer,
torch_dtype=torch.bfloat16,
).to("cuda")
pipe.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora")
prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
control_image = load_image(
"https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/control_image_robot_canny.png"
)
output = pipe(
prompt=prompt,
control_image=control_image,
height=256,
width=256,
num_inference_steps=10,
guidance_scale=30.0,
output_type="np",
generator=torch.manual_seed(0),
).images
out_slice = output[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.8047, 0.8359, 0.8711, 0.6875, 0.7070, 0.7383, 0.5469, 0.5820, 0.6641])
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
self.assertTrue(max_diff < 1e-3)