diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py index a1576be979..9726944ee0 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py @@ -15,7 +15,7 @@ import inspect import math -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import PIL import torch @@ -23,13 +23,17 @@ from transformers import T5EncoderModel, T5Tokenizer from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput +from ...loaders import CogVideoXLoraLoaderMixin from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel from ...models.embeddings import get_3d_rotary_pos_embed from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler from ...utils import ( + USE_PEFT_BACKEND, logging, replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor @@ -265,6 +269,7 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline): max_sequence_length: int = 226, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. @@ -291,9 +296,20 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline): torch device dtype: (`torch.dtype`, *optional*): torch dtype + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ device = device or self._execution_device + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, CogVideoXLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + prompt = [prompt] if isinstance(prompt, str) else prompt if prompt is not None: batch_size = len(prompt) @@ -333,6 +349,11 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline): dtype=dtype, ) + if self.text_encoder is not None: + if isinstance(self, CogVideoXLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + return prompt_embeds, negative_prompt_embeds def prepare_latents( @@ -547,6 +568,10 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline): def num_timesteps(self): return self._num_timesteps + @property + def attention_kwargs(self): + return self._attention_kwargs + @property def interrupt(self): return self._interrupt @@ -573,6 +598,7 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline): negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: str = "pil", return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, @@ -636,6 +662,10 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline): return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, @@ -681,6 +711,7 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline): negative_prompt_embeds, ) self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs self._interrupt = False # 2. Default call parameters @@ -699,6 +730,7 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline): do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt + lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt=prompt, negative_prompt=negative_prompt, @@ -708,6 +740,7 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline): negative_prompt_embeds=negative_prompt_embeds, max_sequence_length=max_sequence_length, device=device, + lora_scale=lora_scale, ) if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)