From d61bb38fb4b21acc62a6b06a0367ceb30434ff45 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 3 Oct 2025 13:14:05 +0530 Subject: [PATCH] up --- tests/lora/test_lora_layers_flux.py | 204 ++++++++++++++++++++++------ 1 file changed, 162 insertions(+), 42 deletions(-) diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index ff53983ecf..f75a7b3777 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -1,3 +1,17 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import copy import gc import os @@ -68,10 +82,10 @@ class TestFluxLoRA(PeftLoraLoaderMixinTests): "scaling_factor": 1.5035, } has_two_text_encoders = True - (tokenizer_cls, tokenizer_id) = (CLIPTokenizer, "peft-internal-testing/tiny-clip-text-2") - (tokenizer_2_cls, tokenizer_2_id) = (AutoTokenizer, "hf-internal-testing/tiny-random-t5") - (text_encoder_cls, text_encoder_id) = (CLIPTextModel, "peft-internal-testing/tiny-clip-text-2") - (text_encoder_2_cls, text_encoder_2_id) = (T5EncoderModel, "hf-internal-testing/tiny-random-t5") + tokenizer_cls, tokenizer_id = CLIPTokenizer, "peft-internal-testing/tiny-clip-text-2" + tokenizer_2_cls, tokenizer_2_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5" + text_encoder_cls, text_encoder_id = CLIPTextModel, "peft-internal-testing/tiny-clip-text-2" + text_encoder_2_cls, text_encoder_2_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5" @property def output_shape(self): @@ -82,9 +96,11 @@ class TestFluxLoRA(PeftLoraLoaderMixinTests): sequence_length = 10 num_channels = 4 sizes = (32, 32) + generator = torch.manual_seed(0) noise = floats_tensor((batch_size, num_channels) + sizes) input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) + pipeline_inputs = { "prompt": "A painting of a squirrel eating a burger", "num_inference_steps": 4, @@ -95,23 +111,31 @@ class TestFluxLoRA(PeftLoraLoaderMixinTests): } if with_generator: pipeline_inputs.update({"generator": generator}) - return (noise, input_ids, pipeline_inputs) + + return noise, input_ids, pipeline_inputs def test_with_alpha_in_state_dict(self): - (components, _, denoiser_lora_config) = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + pipe.transformer.add_adapter(denoiser_lora_config) - assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer" + self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") + images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + with tempfile.TemporaryDirectory() as tmpdirname: denoiser_state_dict = get_peft_model_state_dict(pipe.transformer) self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict) + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) pipe.unload_lora_weights() pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + + # modify the state dict to have alpha values following + # https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA/blob/main/jon_snow.safetensors state_dict_with_alpha = safetensors.torch.load_file( os.path.join(tmpdirname, "pytorch_lora_weights.safetensors") ) @@ -120,8 +144,10 @@ class TestFluxLoRA(PeftLoraLoaderMixinTests): if "transformer" in k and "to_k" in k and ("lora_A" in k): alpha_dict[f"{k}.alpha"] = float(torch.randint(10, 100, size=())) state_dict_with_alpha.update(alpha_dict) + images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" + pipe.unload_lora_weights() pipe.load_lora_weights(state_dict_with_alpha) images_lora_with_alpha = pipe(**inputs, generator=torch.manual_seed(0)).images @@ -131,15 +157,19 @@ class TestFluxLoRA(PeftLoraLoaderMixinTests): assert not np.allclose(images_lora_with_alpha, images_lora, atol=0.001, rtol=0.001) def test_lora_expansion_works_for_absent_keys(self, base_pipe_output): - (components, _, denoiser_lora_config) = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + # Modify the config to have a layer which won't be present in the second LoRA we will load. modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config) modified_denoiser_lora_config.target_modules.add("x_embedder") + pipe.transformer.add_adapter(modified_denoiser_lora_config) assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer" + images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images assert not ( np.allclose(images_lora, base_pipe_output, atol=0.001, rtol=0.001), @@ -148,14 +178,18 @@ class TestFluxLoRA(PeftLoraLoaderMixinTests): with tempfile.TemporaryDirectory() as tmpdirname: denoiser_state_dict = get_peft_model_state_dict(pipe.transformer) self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict) + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + pipe.unload_lora_weights() pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="one") lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) lora_state_dict_without_xembedder = {k: v for (k, v) in lora_state_dict.items() if "x_embedder" not in k} + pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="two") pipe.set_adapters(["one", "two"]) assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer" + images_lora_with_absent_keys = pipe(**inputs, generator=torch.manual_seed(0)).images assert not ( np.allclose(images_lora, images_lora_with_absent_keys, atol=0.001, rtol=0.001), @@ -167,15 +201,17 @@ class TestFluxLoRA(PeftLoraLoaderMixinTests): ) def test_lora_expansion_works_for_extra_keys(self, base_pipe_output): - (components, _, denoiser_lora_config) = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config) modified_denoiser_lora_config.target_modules.add("x_embedder") pipe.transformer.add_adapter(modified_denoiser_lora_config) assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer" + images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images assert not ( np.allclose(images_lora, base_pipe_output, atol=0.001, rtol=0.001), @@ -185,13 +221,16 @@ class TestFluxLoRA(PeftLoraLoaderMixinTests): denoiser_state_dict = get_peft_model_state_dict(pipe.transformer) self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict) assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + pipe.unload_lora_weights() lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) lora_state_dict_without_xembedder = {k: v for (k, v) in lora_state_dict.items() if "x_embedder" not in k} pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="one") pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="two") + pipe.set_adapters(["one", "two"]) assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer" + images_lora_with_extra_keys = pipe(**inputs, generator=torch.manual_seed(0)).images assert not ( np.allclose(images_lora, images_lora_with_extra_keys, atol=0.001, rtol=0.001), @@ -250,10 +289,10 @@ class TestFluxControlLoRA(PeftLoraLoaderMixinTests): "scaling_factor": 1.5035, } has_two_text_encoders = True - (tokenizer_cls, tokenizer_id) = (CLIPTokenizer, "peft-internal-testing/tiny-clip-text-2") - (tokenizer_2_cls, tokenizer_2_id) = (AutoTokenizer, "hf-internal-testing/tiny-random-t5") - (text_encoder_cls, text_encoder_id) = (CLIPTextModel, "peft-internal-testing/tiny-clip-text-2") - (text_encoder_2_cls, text_encoder_2_id) = (T5EncoderModel, "hf-internal-testing/tiny-random-t5") + tokenizer_cls, tokenizer_id = CLIPTokenizer, "peft-internal-testing/tiny-clip-text-2" + tokenizer_2_cls, tokenizer_2_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5" + text_encoder_cls, text_encoder_id = CLIPTextModel, "peft-internal-testing/tiny-clip-text-2" + text_encoder_2_cls, text_encoder_2_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5" @property def output_shape(self): @@ -264,9 +303,11 @@ class TestFluxControlLoRA(PeftLoraLoaderMixinTests): sequence_length = 10 num_channels = 4 sizes = (32, 32) + generator = torch.manual_seed(0) noise = floats_tensor((batch_size, num_channels) + sizes) input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) + np.random.seed(0) pipeline_inputs = { "prompt": "A painting of a squirrel eating a burger", @@ -279,17 +320,22 @@ class TestFluxControlLoRA(PeftLoraLoaderMixinTests): } if with_generator: pipeline_inputs.update({"generator": generator}) - return (noise, input_ids, pipeline_inputs) + + return noise, input_ids, pipeline_inputs def test_with_norm_in_state_dict(self): - (components, _, denoiser_lora_config) = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + logger = logging.get_logger("diffusers.loaders.lora_pipeline") logger.setLevel(logging.INFO) + original_output = pipe(**inputs, generator=torch.manual_seed(0))[0] + for norm_layer in ["norm_q", "norm_k", "norm_added_q", "norm_added_k"]: norm_state_dict = {} for name, module in pipe.transformer.named_modules(): @@ -298,14 +344,17 @@ class TestFluxControlLoRA(PeftLoraLoaderMixinTests): norm_state_dict[f"transformer.{name}.weight"] = torch.randn( module.weight.shape, device=module.weight.device, dtype=module.weight.dtype ) + with CaptureLogger(logger) as cap_logger: pipe.load_lora_weights(norm_state_dict) lora_load_output = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert ( "The provided state dict contains normalization layers in addition to LoRA layers" in cap_logger.out ) assert len(pipe.transformer._transformer_norm_layers) > 0 + pipe.unload_lora_weights() lora_unload_output = pipe(**inputs, generator=torch.manual_seed(0))[0] assert pipe.transformer._transformer_norm_layers is None @@ -314,6 +363,7 @@ class TestFluxControlLoRA(PeftLoraLoaderMixinTests): np.allclose(original_output, lora_load_output, atol=1e-06, rtol=1e-06), f"{norm_layer} is tested", ) + with CaptureLogger(logger) as cap_logger: for key in list(norm_state_dict.keys()): norm_state_dict[key.replace("norm", "norm_k_something_random")] = norm_state_dict.pop(key) @@ -321,14 +371,17 @@ class TestFluxControlLoRA(PeftLoraLoaderMixinTests): assert "Unsupported keys found in state dict when trying to load normalization layers" in cap_logger.out def test_lora_parameter_expanded_shapes(self): - (components, _, _) = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + original_out = pipe(**inputs, generator=torch.manual_seed(0))[0] logger = logging.get_logger("diffusers.loaders.lora_pipeline") logger.setLevel(logging.DEBUG) + + # Change the transformer config to mimic a real use case. num_channels_without_control = 4 transformer = FluxTransformer2DModel.from_config( components["transformer"].config, in_channels=num_channels_without_control @@ -336,15 +389,17 @@ class TestFluxControlLoRA(PeftLoraLoaderMixinTests): assert transformer.config.in_channels == num_channels_without_control, ( f"Expected {num_channels_without_control} channels in the modified transformer but has transformer.config.in_channels={transformer.config.in_channels!r}" ) + original_transformer_state_dict = pipe.transformer.state_dict() x_embedder_weight = original_transformer_state_dict.pop("x_embedder.weight") incompatible_keys = transformer.load_state_dict(original_transformer_state_dict, strict=False) assert "x_embedder.weight" in incompatible_keys.missing_keys, ( "Could not find x_embedder.weight in the missing keys." ) + transformer.x_embedder.weight.data.copy_(x_embedder_weight[..., :num_channels_without_control]) pipe.transformer = transformer - (out_features, in_features) = pipe.transformer.x_embedder.weight.shape + out_features, in_features = pipe.transformer.x_embedder.weight.shape rank = 4 dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False) dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False) @@ -355,12 +410,15 @@ class TestFluxControlLoRA(PeftLoraLoaderMixinTests): with CaptureLogger(logger) as cap_logger: pipe.load_lora_weights(lora_state_dict, "adapter-1") assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" + lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0] assert not np.allclose(original_out, lora_out, rtol=0.0001, atol=0.0001) assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features assert pipe.transformer.config.in_channels == 2 * in_features assert cap_logger.out.startswith("Expanding the nn.Linear input/output features for module") - (components, _, _) = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + + # Testing opposite direction where the LoRA params are zero-padded. + components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) @@ -373,6 +431,7 @@ class TestFluxControlLoRA(PeftLoraLoaderMixinTests): with CaptureLogger(logger) as cap_logger: pipe.load_lora_weights(lora_state_dict, "adapter-1") assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" + lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0] assert not np.allclose(original_out, lora_out, rtol=0.0001, atol=0.0001) assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features @@ -380,19 +439,27 @@ class TestFluxControlLoRA(PeftLoraLoaderMixinTests): assert "The following LoRA modules were zero padded to match the state dict of" in cap_logger.out def test_normal_lora_with_expanded_lora_raises_error(self): - (components, _, _) = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + # Test the following situation. Load a regular LoRA (such as the ones trained on Flux.1-Dev). And then + # load shape expanded LoRA (such as Control LoRA). + components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + + # Change the transformer config to mimic a real use case. num_channels_without_control = 4 transformer = FluxTransformer2DModel.from_config( components["transformer"].config, in_channels=num_channels_without_control ).to(torch_device) components["transformer"] = transformer + pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) + logger = logging.get_logger("diffusers.loaders.lora_pipeline") logger.setLevel(logging.DEBUG) - (out_features, in_features) = pipe.transformer.x_embedder.weight.shape + + out_features, in_features = pipe.transformer.x_embedder.weight.shape rank = 4 + shape_expander_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False) shape_expander_lora_B = torch.nn.Linear(rank, out_features, bias=False) lora_state_dict = { @@ -416,23 +483,32 @@ class TestFluxControlLoRA(PeftLoraLoaderMixinTests): } with CaptureLogger(logger) as cap_logger: pipe.load_lora_weights(lora_state_dict, "adapter-2") + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" assert "The following LoRA modules were zero padded to match the state dict of" in cap_logger.out assert pipe.get_active_adapters() == ["adapter-2"] + lora_output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] assert not np.allclose(lora_output, lora_output_2, atol=0.001, rtol=0.001) - (components, _, _) = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + + # Test the opposite case where the first lora has the correct input features and the second lora has expanded input features. + # This should raise a runtime error on input shapes being incompatible. + components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + # Change the transformer config to mimic a real use case. num_channels_without_control = 4 transformer = FluxTransformer2DModel.from_config( components["transformer"].config, in_channels=num_channels_without_control ).to(torch_device) components["transformer"] = transformer + pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) + logger = logging.get_logger("diffusers.loaders.lora_pipeline") logger.setLevel(logging.DEBUG) - (out_features, in_features) = pipe.transformer.x_embedder.weight.shape + + out_features, in_features = pipe.transformer.x_embedder.weight.shape rank = 4 lora_state_dict = { "transformer.x_embedder.lora_A.weight": normal_lora_A.weight, @@ -442,27 +518,40 @@ class TestFluxControlLoRA(PeftLoraLoaderMixinTests): assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" assert pipe.transformer.x_embedder.weight.data.shape[1] == in_features assert pipe.transformer.config.in_channels == in_features + lora_state_dict = { "transformer.x_embedder.lora_A.weight": shape_expander_lora_A.weight, "transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight, } + # We should check for input shapes being incompatible here. But because above mentioned issue is + # not a supported use case, and because of the PEFT renaming, we will currently have a shape + # mismatch error. with pytest.raises(RuntimeError, match="size mismatch for x_embedder.lora_A.adapter-2.weight"): pipe.load_lora_weights(lora_state_dict, "adapter-2") def test_fuse_expanded_lora_with_regular_lora(self): - (components, _, _) = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + # This test checks if it works when a lora with expanded shapes (like control loras) but + # another lora with correct shapes is loaded. The opposite direction isn't supported and is + # tested with it. + components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + + # Change the transformer config to mimic a real use case. num_channels_without_control = 4 transformer = FluxTransformer2DModel.from_config( components["transformer"].config, in_channels=num_channels_without_control ).to(torch_device) components["transformer"] = transformer + pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) + logger = logging.get_logger("diffusers.loaders.lora_pipeline") logger.setLevel(logging.DEBUG) - (out_features, in_features) = pipe.transformer.x_embedder.weight.shape + + out_features, in_features = pipe.transformer.x_embedder.weight.shape rank = 4 + shape_expander_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False) shape_expander_lora_B = torch.nn.Linear(rank, out_features, bias=False) lora_state_dict = { @@ -471,34 +560,42 @@ class TestFluxControlLoRA(PeftLoraLoaderMixinTests): } pipe.load_lora_weights(lora_state_dict, "adapter-1") assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0] + normal_lora_A = torch.nn.Linear(in_features, rank, bias=False) normal_lora_B = torch.nn.Linear(rank, out_features, bias=False) lora_state_dict = { "transformer.x_embedder.lora_A.weight": normal_lora_A.weight, "transformer.x_embedder.lora_B.weight": normal_lora_B.weight, } + pipe.load_lora_weights(lora_state_dict, "adapter-2") assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" + lora_output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] pipe.set_adapters(["adapter-1", "adapter-2"], [1.0, 1.0]) lora_output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0] assert not np.allclose(lora_output, lora_output_2, atol=0.001, rtol=0.001) assert not np.allclose(lora_output, lora_output_3, atol=0.001, rtol=0.001) assert not np.allclose(lora_output_2, lora_output_3, atol=0.001, rtol=0.001) + pipe.fuse_lora(lora_scale=1.0, adapter_names=["adapter-1", "adapter-2"]) lora_output_4 = pipe(**inputs, generator=torch.manual_seed(0))[0] assert np.allclose(lora_output_3, lora_output_4, atol=0.001, rtol=0.001) - def test_load_regular_lora(self): - (components, _, _) = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + def test_load_regular_lora(self, base_pipe_output): + # This test checks if a regular lora (think of one trained on Flux.1 Dev for example) can be loaded + # into the transformer with more input channels than Flux.1 Dev, for example. Some examples of those + # transformers include Flux Fill, Flux Control, etc. + components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) - original_output = pipe(**inputs, generator=torch.manual_seed(0))[0] - (out_features, in_features) = pipe.transformer.x_embedder.weight.shape + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + out_features, in_features = pipe.transformer.x_embedder.weight.shape rank = 4 in_features = in_features // 2 normal_lora_A = torch.nn.Linear(in_features, rank, bias=False) @@ -512,15 +609,19 @@ class TestFluxControlLoRA(PeftLoraLoaderMixinTests): with CaptureLogger(logger) as cap_logger: pipe.load_lora_weights(lora_state_dict, "adapter-1") assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" + lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0] assert "The following LoRA modules were zero padded to match the state dict of" in cap_logger.out assert pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2 - assert not np.allclose(original_output, lora_output, atol=0.001, rtol=0.001) + assert not np.allclose(base_pipe_output, lora_output, atol=0.001, rtol=0.001) def test_lora_unload_with_parameter_expanded_shapes(self): - (components, _, _) = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + logger = logging.get_logger("diffusers.loaders.lora_pipeline") logger.setLevel(logging.DEBUG) + + # Change the transformer config to mimic a real use case. num_channels_without_control = 4 transformer = FluxTransformer2DModel.from_config( components["transformer"].config, in_channels=num_channels_without_control @@ -528,16 +629,21 @@ class TestFluxControlLoRA(PeftLoraLoaderMixinTests): assert transformer.config.in_channels == num_channels_without_control, ( f"Expected {num_channels_without_control} channels in the modified transformer but has transformer.config.in_channels={transformer.config.in_channels!r}" ) + + # This should be initialized with a Flux pipeline variant that doesn't accept `control_image`. components["transformer"] = transformer pipe = FluxPipeline(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) control_image = inputs.pop("control_image") original_out = pipe(**inputs, generator=torch.manual_seed(0))[0] + control_pipe = self.pipeline_class(**components) - (out_features, in_features) = control_pipe.transformer.x_embedder.weight.shape + out_features, in_features = control_pipe.transformer.x_embedder.weight.shape rank = 4 + dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False) dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False) lora_state_dict = { @@ -547,20 +653,24 @@ class TestFluxControlLoRA(PeftLoraLoaderMixinTests): with CaptureLogger(logger) as cap_logger: control_pipe.load_lora_weights(lora_state_dict, "adapter-1") assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" + inputs["control_image"] = control_image lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0] assert not np.allclose(original_out, lora_out, rtol=0.0001, atol=0.0001) assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features assert pipe.transformer.config.in_channels == 2 * in_features assert cap_logger.out.startswith("Expanding the nn.Linear input/output features for module") + control_pipe.unload_lora_weights(reset_to_overwritten_params=True) assert control_pipe.transformer.config.in_channels == num_channels_without_control, ( f"Expected {num_channels_without_control} channels in the modified transformer but has control_pipe.transformer.config.in_channels={control_pipe.transformer.config.in_channels!r}" ) + loaded_pipe = FluxPipeline.from_pipe(control_pipe) assert loaded_pipe.transformer.config.in_channels == num_channels_without_control, ( f"Expected {num_channels_without_control} channels in the modified transformer but has loaded_pipe.transformer.config.in_channels={loaded_pipe.transformer.config.in_channels!r}" ) + inputs.pop("control_image") unloaded_lora_out = loaded_pipe(**inputs, generator=torch.manual_seed(0))[0] assert not np.allclose(unloaded_lora_out, lora_out, rtol=0.0001, atol=0.0001) @@ -569,9 +679,11 @@ class TestFluxControlLoRA(PeftLoraLoaderMixinTests): assert pipe.transformer.config.in_channels == in_features def test_lora_unload_with_parameter_expanded_shapes_and_no_reset(self): - (components, _, _) = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + logger = logging.get_logger("diffusers.loaders.lora_pipeline") logger.setLevel(logging.DEBUG) + num_channels_without_control = 4 transformer = FluxTransformer2DModel.from_config( components["transformer"].config, in_channels=num_channels_without_control @@ -579,16 +691,21 @@ class TestFluxControlLoRA(PeftLoraLoaderMixinTests): assert transformer.config.in_channels == num_channels_without_control, ( f"Expected {num_channels_without_control} channels in the modified transformer but has transformer.config.in_channels={transformer.config.in_channels!r}" ) + + # This should be initialized with a Flux pipeline variant that doesn't accept `control_image`. components["transformer"] = transformer pipe = FluxPipeline(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - (_, _, inputs) = self.get_dummy_inputs(with_generator=False) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) control_image = inputs.pop("control_image") original_out = pipe(**inputs, generator=torch.manual_seed(0))[0] + control_pipe = self.pipeline_class(**components) - (out_features, in_features) = control_pipe.transformer.x_embedder.weight.shape + out_features, in_features = control_pipe.transformer.x_embedder.weight.shape rank = 4 + dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False) dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False) lora_state_dict = { @@ -598,16 +715,19 @@ class TestFluxControlLoRA(PeftLoraLoaderMixinTests): with CaptureLogger(logger) as cap_logger: control_pipe.load_lora_weights(lora_state_dict, "adapter-1") assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" + inputs["control_image"] = control_image lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0] assert not np.allclose(original_out, lora_out, rtol=0.0001, atol=0.0001) assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features assert pipe.transformer.config.in_channels == 2 * in_features assert cap_logger.out.startswith("Expanding the nn.Linear input/output features for module") + control_pipe.unload_lora_weights(reset_to_overwritten_params=False) assert control_pipe.transformer.config.in_channels == 2 * num_channels_without_control, ( f"Expected {num_channels_without_control} channels in the modified transformer but has control_pipe.transformer.config.in_channels={control_pipe.transformer.config.in_channels!r}" ) + no_lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0] assert not np.allclose(no_lora_out, lora_out, rtol=0.0001, atol=0.0001) assert pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2