mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix loading broken LoRAs that could give NaN (#5316)
* Fix fuse Lora * improve a bit * make style * Update src/diffusers/models/lora.py Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com> * ciao C file * ciao C file * test & make style --------- Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
a844065384
commit
ed2f956072
@@ -121,7 +121,7 @@ class PatchedLoraProjection(nn.Module):
|
||||
|
||||
return super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars)
|
||||
|
||||
def _fuse_lora(self, lora_scale=1.0):
|
||||
def _fuse_lora(self, lora_scale=1.0, safe_fusing=False):
|
||||
if self.lora_linear_layer is None:
|
||||
return
|
||||
|
||||
@@ -135,6 +135,14 @@ class PatchedLoraProjection(nn.Module):
|
||||
w_up = w_up * self.lora_linear_layer.network_alpha / self.lora_linear_layer.rank
|
||||
|
||||
fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
|
||||
|
||||
if safe_fusing and torch.isnan(fused_weight).any().item():
|
||||
raise ValueError(
|
||||
"This LoRA weight seems to be broken. "
|
||||
f"Encountered NaN values when trying to fuse LoRA weights for {self}."
|
||||
"LoRA weights will not be fused."
|
||||
)
|
||||
|
||||
self.regular_linear_layer.weight.data = fused_weight.to(device=device, dtype=dtype)
|
||||
|
||||
# we can drop the lora layer now
|
||||
@@ -672,13 +680,14 @@ class UNet2DConditionLoadersMixin:
|
||||
save_function(state_dict, os.path.join(save_directory, weight_name))
|
||||
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
|
||||
|
||||
def fuse_lora(self, lora_scale=1.0):
|
||||
def fuse_lora(self, lora_scale=1.0, safe_fusing=False):
|
||||
self.lora_scale = lora_scale
|
||||
self._safe_fusing = safe_fusing
|
||||
self.apply(self._fuse_lora_apply)
|
||||
|
||||
def _fuse_lora_apply(self, module):
|
||||
if hasattr(module, "_fuse_lora"):
|
||||
module._fuse_lora(self.lora_scale)
|
||||
module._fuse_lora(self.lora_scale, self._safe_fusing)
|
||||
|
||||
def unfuse_lora(self):
|
||||
self.apply(self._unfuse_lora_apply)
|
||||
@@ -2086,7 +2095,13 @@ class LoraLoaderMixin:
|
||||
# Safe to call the following regardless of LoRA.
|
||||
self._remove_text_encoder_monkey_patch()
|
||||
|
||||
def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True, lora_scale: float = 1.0):
|
||||
def fuse_lora(
|
||||
self,
|
||||
fuse_unet: bool = True,
|
||||
fuse_text_encoder: bool = True,
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
):
|
||||
r"""
|
||||
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
||||
|
||||
@@ -2103,6 +2118,8 @@ class LoraLoaderMixin:
|
||||
LoRA parameters then it won't have any effect.
|
||||
lora_scale (`float`, defaults to 1.0):
|
||||
Controls how much to influence the outputs with the LoRA parameters.
|
||||
safe_fusing (`bool`, defaults to `False`):
|
||||
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
||||
"""
|
||||
if fuse_unet or fuse_text_encoder:
|
||||
self.num_fused_loras += 1
|
||||
@@ -2112,12 +2129,13 @@ class LoraLoaderMixin:
|
||||
)
|
||||
|
||||
if fuse_unet:
|
||||
self.unet.fuse_lora(lora_scale)
|
||||
self.unet.fuse_lora(lora_scale, safe_fusing=safe_fusing)
|
||||
|
||||
if self.use_peft_backend:
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0):
|
||||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False):
|
||||
# TODO(Patrick, Younes): enable "safe" fusing
|
||||
for module in text_encoder.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
if lora_scale != 1.0:
|
||||
@@ -2129,24 +2147,24 @@ class LoraLoaderMixin:
|
||||
if version.parse(__version__) > version.parse("0.23"):
|
||||
deprecate("fuse_text_encoder_lora", "0.25", LORA_DEPRECATION_MESSAGE)
|
||||
|
||||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0):
|
||||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False):
|
||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
||||
attn_module.q_proj._fuse_lora(lora_scale)
|
||||
attn_module.k_proj._fuse_lora(lora_scale)
|
||||
attn_module.v_proj._fuse_lora(lora_scale)
|
||||
attn_module.out_proj._fuse_lora(lora_scale)
|
||||
attn_module.q_proj._fuse_lora(lora_scale, safe_fusing)
|
||||
attn_module.k_proj._fuse_lora(lora_scale, safe_fusing)
|
||||
attn_module.v_proj._fuse_lora(lora_scale, safe_fusing)
|
||||
attn_module.out_proj._fuse_lora(lora_scale, safe_fusing)
|
||||
|
||||
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
|
||||
if isinstance(mlp_module.fc1, PatchedLoraProjection):
|
||||
mlp_module.fc1._fuse_lora(lora_scale)
|
||||
mlp_module.fc2._fuse_lora(lora_scale)
|
||||
mlp_module.fc1._fuse_lora(lora_scale, safe_fusing)
|
||||
mlp_module.fc2._fuse_lora(lora_scale, safe_fusing)
|
||||
|
||||
if fuse_text_encoder:
|
||||
if hasattr(self, "text_encoder"):
|
||||
fuse_text_encoder_lora(self.text_encoder, lora_scale)
|
||||
fuse_text_encoder_lora(self.text_encoder, lora_scale, safe_fusing)
|
||||
if hasattr(self, "text_encoder_2"):
|
||||
fuse_text_encoder_lora(self.text_encoder_2, lora_scale)
|
||||
fuse_text_encoder_lora(self.text_encoder_2, lora_scale, safe_fusing)
|
||||
|
||||
def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True):
|
||||
r"""
|
||||
|
||||
@@ -112,7 +112,7 @@ class LoRACompatibleConv(nn.Conv2d):
|
||||
def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]):
|
||||
self.lora_layer = lora_layer
|
||||
|
||||
def _fuse_lora(self, lora_scale=1.0):
|
||||
def _fuse_lora(self, lora_scale=1.0, safe_fusing=False):
|
||||
if self.lora_layer is None:
|
||||
return
|
||||
|
||||
@@ -128,6 +128,14 @@ class LoRACompatibleConv(nn.Conv2d):
|
||||
fusion = torch.mm(w_up.flatten(start_dim=1), w_down.flatten(start_dim=1))
|
||||
fusion = fusion.reshape((w_orig.shape))
|
||||
fused_weight = w_orig + (lora_scale * fusion)
|
||||
|
||||
if safe_fusing and torch.isnan(fused_weight).any().item():
|
||||
raise ValueError(
|
||||
"This LoRA weight seems to be broken. "
|
||||
f"Encountered NaN values when trying to fuse LoRA weights for {self}."
|
||||
"LoRA weights will not be fused."
|
||||
)
|
||||
|
||||
self.weight.data = fused_weight.to(device=device, dtype=dtype)
|
||||
|
||||
# we can drop the lora layer now
|
||||
@@ -182,7 +190,7 @@ class LoRACompatibleLinear(nn.Linear):
|
||||
def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
|
||||
self.lora_layer = lora_layer
|
||||
|
||||
def _fuse_lora(self, lora_scale=1.0):
|
||||
def _fuse_lora(self, lora_scale=1.0, safe_fusing=False):
|
||||
if self.lora_layer is None:
|
||||
return
|
||||
|
||||
@@ -196,6 +204,14 @@ class LoRACompatibleLinear(nn.Linear):
|
||||
w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank
|
||||
|
||||
fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
|
||||
|
||||
if safe_fusing and torch.isnan(fused_weight).any().item():
|
||||
raise ValueError(
|
||||
"This LoRA weight seems to be broken. "
|
||||
f"Encountered NaN values when trying to fuse LoRA weights for {self}."
|
||||
"LoRA weights will not be fused."
|
||||
)
|
||||
|
||||
self.weight.data = fused_weight.to(device=device, dtype=dtype)
|
||||
|
||||
# we can drop the lora layer now
|
||||
|
||||
@@ -1028,6 +1028,47 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase):
|
||||
|
||||
sd_pipe.unload_lora_weights()
|
||||
|
||||
def test_lora_fuse_nan(self):
|
||||
pipeline_components, lora_components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionXLPipeline(**pipeline_components)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
_, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
# Emulate training.
|
||||
set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True)
|
||||
set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True)
|
||||
set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
StableDiffusionXLPipeline.save_lora_weights(
|
||||
save_directory=tmpdirname,
|
||||
unet_lora_layers=lora_components["unet_lora_layers"],
|
||||
text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
|
||||
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
|
||||
safe_serialization=True,
|
||||
)
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
|
||||
sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
|
||||
# corrupt one LoRA weight with `inf` values
|
||||
with torch.no_grad():
|
||||
sd_pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_layer.down.weight += float(
|
||||
"inf"
|
||||
)
|
||||
|
||||
# with `safe_fusing=True` we should see an Error
|
||||
with self.assertRaises(ValueError):
|
||||
sd_pipe.fuse_lora(safe_fusing=True)
|
||||
|
||||
# without we should not see an error, but every image will be black
|
||||
sd_pipe.fuse_lora(safe_fusing=False)
|
||||
|
||||
out = sd_pipe("test", num_inference_steps=2, output_type="np").images
|
||||
|
||||
assert np.isnan(out).all()
|
||||
|
||||
def test_lora_fusion(self):
|
||||
pipeline_components, lora_components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionXLPipeline(**pipeline_components)
|
||||
|
||||
Reference in New Issue
Block a user