mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
fix: CogVideox train dataset _preprocess_data crop video (#9574)
* Removed int8 to float32 conversion (`* 2.0 - 1.0`) from `train_transforms` as it caused image overexposure. Added `_resize_for_rectangle_crop` function to enable video cropping functionality. The cropping mode can be configured via `video_reshape_mode`, supporting options: ['center', 'random', 'none']. * The number 127.5 may experience precision loss during division operations. * wandb request pil image Type * Resizing bug * del jupyter * make style * Update examples/cogvideo/README.md * make style --------- Co-authored-by: --unset <--unset> Co-authored-by: Aryan <aryan@huggingface.co>
This commit is contained in:
@@ -180,6 +180,7 @@ Note that setting the `<ID_TOKEN>` is not necessary. From some limited experimen
|
||||
|
||||
> [!TIP]
|
||||
> You can pass `--use_8bit_adam` to reduce the memory requirements of training.
|
||||
> You can pass `--video_reshape_mode` video cropping functionality, supporting options: ['center', 'random', 'none']. See [this](https://gist.github.com/glide-the/7658dbfd5f555be0a1a687a4139dba40) notebook for examples.
|
||||
|
||||
> [!IMPORTANT]
|
||||
> The following settings have been tested at the time of adding CogVideoX LoRA training support:
|
||||
|
||||
@@ -21,7 +21,9 @@ import shutil
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms as TT
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
@@ -29,12 +31,14 @@ from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms import InterpolationMode
|
||||
from torchvision.transforms.functional import resize
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import AutoTokenizer, T5EncoderModel, T5Tokenizer
|
||||
|
||||
import diffusers
|
||||
from diffusers import AutoencoderKLCogVideoX, CogVideoXDPMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
|
||||
@@ -214,6 +218,12 @@ def get_args():
|
||||
default=720,
|
||||
help="All input videos are resized to this width.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--video_reshape_mode",
|
||||
type=str,
|
||||
default="center",
|
||||
help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']",
|
||||
)
|
||||
parser.add_argument("--fps", type=int, default=8, help="All input videos will be used at this FPS.")
|
||||
parser.add_argument(
|
||||
"--max_num_frames", type=int, default=49, help="All input videos will be truncated to these many frames."
|
||||
@@ -413,6 +423,7 @@ class VideoDataset(Dataset):
|
||||
video_column: str = "video",
|
||||
height: int = 480,
|
||||
width: int = 720,
|
||||
video_reshape_mode: str = "center",
|
||||
fps: int = 8,
|
||||
max_num_frames: int = 49,
|
||||
skip_frames_start: int = 0,
|
||||
@@ -429,6 +440,7 @@ class VideoDataset(Dataset):
|
||||
self.video_column = video_column
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.video_reshape_mode = video_reshape_mode
|
||||
self.fps = fps
|
||||
self.max_num_frames = max_num_frames
|
||||
self.skip_frames_start = skip_frames_start
|
||||
@@ -532,6 +544,38 @@ class VideoDataset(Dataset):
|
||||
|
||||
return instance_prompts, instance_videos
|
||||
|
||||
def _resize_for_rectangle_crop(self, arr):
|
||||
image_size = self.height, self.width
|
||||
reshape_mode = self.video_reshape_mode
|
||||
if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
|
||||
arr = resize(
|
||||
arr,
|
||||
size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
|
||||
interpolation=InterpolationMode.BICUBIC,
|
||||
)
|
||||
else:
|
||||
arr = resize(
|
||||
arr,
|
||||
size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
|
||||
interpolation=InterpolationMode.BICUBIC,
|
||||
)
|
||||
|
||||
h, w = arr.shape[2], arr.shape[3]
|
||||
arr = arr.squeeze(0)
|
||||
|
||||
delta_h = h - image_size[0]
|
||||
delta_w = w - image_size[1]
|
||||
|
||||
if reshape_mode == "random" or reshape_mode == "none":
|
||||
top = np.random.randint(0, delta_h + 1)
|
||||
left = np.random.randint(0, delta_w + 1)
|
||||
elif reshape_mode == "center":
|
||||
top, left = delta_h // 2, delta_w // 2
|
||||
else:
|
||||
raise NotImplementedError
|
||||
arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
|
||||
return arr
|
||||
|
||||
def _preprocess_data(self):
|
||||
try:
|
||||
import decord
|
||||
@@ -542,15 +586,14 @@ class VideoDataset(Dataset):
|
||||
|
||||
decord.bridge.set_bridge("torch")
|
||||
|
||||
videos = []
|
||||
train_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0),
|
||||
]
|
||||
progress_dataset_bar = tqdm(
|
||||
range(0, len(self.instance_video_paths)),
|
||||
desc="Loading progress resize and crop videos",
|
||||
)
|
||||
videos = []
|
||||
|
||||
for filename in self.instance_video_paths:
|
||||
video_reader = decord.VideoReader(uri=filename.as_posix(), width=self.width, height=self.height)
|
||||
video_reader = decord.VideoReader(uri=filename.as_posix())
|
||||
video_num_frames = len(video_reader)
|
||||
|
||||
start_frame = min(self.skip_frames_start, video_num_frames)
|
||||
@@ -576,10 +619,16 @@ class VideoDataset(Dataset):
|
||||
assert (selected_num_frames - 1) % 4 == 0
|
||||
|
||||
# Training transforms
|
||||
frames = frames.float()
|
||||
frames = torch.stack([train_transforms(frame) for frame in frames], dim=0)
|
||||
videos.append(frames.permute(0, 3, 1, 2).contiguous()) # [F, C, H, W]
|
||||
frames = (frames - 127.5) / 127.5
|
||||
frames = frames.permute(0, 3, 1, 2) # [F, C, H, W]
|
||||
progress_dataset_bar.set_description(
|
||||
f"Loading progress Resizing video from {frames.shape[2]}x{frames.shape[3]} to {self.height}x{self.width}"
|
||||
)
|
||||
frames = self._resize_for_rectangle_crop(frames)
|
||||
videos.append(frames.contiguous()) # [F, C, H, W]
|
||||
progress_dataset_bar.update(1)
|
||||
|
||||
progress_dataset_bar.close()
|
||||
return videos
|
||||
|
||||
|
||||
@@ -694,8 +743,13 @@ def log_validation(
|
||||
|
||||
videos = []
|
||||
for _ in range(args.num_validation_videos):
|
||||
video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0]
|
||||
videos.append(video)
|
||||
pt_images = pipe(**pipeline_args, generator=generator, output_type="pt").frames[0]
|
||||
pt_images = torch.stack([pt_images[i] for i in range(pt_images.shape[0])])
|
||||
|
||||
image_np = VaeImageProcessor.pt_to_numpy(pt_images)
|
||||
image_pil = VaeImageProcessor.numpy_to_pil(image_np)
|
||||
|
||||
videos.append(image_pil)
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
phase_name = "test" if is_final_validation else "validation"
|
||||
@@ -1171,6 +1225,7 @@ def main(args):
|
||||
video_column=args.video_column,
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
video_reshape_mode=args.video_reshape_mode,
|
||||
fps=args.fps,
|
||||
max_num_frames=args.max_num_frames,
|
||||
skip_frames_start=args.skip_frames_start,
|
||||
@@ -1179,13 +1234,21 @@ def main(args):
|
||||
id_token=args.id_token,
|
||||
)
|
||||
|
||||
def encode_video(video):
|
||||
def encode_video(video, bar):
|
||||
bar.update(1)
|
||||
video = video.to(accelerator.device, dtype=vae.dtype).unsqueeze(0)
|
||||
video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
|
||||
latent_dist = vae.encode(video).latent_dist
|
||||
return latent_dist
|
||||
|
||||
train_dataset.instance_videos = [encode_video(video) for video in train_dataset.instance_videos]
|
||||
progress_encode_bar = tqdm(
|
||||
range(0, len(train_dataset.instance_videos)),
|
||||
desc="Loading Encode videos",
|
||||
)
|
||||
train_dataset.instance_videos = [
|
||||
encode_video(video, progress_encode_bar) for video in train_dataset.instance_videos
|
||||
]
|
||||
progress_encode_bar.close()
|
||||
|
||||
def collate_fn(examples):
|
||||
videos = [example["instance_video"].sample() * vae.config.scaling_factor for example in examples]
|
||||
|
||||
Reference in New Issue
Block a user