1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

make fix-copies

This commit is contained in:
Aryan
2024-09-03 06:01:11 +02:00
parent 9a95d8de56
commit efa9b0a199
2 changed files with 37 additions and 5 deletions

View File

@@ -2300,26 +2300,30 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
<Tip warning={true}>
We support loading A1111 formatted LoRA checkpoints in a limited capacity. This function is experimental and
might change in the future.
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
This function is experimental and might change in the future.
</Tip>
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
Can be either:
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
with [`ModelMixin.save_pretrained`].
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
@@ -2334,6 +2338,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
@@ -2525,8 +2530,10 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
A standard state dict containing the lora layer parameters. The key should be prefixed with an
additional `text_encoder` to distinguish between unet lora layers.
network_alphas (`Dict[str, float]`):
See `LoRALinearLayer` for more details.
text_encoder (`T5EncoderModel`):
The value of the network alpha used for stable learning and preventing underflow. This value has the
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
text_encoder (`CLIPTextModel`):
The text encoder model to load the LoRA layers into.
prefix (`str`):
Expected prefix of the `text_encoder` in the `state_dict`.
@@ -2705,7 +2712,9 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
adapter_names (`List[str]`, *optional*):
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
Example:
```py
from diffusers import DiffusionPipeline
import torch

View File

@@ -22,13 +22,17 @@ from PIL import Image
from transformers import T5EncoderModel, T5Tokenizer
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...loaders import CogVideoXLoraLoaderMixin
from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
from ...models.embeddings import get_3d_rotary_pos_embed
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
from ...utils import (
USE_PEFT_BACKEND,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
@@ -161,7 +165,7 @@ def retrieve_latents(
raise AttributeError("Could not access latents of provided encoder_output")
class CogVideoXVideoToVideoPipeline(DiffusionPipeline):
class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
r"""
Pipeline for video-to-video generation using CogVideoX.
@@ -270,6 +274,7 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline):
max_sequence_length: int = 226,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
lora_scale: Optional[float] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
@@ -296,9 +301,20 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline):
torch device
dtype: (`torch.dtype`, *optional*):
torch dtype
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
"""
device = device or self._execution_device
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, CogVideoXLoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if self.text_encoder is not None and USE_PEFT_BACKEND:
scale_lora_layers(self.text_encoder, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
@@ -338,6 +354,11 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline):
dtype=dtype,
)
if self.text_encoder is not None:
if isinstance(self, CogVideoXLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)
return prompt_embeds, negative_prompt_embeds
def prepare_latents(
@@ -572,6 +593,7 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline):
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 226,
lora_scale: Optional[float] = None,
) -> Union[CogVideoXPipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
@@ -694,6 +716,7 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline):
negative_prompt_embeds=negative_prompt_embeds,
max_sequence_length=max_sequence_length,
device=device,
lora_scale=lora_scale,
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)