mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix missing **kwargs in lora_pipeline.py (#11011)
* Update lora_pipeline.py * Apply style fixes * fix-copies --------- Co-authored-by: hlky <hlky@hlky.ac> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
@@ -452,7 +452,11 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
```
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def unfuse_lora(self, components: List[str] = ["unet", "text_encoder"], **kwargs):
|
||||
@@ -473,7 +477,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
||||
LoRA parameters then it won't have any effect.
|
||||
"""
|
||||
super().unfuse_lora(components=components)
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
@@ -892,7 +896,11 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
```
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def unfuse_lora(self, components: List[str] = ["unet", "text_encoder", "text_encoder_2"], **kwargs):
|
||||
@@ -913,7 +921,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
||||
LoRA parameters then it won't have any effect.
|
||||
"""
|
||||
super().unfuse_lora(components=components)
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
@@ -1291,7 +1299,11 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
```
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.unfuse_lora with unet->transformer
|
||||
@@ -1313,7 +1325,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
||||
LoRA parameters then it won't have any effect.
|
||||
"""
|
||||
super().unfuse_lora(components=components)
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
@@ -1829,7 +1841,11 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
)
|
||||
|
||||
super().fuse_lora(
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
|
||||
@@ -1850,7 +1866,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers:
|
||||
transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
|
||||
|
||||
super().unfuse_lora(components=components)
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
# We override this here account for `_transformer_norm_layers` and `_overwritten_params`.
|
||||
def unload_lora_weights(self, reset_to_overwritten_params=False):
|
||||
@@ -2549,7 +2565,11 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
||||
```
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||
@@ -2567,7 +2587,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
"""
|
||||
super().unfuse_lora(components=components)
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
||||
@@ -2853,7 +2873,11 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
||||
```
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
||||
@@ -2872,7 +2896,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
"""
|
||||
super().unfuse_lora(components=components)
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
@@ -3158,7 +3182,11 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
```
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
||||
@@ -3177,7 +3205,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
"""
|
||||
super().unfuse_lora(components=components)
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class SanaLoraLoaderMixin(LoraBaseMixin):
|
||||
@@ -3463,7 +3491,11 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
||||
```
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
||||
@@ -3482,7 +3514,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
"""
|
||||
super().unfuse_lora(components=components)
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
@@ -3771,7 +3803,11 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
```
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
||||
@@ -3790,7 +3826,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
"""
|
||||
super().unfuse_lora(components=components)
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
||||
@@ -4080,7 +4116,11 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
||||
```
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
|
||||
@@ -4099,7 +4139,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
"""
|
||||
super().unfuse_lora(components=components)
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
@@ -4386,7 +4426,11 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
```
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
||||
@@ -4405,7 +4449,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
"""
|
||||
super().unfuse_lora(components=components)
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class CogView4LoraLoaderMixin(LoraBaseMixin):
|
||||
@@ -4691,7 +4735,11 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
|
||||
```
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
||||
@@ -4710,7 +4758,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
"""
|
||||
super().unfuse_lora(components=components)
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
||||
|
||||
Reference in New Issue
Block a user