mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
@@ -517,6 +517,8 @@ class VideoDataset(Dataset):
|
||||
def _preprocess_data(self):
|
||||
import decord
|
||||
|
||||
decord.bridge.set_bridge("torch")
|
||||
|
||||
videos = []
|
||||
|
||||
train_transforms = transforms.Compose(
|
||||
@@ -533,12 +535,12 @@ class VideoDataset(Dataset):
|
||||
start_frame = min(self.skip_frames_start, video_num_frames)
|
||||
end_frame = max(0, video_num_frames - self.skip_frames_end)
|
||||
if end_frame <= start_frame:
|
||||
frames_numpy = video_reader.get_batch([start_frame]).asnumpy()
|
||||
frames_numpy = video_reader.get_batch([start_frame]).numpy()
|
||||
elif end_frame - start_frame <= self.max_num_frames:
|
||||
frames_numpy = video_reader.get_batch(list(range(start_frame, end_frame))).asnumpy()
|
||||
frames_numpy = video_reader.get_batch(list(range(start_frame, end_frame))).numpy()
|
||||
else:
|
||||
indices = list(range(start_frame, end_frame, (end_frame - start_frame) // self.max_num_frames))
|
||||
frames_numpy = video_reader.get_batch(indices).asnumpy()
|
||||
frames_numpy = video_reader.get_batch(indices).numpy()
|
||||
|
||||
# Just to ensure that we don't go over the limit
|
||||
frames_numpy = frames_numpy[: self.max_num_frames]
|
||||
@@ -642,7 +644,7 @@ def log_validation(
|
||||
is_final_validation: bool = False,
|
||||
):
|
||||
logger.info(
|
||||
f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args["prompt"]}."
|
||||
f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}."
|
||||
)
|
||||
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
|
||||
scheduler_args = {}
|
||||
@@ -803,10 +805,6 @@ def main(args):
|
||||
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
|
||||
)
|
||||
|
||||
import decord
|
||||
|
||||
decord.bridge.set_bridge("torch")
|
||||
|
||||
logging_dir = Path(args.output_dir, args.logging_dir)
|
||||
|
||||
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
||||
|
||||
Reference in New Issue
Block a user