diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py index 35e3ae6a6d..cf5e6e399a 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py @@ -1,3 +1,4 @@ +import copy from dataclasses import dataclass from typing import Callable, List, Optional, Union @@ -56,8 +57,8 @@ class CrossFrameAttnProcessor: is_cross_attention = encoder_hidden_states is not None if encoder_hidden_states is None: encoder_hidden_states = hidden_states - elif attn.cross_attention_norm: - encoder_hidden_states = attn.norm_cross(encoder_hidden_states) + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) @@ -285,7 +286,8 @@ class TextToVideoZeroPipeline(StableDiffusionPipeline): latents: latents of backward process output at time timesteps[-1] """ do_classifier_free_guidance = guidance_scale > 1.0 - with self.progress_bar(total=len(timesteps)) as progress_bar: + num_steps = (len(timesteps) - num_warmup_steps) // self.scheduler.order + with self.progress_bar(total=num_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents @@ -465,6 +467,7 @@ class TextToVideoZeroPipeline(StableDiffusionPipeline): extra_step_kwargs=extra_step_kwargs, num_warmup_steps=num_warmup_steps, ) + scheduler_copy = copy.deepcopy(self.scheduler) # Perform the second backward process up to time T_0 x_1_t0 = self.backward_loop( @@ -475,7 +478,7 @@ class TextToVideoZeroPipeline(StableDiffusionPipeline): callback=callback, callback_steps=callback_steps, extra_step_kwargs=extra_step_kwargs, - num_warmup_steps=num_warmup_steps, + num_warmup_steps=0, ) # Propagate first frame latents at time T_0 to remaining frames @@ -502,7 +505,7 @@ class TextToVideoZeroPipeline(StableDiffusionPipeline): b, l, d = prompt_embeds.size() prompt_embeds = prompt_embeds[:, None].repeat(1, video_length, 1, 1).reshape(b * video_length, l, d) - self.scheduler.set_timesteps(num_inference_steps, device=device) + self.scheduler = scheduler_copy x_1k_0 = self.backward_loop( timesteps=timesteps[-t1 - 1 :], prompt_embeds=prompt_embeds, @@ -511,7 +514,7 @@ class TextToVideoZeroPipeline(StableDiffusionPipeline): callback=callback, callback_steps=callback_steps, extra_step_kwargs=extra_step_kwargs, - num_warmup_steps=num_warmup_steps, + num_warmup_steps=0, ) latents = x_1k_0 diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index bb159d9db3..c717d722f8 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -86,6 +86,7 @@ if is_torch_available(): load_hf_numpy, load_image, load_numpy, + load_pt, nightly, parse_flag_from_env, print_tensor_test, diff --git a/tests/pipelines/text_to_video/test_text_to_video_zero.py b/tests/pipelines/text_to_video/test_text_to_video_zero.py index e6a726bf13..45bb93fbd9 100644 --- a/tests/pipelines/text_to_video/test_text_to_video_zero.py +++ b/tests/pipelines/text_to_video/test_text_to_video_zero.py @@ -18,7 +18,7 @@ import unittest import torch from diffusers import DDIMScheduler, TextToVideoZeroPipeline -from diffusers.utils import require_torch_gpu, slow +from diffusers.utils import load_pt, require_torch_gpu, slow from ...test_pipelines_common import assert_mean_pixel_difference @@ -35,8 +35,8 @@ class TextToVideoZeroPipelineSlowTests(unittest.TestCase): prompt = "A bear is playing a guitar on Times Square" result = pipe(prompt=prompt, generator=generator).images - expected_result = torch.load( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/tree/main/text-to-video/A bear is playing a guitar on Times Square.pt" + expected_result = load_pt( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text-to-video/A bear is playing a guitar on Times Square.pt" ) assert_mean_pixel_difference(result, expected_result)