mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Text2video zero refinements (#3070)
* fix progress bar issue in pipeline_text_to_video_zero.py. Copy scheduler after first backward * fix tensor loading in test_text_to_video_zero.py * make style && make quality
This commit is contained in:
committed by
Daniel Gu
parent
115e382d3b
commit
10c54cbf8f
@@ -1,3 +1,4 @@
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
@@ -56,8 +57,8 @@ class CrossFrameAttnProcessor:
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.cross_attention_norm:
|
||||
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
@@ -285,7 +286,8 @@ class TextToVideoZeroPipeline(StableDiffusionPipeline):
|
||||
latents: latents of backward process output at time timesteps[-1]
|
||||
"""
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
with self.progress_bar(total=len(timesteps)) as progress_bar:
|
||||
num_steps = (len(timesteps) - num_warmup_steps) // self.scheduler.order
|
||||
with self.progress_bar(total=num_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
@@ -465,6 +467,7 @@ class TextToVideoZeroPipeline(StableDiffusionPipeline):
|
||||
extra_step_kwargs=extra_step_kwargs,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
)
|
||||
scheduler_copy = copy.deepcopy(self.scheduler)
|
||||
|
||||
# Perform the second backward process up to time T_0
|
||||
x_1_t0 = self.backward_loop(
|
||||
@@ -475,7 +478,7 @@ class TextToVideoZeroPipeline(StableDiffusionPipeline):
|
||||
callback=callback,
|
||||
callback_steps=callback_steps,
|
||||
extra_step_kwargs=extra_step_kwargs,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
num_warmup_steps=0,
|
||||
)
|
||||
|
||||
# Propagate first frame latents at time T_0 to remaining frames
|
||||
@@ -502,7 +505,7 @@ class TextToVideoZeroPipeline(StableDiffusionPipeline):
|
||||
b, l, d = prompt_embeds.size()
|
||||
prompt_embeds = prompt_embeds[:, None].repeat(1, video_length, 1, 1).reshape(b * video_length, l, d)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
self.scheduler = scheduler_copy
|
||||
x_1k_0 = self.backward_loop(
|
||||
timesteps=timesteps[-t1 - 1 :],
|
||||
prompt_embeds=prompt_embeds,
|
||||
@@ -511,7 +514,7 @@ class TextToVideoZeroPipeline(StableDiffusionPipeline):
|
||||
callback=callback,
|
||||
callback_steps=callback_steps,
|
||||
extra_step_kwargs=extra_step_kwargs,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
num_warmup_steps=0,
|
||||
)
|
||||
latents = x_1k_0
|
||||
|
||||
|
||||
@@ -86,6 +86,7 @@ if is_torch_available():
|
||||
load_hf_numpy,
|
||||
load_image,
|
||||
load_numpy,
|
||||
load_pt,
|
||||
nightly,
|
||||
parse_flag_from_env,
|
||||
print_tensor_test,
|
||||
|
||||
@@ -18,7 +18,7 @@ import unittest
|
||||
import torch
|
||||
|
||||
from diffusers import DDIMScheduler, TextToVideoZeroPipeline
|
||||
from diffusers.utils import require_torch_gpu, slow
|
||||
from diffusers.utils import load_pt, require_torch_gpu, slow
|
||||
|
||||
from ...test_pipelines_common import assert_mean_pixel_difference
|
||||
|
||||
@@ -35,8 +35,8 @@ class TextToVideoZeroPipelineSlowTests(unittest.TestCase):
|
||||
prompt = "A bear is playing a guitar on Times Square"
|
||||
result = pipe(prompt=prompt, generator=generator).images
|
||||
|
||||
expected_result = torch.load(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/tree/main/text-to-video/A bear is playing a guitar on Times Square.pt"
|
||||
expected_result = load_pt(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text-to-video/A bear is playing a guitar on Times Square.pt"
|
||||
)
|
||||
|
||||
assert_mean_pixel_difference(result, expected_result)
|
||||
|
||||
Reference in New Issue
Block a user