mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[LoRA] minor fix for load_lora_weights() for Flux and a test (#11595)
* fix peft delete adapters for flux. * add test * empty commit
This commit is contained in:
@@ -2545,14 +2545,13 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
if unexpected_modules:
|
||||
logger.debug(f"Found unexpected modules: {unexpected_modules}. These will be ignored.")
|
||||
|
||||
is_peft_loaded = getattr(transformer, "peft_config", None) is not None
|
||||
for k in lora_module_names:
|
||||
if k in unexpected_modules:
|
||||
continue
|
||||
|
||||
base_param_name = (
|
||||
f"{k.replace(prefix, '')}.base_layer.weight"
|
||||
if is_peft_loaded and f"{k.replace(prefix, '')}.base_layer.weight" in transformer_state_dict
|
||||
if f"{k.replace(prefix, '')}.base_layer.weight" in transformer_state_dict
|
||||
else f"{k.replace(prefix, '')}.weight"
|
||||
)
|
||||
base_weight_param = transformer_state_dict[base_param_name]
|
||||
|
||||
@@ -2149,3 +2149,51 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
def test_inference_load_delete_load_adapters(self):
|
||||
"Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
|
||||
for scheduler_cls in self.scheduler_classes:
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
)
|
||||
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
lora_loadable_components = self.pipeline_class._lora_loadable_modules
|
||||
if "text_encoder_2" in lora_loadable_components:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
|
||||
output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts)
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# First, delete adapter and compare.
|
||||
pipe.delete_adapters(pipe.get_active_adapters()[0])
|
||||
output_no_adapter = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertFalse(np.allclose(output_adapter_1, output_no_adapter, atol=1e-3, rtol=1e-3))
|
||||
self.assertTrue(np.allclose(output_no_lora, output_no_adapter, atol=1e-3, rtol=1e-3))
|
||||
|
||||
# Then load adapter and compare.
|
||||
pipe.load_lora_weights(tmpdirname)
|
||||
output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(np.allclose(output_adapter_1, output_lora_loaded, atol=1e-3, rtol=1e-3))
|
||||
|
||||
Reference in New Issue
Block a user