From efa9b0a19915bca9a7587a02b82c57634f9dccbd Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 3 Sep 2024 06:01:11 +0200 Subject: [PATCH] make fix-copies --- src/diffusers/loaders/lora_pipeline.py | 17 ++++++++++--- .../pipeline_cogvideox_video2video.py | 25 ++++++++++++++++++- 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 5ce899a12a..46da37eed6 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2300,26 +2300,30 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin): - We support loading A1111 formatted LoRA checkpoints in a limited capacity. This function is experimental and - might change in the future. + We support loading A1111 formatted LoRA checkpoints in a limited capacity. + + This function is experimental and might change in the future. Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): Can be either: + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on the Hub. - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved with [`ModelMixin.save_pretrained`]. - A [torch state dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. + proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. @@ -2334,6 +2338,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin): allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. + """ # Load the main state dict first which has the LoRA layers for either of # transformer and text encoder or both. @@ -2525,8 +2530,10 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin): A standard state dict containing the lora layer parameters. The key should be prefixed with an additional `text_encoder` to distinguish between unet lora layers. network_alphas (`Dict[str, float]`): - See `LoRALinearLayer` for more details. - text_encoder (`T5EncoderModel`): + The value of the network alpha used for stable learning and preventing underflow. This value has the + same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this + link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). + text_encoder (`CLIPTextModel`): The text encoder model to load the LoRA layers into. prefix (`str`): Expected prefix of the `text_encoder` in the `state_dict`. @@ -2705,7 +2712,9 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin): Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. adapter_names (`List[str]`, *optional*): Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. + Example: + ```py from diffusers import DiffusionPipeline import torch diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py index 16686d1ab7..7e4310cae8 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py @@ -22,13 +22,17 @@ from PIL import Image from transformers import T5EncoderModel, T5Tokenizer from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import CogVideoXLoraLoaderMixin from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel from ...models.embeddings import get_3d_rotary_pos_embed from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler from ...utils import ( + USE_PEFT_BACKEND, logging, replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor @@ -161,7 +165,7 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -class CogVideoXVideoToVideoPipeline(DiffusionPipeline): +class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): r""" Pipeline for video-to-video generation using CogVideoX. @@ -270,6 +274,7 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline): max_sequence_length: int = 226, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. @@ -296,9 +301,20 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline): torch device dtype: (`torch.dtype`, *optional*): torch dtype + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ device = device or self._execution_device + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, CogVideoXLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + prompt = [prompt] if isinstance(prompt, str) else prompt if prompt is not None: batch_size = len(prompt) @@ -338,6 +354,11 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline): dtype=dtype, ) + if self.text_encoder is not None: + if isinstance(self, CogVideoXLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + return prompt_embeds, negative_prompt_embeds def prepare_latents( @@ -572,6 +593,7 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline): ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 226, + lora_scale: Optional[float] = None, ) -> Union[CogVideoXPipelineOutput, Tuple]: """ Function invoked when calling the pipeline for generation. @@ -694,6 +716,7 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline): negative_prompt_embeds=negative_prompt_embeds, max_sequence_length=max_sequence_length, device=device, + lora_scale=lora_scale, ) if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)