From a812fb6f5c2f147a6d98c994effdaa5c2087e53b Mon Sep 17 00:00:00 2001 From: Andranik Movsisyan <48154088+19and99@users.noreply.github.com> Date: Mon, 12 Jun 2023 20:03:18 +0400 Subject: [PATCH] Text2video zero refinements (#3733) * fix docs typos. add frame_ids argument to text2video-zero pipeline call * make style && make quality * add support of pytorch 2.0 scaled_dot_product_attention for CrossFrameAttnProcessor * add chunk-by-chunk processing to text2video-zero docs * make style && make quality * Update docs/source/en/api/pipelines/text_to_video_zero.mdx Co-authored-by: Sayak Paul --------- Co-authored-by: Sayak Paul --- .../en/api/pipelines/text_to_video_zero.mdx | 41 +++++++- .../pipeline_text_to_video_zero.py | 98 +++++++++++++++++-- 2 files changed, 130 insertions(+), 9 deletions(-) diff --git a/docs/source/en/api/pipelines/text_to_video_zero.mdx b/docs/source/en/api/pipelines/text_to_video_zero.mdx index 3ee10f01c3..3c3dcf5bb1 100644 --- a/docs/source/en/api/pipelines/text_to_video_zero.mdx +++ b/docs/source/en/api/pipelines/text_to_video_zero.mdx @@ -80,6 +80,41 @@ You can change these parameters in the pipeline call: * Video length: * `video_length`, the number of frames video_length to be generated. Default: `video_length=8` +We an also generate longer videos by doing the processing in a chunk-by-chunk manner: +```python +import torch +import imageio +from diffusers import TextToVideoZeroPipeline +import numpy as np + +model_id = "runwayml/stable-diffusion-v1-5" +pipe = TextToVideoZeroPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda") +seed = 0 +video_length = 8 +chunk_size = 4 +prompt = "A panda is playing guitar on times square" + +# Generate the video chunk-by-chunk +result = [] +chunk_ids = np.arange(0, video_length, chunk_size - 1) +generator = torch.Generator(device="cuda") +for i in range(len(chunk_ids)): + print(f"Processing chunk {i + 1} / {len(chunk_ids)}") + ch_start = chunk_ids[i] + ch_end = video_length if i == len(chunk_ids) - 1 else chunk_ids[i + 1] + # Attach the first frame for Cross Frame Attention + frame_ids = [0] + list(range(ch_start, ch_end)) + # Fix the seed for the temporal consistency + generator.manual_seed(seed) + output = pipe(prompt=prompt, video_length=len(frame_ids), generator=generator, frame_ids=frame_ids) + result.append(output.images[1:]) + +# Concatenate chunks and save +result = np.concatenate(result) +result = [(r * 255).astype("uint8") for r in result] +imageio.mimsave("video.mp4", result, fps=4) +``` + ### Text-To-Video with Pose Control To generate a video from prompt with additional pose control @@ -202,7 +237,7 @@ can run with custom [DreamBooth](../training/dreambooth) models, as shown below reader = imageio.get_reader(video_path, "ffmpeg") frame_count = 8 - video = [Image.fromarray(reader.get_data(i)) for i in range(frame_count)] + canny_edges = [Image.fromarray(reader.get_data(i)) for i in range(frame_count)] ``` 3. Run `StableDiffusionControlNetPipeline` with custom trained DreamBooth model @@ -223,10 +258,10 @@ can run with custom [DreamBooth](../training/dreambooth) models, as shown below pipe.controlnet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2)) # fix latents for all frames - latents = torch.randn((1, 4, 64, 64), device="cuda", dtype=torch.float16).repeat(len(pose_images), 1, 1, 1) + latents = torch.randn((1, 4, 64, 64), device="cuda", dtype=torch.float16).repeat(len(canny_edges), 1, 1, 1) prompt = "oil painting of a beautiful girl avatar style" - result = pipe(prompt=[prompt] * len(pose_images), image=pose_images, latents=latents).images + result = pipe(prompt=[prompt] * len(canny_edges), image=canny_edges, latents=latents).images imageio.mimsave("video.mp4", result, fps=4) ``` diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py index 5b163bbbc8..fe7207f904 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py @@ -38,12 +38,12 @@ def rearrange_4(tensor): class CrossFrameAttnProcessor: """ - Cross frame attention processor. For each frame the self-attention is replaced with attention with first frame + Cross frame attention processor. Each frame attends the first frame. Args: batch_size: The number that represents actual batch size, other than the frames. - For example, using calling unet with a single prompt and num_images_per_prompt=1, batch_size should be - equal to 2, due to classifier-free guidance. + For example, calling unet with a single prompt and num_images_per_prompt=1, batch_size should be equal to + 2, due to classifier-free guidance. """ def __init__(self, batch_size=2): @@ -63,7 +63,7 @@ class CrossFrameAttnProcessor: key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) - # Sparse Attention + # Cross Frame Attention if not is_cross_attention: video_length = key.size()[0] // self.batch_size first_frame_index = [0] * video_length @@ -95,6 +95,81 @@ class CrossFrameAttnProcessor: return hidden_states +class CrossFrameAttnProcessor2_0: + """ + Cross frame attention processor with scaled_dot_product attention of Pytorch 2.0. + + Args: + batch_size: The number that represents actual batch size, other than the frames. + For example, calling unet with a single prompt and num_images_per_prompt=1, batch_size should be equal to + 2, due to classifier-free guidance. + """ + + def __init__(self, batch_size=2): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + self.batch_size = batch_size + + def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + inner_dim = hidden_states.shape[-1] + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + query = attn.to_q(hidden_states) + + is_cross_attention = encoder_hidden_states is not None + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + # Cross Frame Attention + if not is_cross_attention: + video_length = key.size()[0] // self.batch_size + first_frame_index = [0] * video_length + + # rearrange keys to have batch and frames in the 1st and 2nd dims respectively + key = rearrange_3(key, video_length) + key = key[:, first_frame_index] + # rearrange values to have batch and frames in the 1st and 2nd dims respectively + value = rearrange_3(value, video_length) + value = value[:, first_frame_index] + + # rearrange back to original shape + key = rearrange_4(key) + value = rearrange_4(value) + + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + @dataclass class TextToVideoPipelineOutput(BaseOutput): images: Union[List[PIL.Image.Image], np.ndarray] @@ -227,7 +302,12 @@ class TextToVideoZeroPipeline(StableDiffusionPipeline): super().__init__( vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker ) - self.unet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2)) + processor = ( + CrossFrameAttnProcessor2_0(batch_size=2) + if hasattr(F, "scaled_dot_product_attention") + else CrossFrameAttnProcessor(batch_size=2) + ) + self.unet.set_attn_processor(processor) def forward_loop(self, x_t0, t0, t1, generator): """ @@ -338,6 +418,7 @@ class TextToVideoZeroPipeline(StableDiffusionPipeline): callback_steps: Optional[int] = 1, t0: int = 44, t1: int = 47, + frame_ids: Optional[List[int]] = None, ): """ Function invoked when calling the pipeline for generation. @@ -399,6 +480,9 @@ class TextToVideoZeroPipeline(StableDiffusionPipeline): t1 (`int`, *optional*, defaults to 47): Timestep t0. Should be in the range [t0 + 1, num_inference_steps - 1]. See the [paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1. + frame_ids (`List[int]`, *optional*): + Indexes of the frames that are being generated. This is used when generating longer videos + chunk-by-chunk. Returns: [`~pipelines.text_to_video_synthesis.TextToVideoPipelineOutput`]: @@ -407,7 +491,9 @@ class TextToVideoZeroPipeline(StableDiffusionPipeline): likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ assert video_length > 0 - frame_ids = list(range(video_length)) + if frame_ids is None: + frame_ids = list(range(video_length)) + assert len(frame_ids) == video_length assert num_videos_per_prompt == 1