mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Copy lora functions to XLPipelines (#4512)
* add load_lora_weights and save_lora_weights to StableDiffusionXLImg2ImgPipeline * add load_lora_weights and save_lora_weights to StableDiffusionXLInpaintPipeline * apply black format * apply black format * add copy statement * fix statements * fix statements * fix statements * run `make fix-copies`
This commit is contained in:
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -1012,3 +1013,76 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
|
||||
return (image,)
|
||||
|
||||
return StableDiffusionXLPipelineOutput(images=image)
|
||||
|
||||
# Overrride to properly handle the loading and unloading of the additional text encoder.
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.load_lora_weights
|
||||
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
|
||||
# We could have accessed the unet config from `lora_state_dict()` too. We pass
|
||||
# it here explicitly to be able to tell that it's coming from an SDXL
|
||||
# pipeline.
|
||||
state_dict, network_alphas = self.lora_state_dict(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
unet_config=self.unet.config,
|
||||
**kwargs,
|
||||
)
|
||||
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
|
||||
|
||||
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
||||
if len(text_encoder_state_dict) > 0:
|
||||
self.load_lora_into_text_encoder(
|
||||
text_encoder_state_dict,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder,
|
||||
prefix="text_encoder",
|
||||
lora_scale=self.lora_scale,
|
||||
)
|
||||
|
||||
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
|
||||
if len(text_encoder_2_state_dict) > 0:
|
||||
self.load_lora_into_text_encoder(
|
||||
text_encoder_2_state_dict,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder_2,
|
||||
prefix="text_encoder_2",
|
||||
lora_scale=self.lora_scale,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights
|
||||
def save_lora_weights(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = False,
|
||||
):
|
||||
state_dict = {}
|
||||
|
||||
def pack_weights(layers, prefix):
|
||||
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
|
||||
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
|
||||
return layers_state_dict
|
||||
|
||||
state_dict.update(pack_weights(unet_lora_layers, "unet"))
|
||||
|
||||
if text_encoder_lora_layers and text_encoder_2_lora_layers:
|
||||
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
|
||||
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
|
||||
|
||||
self.write_lora_layers(
|
||||
state_dict=state_dict,
|
||||
save_directory=save_directory,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._remove_text_encoder_monkey_patch
|
||||
def _remove_text_encoder_monkey_patch(self):
|
||||
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
|
||||
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -1292,3 +1293,76 @@ class StableDiffusionXLInpaintPipeline(DiffusionPipeline, LoraLoaderMixin, FromS
|
||||
return (image,)
|
||||
|
||||
return StableDiffusionXLPipelineOutput(images=image)
|
||||
|
||||
# Overrride to properly handle the loading and unloading of the additional text encoder.
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.load_lora_weights
|
||||
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
|
||||
# We could have accessed the unet config from `lora_state_dict()` too. We pass
|
||||
# it here explicitly to be able to tell that it's coming from an SDXL
|
||||
# pipeline.
|
||||
state_dict, network_alphas = self.lora_state_dict(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
unet_config=self.unet.config,
|
||||
**kwargs,
|
||||
)
|
||||
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
|
||||
|
||||
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
||||
if len(text_encoder_state_dict) > 0:
|
||||
self.load_lora_into_text_encoder(
|
||||
text_encoder_state_dict,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder,
|
||||
prefix="text_encoder",
|
||||
lora_scale=self.lora_scale,
|
||||
)
|
||||
|
||||
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
|
||||
if len(text_encoder_2_state_dict) > 0:
|
||||
self.load_lora_into_text_encoder(
|
||||
text_encoder_2_state_dict,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder_2,
|
||||
prefix="text_encoder_2",
|
||||
lora_scale=self.lora_scale,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights
|
||||
def save_lora_weights(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = False,
|
||||
):
|
||||
state_dict = {}
|
||||
|
||||
def pack_weights(layers, prefix):
|
||||
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
|
||||
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
|
||||
return layers_state_dict
|
||||
|
||||
state_dict.update(pack_weights(unet_lora_layers, "unet"))
|
||||
|
||||
if text_encoder_lora_layers and text_encoder_2_lora_layers:
|
||||
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
|
||||
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
|
||||
|
||||
self.write_lora_layers(
|
||||
state_dict=state_dict,
|
||||
save_directory=save_directory,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._remove_text_encoder_monkey_patch
|
||||
def _remove_text_encoder_monkey_patch(self):
|
||||
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
|
||||
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
|
||||
|
||||
Reference in New Issue
Block a user