mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Test error raised when loading normal and expanding loras together in Flux (#10188)
* add test for expanding lora and normal lora error * Update tests/lora/test_lora_layers_flux.py * fix things. * Update src/diffusers/loaders/peft.py --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -2337,12 +2337,19 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
f"this please open an issue at https://github.com/huggingface/diffusers/issues."
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
debug_message = (
|
||||
f'Expanding the nn.Linear input/output features for module="{name}" because the provided LoRA '
|
||||
f"checkpoint contains higher number of features than expected. The number of input_features will be "
|
||||
f"expanded from {module_in_features} to {in_features}, and the number of output features will be "
|
||||
f"expanded from {module_out_features} to {out_features}."
|
||||
f"expanded from {module_in_features} to {in_features}"
|
||||
)
|
||||
if module_out_features != out_features:
|
||||
debug_message += (
|
||||
", and the number of output features will be "
|
||||
f"expanded from {module_out_features} to {out_features}."
|
||||
)
|
||||
else:
|
||||
debug_message += "."
|
||||
logger.debug(debug_message)
|
||||
|
||||
has_param_with_shape_update = True
|
||||
parent_module_name, _, current_module_name = name.rpartition(".")
|
||||
|
||||
@@ -205,6 +205,7 @@ class PeftAdapterMixin:
|
||||
weights.
|
||||
"""
|
||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
@@ -316,8 +317,22 @@ class PeftAdapterMixin:
|
||||
if is_peft_version(">=", "0.13.1"):
|
||||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
|
||||
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
|
||||
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
|
||||
# To handle scenarios where we cannot successfully set state dict. If it's unsucessful,
|
||||
# we should also delete the `peft_config` associated to the `adapter_name`.
|
||||
try:
|
||||
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
|
||||
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
|
||||
except RuntimeError as e:
|
||||
for module in self.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
active_adapters = module.active_adapters
|
||||
for active_adapter in active_adapters:
|
||||
if adapter_name in active_adapter:
|
||||
module.delete_adapter(adapter_name)
|
||||
|
||||
self.peft_config.pop(adapter_name)
|
||||
logger.error(f"Loading {adapter_name} was unsucessful with the following error: \n{e}")
|
||||
raise
|
||||
|
||||
warn_msg = ""
|
||||
if incompatible_keys is not None:
|
||||
|
||||
@@ -430,6 +430,122 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3))
|
||||
self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))
|
||||
|
||||
def test_lora_expanding_shape_with_normal_lora_raises_error(self):
|
||||
# TODO: This test checks if an error is raised when a lora expands shapes (like control loras) but
|
||||
# another lora with correct shapes is loaded. This is not supported at the moment and should raise an error.
|
||||
# When we do support it, this test should be removed. Context: https://github.com/huggingface/diffusers/issues/10180
|
||||
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
|
||||
# Change the transformer config to mimic a real use case.
|
||||
num_channels_without_control = 4
|
||||
transformer = FluxTransformer2DModel.from_config(
|
||||
components["transformer"].config, in_channels=num_channels_without_control
|
||||
).to(torch_device)
|
||||
components["transformer"] = transformer
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
out_features, in_features = pipe.transformer.x_embedder.weight.shape
|
||||
rank = 4
|
||||
|
||||
shape_expander_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False)
|
||||
shape_expander_lora_B = torch.nn.Linear(rank, out_features, bias=False)
|
||||
lora_state_dict = {
|
||||
"transformer.x_embedder.lora_A.weight": shape_expander_lora_A.weight,
|
||||
"transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight,
|
||||
}
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.load_lora_weights(lora_state_dict, "adapter-1")
|
||||
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
self.assertTrue(pipe.get_active_adapters() == ["adapter-1"])
|
||||
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
|
||||
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
|
||||
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
|
||||
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
normal_lora_A = torch.nn.Linear(in_features, rank, bias=False)
|
||||
normal_lora_B = torch.nn.Linear(rank, out_features, bias=False)
|
||||
lora_state_dict = {
|
||||
"transformer.x_embedder.lora_A.weight": normal_lora_A.weight,
|
||||
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
|
||||
}
|
||||
|
||||
# The first lora expanded the input features of x_embedder. Here, we are trying to load a lora with the correct
|
||||
# input features before expansion. This should raise an error about the weight shapes being incompatible.
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"size mismatch for x_embedder.lora_A.adapter-2.weight",
|
||||
pipe.load_lora_weights,
|
||||
lora_state_dict,
|
||||
"adapter-2",
|
||||
)
|
||||
# We should have `adapter-1` as the only adapter.
|
||||
self.assertTrue(pipe.get_active_adapters() == ["adapter-1"])
|
||||
|
||||
# Check if the output is the same after lora loading error
|
||||
lora_output_after_error = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(np.allclose(lora_output, lora_output_after_error, atol=1e-3, rtol=1e-3))
|
||||
|
||||
# Test the opposite case where the first lora has the correct input features and the second lora has expanded input features.
|
||||
# This should raise a runtime error on input shapes being incompatible. But it doesn't. This is because PEFT renames the
|
||||
# original layers as `base_layer` and the lora layers with the adapter names. This makes our logic to check if a lora
|
||||
# weight is compatible with the current model inadequate. This should be addressed when attempting support for
|
||||
# https://github.com/huggingface/diffusers/issues/10180 (TODO)
|
||||
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
# Change the transformer config to mimic a real use case.
|
||||
num_channels_without_control = 4
|
||||
transformer = FluxTransformer2DModel.from_config(
|
||||
components["transformer"].config, in_channels=num_channels_without_control
|
||||
).to(torch_device)
|
||||
components["transformer"] = transformer
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
out_features, in_features = pipe.transformer.x_embedder.weight.shape
|
||||
rank = 4
|
||||
|
||||
lora_state_dict = {
|
||||
"transformer.x_embedder.lora_A.weight": normal_lora_A.weight,
|
||||
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
|
||||
}
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.load_lora_weights(lora_state_dict, "adapter-1")
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
|
||||
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features)
|
||||
self.assertTrue(pipe.transformer.config.in_channels == in_features)
|
||||
self.assertFalse(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
|
||||
|
||||
lora_state_dict = {
|
||||
"transformer.x_embedder.lora_A.weight": shape_expander_lora_A.weight,
|
||||
"transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight,
|
||||
}
|
||||
|
||||
# We should check for input shapes being incompatible here. But because above mentioned issue is
|
||||
# not a supported use case, and because of the PEFT renaming, we will currently have a shape
|
||||
# mismatch error.
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"size mismatch for x_embedder.lora_A.adapter-2.weight",
|
||||
pipe.load_lora_weights,
|
||||
lora_state_dict,
|
||||
"adapter-2",
|
||||
)
|
||||
|
||||
@unittest.skip("Not supported in Flux.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user