mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[LoRA] feat: support loading loras into 4bit quantized Flux models. (#10578)
* feat: support loading loras into 4bit quantized models. * updates * update * remove weight check.
This commit is contained in:
@@ -21,6 +21,7 @@ from huggingface_hub.utils import validate_hf_hub_args
|
||||
from ..utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
deprecate,
|
||||
get_submodule_by_name,
|
||||
is_peft_available,
|
||||
is_peft_version,
|
||||
is_torch_version,
|
||||
@@ -1981,10 +1982,17 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
in_features = state_dict[lora_A_weight_name].shape[1]
|
||||
out_features = state_dict[lora_B_weight_name].shape[0]
|
||||
|
||||
# Model maybe loaded with different quantization schemes which may flatten the params.
|
||||
# `bitsandbytes`, for example, flatten the weights when using 4bit. 8bit bnb models
|
||||
# preserve weight shape.
|
||||
module_weight_shape = cls._calculate_module_shape(model=transformer, base_module=module)
|
||||
|
||||
# This means there's no need for an expansion in the params, so we simply skip.
|
||||
if tuple(module_weight.shape) == (out_features, in_features):
|
||||
if tuple(module_weight_shape) == (out_features, in_features):
|
||||
continue
|
||||
|
||||
# TODO (sayakpaul): We still need to consider if the module we're expanding is
|
||||
# quantized and handle it accordingly if that is the case.
|
||||
module_out_features, module_in_features = module_weight.shape
|
||||
debug_message = ""
|
||||
if in_features > module_in_features:
|
||||
@@ -2080,13 +2088,16 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
base_weight_param = transformer_state_dict[base_param_name]
|
||||
lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"]
|
||||
|
||||
if base_weight_param.shape[1] > lora_A_param.shape[1]:
|
||||
# TODO (sayakpaul): Handle the cases when we actually need to expand when using quantization.
|
||||
base_module_shape = cls._calculate_module_shape(model=transformer, base_weight_param_name=base_param_name)
|
||||
|
||||
if base_module_shape[1] > lora_A_param.shape[1]:
|
||||
shape = (lora_A_param.shape[0], base_weight_param.shape[1])
|
||||
expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device)
|
||||
expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param)
|
||||
lora_state_dict[f"{prefix}{k}.lora_A.weight"] = expanded_state_dict_weight
|
||||
expanded_module_names.add(k)
|
||||
elif base_weight_param.shape[1] < lora_A_param.shape[1]:
|
||||
elif base_module_shape[1] < lora_A_param.shape[1]:
|
||||
raise NotImplementedError(
|
||||
f"This LoRA param ({k}.lora_A.weight) has an incompatible shape {lora_A_param.shape}. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new."
|
||||
)
|
||||
@@ -2098,6 +2109,28 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
return lora_state_dict
|
||||
|
||||
@staticmethod
|
||||
def _calculate_module_shape(
|
||||
model: "torch.nn.Module",
|
||||
base_module: "torch.nn.Linear" = None,
|
||||
base_weight_param_name: str = None,
|
||||
) -> "torch.Size":
|
||||
def _get_weight_shape(weight: torch.Tensor):
|
||||
return weight.quant_state.shape if weight.__class__.__name__ == "Params4bit" else weight.shape
|
||||
|
||||
if base_module is not None:
|
||||
return _get_weight_shape(base_module.weight)
|
||||
elif base_weight_param_name is not None:
|
||||
if not base_weight_param_name.endswith(".weight"):
|
||||
raise ValueError(
|
||||
f"Invalid `base_weight_param_name` passed as it does not end with '.weight' {base_weight_param_name=}."
|
||||
)
|
||||
module_path = base_weight_param_name.rsplit(".weight", 1)[0]
|
||||
submodule = get_submodule_by_name(model, module_path)
|
||||
return _get_weight_shape(submodule.weight)
|
||||
|
||||
raise ValueError("Either `base_module` or `base_weight_param_name` must be provided.")
|
||||
|
||||
|
||||
# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
|
||||
# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support.
|
||||
|
||||
@@ -101,7 +101,7 @@ from .import_utils import (
|
||||
is_xformers_available,
|
||||
requires_backends,
|
||||
)
|
||||
from .loading_utils import get_module_from_name, load_image, load_video
|
||||
from .loading_utils import get_module_from_name, get_submodule_by_name, load_image, load_video
|
||||
from .logging import get_logger
|
||||
from .outputs import BaseOutput
|
||||
from .peft_utils import (
|
||||
|
||||
@@ -148,3 +148,15 @@ def get_module_from_name(module, tensor_name: str) -> Tuple[Any, str]:
|
||||
module = new_module
|
||||
tensor_name = splits[-1]
|
||||
return module, tensor_name
|
||||
|
||||
|
||||
def get_submodule_by_name(root_module, module_path: str):
|
||||
current = root_module
|
||||
parts = module_path.split(".")
|
||||
for part in parts:
|
||||
if part.isdigit():
|
||||
idx = int(part)
|
||||
current = current[idx] # e.g., for nn.ModuleList or nn.Sequential
|
||||
else:
|
||||
current = getattr(current, part)
|
||||
return current
|
||||
|
||||
@@ -20,6 +20,7 @@ import unittest
|
||||
import numpy as np
|
||||
import pytest
|
||||
import safetensors.torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel
|
||||
from diffusers.utils import is_accelerate_version, logging
|
||||
@@ -568,6 +569,27 @@ class SlowBnb4BitFluxTests(Base4bitTests):
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
|
||||
self.assertTrue(max_diff < 1e-3)
|
||||
|
||||
def test_lora_loading(self):
|
||||
self.pipeline_4bit.load_lora_weights(
|
||||
hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd"
|
||||
)
|
||||
self.pipeline_4bit.set_adapters("hyper-sd", adapter_weights=0.125)
|
||||
|
||||
output = self.pipeline_4bit(
|
||||
prompt=self.prompt,
|
||||
height=256,
|
||||
width=256,
|
||||
max_sequence_length=64,
|
||||
output_type="np",
|
||||
num_inference_steps=8,
|
||||
generator=torch.Generator().manual_seed(42),
|
||||
).images
|
||||
out_slice = output[0, -3:, -3:, -1].flatten()
|
||||
expected_slice = np.array([0.5347, 0.5342, 0.5283, 0.5093, 0.4988, 0.5093, 0.5044, 0.5015, 0.4946])
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
|
||||
self.assertTrue(max_diff < 1e-3)
|
||||
|
||||
|
||||
@slow
|
||||
class BaseBnb4BitSerializationTests(Base4bitTests):
|
||||
|
||||
Reference in New Issue
Block a user