1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Fix missing **kwargs in lora_pipeline.py (#11011)

* Update lora_pipeline.py

* Apply style fixes

* fix-copies

---------

Co-authored-by: hlky <hlky@hlky.ac>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
CyberVy
2025-03-12 01:34:27 +08:00
committed by GitHub
parent 36d0553af2
commit d87ce2cefc

View File

@@ -452,7 +452,11 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)
def unfuse_lora(self, components: List[str] = ["unet", "text_encoder"], **kwargs):
@@ -473,7 +477,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
"""
super().unfuse_lora(components=components)
super().unfuse_lora(components=components, **kwargs)
class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
@@ -892,7 +896,11 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)
def unfuse_lora(self, components: List[str] = ["unet", "text_encoder", "text_encoder_2"], **kwargs):
@@ -913,7 +921,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
"""
super().unfuse_lora(components=components)
super().unfuse_lora(components=components, **kwargs)
class SD3LoraLoaderMixin(LoraBaseMixin):
@@ -1291,7 +1299,11 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.unfuse_lora with unet->transformer
@@ -1313,7 +1325,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
"""
super().unfuse_lora(components=components)
super().unfuse_lora(components=components, **kwargs)
class FluxLoraLoaderMixin(LoraBaseMixin):
@@ -1829,7 +1841,11 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
)
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
@@ -1850,7 +1866,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers:
transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
super().unfuse_lora(components=components)
super().unfuse_lora(components=components, **kwargs)
# We override this here account for `_transformer_norm_layers` and `_overwritten_params`.
def unload_lora_weights(self, reset_to_overwritten_params=False):
@@ -2549,7 +2565,11 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
@@ -2567,7 +2587,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
"""
super().unfuse_lora(components=components)
super().unfuse_lora(components=components, **kwargs)
class Mochi1LoraLoaderMixin(LoraBaseMixin):
@@ -2853,7 +2873,11 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
@@ -2872,7 +2896,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
"""
super().unfuse_lora(components=components)
super().unfuse_lora(components=components, **kwargs)
class LTXVideoLoraLoaderMixin(LoraBaseMixin):
@@ -3158,7 +3182,11 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
@@ -3177,7 +3205,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
"""
super().unfuse_lora(components=components)
super().unfuse_lora(components=components, **kwargs)
class SanaLoraLoaderMixin(LoraBaseMixin):
@@ -3463,7 +3491,11 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
@@ -3482,7 +3514,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
"""
super().unfuse_lora(components=components)
super().unfuse_lora(components=components, **kwargs)
class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
@@ -3771,7 +3803,11 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
@@ -3790,7 +3826,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
"""
super().unfuse_lora(components=components)
super().unfuse_lora(components=components, **kwargs)
class Lumina2LoraLoaderMixin(LoraBaseMixin):
@@ -4080,7 +4116,11 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
@@ -4099,7 +4139,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
"""
super().unfuse_lora(components=components)
super().unfuse_lora(components=components, **kwargs)
class WanLoraLoaderMixin(LoraBaseMixin):
@@ -4386,7 +4426,11 @@ class WanLoraLoaderMixin(LoraBaseMixin):
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
@@ -4405,7 +4449,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
"""
super().unfuse_lora(components=components)
super().unfuse_lora(components=components, **kwargs)
class CogView4LoraLoaderMixin(LoraBaseMixin):
@@ -4691,7 +4735,11 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
@@ -4710,7 +4758,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
"""
super().unfuse_lora(components=components)
super().unfuse_lora(components=components, **kwargs)
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):