mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[training] CogVideoX-I2V LoRA (#9482)
* update * update * update * update * update * add coauthor Co-Authored-By: yuan-shenghai <963658029@qq.com> * add coauthor Co-Authored-By: Shenghai Yuan <140951558+SHYuanBest@users.noreply.github.com> * update Co-Authored-By: yuan-shenghai <963658029@qq.com> * update --------- Co-authored-by: yuan-shenghai <963658029@qq.com> Co-authored-by: Shenghai Yuan <140951558+SHYuanBest@users.noreply.github.com>
This commit is contained in:
@@ -10,6 +10,11 @@ In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-de
|
||||
|
||||
At the moment, LoRA finetuning has only been tested for [CogVideoX-2b](https://huggingface.co/THUDM/CogVideoX-2b).
|
||||
|
||||
> [!NOTE]
|
||||
> The scripts for CogVideoX come with limited support and may not be fully compatible with different training techniques. They are not feature-rich either and simply serve as minimal examples of finetuning to take inspiration from and improve.
|
||||
>
|
||||
> A repository containing memory-optimized finetuning scripts with support for multiple resolutions, dataset preparation, captioning, etc. is available [here](https://github.com/a-r-r-o-w/cogvideox-factory), which will be maintained jointly by the CogVideoX and Diffusers team.
|
||||
|
||||
## Data Preparation
|
||||
|
||||
The training scripts accepts data in two formats.
|
||||
@@ -132,6 +137,8 @@ Assuming you are training on 50 videos of a similar concept, we have found 1500-
|
||||
- 1500 steps on 50 videos would correspond to `30` training epochs
|
||||
- 4000 steps on 100 videos would correspond to `40` training epochs
|
||||
|
||||
The following bash script launches training for text-to-video lora.
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
|
||||
@@ -172,6 +179,8 @@ accelerate launch --gpu_ids $GPU_IDS examples/cogvideo/train_cogvideox_lora.py \
|
||||
--report_to wandb
|
||||
```
|
||||
|
||||
For launching image-to-video finetuning instead, run the `train_cogvideox_image_to_video_lora.py` file instead. Additionally, you will have to pass `--validation_images` as paths to initial images corresponding to `--validation_prompts` for I2V validation to work.
|
||||
|
||||
To better track our training experiments, we're using the following flags in the command above:
|
||||
* `--report_to wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
|
||||
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
|
||||
@@ -197,8 +206,6 @@ Note that setting the `<ID_TOKEN>` is not necessary. From some limited experimen
|
||||
>
|
||||
> Note that our testing is not exhaustive due to limited time for exploration. Our recommendation would be to play around with the different knobs and dials to find the best settings for your data.
|
||||
|
||||
<!-- TODO: Test finetuning with CogVideoX-5b and CogVideoX-5b-I2V and update scripts accordingly -->
|
||||
|
||||
## Inference
|
||||
|
||||
Once you have trained a lora model, the inference can be done simply loading the lora weights into the `CogVideoXPipeline`.
|
||||
@@ -227,3 +234,5 @@ prompt = (
|
||||
frames = pipe(prompt, guidance_scale=6, use_dynamic_cfg=True).frames[0]
|
||||
export_to_video(frames, "output.mp4", fps=8)
|
||||
```
|
||||
|
||||
If you've trained a LoRA for `CogVideoXImageToVideoPipeline` instead, everything in the above example remains the same except you must also pass an image as initial condition for generation.
|
||||
|
||||
1621
examples/cogvideo/train_cogvideox_image_to_video_lora.py
Normal file
1621
examples/cogvideo/train_cogvideox_image_to_video_lora.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -25,7 +25,7 @@ import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms as TT
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate import Accelerator, DistributedType
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
@@ -922,7 +922,7 @@ def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False):
|
||||
)
|
||||
args.optimizer = "adamw"
|
||||
|
||||
if args.use_8bit_adam and not (args.optimizer.lower() not in ["adam", "adamw"]):
|
||||
if args.use_8bit_adam and args.optimizer.lower() not in ["adam", "adamw"]:
|
||||
logger.warning(
|
||||
f"use_8bit_adam is ignored when optimizer is not set to 'Adam' or 'AdamW'. Optimizer was "
|
||||
f"set to {args.optimizer.lower()}"
|
||||
@@ -1211,7 +1211,7 @@ def main(args):
|
||||
)
|
||||
use_deepspeed_scheduler = (
|
||||
accelerator.state.deepspeed_plugin is not None
|
||||
and "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
|
||||
and "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config
|
||||
)
|
||||
|
||||
optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer)
|
||||
@@ -1255,6 +1255,7 @@ def main(args):
|
||||
prompts = [example["instance_prompt"] for example in examples]
|
||||
|
||||
videos = torch.cat(videos)
|
||||
videos = videos.permute(0, 2, 1, 3, 4)
|
||||
videos = videos.to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
return {
|
||||
@@ -1376,7 +1377,7 @@ def main(args):
|
||||
models_to_accumulate = [transformer]
|
||||
|
||||
with accelerator.accumulate(models_to_accumulate):
|
||||
model_input = batch["videos"].permute(0, 2, 1, 3, 4).to(dtype=weight_dtype) # [B, F, C, H, W]
|
||||
model_input = batch["videos"].to(dtype=weight_dtype) # [B, F, C, H, W]
|
||||
prompts = batch["prompts"]
|
||||
|
||||
# encode prompts
|
||||
@@ -1455,7 +1456,7 @@ def main(args):
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
if accelerator.is_main_process:
|
||||
if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
||||
if args.checkpoints_total_limit is not None:
|
||||
@@ -1495,7 +1496,6 @@ def main(args):
|
||||
args.pretrained_model_name_or_path,
|
||||
transformer=unwrap_model(transformer),
|
||||
text_encoder=unwrap_model(text_encoder),
|
||||
vae=unwrap_model(vae),
|
||||
scheduler=scheduler,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
@@ -1539,6 +1539,10 @@ def main(args):
|
||||
transformer_lora_layers=transformer_lora_layers,
|
||||
)
|
||||
|
||||
# Cleanup trained models to save memory
|
||||
del transformer
|
||||
free_memory()
|
||||
|
||||
# Final test inference
|
||||
pipe = CogVideoXPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
|
||||
import inspect
|
||||
import math
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import PIL
|
||||
import torch
|
||||
@@ -23,6 +23,7 @@ from transformers import T5EncoderModel, T5Tokenizer
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import PipelineImageInput
|
||||
from ...loaders import CogVideoXLoraLoaderMixin
|
||||
from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
|
||||
from ...models.embeddings import get_3d_rotary_pos_embed
|
||||
from ...pipelines.pipeline_utils import DiffusionPipeline
|
||||
@@ -152,7 +153,7 @@ def retrieve_latents(
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
class CogVideoXImageToVideoPipeline(DiffusionPipeline):
|
||||
class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for image-to-video generation using CogVideoX.
|
||||
|
||||
@@ -546,6 +547,10 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline):
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
@@ -572,6 +577,7 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline):
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: str = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
@@ -635,6 +641,10 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline):
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
||||
of a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
@@ -679,6 +689,7 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline):
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
)
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Default call parameters
|
||||
@@ -768,6 +779,7 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline):
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_pred.float()
|
||||
|
||||
Reference in New Issue
Block a user