diff --git a/examples/cogvideo/train_cogvideox_lora.py b/examples/cogvideo/train_cogvideox_lora.py
new file mode 100644
index 0000000000..9713809182
--- /dev/null
+++ b/examples/cogvideo/train_cogvideox_lora.py
@@ -0,0 +1,1483 @@
+# Copyright 2024 The HuggingFace Team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import itertools
+import logging
+import math
+import os
+import shutil
+from pathlib import Path
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, upload_folder
+from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
+from torch.utils.data import DataLoader, Dataset
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import T5EncoderModel, T5Tokenizer
+
+import diffusers
+from diffusers import AutoencoderKLCogVideoX, CogVideoXDPMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel
+from diffusers.models.embeddings import get_3d_rotary_pos_embed
+from diffusers.optimization import get_scheduler
+from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
+from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params
+from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft, export_to_video, is_wandb_available
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.31.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def get_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script for CogVideoX.")
+
+ # Model information
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+
+ # Dataset information
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that ๐ค Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--instance_data_root",
+ type=str,
+ default=None,
+ help=("A folder containing the training data. "),
+ )
+ parser.add_argument(
+ "--video_column",
+ type=str,
+ default="video",
+ help="The column of the dataset containing videos. Or, the name of the file in `--instance_data_root` folder containing the line-separated path to video data.",
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing the instance prompt for each video. Or, the name of the file in `--instance_data_root` folder containing the line-separated instance prompts.",
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+
+ # Validation
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.",
+ )
+ parser.add_argument(
+ "--validation_prompt_separator",
+ type=str,
+ default=":::",
+ help="String that separates multiple validation prompts",
+ )
+ parser.add_argument(
+ "--num_validation_videos",
+ type=int,
+ default=1,
+ help="Number of videos that should be generated during validation per `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=50,
+ help=(
+ "Run validation every X epochs. Validation consists of running the prompt `args.validation_prompt` multiple times: `args.num_validation_videos`."
+ ),
+ )
+ parser.add_argument(
+ "--guidance_scale",
+ type=float,
+ default=6,
+ help="The guidance scale to use while sampling validation videos.",
+ )
+ parser.add_argument(
+ "--use_dynamic_cfg",
+ action="store_true",
+ default=False,
+ help="Whether or not to use the default cosine dynamic guidance schedule when sampling validation videos.",
+ )
+
+ # Training information
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=4,
+ help=("The dimension of the LoRA update matrices."),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="cogvideox-lora",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--height",
+ type=int,
+ default=480,
+ help="All input videos are resized to this height.",
+ )
+ parser.add_argument(
+ "--width",
+ type=int,
+ default=720,
+ help="All input videos are resized to this width.",
+ )
+ parser.add_argument("--fps", type=int, default=8, help="All input videos will be used at this FPS.")
+ parser.add_argument(
+ "--max_num_frames", type=int, default=49, help="All input videos will be truncated to these many frames."
+ )
+ parser.add_argument(
+ "--skip_frames_start",
+ type=int,
+ default=0,
+ help="Number of frames to skip from the beginning of each input video. Useful if training data contains intro sequences.",
+ )
+ parser.add_argument(
+ "--skip_frames_end",
+ type=int,
+ default=0,
+ help="Number of frames to skip from the end of each input video. Useful if training data contains outro sequences.",
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip videos horizontally",
+ )
+ parser.add_argument(
+ "--train_text_encoder",
+ action="store_true",
+ help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides `--num_train_epochs`.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--text_encoder_lr",
+ type=float,
+ default=5e-6,
+ help="Text encoder learning rate to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+
+ # Optimizer
+ parser.add_argument(
+ "--optimizer",
+ type=str,
+ default="AdamW",
+ help=('The optimizer type to use. Choose between ["AdamW"]'),
+ )
+ parser.add_argument(
+ "--use_8bit_adam",
+ action="store_true",
+ help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
+ )
+ parser.add_argument(
+ "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
+ )
+ parser.add_argument(
+ "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers."
+ )
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
+ parser.add_argument(
+ "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
+ )
+ parser.add_argument(
+ "--adam_epsilon",
+ type=float,
+ default=1e-08,
+ help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
+ )
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+
+ # Other information
+ parser.add_argument("--tracker_name", type=str, default=None, help="Project tracker name")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default=None,
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+
+ return parser.parse_args()
+
+
+class VideoDataset(Dataset):
+ def __init__(
+ self,
+ instance_data_root: str,
+ dataset_name: Optional[str] = None,
+ dataset_config_name: Optional[str] = None,
+ caption_column: str = "text",
+ video_column: str = "video",
+ height: int = 480,
+ width: int = 720,
+ fps: int = 8,
+ max_num_frames: int = 49,
+ skip_frames_start: int = 0,
+ skip_frames_end: int = 0,
+ cache_dir: Optional[str] = None,
+ ) -> None:
+ super().__init__()
+
+ self.instance_data_root = Path(instance_data_root) if instance_data_root is not None else None
+ self.dataset_name = dataset_name
+ self.dataset_config_name = dataset_config_name
+ self.caption_column = caption_column
+ self.video_column = video_column
+ self.height = height
+ self.width = width
+ self.fps = fps
+ self.max_num_frames = max_num_frames
+ self.skip_frames_start = skip_frames_start
+ self.skip_frames_end = skip_frames_end
+ self.cache_dir = cache_dir
+
+ if dataset_name is not None:
+ self.instance_prompts, self.instance_video_paths = self._load_dataset_from_hub()
+ else:
+ self.instance_prompts, self.instance_video_paths = self._load_dataset_from_local_path()
+
+ self.num_instance_videos = len(self.instance_video_paths)
+ if self.num_instance_videos != len(self.instance_prompts):
+ raise ValueError(
+ f"Expected length of instance prompts and videos to be the same but found {len(self.instance_prompts)=} and {len(self.instance_video_paths)=}. Please ensure that the number of caption prompts and videos match in your dataset."
+ )
+
+ self.instance_videos = self._preprocess_data()
+
+ def __len__(self):
+ return self.num_instance_videos
+
+ def __getitem__(self, index):
+ return {
+ "instance_prompt": self.instance_prompts[index],
+ "instance_video": self.instance_videos[index],
+ }
+
+ def _load_dataset_from_hub(self):
+ try:
+ from datasets import load_dataset
+ except ImportError:
+ raise ImportError(
+ "You are trying to load your data using the datasets library. If you wish to train using custom "
+ "captions please install the datasets library: `pip install datasets`. If you wish to load a "
+ "local folder containing images only, specify --instance_data_root instead."
+ )
+
+ # Downloading and loading a dataset from the hub. See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
+ dataset = load_dataset(
+ self.dataset_name,
+ self.dataset_config_name,
+ cache_dir=self.cache_dir,
+ )
+ column_names = dataset["train"].column_names
+
+ if self.video_column is None:
+ video_column = column_names[0]
+ logger.info(f"`video_column` defaulting to {video_column}")
+ else:
+ video_column = self.video_column
+ if video_column not in column_names:
+ raise ValueError(
+ f"`--video_column` value '{video_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+
+ if self.caption_column is None:
+ caption_column = column_names[1]
+ logger.info(f"`caption_column` defaulting to {caption_column}")
+ else:
+ caption_column = self.caption_column
+ if self.caption_column not in column_names:
+ raise ValueError(
+ f"`--caption_column` value '{self.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+
+ instance_prompts = dataset["train"][caption_column]
+ instance_videos = dataset["train"][video_column]
+
+ return instance_prompts, instance_videos
+
+ def _load_dataset_from_local_path(self):
+ if not self.instance_data_root.exists():
+ raise ValueError("Instance videos root folder does not exist")
+
+ prompt_path = self.instance_data_root.joinpath(self.caption_column)
+ video_path = self.instance_data_root.joinpath(self.video_column)
+
+ if not prompt_path.exists() or not prompt_path.is_file():
+ raise ValueError(
+ "Expected `--caption_column` to be path to a file in `--instance_data_root` containing line-separated text prompts."
+ )
+ if not video_path.exists() or not video_path.is_file():
+ raise ValueError(
+ "Expected `--video_column` to be path to a file in `--instance_data_root` containing line-separated paths to video data in the same directory."
+ )
+
+ with open(prompt_path, "r", encoding="utf-8") as file:
+ instance_prompts = [line.strip() for line in file.readlines() if len(line.strip()) > 0]
+ with open(video_path, "r", encoding="utf-8") as file:
+ instance_videos = [
+ self.instance_data_root.joinpath(line.strip()) for line in file.readlines() if len(line.strip()) > 0
+ ]
+
+ if any(not path.is_file() for path in instance_videos):
+ raise ValueError(
+ "Expected '--video_column' to be a path to a file in `--instance_data_root` containing line-separated paths to video data but found atleast one path that is not a valid file."
+ )
+
+ return instance_prompts, instance_videos
+
+ def _preprocess_data(self):
+ import decord
+
+ videos = []
+
+ train_transforms = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ for filename in self.instance_video_paths:
+ video_reader = decord.VideoReader(uri=filename.as_posix(), width=self.width, height=self.height)
+ video_num_frames = len(video_reader)
+
+ start_frame = min(self.skip_frames_start, video_num_frames)
+ end_frame = max(0, video_num_frames - self.skip_frames_end)
+ if end_frame <= start_frame:
+ frames_numpy = video_reader.get_batch([start_frame]).asnumpy()
+ elif end_frame - start_frame <= self.max_num_frames:
+ frames_numpy = video_reader.get_batch(list(range(start_frame, end_frame))).asnumpy()
+ else:
+ indices = list(range(start_frame, end_frame, (end_frame - start_frame) // self.max_num_frames))
+ frames_numpy = video_reader.get_batch(indices).asnumpy()
+
+ # Just to ensure that we don't go over the limit
+ frames_numpy = frames_numpy[: self.max_num_frames]
+ selected_num_frames = frames_numpy.shape[0]
+
+ # Choose first (4k + 1) frames as this is how many is required by the VAE
+ remainder = (3 + (selected_num_frames % 4)) % 4
+ if remainder != 0:
+ frames_numpy = frames_numpy[:-remainder]
+ selected_num_frames = frames_numpy.shape[0]
+
+ assert (selected_num_frames - 1) % 4 == 0
+
+ # Training transforms
+ frames_tensor = torch.stack([train_transforms(frame) for frame in frames_numpy], dim=0)
+ videos.append(frames_tensor) # [F, C, H, W]
+
+ return videos
+
+
+def save_model_card(
+ repo_id: str,
+ videos=None,
+ base_model: str = None,
+ train_text_encoder=False,
+ validation_prompt=None,
+ repo_folder=None,
+ fps=8,
+):
+ widget_dict = []
+ if videos is not None:
+ for i, video in enumerate(videos):
+ export_to_video(video, os.path.join(repo_folder, f"video_{i}.mp4", fps=fps))
+ widget_dict.append(
+ {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"video_{i}.mp4"}}
+ )
+
+ model_description = f"""
+# CogVideoX LoRA - {repo_id}
+
+
+
+## Model description
+
+These are {repo_id} LoRA weights for {base_model}.
+
+The weights were trained using the [CogVideoX Diffusers trainer](TODO).
+
+Was LoRA for the text encoder enabled? {train_text_encoder}.
+
+## Download model
+
+[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab.
+
+## Use it with the [๐งจ diffusers library](https://github.com/huggingface/diffusers)
+
+```py
+from diffusers import CogVideoXPipeline
+import torch
+
+pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to("cuda")
+pipe.load_lora_weights("{repo_id}", weight_name="pytorch_lora_weights.safetensors")
+video = pipe("{validation_prompt}").frames[0]
+```
+
+For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
+
+## License
+
+Please adhere to the licensing terms as described [here](https://huggingface.co/THUDM/CogVideoX-5b/blob/main/LICENSE) and [here](https://huggingface.co/THUDM/CogVideoX-2b/blob/main/LICENSE).
+"""
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="other",
+ base_model=base_model,
+ prompt=validation_prompt,
+ model_description=model_description,
+ widget=widget_dict,
+ )
+ tags = [
+ "text-to-video",
+ "diffusers-training",
+ "diffusers",
+ "lora",
+ "cogvideox",
+ "cogvideox-diffusers",
+ "template:sd-lora",
+ ]
+
+ model_card = populate_model_card(model_card, tags=tags)
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def log_validation(
+ pipe,
+ args,
+ accelerator,
+ pipeline_args,
+ epoch,
+ is_final_validation: bool = False,
+):
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args["prompt"]}."
+ )
+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
+ scheduler_args = {}
+
+ if "variance_type" in pipe.scheduler.config:
+ variance_type = pipe.scheduler.config.variance_type
+
+ if variance_type in ["learned", "learned_range"]:
+ variance_type = "fixed_small"
+
+ scheduler_args["variance_type"] = variance_type
+
+ pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, **scheduler_args)
+ pipe = pipe.to(accelerator.device)
+ # pipe.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+
+ videos = []
+ with torch.cuda.amp.autocast():
+ for _ in range(args.num_validation_videos):
+ video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0]
+ videos.append(video)
+
+ for tracker in accelerator.trackers:
+ phase_name = "test" if is_final_validation else "validation"
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ phase_name: [
+ wandb.Video(video, caption=f"{i}: {args.validation_prompt}") for i, video in enumerate(videos)
+ ]
+ }
+ )
+
+ del pipe
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ return videos
+
+
+def collate_fn(examples):
+ videos = [example["instance_video"] for example in examples]
+ prompts = [example["instance_prompt"] for example in examples]
+
+ videos = torch.stack(videos)
+ videos = videos.to(memory_format=torch.contiguous_format).float()
+
+ return {
+ "videos": videos,
+ "prompts": prompts,
+ }
+
+
+def _get_t5_prompt_embeds(
+ tokenizer: T5Tokenizer,
+ text_encoder: T5EncoderModel,
+ prompt: Union[str, List[str]],
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ text_input_ids=None,
+):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if tokenizer is not None:
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ else:
+ if text_input_ids is None:
+ raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.")
+
+ prompt_embeds = text_encoder(text_input_ids.to(device))[0]
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+
+def encode_prompt(
+ tokenizer: T5Tokenizer,
+ text_encoder: T5EncoderModel,
+ prompt: Union[str, List[str]],
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ text_input_ids=None,
+):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ prompt_embeds = _get_t5_prompt_embeds(
+ tokenizer,
+ text_encoder,
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ text_input_ids=text_input_ids,
+ )
+ return prompt_embeds
+
+
+def prepare_rotary_positional_embeddings(
+ height: int,
+ width: int,
+ num_frames: int,
+ vae_scale_factor_spatial: int = 8,
+ patch_size: int = 2,
+ attention_head_dim: int = 64,
+ device: Optional[torch.device] = None,
+ base_height: int = 480,
+ base_width: int = 720,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ grid_height = height // (vae_scale_factor_spatial * patch_size)
+ grid_width = width // (vae_scale_factor_spatial * patch_size)
+ base_size_width = base_width // (vae_scale_factor_spatial * patch_size)
+ base_size_height = base_height // (vae_scale_factor_spatial * patch_size)
+
+ grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height)
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
+ embed_dim=attention_head_dim,
+ crops_coords=grid_crops_coords,
+ grid_size=(grid_height, grid_width),
+ temporal_size=num_frames,
+ )
+
+ freqs_cos = freqs_cos.to(device=device)
+ freqs_sin = freqs_sin.to(device=device)
+ return freqs_cos, freqs_sin
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+
+ import decord
+
+ decord.bridge.set_bridge("torch")
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ kwargs_handlers=[kwargs],
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name,
+ exist_ok=True,
+ ).repo_id
+
+ # Prepare models and scheduler
+ tokenizer = T5Tokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
+ )
+
+ text_encoder = T5EncoderModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+ )
+
+ transformer = CogVideoXTransformer3DModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant
+ )
+
+ vae = AutoencoderKLCogVideoX.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
+ )
+
+ scheduler = CogVideoXDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+
+ # We only train the additional adapter LoRA layers
+ transformer.requires_grad_(False)
+ vae.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+
+ # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+
+ vae.to(accelerator.device, dtype=weight_dtype)
+ transformer.to(accelerator.device, dtype=weight_dtype)
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+
+ if args.gradient_checkpointing:
+ transformer.enable_gradient_checkpointing()
+ if args.train_text_encoder:
+ text_encoder.gradient_checkpointing_enable()
+
+ # now we will add new LoRA weights to the attention layers
+ transformer_lora_config = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.rank,
+ init_lora_weights="gaussian",
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
+ )
+ transformer.add_adapter(transformer_lora_config)
+
+ if args.train_text_encoder:
+ text_lora_config = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.rank,
+ init_lora_weights="gaussian",
+ target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
+ )
+ text_encoder.add_adapter(text_lora_config)
+
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ transformer_lora_layers_to_save = None
+ text_encoder_lora_layers_to_save = None
+
+ for model in models:
+ if isinstance(model, type(unwrap_model(transformer))):
+ transformer_lora_layers_to_save = get_peft_model_state_dict(model)
+ elif isinstance(model, type(unwrap_model(text_encoder))):
+ text_encoder_lora_layers_to_save = get_peft_model_state_dict(model)
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ CogVideoXPipeline.save_lora_weights(
+ output_dir,
+ transformer_lora_layers=transformer_lora_layers_to_save,
+ text_encoder_lora_layers=text_encoder_lora_layers_to_save,
+ )
+
+ def load_model_hook(models, input_dir):
+ transformer_ = None
+ text_encoder_ = None
+
+ while len(models) > 0:
+ model = models.pop()
+
+ if isinstance(model, type(unwrap_model(transformer))):
+ transformer_ = model
+ elif isinstance(model, type(unwrap_model(text_encoder))):
+ text_encoder_ = model
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ lora_state_dict = CogVideoXPipeline.lora_state_dict(input_dir)
+
+ transformer_state_dict = {
+ f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
+ }
+ transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
+ incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
+ if incompatible_keys is not None:
+ # check only for unexpected keys
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+ if unexpected_keys:
+ logger.warning(
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
+ f" {unexpected_keys}. "
+ )
+ if args.train_text_encoder:
+ # Do we need to call `scale_lora_layers()` here?
+ _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_)
+
+ # Make sure the trainable params are in float32. This is again needed since the base models
+ # are in `weight_dtype`. More details:
+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
+ if args.mixed_precision == "fp16":
+ models = [transformer_]
+ if args.train_text_encoder:
+ models.extend([text_encoder_])
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(models)
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32 and torch.cuda.is_available():
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Make sure the trainable params are in float32.
+ if args.mixed_precision == "fp16":
+ models = [transformer]
+ if args.train_text_encoder:
+ models.extend([text_encoder])
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(models, dtype=torch.float32)
+
+ transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
+ if args.train_text_encoder:
+ text_encoder_lora_parameters = list(filter(lambda p: p.requires_grad, text_encoder.parameters()))
+
+ # Optimization parameters
+ transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate}
+ if args.train_text_encoder:
+ # different learning rate for text encoder and unet
+ text_encoder_parameters_with_lr = {
+ "params": text_encoder_lora_parameters,
+ "weight_decay": args.adam_weight_decay_text_encoder,
+ "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
+ }
+ params_to_optimize = [
+ transformer_parameters_with_lr,
+ text_encoder_parameters_with_lr,
+ ]
+ else:
+ params_to_optimize = [transformer_parameters_with_lr]
+
+ # Optimizer creation
+ if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
+ logger.warning(
+ f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
+ "Defaulting to adamW"
+ )
+ args.optimizer = "adamw"
+
+ if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
+ logger.warning(
+ f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
+ f"set to {args.optimizer.lower()}"
+ )
+
+ if args.optimizer.lower() == "adamw":
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Dataset and DataLoader
+ train_dataset = VideoDataset(
+ instance_data_root=args.instance_data_root,
+ dataset_name=args.dataset_name,
+ dataset_config_name=args.dataset_config_name,
+ caption_column=args.caption_column,
+ video_column=args.video_column,
+ height=args.height,
+ width=args.width,
+ fps=args.fps,
+ max_num_frames=args.max_num_frames,
+ skip_frames_start=args.skip_frames_start,
+ skip_frames_end=args.skip_frames_end,
+ cache_dir=args.cache_dir,
+ )
+
+ train_dataloader = DataLoader(
+ train_dataset,
+ batch_size=args.train_batch_size,
+ shuffle=True,
+ collate_fn=collate_fn,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ if not args.train_text_encoder:
+
+ def compute_text_embeddings(prompt):
+ with torch.no_grad():
+ prompt_embeds = encode_prompt(
+ tokenizer,
+ text_encoder,
+ prompt,
+ num_videos_per_prompt=1,
+ device=accelerator.device,
+ dtype=weight_dtype,
+ )
+ return prompt_embeds
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ if args.train_text_encoder:
+ (
+ transformer,
+ text_encoder,
+ optimizer,
+ train_dataloader,
+ lr_scheduler,
+ ) = accelerator.prepare(
+ transformer,
+ text_encoder,
+ optimizer,
+ train_dataloader,
+ lr_scheduler,
+ )
+ else:
+ transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ transformer, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_name = args.tracker_name or "cogvideox-lora"
+ accelerator.init_trackers(tracker_name, config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if not args.resume_from_checkpoint:
+ initial_global_step = 0
+ else:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the mos recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+ vae_scale_factor_spatial = 2 ** (len(vae.config.block_out_channels) - 1)
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ transformer.train()
+ if args.train_text_encoder:
+ text_encoder.train()
+ # set top parameter requires_grad = True for gradient checkpointing works
+ accelerator.unwrap_model(text_encoder).text_model.embeddings.requires_grad_(True)
+
+ for step, batch in enumerate(train_dataloader):
+ models_to_accumulate = [transformer]
+ if args.train_text_encoder:
+ models_to_accumulate.extend([text_encoder])
+
+ with accelerator.accumulate(models_to_accumulate):
+ videos = batch["videos"].to(dtype=vae.dtype)
+ prompts = batch["prompts"]
+
+ # encode prompts
+ if not args.train_text_encoder:
+ prompt_embeds = compute_text_embeddings(prompts)
+ else:
+ text_inputs = tokenizer(
+ prompts,
+ padding="max_length",
+ max_length=226,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ prompt_embeds = encode_prompt(
+ tokenizer=None,
+ text_encoder=text_encoder,
+ prompt=None,
+ num_videos_per_prompt=1,
+ device=accelerator.device,
+ dtype=weight_dtype,
+ text_input_ids=text_input_ids,
+ )
+
+ # Convert videos to latents
+ print("videos.shape:", videos.shape)
+ videos = videos.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
+ model_input = vae.encode(videos).latent_dist.sample() * vae.config.scaling_factor
+ model_input = model_input.permute(0, 2, 1, 3, 4).to(dtype=weight_dtype) # [B, F, C, H, W]
+ print("latents.shape:", model_input.shape)
+
+ # Sample noise that will be added to the latents
+ noise = torch.rand_like(model_input)
+ batch_size, num_frames, num_channels, height, width = model_input.shape
+
+ # Sample a random timestep for each image
+ timesteps = torch.randint(
+ 0, scheduler.config.num_train_timesteps, (batch_size,), device=model_input.device
+ )
+ timesteps = timesteps.long()
+
+ # Prepare rotary embeds
+ image_rotary_emb = (
+ prepare_rotary_positional_embeddings(
+ height=args.height,
+ width=args.width,
+ num_frames=num_frames,
+ vae_scale_factor_spatial=vae_scale_factor_spatial,
+ patch_size=transformer.config.patch_size,
+ attention_head_dim=transformer.config.attention_head_dim,
+ device=accelerator.device,
+ )
+ if transformer.config.use_rotary_positional_embeddings
+ else None
+ )
+
+ # Add noise to the model input according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_model_input = scheduler.add_noise(model_input, noise, timesteps)
+
+ # Predict the noise residual
+ model_pred = transformer(
+ hidden_states=noisy_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timesteps,
+ image_rotary_emb=image_rotary_emb,
+ return_dict=False,
+ )[0]
+
+ if scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif scheduler.config.prediction_type == "v_prediction":
+ target = scheduler.get_velocity(model_input, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {scheduler.config.prediction_type}")
+
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ accelerator.backward(loss)
+
+ if accelerator.sync_gradients:
+ params_to_clip = (
+ itertools.chain(transformer.parameters(), text_encoder.parameters())
+ if args.train_text_encoder
+ else transformer.parameters()
+ )
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ # Create pipeline
+ pipe = CogVideoXPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ transformer=unwrap_model(transformer),
+ text_encoder=unwrap_model(text_encoder),
+ vae=vae,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+
+ validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
+ for validation_prompt in validation_prompts:
+ pipeline_args = {
+ "prompt": validation_prompt,
+ "guidance_scale": args.guidance_scale,
+ "use_dynamic_cfg": args.use_dynamic_cfg,
+ }
+
+ validation_outputs = log_validation(
+ pipe=pipe,
+ args=args,
+ accelerator=accelerator,
+ pipeline_args=pipeline_args,
+ epoch=epoch,
+ )
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ transformer = unwrap_model(transformer)
+ transformer = transformer.to(torch.float32)
+ transformer_lora_layers = get_peft_model_state_dict(transformer)
+
+ if args.train_text_encoder:
+ text_encoder = unwrap_model(text_encoder)
+ text_encoder_lora_layers = get_peft_model_state_dict(text_encoder.to(torch.float32))
+ else:
+ text_encoder_lora_layers = None
+
+ CogVideoXPipeline.save_lora_weights(
+ save_directory=args.output_dir,
+ transformer_lora_layers=transformer_lora_layers,
+ text_encoder_lora_layers=text_encoder_lora_layers,
+ )
+
+ # Final inference
+ pipe = CogVideoXPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ # load attention processors
+ pipe.load_lora_weights(args.output_dir)
+
+ # run inference
+ validation_outputs = []
+ if args.validation_prompt and args.num_validation_videos > 0:
+ validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
+ for validation_prompt in validation_prompts:
+ pipeline_args = {
+ "prompt": validation_prompt,
+ "guidance_scale": args.guidance_scale,
+ "use_dynamic_cfg": args.use_dynamic_cfg,
+ }
+
+ video = log_validation(
+ pipe=pipe,
+ args=args,
+ accelerator=accelerator,
+ pipeline_args=pipeline_args,
+ epoch=epoch,
+ is_final_validation=True,
+ )
+ validation_outputs.extend(video)
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ videos=validation_outputs,
+ base_model=args.pretrained_model_name_or_path,
+ train_text_encoder=args.train_text_encoder,
+ validation_prompt=args.validation_prompt,
+ repo_folder=args.output_dir,
+ fps=args.fps,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ print("Hello, world!")
+ args = get_args()
+ main(args)
+
+ # class args:
+ # instance_data_root = "./z"
+ # dataset_name = None
+ # dataset_config_name = None
+ # caption_column = "prompts.txt"
+ # video_column = "videos.txt"
+ # height = 480
+ # width = 720
+ # fps = 8
+ # max_num_frames = 49
+ # skip_frames_start = 0
+ # skip_frames_end = 0
+ # cache_dir = None
+
+ # # Dataset and DataLoaders creation:
+ # train_dataset = VideoDataset(
+ # instance_data_root=args.instance_data_root,
+ # dataset_name=args.dataset_name,
+ # dataset_config_name=args.dataset_config_name,
+ # caption_column=args.caption_column,
+ # video_column=args.video_column,
+ # height=args.height,
+ # width=args.width,
+ # fps=args.fps,
+ # max_num_frames=args.max_num_frames,
+ # skip_frames_start=args.skip_frames_start,
+ # skip_frames_end=args.skip_frames_end,
+ # cache_dir=args.cache_dir,
+ # )
diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py
index bccd37ddc4..bf72122168 100644
--- a/src/diffusers/loaders/__init__.py
+++ b/src/diffusers/loaders/__init__.py
@@ -67,6 +67,7 @@ if is_torch_available():
"StableDiffusionXLLoraLoaderMixin",
"LoraLoaderMixin",
"FluxLoraLoaderMixin",
+ "CogVideoXLoraLoaderMixin",
]
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
_import_structure["ip_adapter"] = ["IPAdapterMixin"]
@@ -84,6 +85,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .ip_adapter import IPAdapterMixin
from .lora_pipeline import (
AmusedLoraLoaderMixin,
+ CogVideoXLoraLoaderMixin,
FluxLoraLoaderMixin,
LoraLoaderMixin,
SD3LoraLoaderMixin,
diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py
index cefe66bc8c..fe1fe6d00d 100644
--- a/src/diffusers/loaders/lora_pipeline.py
+++ b/src/diffusers/loaders/lora_pipeline.py
@@ -2257,6 +2257,468 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
)
+class CogVideoXLoraLoaderMixin(LoraBaseMixin):
+ r"""
+ Load LoRA layers into [`CogVideoXTransformer3DModel`],
+ [`T5EncoderModel`](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel). Specific to
+ [`CogVideoX`].
+ """
+
+ _lora_loadable_modules = ["transformer", "text_encoder"]
+ transformer_name = TRANSFORMER_NAME
+ text_encoder_name = TEXT_ENCODER_NAME
+
+ @classmethod
+ @validate_hf_hub_args
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
+ def lora_state_dict(
+ cls,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ **kwargs,
+ ):
+ r"""
+ Return state dict for lora weights and the network alphas.
+
+
+
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity. This function is experimental and
+ might change in the future.
+
+
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ Can be either:
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ with [`ModelMixin.save_pretrained`].
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+ """
+ # Load the main state dict first which has the LoRA layers for either of
+ # transformer and text encoder or both.
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = True
+ allow_pickle = True
+
+ user_agent = {
+ "file_type": "attn_procs_weights",
+ "framework": "pytorch",
+ }
+
+ state_dict = cls._fetch_state_dict(
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
+ weight_name=weight_name,
+ use_safetensors=use_safetensors,
+ local_files_only=local_files_only,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ allow_pickle=allow_pickle,
+ )
+
+ return state_dict
+
+ def load_lora_weights(
+ self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
+ ):
+ """
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
+ dict is loaded into `self.transformer`.
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ kwargs (`dict`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ # if a dict is passed, copy it instead of modifying it inplace
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
+
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
+ state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+
+ is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
+ if not is_correct_format:
+ raise ValueError("Invalid LoRA checkpoint.")
+
+ self.load_lora_into_transformer(
+ state_dict,
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ )
+
+ text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
+ if len(text_encoder_state_dict) > 0:
+ self.load_lora_into_text_encoder(
+ text_encoder_state_dict,
+ network_alphas=None,
+ text_encoder=self.text_encoder,
+ prefix="text_encoder",
+ lora_scale=self.lora_scale,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer
+ def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None):
+ """
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
+ encoder lora layers.
+ transformer (`SD3Transformer2DModel`):
+ The Transformer model to load the LoRA layers into.
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ """
+ from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
+
+ keys = list(state_dict.keys())
+
+ transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
+ state_dict = {
+ k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
+ }
+
+ if len(state_dict.keys()) > 0:
+ # check with first key if is not in peft format
+ first_key = next(iter(state_dict.keys()))
+ if "lora_A" not in first_key:
+ state_dict = convert_unet_state_dict_to_peft(state_dict)
+
+ if adapter_name in getattr(transformer, "peft_config", {}):
+ raise ValueError(
+ f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
+ )
+
+ rank = {}
+ for key, val in state_dict.items():
+ if "lora_B" in key:
+ rank[key] = val.shape[1]
+
+ lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict)
+ if "use_dora" in lora_config_kwargs:
+ if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
+ raise ValueError(
+ "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
+ )
+ else:
+ lora_config_kwargs.pop("use_dora")
+ lora_config = LoraConfig(**lora_config_kwargs)
+
+ # adapter_name
+ if adapter_name is None:
+ adapter_name = get_adapter_name(transformer)
+
+ # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
+ # otherwise loading LoRA weights will lead to an error
+ is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
+
+ inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
+ incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
+
+ if incompatible_keys is not None:
+ # check only for unexpected keys
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+ if unexpected_keys:
+ logger.warning(
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
+ f" {unexpected_keys}. "
+ )
+
+ # Offload back.
+ if is_model_cpu_offload:
+ _pipeline.enable_model_cpu_offload()
+ elif is_sequential_cpu_offload:
+ _pipeline.enable_sequential_cpu_offload()
+ # Unsafe code />
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
+ def load_lora_into_text_encoder(
+ cls,
+ state_dict,
+ network_alphas,
+ text_encoder,
+ prefix=None,
+ lora_scale=1.0,
+ adapter_name=None,
+ _pipeline=None,
+ ):
+ """
+ This will load the LoRA layers specified in `state_dict` into `text_encoder`
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The key should be prefixed with an
+ additional `text_encoder` to distinguish between unet lora layers.
+ network_alphas (`Dict[str, float]`):
+ See `LoRALinearLayer` for more details.
+ text_encoder (`T5EncoderModel`):
+ The text encoder model to load the LoRA layers into.
+ prefix (`str`):
+ Expected prefix of the `text_encoder` in the `state_dict`.
+ lora_scale (`float`):
+ How much to scale the output of the lora linear layer before it is added with the output of the regular
+ lora layer.
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ from peft import LoraConfig
+
+ # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
+ # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
+ # their prefixes.
+ keys = list(state_dict.keys())
+ prefix = cls.text_encoder_name if prefix is None else prefix
+
+ # Safe prefix to check with.
+ if any(cls.text_encoder_name in key for key in keys):
+ # Load the layers corresponding to text encoder and make necessary adjustments.
+ text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
+ text_encoder_lora_state_dict = {
+ k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
+ }
+
+ if len(text_encoder_lora_state_dict) > 0:
+ logger.info(f"Loading {prefix}.")
+ rank = {}
+ text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
+
+ # convert state dict
+ text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
+
+ for name, _ in text_encoder_attn_modules(text_encoder):
+ for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
+ rank_key = f"{name}.{module}.lora_B.weight"
+ if rank_key not in text_encoder_lora_state_dict:
+ continue
+ rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
+
+ for name, _ in text_encoder_mlp_modules(text_encoder):
+ for module in ("fc1", "fc2"):
+ rank_key = f"{name}.{module}.lora_B.weight"
+ if rank_key not in text_encoder_lora_state_dict:
+ continue
+ rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
+
+ if network_alphas is not None:
+ alpha_keys = [
+ k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
+ ]
+ network_alphas = {
+ k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
+ }
+
+ lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
+ if "use_dora" in lora_config_kwargs:
+ if lora_config_kwargs["use_dora"]:
+ if is_peft_version("<", "0.9.0"):
+ raise ValueError(
+ "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
+ )
+ else:
+ if is_peft_version("<", "0.9.0"):
+ lora_config_kwargs.pop("use_dora")
+ lora_config = LoraConfig(**lora_config_kwargs)
+
+ # adapter_name
+ if adapter_name is None:
+ adapter_name = get_adapter_name(text_encoder)
+
+ is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
+
+ # inject LoRA layers and load the state dict
+ # in transformers we automatically check whether the adapter name is already in use or not
+ text_encoder.load_adapter(
+ adapter_name=adapter_name,
+ adapter_state_dict=text_encoder_lora_state_dict,
+ peft_config=lora_config,
+ )
+
+ # scale LoRA layers with `lora_scale`
+ scale_lora_layers(text_encoder, weight=lora_scale)
+
+ text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
+
+ # Offload back.
+ if is_model_cpu_offload:
+ _pipeline.enable_model_cpu_offload()
+ elif is_sequential_cpu_offload:
+ _pipeline.enable_sequential_cpu_offload()
+ # Unsafe code />
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights with unet->transformer
+ def save_lora_weights(
+ cls,
+ save_directory: Union[str, os.PathLike],
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ ):
+ r"""
+ Save the LoRA parameters corresponding to the UNet and text encoder.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `transformer`.
+ text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
+ encoder LoRA state dict because it comes from ๐ค Transformers.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful during distributed training and you
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
+ process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful during distributed training when you need to
+ replace `torch.save` with another method. Can be configured with the environment variable
+ `DIFFUSERS_SAVE_MODE`.
+ safe_serialization (`bool`, *optional*, defaults to `True`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ """
+ state_dict = {}
+
+ if not (transformer_lora_layers or text_encoder_lora_layers):
+ raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.")
+
+ if transformer_lora_layers:
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+
+ if text_encoder_lora_layers:
+ state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
+
+ # Save the model
+ cls.write_lora_layers(
+ state_dict=state_dict,
+ save_directory=save_directory,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
+ def fuse_lora(
+ self,
+ components: List[str] = ["transformer", "text_encoder"],
+ lora_scale: float = 1.0,
+ safe_fusing: bool = False,
+ adapter_names: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ r"""
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
+ lora_scale (`float`, defaults to 1.0):
+ Controls how much to influence the outputs with the LoRA parameters.
+ safe_fusing (`bool`, defaults to `False`):
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
+ adapter_names (`List[str]`, *optional*):
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
+ Example:
+ ```py
+ from diffusers import DiffusionPipeline
+ import torch
+
+ pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ).to("cuda")
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
+ pipeline.fuse_lora(lora_scale=0.7)
+ ```
+ """
+ super().fuse_lora(
+ components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
+ )
+
+ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
+ r"""
+ Reverses the effect of
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
+ """
+ super().unfuse_lora(components=components)
+
+
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
def __init__(self, *args, **kwargs):
deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."
diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py
index 89d6a28b14..d1c6721512 100644
--- a/src/diffusers/loaders/peft.py
+++ b/src/diffusers/loaders/peft.py
@@ -33,6 +33,7 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
"UNetMotionModel": _maybe_expand_lora_scales,
"SD3Transformer2DModel": lambda model_cls, weights: weights,
"FluxTransformer2DModel": lambda model_cls, weights: weights,
+ "CogVideoXTransformer3DModel": lambda model_cls, weights: weights,
}
diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py
index c8d4b18963..753514a42e 100644
--- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py
+++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py
@@ -19,6 +19,7 @@ import torch
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import PeftAdapterMixin
from ...utils import is_torch_version, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention, FeedForward
@@ -152,7 +153,7 @@ class CogVideoXBlock(nn.Module):
return hidden_states, encoder_hidden_states
-class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
+class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
"""
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
index 11f491e495..e48dda93f7 100644
--- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
+++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
@@ -22,11 +22,19 @@ import torch
from transformers import T5EncoderModel, T5Tokenizer
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...loaders import CogVideoXLoraLoaderMixin
from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
from ...models.embeddings import get_3d_rotary_pos_embed
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
-from ...utils import BaseOutput, logging, replace_example_docstring
+from ...utils import (
+ USE_PEFT_BACKEND,
+ BaseOutput,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
@@ -151,7 +159,7 @@ class CogVideoXPipelineOutput(BaseOutput):
frames: torch.Tensor
-class CogVideoXPipeline(DiffusionPipeline):
+class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
r"""
Pipeline for text-to-video generation using CogVideoX.
@@ -258,6 +266,7 @@ class CogVideoXPipeline(DiffusionPipeline):
max_sequence_length: int = 226,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
+ lora_scale: Optional[float] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
@@ -284,9 +293,20 @@ class CogVideoXPipeline(DiffusionPipeline):
torch device
dtype: (`torch.dtype`, *optional*):
torch dtype
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
"""
device = device or self._execution_device
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, CogVideoXLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
@@ -326,6 +346,11 @@ class CogVideoXPipeline(DiffusionPipeline):
dtype=dtype,
)
+ if self.text_encoder is not None:
+ if isinstance(self, CogVideoXLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
return prompt_embeds, negative_prompt_embeds
def prepare_latents(
@@ -507,6 +532,7 @@ class CogVideoXPipeline(DiffusionPipeline):
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 226,
+ lora_scale: Optional[float] = None,
) -> Union[CogVideoXPipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
@@ -634,6 +660,7 @@ class CogVideoXPipeline(DiffusionPipeline):
negative_prompt_embeds=negative_prompt_embeds,
max_sequence_length=max_sequence_length,
device=device,
+ lora_scale=lora_scale,
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)