1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[LoRA] feat: support loading loras into 4bit quantized Flux models. (#10578)

* feat: support loading loras into 4bit quantized models.

* updates

* update

* remove weight check.
This commit is contained in:
Sayak Paul
2025-01-15 12:40:40 +05:30
committed by DN6
parent a663a67ea2
commit 263b973466
4 changed files with 71 additions and 4 deletions

View File

@@ -21,6 +21,7 @@ from huggingface_hub.utils import validate_hf_hub_args
from ..utils import (
USE_PEFT_BACKEND,
deprecate,
get_submodule_by_name,
is_peft_available,
is_peft_version,
is_torch_version,
@@ -1981,10 +1982,17 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
in_features = state_dict[lora_A_weight_name].shape[1]
out_features = state_dict[lora_B_weight_name].shape[0]
# Model maybe loaded with different quantization schemes which may flatten the params.
# `bitsandbytes`, for example, flatten the weights when using 4bit. 8bit bnb models
# preserve weight shape.
module_weight_shape = cls._calculate_module_shape(model=transformer, base_module=module)
# This means there's no need for an expansion in the params, so we simply skip.
if tuple(module_weight.shape) == (out_features, in_features):
if tuple(module_weight_shape) == (out_features, in_features):
continue
# TODO (sayakpaul): We still need to consider if the module we're expanding is
# quantized and handle it accordingly if that is the case.
module_out_features, module_in_features = module_weight.shape
debug_message = ""
if in_features > module_in_features:
@@ -2080,13 +2088,16 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
base_weight_param = transformer_state_dict[base_param_name]
lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"]
if base_weight_param.shape[1] > lora_A_param.shape[1]:
# TODO (sayakpaul): Handle the cases when we actually need to expand when using quantization.
base_module_shape = cls._calculate_module_shape(model=transformer, base_weight_param_name=base_param_name)
if base_module_shape[1] > lora_A_param.shape[1]:
shape = (lora_A_param.shape[0], base_weight_param.shape[1])
expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device)
expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param)
lora_state_dict[f"{prefix}{k}.lora_A.weight"] = expanded_state_dict_weight
expanded_module_names.add(k)
elif base_weight_param.shape[1] < lora_A_param.shape[1]:
elif base_module_shape[1] < lora_A_param.shape[1]:
raise NotImplementedError(
f"This LoRA param ({k}.lora_A.weight) has an incompatible shape {lora_A_param.shape}. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new."
)
@@ -2098,6 +2109,28 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
return lora_state_dict
@staticmethod
def _calculate_module_shape(
model: "torch.nn.Module",
base_module: "torch.nn.Linear" = None,
base_weight_param_name: str = None,
) -> "torch.Size":
def _get_weight_shape(weight: torch.Tensor):
return weight.quant_state.shape if weight.__class__.__name__ == "Params4bit" else weight.shape
if base_module is not None:
return _get_weight_shape(base_module.weight)
elif base_weight_param_name is not None:
if not base_weight_param_name.endswith(".weight"):
raise ValueError(
f"Invalid `base_weight_param_name` passed as it does not end with '.weight' {base_weight_param_name=}."
)
module_path = base_weight_param_name.rsplit(".weight", 1)[0]
submodule = get_submodule_by_name(model, module_path)
return _get_weight_shape(submodule.weight)
raise ValueError("Either `base_module` or `base_weight_param_name` must be provided.")
# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support.

View File

@@ -100,7 +100,7 @@ from .import_utils import (
is_xformers_available,
requires_backends,
)
from .loading_utils import get_module_from_name, load_image, load_video
from .loading_utils import get_module_from_name, get_submodule_by_name, load_image, load_video
from .logging import get_logger
from .outputs import BaseOutput
from .peft_utils import (

View File

@@ -148,3 +148,15 @@ def get_module_from_name(module, tensor_name: str) -> Tuple[Any, str]:
module = new_module
tensor_name = splits[-1]
return module, tensor_name
def get_submodule_by_name(root_module, module_path: str):
current = root_module
parts = module_path.split(".")
for part in parts:
if part.isdigit():
idx = int(part)
current = current[idx] # e.g., for nn.ModuleList or nn.Sequential
else:
current = getattr(current, part)
return current

View File

@@ -20,6 +20,7 @@ import unittest
import numpy as np
import pytest
import safetensors.torch
from huggingface_hub import hf_hub_download
from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel
from diffusers.utils import is_accelerate_version, logging
@@ -568,6 +569,27 @@ class SlowBnb4BitFluxTests(Base4bitTests):
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
self.assertTrue(max_diff < 1e-3)
def test_lora_loading(self):
self.pipeline_4bit.load_lora_weights(
hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd"
)
self.pipeline_4bit.set_adapters("hyper-sd", adapter_weights=0.125)
output = self.pipeline_4bit(
prompt=self.prompt,
height=256,
width=256,
max_sequence_length=64,
output_type="np",
num_inference_steps=8,
generator=torch.Generator().manual_seed(42),
).images
out_slice = output[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.5347, 0.5342, 0.5283, 0.5093, 0.4988, 0.5093, 0.5044, 0.5015, 0.4946])
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
self.assertTrue(max_diff < 1e-3)
@slow
class BaseBnb4BitSerializationTests(Base4bitTests):