1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[LoRA] fix: lora unloading when using expanded Flux LoRAs. (#10397)

* fix: lora unloading when using expanded Flux LoRAs.

* fix argument name.

Co-authored-by: a-r-r-o-w <contact.aryanvs@gmail.com>

* docs.

---------

Co-authored-by: a-r-r-o-w <contact.aryanvs@gmail.com>
This commit is contained in:
Sayak Paul
2025-01-07 00:05:05 +05:30
committed by GitHub
parent 2f25156c14
commit d9d94e12f3
3 changed files with 83 additions and 4 deletions

View File

@@ -305,6 +305,10 @@ image = control_pipe(
image.save("output.png")
```
## Note about `unload_lora_weights()` when using Flux LoRAs
When unloading the Control LoRA weights, call `pipe.unload_lora_weights(reset_to_overwritten_params=True)` to reset the `pipe.transformer` completely back to its original form. The resultant pipeline can then be used with methods like [`DiffusionPipeline.from_pipe`]. More details about this argument are available in [this PR](https://github.com/huggingface/diffusers/pull/10397).
## Running FP16 inference
Flux can generate high-quality images with FP16 (i.e. to accelerate inference on Turing/Volta GPUs) but produces different outputs compared to FP32/BF16. The issue is that some activations in the text encoders have to be clipped when running in FP16, which affects the overall image. Forcing text encoders to run with FP32 inference thus removes this output difference. See [here](https://github.com/huggingface/diffusers/pull/9097#issuecomment-2272292516) for details.

View File

@@ -2277,8 +2277,24 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
super().unfuse_lora(components=components)
# We override this here account for `_transformer_norm_layers`.
def unload_lora_weights(self):
# We override this here account for `_transformer_norm_layers` and `_overwritten_params`.
def unload_lora_weights(self, reset_to_overwritten_params=False):
"""
Unloads the LoRA parameters.
Args:
reset_to_overwritten_params (`bool`, defaults to `False`): Whether to reset the LoRA-loaded modules
to their original params. Refer to the [Flux
documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux) to learn more.
Examples:
```python
>>> # Assuming `pipeline` is already loaded with the LoRA parameters.
>>> pipeline.unload_lora_weights()
>>> ...
```
"""
super().unload_lora_weights()
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
@@ -2286,7 +2302,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
transformer._transformer_norm_layers = None
if getattr(transformer, "_overwritten_params", None) is not None:
if reset_to_overwritten_params and getattr(transformer, "_overwritten_params", None) is not None:
overwritten_params = transformer._overwritten_params
module_names = set()

View File

@@ -706,7 +706,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
control_pipe.unload_lora_weights()
control_pipe.unload_lora_weights(reset_to_overwritten_params=True)
self.assertTrue(
control_pipe.transformer.config.in_channels == num_channels_without_control,
f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}",
@@ -724,6 +724,65 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features)
self.assertTrue(pipe.transformer.config.in_channels == in_features)
def test_lora_unload_with_parameter_expanded_shapes_and_no_reset(self):
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
logger.setLevel(logging.DEBUG)
# Change the transformer config to mimic a real use case.
num_channels_without_control = 4
transformer = FluxTransformer2DModel.from_config(
components["transformer"].config, in_channels=num_channels_without_control
).to(torch_device)
self.assertTrue(
transformer.config.in_channels == num_channels_without_control,
f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}",
)
# This should be initialized with a Flux pipeline variant that doesn't accept `control_image`.
components["transformer"] = transformer
pipe = FluxPipeline(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
control_image = inputs.pop("control_image")
original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
control_pipe = self.pipeline_class(**components)
out_features, in_features = control_pipe.transformer.x_embedder.weight.shape
rank = 4
dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False)
dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False)
lora_state_dict = {
"transformer.x_embedder.lora_A.weight": dummy_lora_A.weight,
"transformer.x_embedder.lora_B.weight": dummy_lora_B.weight,
}
with CaptureLogger(logger) as cap_logger:
control_pipe.load_lora_weights(lora_state_dict, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
inputs["control_image"] = control_image
lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
control_pipe.unload_lora_weights(reset_to_overwritten_params=False)
self.assertTrue(
control_pipe.transformer.config.in_channels == 2 * num_channels_without_control,
f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}",
)
no_lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(np.allclose(no_lora_out, lora_out, rtol=1e-4, atol=1e-4))
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2)
self.assertTrue(pipe.transformer.config.in_channels == in_features * 2)
@unittest.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass