mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Text2video zero refinements (#3733)
* fix docs typos. add frame_ids argument to text2video-zero pipeline call * make style && make quality * add support of pytorch 2.0 scaled_dot_product_attention for CrossFrameAttnProcessor * add chunk-by-chunk processing to text2video-zero docs * make style && make quality * Update docs/source/en/api/pipelines/text_to_video_zero.mdx Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
committed by
GitHub
parent
f46b22ba13
commit
a812fb6f5c
@@ -80,6 +80,41 @@ You can change these parameters in the pipeline call:
|
||||
* Video length:
|
||||
* `video_length`, the number of frames video_length to be generated. Default: `video_length=8`
|
||||
|
||||
We an also generate longer videos by doing the processing in a chunk-by-chunk manner:
|
||||
```python
|
||||
import torch
|
||||
import imageio
|
||||
from diffusers import TextToVideoZeroPipeline
|
||||
import numpy as np
|
||||
|
||||
model_id = "runwayml/stable-diffusion-v1-5"
|
||||
pipe = TextToVideoZeroPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
|
||||
seed = 0
|
||||
video_length = 8
|
||||
chunk_size = 4
|
||||
prompt = "A panda is playing guitar on times square"
|
||||
|
||||
# Generate the video chunk-by-chunk
|
||||
result = []
|
||||
chunk_ids = np.arange(0, video_length, chunk_size - 1)
|
||||
generator = torch.Generator(device="cuda")
|
||||
for i in range(len(chunk_ids)):
|
||||
print(f"Processing chunk {i + 1} / {len(chunk_ids)}")
|
||||
ch_start = chunk_ids[i]
|
||||
ch_end = video_length if i == len(chunk_ids) - 1 else chunk_ids[i + 1]
|
||||
# Attach the first frame for Cross Frame Attention
|
||||
frame_ids = [0] + list(range(ch_start, ch_end))
|
||||
# Fix the seed for the temporal consistency
|
||||
generator.manual_seed(seed)
|
||||
output = pipe(prompt=prompt, video_length=len(frame_ids), generator=generator, frame_ids=frame_ids)
|
||||
result.append(output.images[1:])
|
||||
|
||||
# Concatenate chunks and save
|
||||
result = np.concatenate(result)
|
||||
result = [(r * 255).astype("uint8") for r in result]
|
||||
imageio.mimsave("video.mp4", result, fps=4)
|
||||
```
|
||||
|
||||
|
||||
### Text-To-Video with Pose Control
|
||||
To generate a video from prompt with additional pose control
|
||||
@@ -202,7 +237,7 @@ can run with custom [DreamBooth](../training/dreambooth) models, as shown below
|
||||
|
||||
reader = imageio.get_reader(video_path, "ffmpeg")
|
||||
frame_count = 8
|
||||
video = [Image.fromarray(reader.get_data(i)) for i in range(frame_count)]
|
||||
canny_edges = [Image.fromarray(reader.get_data(i)) for i in range(frame_count)]
|
||||
```
|
||||
|
||||
3. Run `StableDiffusionControlNetPipeline` with custom trained DreamBooth model
|
||||
@@ -223,10 +258,10 @@ can run with custom [DreamBooth](../training/dreambooth) models, as shown below
|
||||
pipe.controlnet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2))
|
||||
|
||||
# fix latents for all frames
|
||||
latents = torch.randn((1, 4, 64, 64), device="cuda", dtype=torch.float16).repeat(len(pose_images), 1, 1, 1)
|
||||
latents = torch.randn((1, 4, 64, 64), device="cuda", dtype=torch.float16).repeat(len(canny_edges), 1, 1, 1)
|
||||
|
||||
prompt = "oil painting of a beautiful girl avatar style"
|
||||
result = pipe(prompt=[prompt] * len(pose_images), image=pose_images, latents=latents).images
|
||||
result = pipe(prompt=[prompt] * len(canny_edges), image=canny_edges, latents=latents).images
|
||||
imageio.mimsave("video.mp4", result, fps=4)
|
||||
```
|
||||
|
||||
|
||||
@@ -38,12 +38,12 @@ def rearrange_4(tensor):
|
||||
|
||||
class CrossFrameAttnProcessor:
|
||||
"""
|
||||
Cross frame attention processor. For each frame the self-attention is replaced with attention with first frame
|
||||
Cross frame attention processor. Each frame attends the first frame.
|
||||
|
||||
Args:
|
||||
batch_size: The number that represents actual batch size, other than the frames.
|
||||
For example, using calling unet with a single prompt and num_images_per_prompt=1, batch_size should be
|
||||
equal to 2, due to classifier-free guidance.
|
||||
For example, calling unet with a single prompt and num_images_per_prompt=1, batch_size should be equal to
|
||||
2, due to classifier-free guidance.
|
||||
"""
|
||||
|
||||
def __init__(self, batch_size=2):
|
||||
@@ -63,7 +63,7 @@ class CrossFrameAttnProcessor:
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
# Sparse Attention
|
||||
# Cross Frame Attention
|
||||
if not is_cross_attention:
|
||||
video_length = key.size()[0] // self.batch_size
|
||||
first_frame_index = [0] * video_length
|
||||
@@ -95,6 +95,81 @@ class CrossFrameAttnProcessor:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CrossFrameAttnProcessor2_0:
|
||||
"""
|
||||
Cross frame attention processor with scaled_dot_product attention of Pytorch 2.0.
|
||||
|
||||
Args:
|
||||
batch_size: The number that represents actual batch size, other than the frames.
|
||||
For example, calling unet with a single prompt and num_images_per_prompt=1, batch_size should be equal to
|
||||
2, due to classifier-free guidance.
|
||||
"""
|
||||
|
||||
def __init__(self, batch_size=2):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
self.batch_size = batch_size
|
||||
|
||||
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
inner_dim = hidden_states.shape[-1]
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
# Cross Frame Attention
|
||||
if not is_cross_attention:
|
||||
video_length = key.size()[0] // self.batch_size
|
||||
first_frame_index = [0] * video_length
|
||||
|
||||
# rearrange keys to have batch and frames in the 1st and 2nd dims respectively
|
||||
key = rearrange_3(key, video_length)
|
||||
key = key[:, first_frame_index]
|
||||
# rearrange values to have batch and frames in the 1st and 2nd dims respectively
|
||||
value = rearrange_3(value, video_length)
|
||||
value = value[:, first_frame_index]
|
||||
|
||||
# rearrange back to original shape
|
||||
key = rearrange_4(key)
|
||||
value = rearrange_4(value)
|
||||
|
||||
head_dim = inner_dim // attn.heads
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextToVideoPipelineOutput(BaseOutput):
|
||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
@@ -227,7 +302,12 @@ class TextToVideoZeroPipeline(StableDiffusionPipeline):
|
||||
super().__init__(
|
||||
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
|
||||
)
|
||||
self.unet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2))
|
||||
processor = (
|
||||
CrossFrameAttnProcessor2_0(batch_size=2)
|
||||
if hasattr(F, "scaled_dot_product_attention")
|
||||
else CrossFrameAttnProcessor(batch_size=2)
|
||||
)
|
||||
self.unet.set_attn_processor(processor)
|
||||
|
||||
def forward_loop(self, x_t0, t0, t1, generator):
|
||||
"""
|
||||
@@ -338,6 +418,7 @@ class TextToVideoZeroPipeline(StableDiffusionPipeline):
|
||||
callback_steps: Optional[int] = 1,
|
||||
t0: int = 44,
|
||||
t1: int = 47,
|
||||
frame_ids: Optional[List[int]] = None,
|
||||
):
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@@ -399,6 +480,9 @@ class TextToVideoZeroPipeline(StableDiffusionPipeline):
|
||||
t1 (`int`, *optional*, defaults to 47):
|
||||
Timestep t0. Should be in the range [t0 + 1, num_inference_steps - 1]. See the
|
||||
[paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1.
|
||||
frame_ids (`List[int]`, *optional*):
|
||||
Indexes of the frames that are being generated. This is used when generating longer videos
|
||||
chunk-by-chunk.
|
||||
|
||||
Returns:
|
||||
[`~pipelines.text_to_video_synthesis.TextToVideoPipelineOutput`]:
|
||||
@@ -407,7 +491,9 @@ class TextToVideoZeroPipeline(StableDiffusionPipeline):
|
||||
likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
assert video_length > 0
|
||||
frame_ids = list(range(video_length))
|
||||
if frame_ids is None:
|
||||
frame_ids = list(range(video_length))
|
||||
assert len(frame_ids) == video_length
|
||||
|
||||
assert num_videos_per_prompt == 1
|
||||
|
||||
|
||||
Reference in New Issue
Block a user