From ed2f956072a3b446d984f359ba6c427c259ab4ee Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 9 Oct 2023 18:01:55 +0200 Subject: [PATCH] Fix loading broken LoRAs that could give NaN (#5316) * Fix fuse Lora * improve a bit * make style * Update src/diffusers/models/lora.py Co-authored-by: Benjamin Bossan * ciao C file * ciao C file * test & make style --------- Co-authored-by: Benjamin Bossan --- src/diffusers/loaders.py | 48 +++++++++++++++------- src/diffusers/models/lora.py | 20 ++++++++- tests/lora/test_lora_layers_old_backend.py | 41 ++++++++++++++++++ 3 files changed, 92 insertions(+), 17 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 11858ac3bb..2cc547be01 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -121,7 +121,7 @@ class PatchedLoraProjection(nn.Module): return super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars) - def _fuse_lora(self, lora_scale=1.0): + def _fuse_lora(self, lora_scale=1.0, safe_fusing=False): if self.lora_linear_layer is None: return @@ -135,6 +135,14 @@ class PatchedLoraProjection(nn.Module): w_up = w_up * self.lora_linear_layer.network_alpha / self.lora_linear_layer.rank fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0]) + + if safe_fusing and torch.isnan(fused_weight).any().item(): + raise ValueError( + "This LoRA weight seems to be broken. " + f"Encountered NaN values when trying to fuse LoRA weights for {self}." + "LoRA weights will not be fused." + ) + self.regular_linear_layer.weight.data = fused_weight.to(device=device, dtype=dtype) # we can drop the lora layer now @@ -672,13 +680,14 @@ class UNet2DConditionLoadersMixin: save_function(state_dict, os.path.join(save_directory, weight_name)) logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") - def fuse_lora(self, lora_scale=1.0): + def fuse_lora(self, lora_scale=1.0, safe_fusing=False): self.lora_scale = lora_scale + self._safe_fusing = safe_fusing self.apply(self._fuse_lora_apply) def _fuse_lora_apply(self, module): if hasattr(module, "_fuse_lora"): - module._fuse_lora(self.lora_scale) + module._fuse_lora(self.lora_scale, self._safe_fusing) def unfuse_lora(self): self.apply(self._unfuse_lora_apply) @@ -2086,7 +2095,13 @@ class LoraLoaderMixin: # Safe to call the following regardless of LoRA. self._remove_text_encoder_monkey_patch() - def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True, lora_scale: float = 1.0): + def fuse_lora( + self, + fuse_unet: bool = True, + fuse_text_encoder: bool = True, + lora_scale: float = 1.0, + safe_fusing: bool = False, + ): r""" Fuses the LoRA parameters into the original parameters of the corresponding blocks. @@ -2103,6 +2118,8 @@ class LoraLoaderMixin: LoRA parameters then it won't have any effect. lora_scale (`float`, defaults to 1.0): Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. """ if fuse_unet or fuse_text_encoder: self.num_fused_loras += 1 @@ -2112,12 +2129,13 @@ class LoraLoaderMixin: ) if fuse_unet: - self.unet.fuse_lora(lora_scale) + self.unet.fuse_lora(lora_scale, safe_fusing=safe_fusing) if self.use_peft_backend: from peft.tuners.tuners_utils import BaseTunerLayer - def fuse_text_encoder_lora(text_encoder, lora_scale=1.0): + def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False): + # TODO(Patrick, Younes): enable "safe" fusing for module in text_encoder.modules(): if isinstance(module, BaseTunerLayer): if lora_scale != 1.0: @@ -2129,24 +2147,24 @@ class LoraLoaderMixin: if version.parse(__version__) > version.parse("0.23"): deprecate("fuse_text_encoder_lora", "0.25", LORA_DEPRECATION_MESSAGE) - def fuse_text_encoder_lora(text_encoder, lora_scale=1.0): + def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False): for _, attn_module in text_encoder_attn_modules(text_encoder): if isinstance(attn_module.q_proj, PatchedLoraProjection): - attn_module.q_proj._fuse_lora(lora_scale) - attn_module.k_proj._fuse_lora(lora_scale) - attn_module.v_proj._fuse_lora(lora_scale) - attn_module.out_proj._fuse_lora(lora_scale) + attn_module.q_proj._fuse_lora(lora_scale, safe_fusing) + attn_module.k_proj._fuse_lora(lora_scale, safe_fusing) + attn_module.v_proj._fuse_lora(lora_scale, safe_fusing) + attn_module.out_proj._fuse_lora(lora_scale, safe_fusing) for _, mlp_module in text_encoder_mlp_modules(text_encoder): if isinstance(mlp_module.fc1, PatchedLoraProjection): - mlp_module.fc1._fuse_lora(lora_scale) - mlp_module.fc2._fuse_lora(lora_scale) + mlp_module.fc1._fuse_lora(lora_scale, safe_fusing) + mlp_module.fc2._fuse_lora(lora_scale, safe_fusing) if fuse_text_encoder: if hasattr(self, "text_encoder"): - fuse_text_encoder_lora(self.text_encoder, lora_scale) + fuse_text_encoder_lora(self.text_encoder, lora_scale, safe_fusing) if hasattr(self, "text_encoder_2"): - fuse_text_encoder_lora(self.text_encoder_2, lora_scale) + fuse_text_encoder_lora(self.text_encoder_2, lora_scale, safe_fusing) def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True): r""" diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index a777bb93e1..aec7200afd 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -112,7 +112,7 @@ class LoRACompatibleConv(nn.Conv2d): def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]): self.lora_layer = lora_layer - def _fuse_lora(self, lora_scale=1.0): + def _fuse_lora(self, lora_scale=1.0, safe_fusing=False): if self.lora_layer is None: return @@ -128,6 +128,14 @@ class LoRACompatibleConv(nn.Conv2d): fusion = torch.mm(w_up.flatten(start_dim=1), w_down.flatten(start_dim=1)) fusion = fusion.reshape((w_orig.shape)) fused_weight = w_orig + (lora_scale * fusion) + + if safe_fusing and torch.isnan(fused_weight).any().item(): + raise ValueError( + "This LoRA weight seems to be broken. " + f"Encountered NaN values when trying to fuse LoRA weights for {self}." + "LoRA weights will not be fused." + ) + self.weight.data = fused_weight.to(device=device, dtype=dtype) # we can drop the lora layer now @@ -182,7 +190,7 @@ class LoRACompatibleLinear(nn.Linear): def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]): self.lora_layer = lora_layer - def _fuse_lora(self, lora_scale=1.0): + def _fuse_lora(self, lora_scale=1.0, safe_fusing=False): if self.lora_layer is None: return @@ -196,6 +204,14 @@ class LoRACompatibleLinear(nn.Linear): w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0]) + + if safe_fusing and torch.isnan(fused_weight).any().item(): + raise ValueError( + "This LoRA weight seems to be broken. " + f"Encountered NaN values when trying to fuse LoRA weights for {self}." + "LoRA weights will not be fused." + ) + self.weight.data = fused_weight.to(device=device, dtype=dtype) # we can drop the lora layer now diff --git a/tests/lora/test_lora_layers_old_backend.py b/tests/lora/test_lora_layers_old_backend.py index d616ef8c78..8c1fb48776 100644 --- a/tests/lora/test_lora_layers_old_backend.py +++ b/tests/lora/test_lora_layers_old_backend.py @@ -1028,6 +1028,47 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase): sd_pipe.unload_lora_weights() + def test_lora_fuse_nan(self): + pipeline_components, lora_components = self.get_dummy_components() + sd_pipe = StableDiffusionXLPipeline(**pipeline_components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + _, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False) + + # Emulate training. + set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True) + set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True) + set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True) + + with tempfile.TemporaryDirectory() as tmpdirname: + StableDiffusionXLPipeline.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_layers"], + text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], + safe_serialization=True, + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + + # corrupt one LoRA weight with `inf` values + with torch.no_grad(): + sd_pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_layer.down.weight += float( + "inf" + ) + + # with `safe_fusing=True` we should see an Error + with self.assertRaises(ValueError): + sd_pipe.fuse_lora(safe_fusing=True) + + # without we should not see an error, but every image will be black + sd_pipe.fuse_lora(safe_fusing=False) + + out = sd_pipe("test", num_inference_steps=2, output_type="np").images + + assert np.isnan(out).all() + def test_lora_fusion(self): pipeline_components, lora_components = self.get_dummy_components() sd_pipe = StableDiffusionXLPipeline(**pipeline_components)