mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Flux quantized with lora (#10990)
* Flux quantized with lora * fix * changes * Apply suggestions from code review Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * Apply style fixes * enable model cpu offload() * Update src/diffusers/loaders/lora_pipeline.py Co-authored-by: hlky <hlky@hlky.ac> * update * Apply suggestions from code review * update * add peft as an additional dependency for gguf --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
2
.github/workflows/nightly_tests.yml
vendored
2
.github/workflows/nightly_tests.yml
vendored
@@ -417,7 +417,7 @@ jobs:
|
||||
additional_deps: ["peft"]
|
||||
- backend: "gguf"
|
||||
test_location: "gguf"
|
||||
additional_deps: []
|
||||
additional_deps: ["peft"]
|
||||
- backend: "torchao"
|
||||
test_location: "torchao"
|
||||
additional_deps: []
|
||||
|
||||
@@ -22,6 +22,8 @@ from ..utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
deprecate,
|
||||
get_submodule_by_name,
|
||||
is_bitsandbytes_available,
|
||||
is_gguf_available,
|
||||
is_peft_available,
|
||||
is_peft_version,
|
||||
is_torch_version,
|
||||
@@ -68,6 +70,49 @@ TRANSFORMER_NAME = "transformer"
|
||||
_MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX = {"x_embedder": "in_channels"}
|
||||
|
||||
|
||||
def _maybe_dequantize_weight_for_expanded_lora(model, module):
|
||||
if is_bitsandbytes_available():
|
||||
from ..quantizers.bitsandbytes import dequantize_bnb_weight
|
||||
|
||||
if is_gguf_available():
|
||||
from ..quantizers.gguf.utils import dequantize_gguf_tensor
|
||||
|
||||
is_bnb_4bit_quantized = module.weight.__class__.__name__ == "Params4bit"
|
||||
is_gguf_quantized = module.weight.__class__.__name__ == "GGUFParameter"
|
||||
|
||||
if is_bnb_4bit_quantized and not is_bitsandbytes_available():
|
||||
raise ValueError(
|
||||
"The checkpoint seems to have been quantized with `bitsandbytes` (4bits). Install `bitsandbytes` to load quantized checkpoints."
|
||||
)
|
||||
if is_gguf_quantized and not is_gguf_available():
|
||||
raise ValueError(
|
||||
"The checkpoint seems to have been quantized with `gguf`. Install `gguf` to load quantized checkpoints."
|
||||
)
|
||||
|
||||
weight_on_cpu = False
|
||||
if not module.weight.is_cuda:
|
||||
weight_on_cpu = True
|
||||
|
||||
if is_bnb_4bit_quantized:
|
||||
module_weight = dequantize_bnb_weight(
|
||||
module.weight.cuda() if weight_on_cpu else module.weight,
|
||||
state=module.weight.quant_state,
|
||||
dtype=model.dtype,
|
||||
).data
|
||||
elif is_gguf_quantized:
|
||||
module_weight = dequantize_gguf_tensor(
|
||||
module.weight.cuda() if weight_on_cpu else module.weight,
|
||||
)
|
||||
module_weight = module_weight.to(model.dtype)
|
||||
else:
|
||||
module_weight = module.weight.data
|
||||
|
||||
if weight_on_cpu:
|
||||
module_weight = module_weight.cpu()
|
||||
|
||||
return module_weight
|
||||
|
||||
|
||||
class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
r"""
|
||||
Load LoRA layers into Stable Diffusion [`UNet2DConditionModel`] and
|
||||
@@ -2267,6 +2312,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
overwritten_params = {}
|
||||
|
||||
is_peft_loaded = getattr(transformer, "peft_config", None) is not None
|
||||
is_quantized = hasattr(transformer, "hf_quantizer")
|
||||
for name, module in transformer.named_modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
module_weight = module.weight.data
|
||||
@@ -2291,9 +2337,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
if tuple(module_weight_shape) == (out_features, in_features):
|
||||
continue
|
||||
|
||||
# TODO (sayakpaul): We still need to consider if the module we're expanding is
|
||||
# quantized and handle it accordingly if that is the case.
|
||||
module_out_features, module_in_features = module_weight.shape
|
||||
module_out_features, module_in_features = module_weight_shape
|
||||
debug_message = ""
|
||||
if in_features > module_in_features:
|
||||
debug_message += (
|
||||
@@ -2316,6 +2360,10 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
parent_module_name, _, current_module_name = name.rpartition(".")
|
||||
parent_module = transformer.get_submodule(parent_module_name)
|
||||
|
||||
if is_quantized:
|
||||
module_weight = _maybe_dequantize_weight_for_expanded_lora(transformer, module)
|
||||
|
||||
# TODO: consider if this layer needs to be a quantized layer as well if `is_quantized` is True.
|
||||
with torch.device("meta"):
|
||||
expanded_module = torch.nn.Linear(
|
||||
in_features, out_features, bias=bias, dtype=module_weight.dtype
|
||||
@@ -2327,7 +2375,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
new_weight = torch.zeros_like(
|
||||
expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype
|
||||
)
|
||||
slices = tuple(slice(0, dim) for dim in module_weight.shape)
|
||||
slices = tuple(slice(0, dim) for dim in module_weight_shape)
|
||||
new_weight[slices] = module_weight
|
||||
tmp_state_dict = {"weight": new_weight}
|
||||
if module_bias is not None:
|
||||
@@ -2416,7 +2464,12 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
base_weight_param_name: str = None,
|
||||
) -> "torch.Size":
|
||||
def _get_weight_shape(weight: torch.Tensor):
|
||||
return weight.quant_state.shape if weight.__class__.__name__ == "Params4bit" else weight.shape
|
||||
if weight.__class__.__name__ == "Params4bit":
|
||||
return weight.quant_state.shape
|
||||
elif weight.__class__.__name__ == "GGUFParameter":
|
||||
return weight.quant_shape
|
||||
else:
|
||||
return weight.shape
|
||||
|
||||
if base_module is not None:
|
||||
return _get_weight_shape(base_module.weight)
|
||||
|
||||
@@ -400,6 +400,8 @@ class GGUFParameter(torch.nn.Parameter):
|
||||
data = data if data is not None else torch.empty(0)
|
||||
self = torch.Tensor._make_subclass(cls, data, requires_grad)
|
||||
self.quant_type = quant_type
|
||||
block_size, type_size = GGML_QUANT_SIZES[quant_type]
|
||||
self.quant_shape = _quant_shape_from_byte_shape(self.shape, type_size, block_size)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
@@ -21,8 +21,15 @@ import numpy as np
|
||||
import pytest
|
||||
import safetensors.torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
|
||||
from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel
|
||||
from diffusers import (
|
||||
BitsAndBytesConfig,
|
||||
DiffusionPipeline,
|
||||
FluxControlPipeline,
|
||||
FluxTransformer2DModel,
|
||||
SD3Transformer2DModel,
|
||||
)
|
||||
from diffusers.utils import is_accelerate_version, logging
|
||||
from diffusers.utils.testing_utils import (
|
||||
CaptureLogger,
|
||||
@@ -696,6 +703,42 @@ class SlowBnb4BitFluxTests(Base4bitTests):
|
||||
self.assertTrue(max_diff < 1e-3)
|
||||
|
||||
|
||||
@require_transformers_version_greater("4.44.0")
|
||||
@require_peft_backend
|
||||
class SlowBnb4BitFluxControlWithLoraTests(Base4bitTests):
|
||||
def setUp(self) -> None:
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
self.pipeline_4bit = FluxControlPipeline.from_pretrained("eramth/flux-4bit", torch_dtype=torch.float16)
|
||||
self.pipeline_4bit.enable_model_cpu_offload()
|
||||
|
||||
def tearDown(self):
|
||||
del self.pipeline_4bit
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_lora_loading(self):
|
||||
self.pipeline_4bit.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora")
|
||||
|
||||
output = self.pipeline_4bit(
|
||||
prompt=self.prompt,
|
||||
control_image=Image.new(mode="RGB", size=(256, 256)),
|
||||
height=256,
|
||||
width=256,
|
||||
max_sequence_length=64,
|
||||
output_type="np",
|
||||
num_inference_steps=8,
|
||||
generator=torch.Generator().manual_seed(42),
|
||||
).images
|
||||
out_slice = output[0, -3:, -3:, -1].flatten()
|
||||
expected_slice = np.array([0.1636, 0.1675, 0.1982, 0.1743, 0.1809, 0.1936, 0.1743, 0.2095, 0.2139])
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
|
||||
self.assertTrue(max_diff < 1e-3, msg=f"{out_slice=} != {expected_slice=}")
|
||||
|
||||
|
||||
@slow
|
||||
class BaseBnb4BitSerializationTests(Base4bitTests):
|
||||
def tearDown(self):
|
||||
|
||||
@@ -8,12 +8,14 @@ import torch.nn as nn
|
||||
from diffusers import (
|
||||
AuraFlowPipeline,
|
||||
AuraFlowTransformer2DModel,
|
||||
FluxControlPipeline,
|
||||
FluxPipeline,
|
||||
FluxTransformer2DModel,
|
||||
GGUFQuantizationConfig,
|
||||
SD3Transformer2DModel,
|
||||
StableDiffusion3Pipeline,
|
||||
)
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.testing_utils import (
|
||||
is_gguf_available,
|
||||
nightly,
|
||||
@@ -21,6 +23,7 @@ from diffusers.utils.testing_utils import (
|
||||
require_accelerate,
|
||||
require_big_gpu_with_torch_cuda,
|
||||
require_gguf_version_greater_or_equal,
|
||||
require_peft_backend,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
@@ -456,3 +459,46 @@ class AuraFlowGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
|
||||
)
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice)
|
||||
assert max_diff < 1e-4
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
@nightly
|
||||
@require_big_gpu_with_torch_cuda
|
||||
@require_accelerate
|
||||
@require_gguf_version_greater_or_equal("0.10.0")
|
||||
class FluxControlLoRAGGUFTests(unittest.TestCase):
|
||||
def test_lora_loading(self):
|
||||
ckpt_path = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
|
||||
transformer = FluxTransformer2DModel.from_single_file(
|
||||
ckpt_path,
|
||||
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
pipe = FluxControlPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
transformer=transformer,
|
||||
torch_dtype=torch.bfloat16,
|
||||
).to("cuda")
|
||||
pipe.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora")
|
||||
|
||||
prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
|
||||
control_image = load_image(
|
||||
"https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/control_image_robot_canny.png"
|
||||
)
|
||||
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
control_image=control_image,
|
||||
height=256,
|
||||
width=256,
|
||||
num_inference_steps=10,
|
||||
guidance_scale=30.0,
|
||||
output_type="np",
|
||||
generator=torch.manual_seed(0),
|
||||
).images
|
||||
|
||||
out_slice = output[0, -3:, -3:, -1].flatten()
|
||||
expected_slice = np.array([0.8047, 0.8359, 0.8711, 0.6875, 0.7070, 0.7383, 0.5469, 0.5820, 0.6641])
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
|
||||
self.assertTrue(max_diff < 1e-3)
|
||||
|
||||
Reference in New Issue
Block a user