From f12e669ed39a54318a473969cdcad56f37d49f6f Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 28 Aug 2024 15:51:57 +0200 Subject: [PATCH] update --- examples/cogvideo/train_cogvideox_lora.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/examples/cogvideo/train_cogvideox_lora.py b/examples/cogvideo/train_cogvideox_lora.py index 9713809182..427b6c6f76 100644 --- a/examples/cogvideo/train_cogvideox_lora.py +++ b/examples/cogvideo/train_cogvideox_lora.py @@ -517,6 +517,8 @@ class VideoDataset(Dataset): def _preprocess_data(self): import decord + decord.bridge.set_bridge("torch") + videos = [] train_transforms = transforms.Compose( @@ -533,12 +535,12 @@ class VideoDataset(Dataset): start_frame = min(self.skip_frames_start, video_num_frames) end_frame = max(0, video_num_frames - self.skip_frames_end) if end_frame <= start_frame: - frames_numpy = video_reader.get_batch([start_frame]).asnumpy() + frames_numpy = video_reader.get_batch([start_frame]).numpy() elif end_frame - start_frame <= self.max_num_frames: - frames_numpy = video_reader.get_batch(list(range(start_frame, end_frame))).asnumpy() + frames_numpy = video_reader.get_batch(list(range(start_frame, end_frame))).numpy() else: indices = list(range(start_frame, end_frame, (end_frame - start_frame) // self.max_num_frames)) - frames_numpy = video_reader.get_batch(indices).asnumpy() + frames_numpy = video_reader.get_batch(indices).numpy() # Just to ensure that we don't go over the limit frames_numpy = frames_numpy[: self.max_num_frames] @@ -642,7 +644,7 @@ def log_validation( is_final_validation: bool = False, ): logger.info( - f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args["prompt"]}." + f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}." ) # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it scheduler_args = {} @@ -803,10 +805,6 @@ def main(args): "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." ) - import decord - - decord.bridge.set_bridge("torch") - logging_dir = Path(args.output_dir, args.logging_dir) accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)