1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

add lora stuff to img2vid pipeline to fix tests

This commit is contained in:
Aryan
2024-09-16 15:13:59 +02:00
parent ec8d483e72
commit 969a9608c3

View File

@@ -15,7 +15,7 @@
import inspect
import math
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import PIL
import torch
@@ -23,13 +23,17 @@ from transformers import T5EncoderModel, T5Tokenizer
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput
from ...loaders import CogVideoXLoraLoaderMixin
from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
from ...models.embeddings import get_3d_rotary_pos_embed
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
from ...utils import (
USE_PEFT_BACKEND,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
@@ -265,6 +269,7 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline):
max_sequence_length: int = 226,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
lora_scale: Optional[float] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
@@ -291,9 +296,20 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline):
torch device
dtype: (`torch.dtype`, *optional*):
torch dtype
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
"""
device = device or self._execution_device
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, CogVideoXLoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if self.text_encoder is not None and USE_PEFT_BACKEND:
scale_lora_layers(self.text_encoder, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
@@ -333,6 +349,11 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline):
dtype=dtype,
)
if self.text_encoder is not None:
if isinstance(self, CogVideoXLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)
return prompt_embeds, negative_prompt_embeds
def prepare_latents(
@@ -547,6 +568,10 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline):
def num_timesteps(self):
return self._num_timesteps
@property
def attention_kwargs(self):
return self._attention_kwargs
@property
def interrupt(self):
return self._interrupt
@@ -573,6 +598,7 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline):
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: str = "pil",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
@@ -636,6 +662,10 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline):
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
@@ -681,6 +711,7 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline):
negative_prompt_embeds,
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._interrupt = False
# 2. Default call parameters
@@ -699,6 +730,7 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline):
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
@@ -708,6 +740,7 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline):
negative_prompt_embeds=negative_prompt_embeds,
max_sequence_length=max_sequence_length,
device=device,
lora_scale=lora_scale,
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)