1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

fix: CogVideox train dataset _preprocess_data crop video (#9574)

* Removed int8 to float32 conversion (`* 2.0 - 1.0`) from `train_transforms` as it caused image overexposure.

Added `_resize_for_rectangle_crop` function to enable video cropping functionality. The cropping mode can be configured via `video_reshape_mode`, supporting options: ['center', 'random', 'none'].

* The number 127.5 may experience precision loss during division operations.

* wandb request pil image Type

* Resizing bug

* del jupyter

* make style

* Update examples/cogvideo/README.md

* make style

---------

Co-authored-by: --unset <--unset>
Co-authored-by: Aryan <aryan@huggingface.co>
This commit is contained in:
glide-the
2024-10-08 15:22:52 +08:00
committed by GitHub
parent 63a5c8742a
commit 66eef9a6dc
2 changed files with 78 additions and 14 deletions

View File

@@ -180,6 +180,7 @@ Note that setting the `<ID_TOKEN>` is not necessary. From some limited experimen
> [!TIP]
> You can pass `--use_8bit_adam` to reduce the memory requirements of training.
> You can pass `--video_reshape_mode` video cropping functionality, supporting options: ['center', 'random', 'none']. See [this](https://gist.github.com/glide-the/7658dbfd5f555be0a1a687a4139dba40) notebook for examples.
> [!IMPORTANT]
> The following settings have been tested at the time of adding CogVideoX LoRA training support:

View File

@@ -21,7 +21,9 @@ import shutil
from pathlib import Path
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torchvision.transforms as TT
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
@@ -29,12 +31,14 @@ from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration
from huggingface_hub import create_repo, upload_folder
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import resize
from tqdm.auto import tqdm
from transformers import AutoTokenizer, T5EncoderModel, T5Tokenizer
import diffusers
from diffusers import AutoencoderKLCogVideoX, CogVideoXDPMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel
from diffusers.image_processor import VaeImageProcessor
from diffusers.models.embeddings import get_3d_rotary_pos_embed
from diffusers.optimization import get_scheduler
from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
@@ -214,6 +218,12 @@ def get_args():
default=720,
help="All input videos are resized to this width.",
)
parser.add_argument(
"--video_reshape_mode",
type=str,
default="center",
help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']",
)
parser.add_argument("--fps", type=int, default=8, help="All input videos will be used at this FPS.")
parser.add_argument(
"--max_num_frames", type=int, default=49, help="All input videos will be truncated to these many frames."
@@ -413,6 +423,7 @@ class VideoDataset(Dataset):
video_column: str = "video",
height: int = 480,
width: int = 720,
video_reshape_mode: str = "center",
fps: int = 8,
max_num_frames: int = 49,
skip_frames_start: int = 0,
@@ -429,6 +440,7 @@ class VideoDataset(Dataset):
self.video_column = video_column
self.height = height
self.width = width
self.video_reshape_mode = video_reshape_mode
self.fps = fps
self.max_num_frames = max_num_frames
self.skip_frames_start = skip_frames_start
@@ -532,6 +544,38 @@ class VideoDataset(Dataset):
return instance_prompts, instance_videos
def _resize_for_rectangle_crop(self, arr):
image_size = self.height, self.width
reshape_mode = self.video_reshape_mode
if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
arr = resize(
arr,
size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
interpolation=InterpolationMode.BICUBIC,
)
else:
arr = resize(
arr,
size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
interpolation=InterpolationMode.BICUBIC,
)
h, w = arr.shape[2], arr.shape[3]
arr = arr.squeeze(0)
delta_h = h - image_size[0]
delta_w = w - image_size[1]
if reshape_mode == "random" or reshape_mode == "none":
top = np.random.randint(0, delta_h + 1)
left = np.random.randint(0, delta_w + 1)
elif reshape_mode == "center":
top, left = delta_h // 2, delta_w // 2
else:
raise NotImplementedError
arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
return arr
def _preprocess_data(self):
try:
import decord
@@ -542,15 +586,14 @@ class VideoDataset(Dataset):
decord.bridge.set_bridge("torch")
videos = []
train_transforms = transforms.Compose(
[
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0),
]
progress_dataset_bar = tqdm(
range(0, len(self.instance_video_paths)),
desc="Loading progress resize and crop videos",
)
videos = []
for filename in self.instance_video_paths:
video_reader = decord.VideoReader(uri=filename.as_posix(), width=self.width, height=self.height)
video_reader = decord.VideoReader(uri=filename.as_posix())
video_num_frames = len(video_reader)
start_frame = min(self.skip_frames_start, video_num_frames)
@@ -576,10 +619,16 @@ class VideoDataset(Dataset):
assert (selected_num_frames - 1) % 4 == 0
# Training transforms
frames = frames.float()
frames = torch.stack([train_transforms(frame) for frame in frames], dim=0)
videos.append(frames.permute(0, 3, 1, 2).contiguous()) # [F, C, H, W]
frames = (frames - 127.5) / 127.5
frames = frames.permute(0, 3, 1, 2) # [F, C, H, W]
progress_dataset_bar.set_description(
f"Loading progress Resizing video from {frames.shape[2]}x{frames.shape[3]} to {self.height}x{self.width}"
)
frames = self._resize_for_rectangle_crop(frames)
videos.append(frames.contiguous()) # [F, C, H, W]
progress_dataset_bar.update(1)
progress_dataset_bar.close()
return videos
@@ -694,8 +743,13 @@ def log_validation(
videos = []
for _ in range(args.num_validation_videos):
video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0]
videos.append(video)
pt_images = pipe(**pipeline_args, generator=generator, output_type="pt").frames[0]
pt_images = torch.stack([pt_images[i] for i in range(pt_images.shape[0])])
image_np = VaeImageProcessor.pt_to_numpy(pt_images)
image_pil = VaeImageProcessor.numpy_to_pil(image_np)
videos.append(image_pil)
for tracker in accelerator.trackers:
phase_name = "test" if is_final_validation else "validation"
@@ -1171,6 +1225,7 @@ def main(args):
video_column=args.video_column,
height=args.height,
width=args.width,
video_reshape_mode=args.video_reshape_mode,
fps=args.fps,
max_num_frames=args.max_num_frames,
skip_frames_start=args.skip_frames_start,
@@ -1179,13 +1234,21 @@ def main(args):
id_token=args.id_token,
)
def encode_video(video):
def encode_video(video, bar):
bar.update(1)
video = video.to(accelerator.device, dtype=vae.dtype).unsqueeze(0)
video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
latent_dist = vae.encode(video).latent_dist
return latent_dist
train_dataset.instance_videos = [encode_video(video) for video in train_dataset.instance_videos]
progress_encode_bar = tqdm(
range(0, len(train_dataset.instance_videos)),
desc="Loading Encode videos",
)
train_dataset.instance_videos = [
encode_video(video, progress_encode_bar) for video in train_dataset.instance_videos
]
progress_encode_bar.close()
def collate_fn(examples):
videos = [example["instance_video"].sample() * vae.config.scaling_factor for example in examples]