1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Text2video zero refinements (#3733)

* fix docs typos. add frame_ids argument to text2video-zero pipeline call

* make style && make quality

* add support of pytorch 2.0 scaled_dot_product_attention for CrossFrameAttnProcessor

* add chunk-by-chunk processing to text2video-zero docs

* make style && make quality

* Update docs/source/en/api/pipelines/text_to_video_zero.mdx

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
Andranik Movsisyan
2023-06-12 20:03:18 +04:00
committed by GitHub
parent f46b22ba13
commit a812fb6f5c
2 changed files with 130 additions and 9 deletions

View File

@@ -80,6 +80,41 @@ You can change these parameters in the pipeline call:
* Video length:
* `video_length`, the number of frames video_length to be generated. Default: `video_length=8`
We an also generate longer videos by doing the processing in a chunk-by-chunk manner:
```python
import torch
import imageio
from diffusers import TextToVideoZeroPipeline
import numpy as np
model_id = "runwayml/stable-diffusion-v1-5"
pipe = TextToVideoZeroPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
seed = 0
video_length = 8
chunk_size = 4
prompt = "A panda is playing guitar on times square"
# Generate the video chunk-by-chunk
result = []
chunk_ids = np.arange(0, video_length, chunk_size - 1)
generator = torch.Generator(device="cuda")
for i in range(len(chunk_ids)):
print(f"Processing chunk {i + 1} / {len(chunk_ids)}")
ch_start = chunk_ids[i]
ch_end = video_length if i == len(chunk_ids) - 1 else chunk_ids[i + 1]
# Attach the first frame for Cross Frame Attention
frame_ids = [0] + list(range(ch_start, ch_end))
# Fix the seed for the temporal consistency
generator.manual_seed(seed)
output = pipe(prompt=prompt, video_length=len(frame_ids), generator=generator, frame_ids=frame_ids)
result.append(output.images[1:])
# Concatenate chunks and save
result = np.concatenate(result)
result = [(r * 255).astype("uint8") for r in result]
imageio.mimsave("video.mp4", result, fps=4)
```
### Text-To-Video with Pose Control
To generate a video from prompt with additional pose control
@@ -202,7 +237,7 @@ can run with custom [DreamBooth](../training/dreambooth) models, as shown below
reader = imageio.get_reader(video_path, "ffmpeg")
frame_count = 8
video = [Image.fromarray(reader.get_data(i)) for i in range(frame_count)]
canny_edges = [Image.fromarray(reader.get_data(i)) for i in range(frame_count)]
```
3. Run `StableDiffusionControlNetPipeline` with custom trained DreamBooth model
@@ -223,10 +258,10 @@ can run with custom [DreamBooth](../training/dreambooth) models, as shown below
pipe.controlnet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2))
# fix latents for all frames
latents = torch.randn((1, 4, 64, 64), device="cuda", dtype=torch.float16).repeat(len(pose_images), 1, 1, 1)
latents = torch.randn((1, 4, 64, 64), device="cuda", dtype=torch.float16).repeat(len(canny_edges), 1, 1, 1)
prompt = "oil painting of a beautiful girl avatar style"
result = pipe(prompt=[prompt] * len(pose_images), image=pose_images, latents=latents).images
result = pipe(prompt=[prompt] * len(canny_edges), image=canny_edges, latents=latents).images
imageio.mimsave("video.mp4", result, fps=4)
```

View File

@@ -38,12 +38,12 @@ def rearrange_4(tensor):
class CrossFrameAttnProcessor:
"""
Cross frame attention processor. For each frame the self-attention is replaced with attention with first frame
Cross frame attention processor. Each frame attends the first frame.
Args:
batch_size: The number that represents actual batch size, other than the frames.
For example, using calling unet with a single prompt and num_images_per_prompt=1, batch_size should be
equal to 2, due to classifier-free guidance.
For example, calling unet with a single prompt and num_images_per_prompt=1, batch_size should be equal to
2, due to classifier-free guidance.
"""
def __init__(self, batch_size=2):
@@ -63,7 +63,7 @@ class CrossFrameAttnProcessor:
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
# Sparse Attention
# Cross Frame Attention
if not is_cross_attention:
video_length = key.size()[0] // self.batch_size
first_frame_index = [0] * video_length
@@ -95,6 +95,81 @@ class CrossFrameAttnProcessor:
return hidden_states
class CrossFrameAttnProcessor2_0:
"""
Cross frame attention processor with scaled_dot_product attention of Pytorch 2.0.
Args:
batch_size: The number that represents actual batch size, other than the frames.
For example, calling unet with a single prompt and num_images_per_prompt=1, batch_size should be equal to
2, due to classifier-free guidance.
"""
def __init__(self, batch_size=2):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.batch_size = batch_size
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
inner_dim = hidden_states.shape[-1]
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
query = attn.to_q(hidden_states)
is_cross_attention = encoder_hidden_states is not None
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
# Cross Frame Attention
if not is_cross_attention:
video_length = key.size()[0] // self.batch_size
first_frame_index = [0] * video_length
# rearrange keys to have batch and frames in the 1st and 2nd dims respectively
key = rearrange_3(key, video_length)
key = key[:, first_frame_index]
# rearrange values to have batch and frames in the 1st and 2nd dims respectively
value = rearrange_3(value, video_length)
value = value[:, first_frame_index]
# rearrange back to original shape
key = rearrange_4(key)
value = rearrange_4(value)
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
@dataclass
class TextToVideoPipelineOutput(BaseOutput):
images: Union[List[PIL.Image.Image], np.ndarray]
@@ -227,7 +302,12 @@ class TextToVideoZeroPipeline(StableDiffusionPipeline):
super().__init__(
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
)
self.unet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2))
processor = (
CrossFrameAttnProcessor2_0(batch_size=2)
if hasattr(F, "scaled_dot_product_attention")
else CrossFrameAttnProcessor(batch_size=2)
)
self.unet.set_attn_processor(processor)
def forward_loop(self, x_t0, t0, t1, generator):
"""
@@ -338,6 +418,7 @@ class TextToVideoZeroPipeline(StableDiffusionPipeline):
callback_steps: Optional[int] = 1,
t0: int = 44,
t1: int = 47,
frame_ids: Optional[List[int]] = None,
):
"""
Function invoked when calling the pipeline for generation.
@@ -399,6 +480,9 @@ class TextToVideoZeroPipeline(StableDiffusionPipeline):
t1 (`int`, *optional*, defaults to 47):
Timestep t0. Should be in the range [t0 + 1, num_inference_steps - 1]. See the
[paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1.
frame_ids (`List[int]`, *optional*):
Indexes of the frames that are being generated. This is used when generating longer videos
chunk-by-chunk.
Returns:
[`~pipelines.text_to_video_synthesis.TextToVideoPipelineOutput`]:
@@ -407,7 +491,9 @@ class TextToVideoZeroPipeline(StableDiffusionPipeline):
likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`.
"""
assert video_length > 0
frame_ids = list(range(video_length))
if frame_ids is None:
frame_ids = list(range(video_length))
assert len(frame_ids) == video_length
assert num_videos_per_prompt == 1