mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[MS Text To Video] Add first text to video (#2738)
* [MS Text To Video} Add first text to video * upload * make first model example * match unet3d params * make sure weights are correcctly converted * improve * forward pass works, but diff result * make forward work * fix more * finish * refactor video output class. * feat: add support for a video export utility. * fix: opencv availability check. * run make fix-copies. * add: docs for the model components. * add: standalone pipeline doc. * edit docstring of the pipeline. * add: right path to TransformerTempModel * add: first set of tests. * complete fast tests for text to video. * fix bug * up * three fast tests failing. * add: note on slow tests * make work with all schedulers * apply styling. * add slow tests * change file name * update * more correction * more fixes * finish * up * Apply suggestions from code review * up * finish * make copies * fix pipeline tests * fix more tests * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * apply suggestions * up * revert --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
committed by
GitHub
parent
7fe88613fa
commit
ca1a22296d
@@ -192,6 +192,8 @@
|
||||
title: Stable unCLIP
|
||||
- local: api/pipelines/stochastic_karras_ve
|
||||
title: Stochastic Karras VE
|
||||
- local: api/pipelines/text_to_video
|
||||
title: Text-to-Video
|
||||
- local: api/pipelines/unclip
|
||||
title: UnCLIP
|
||||
- local: api/pipelines/latent_diffusion_uncond
|
||||
|
||||
@@ -37,6 +37,12 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module
|
||||
## UNet2DConditionModel
|
||||
[[autodoc]] UNet2DConditionModel
|
||||
|
||||
## UNet3DConditionOutput
|
||||
[[autodoc]] models.unet_3d_condition.UNet3DConditionOutput
|
||||
|
||||
## UNet3DConditionModel
|
||||
[[autodoc]] UNet3DConditionModel
|
||||
|
||||
## DecoderOutput
|
||||
[[autodoc]] models.vae.DecoderOutput
|
||||
|
||||
@@ -58,6 +64,12 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module
|
||||
## Transformer2DModelOutput
|
||||
[[autodoc]] models.transformer_2d.Transformer2DModelOutput
|
||||
|
||||
## TransformerTemporalModel
|
||||
[[autodoc]] models.transformer_temporal.TransformerTemporalModel
|
||||
|
||||
## Transformer2DModelOutput
|
||||
[[autodoc]] models.transformer_temporal.TransformerTemporalModelOutput
|
||||
|
||||
## PriorTransformer
|
||||
[[autodoc]] models.prior_transformer.PriorTransformer
|
||||
|
||||
|
||||
@@ -77,6 +77,7 @@ available a colab notebook to directly try them out.
|
||||
| [stable_unclip](./stable_unclip) | **Stable unCLIP** | Text-to-Image Generation |
|
||||
| [stable_unclip](./stable_unclip) | **Stable unCLIP** | Image-to-Image Text-Guided Generation |
|
||||
| [stochastic_karras_ve](./stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
|
||||
| [text_to_video_sd](./api/pipelines/text_to_video) | [Modelscope's Text-to-video-synthesis Model in Open Domain](https://modelscope.cn/models/damo/text-to-video-synthesis/summary) | Text-to-Video Generation |
|
||||
| [unclip](./unclip) | [Hierarchical Text-Conditional Image Generation with CLIP Latents](https://arxiv.org/abs/2204.06125) | Text-to-Image Generation |
|
||||
| [versatile_diffusion](./versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Text-to-Image Generation |
|
||||
| [versatile_diffusion](./versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Image Variations Generation |
|
||||
|
||||
122
docs/source/en/api/pipelines/text_to_video.mdx
Normal file
122
docs/source/en/api/pipelines/text_to_video.mdx
Normal file
@@ -0,0 +1,122 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Text-to-video synthesis
|
||||
|
||||
Text-to-video synthesis from [ModelScope](https://modelscope.cn/) can be considered the same as Stable Diffusion structure-wise but it is extended to videos instead of static images. More specifically, this system allows us to generate videos from a natural language text prompt.
|
||||
|
||||
From the [model summary](https://huggingface.co/damo-vilab/modelscope-damo-text-to-video-synthesis):
|
||||
|
||||
*This model is based on a multi-stage text-to-video generation diffusion model, which inputs a description text and returns a video that matches the text description. Only English input is supported.*
|
||||
|
||||
Resources:
|
||||
|
||||
* [Website](https://modelscope.cn/models/damo/text-to-video-synthesis/summary)
|
||||
* [GitHub repository](https://github.com/modelscope/modelscope/)
|
||||
* [Spaces] (TODO)
|
||||
|
||||
## Available Pipelines:
|
||||
|
||||
| Pipeline | Tasks | Demo
|
||||
|---|---|:---:|
|
||||
| [DiffusionPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py) | *Text-to-Video Generation* | [Spaces] (TODO)
|
||||
|
||||
## Usage example
|
||||
|
||||
Let's start by generating a short video with the default length of 16 frames (2s at 8 fps):
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16")
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
prompt = "Spiderman is surfing"
|
||||
video_frames = pipe(prompt).frames
|
||||
video_path = export_to_video(video_frames)
|
||||
video_path
|
||||
```
|
||||
|
||||
Diffusers supports different optimization techniques to improve the latency
|
||||
and memory footprint of a pipeline. Since videos are often more memory-heavy than images,
|
||||
we can enable CPU offloading and VAE slicing to keep the memory footprint at bay.
|
||||
|
||||
Let's generate a video of 8 seconds (64 frames) on the same GPU using CPU offloading and VAE slicing:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16")
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
# memory optimization
|
||||
pipe.enable_vae_slicing()
|
||||
|
||||
prompt = "Darth Vader surfing a wave"
|
||||
video_frames = pipe(prompt, num_frames=64).frames
|
||||
video_path = export_to_video(video_frames)
|
||||
video_path
|
||||
```
|
||||
|
||||
It just takes **7 GBs of GPU memory** to generate the 64 video frames using PyTorch 2.0, "fp16" precision and the techniques mentioned above.
|
||||
|
||||
We can also use a different scheduler easily, using the same method we'd use for Stable Diffusion:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16")
|
||||
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
prompt = "Spiderman is surfing"
|
||||
video_frames = pipe(prompt, num_inference_steps=25).frames
|
||||
video_path = export_to_video(video_frames)
|
||||
video_path
|
||||
```
|
||||
|
||||
Here are some sample outputs:
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td><center>
|
||||
An astronaut riding a horse.
|
||||
<br>
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astr.gif"
|
||||
alt="An astronaut riding a horse."
|
||||
style="width: 300px;" />
|
||||
</center></td>
|
||||
<td ><center>
|
||||
Darth vader surfing in waves.
|
||||
<br>
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/vader.gif"
|
||||
alt="Darth vader surfing in waves."
|
||||
style="width: 300px;" />
|
||||
</center></td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## Available checkpoints
|
||||
|
||||
* [damo-vilab/text-to-video-ms-1.7b](https://huggingface.co/damo-vilab/text-to-video-ms-1.7b/)
|
||||
* [damo-vilab/text-to-video-ms-1.7b-legacy](https://huggingface.co/damo-vilab/text-to-video-ms-1.7b-legacy)
|
||||
|
||||
## DiffusionPipeline
|
||||
[[autodoc]] DiffusionPipeline
|
||||
- all
|
||||
- __call__
|
||||
@@ -84,8 +84,9 @@ The library has three main components:
|
||||
| [stable_unclip](./stable_unclip) | Stable unCLIP | Text-to-Image Generation |
|
||||
| [stable_unclip](./stable_unclip) | Stable unCLIP | Image-to-Image Text-Guided Generation |
|
||||
| [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
|
||||
| [text_to_video_sd](./api/pipelines/text_to_video) | [Modelscope's Text-to-video-synthesis Model in Open Domain](https://modelscope.cn/models/damo/text-to-video-synthesis/summary) | Text-to-Video Generation |
|
||||
| [unclip](./api/pipelines/unclip) | [Hierarchical Text-Conditional Image Generation with CLIP Latents](https://arxiv.org/abs/2204.06125)(implementation by [kakaobrain](https://github.com/kakaobrain/karlo)) | Text-to-Image Generation |
|
||||
| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Text-to-Image Generation |
|
||||
| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Image Variations Generation |
|
||||
| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Dual Image and Text Guided Generation |
|
||||
| [vq_diffusion](./api/pipelines/vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation |
|
||||
| [vq_diffusion](./api/pipelines/vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation |
|
||||
|
||||
@@ -216,7 +216,7 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.")
|
||||
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
|
||||
@@ -314,7 +314,7 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.")
|
||||
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
|
||||
@@ -314,7 +314,7 @@ class StableDiffusionControlNetInpaintImg2ImgPipeline(DiffusionPipeline):
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.")
|
||||
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
|
||||
428
scripts/convert_ms_text_to_video_to_diffusers.py
Normal file
428
scripts/convert_ms_text_to_video_to_diffusers.py
Normal file
@@ -0,0 +1,428 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Conversion script for the LDM checkpoints. """
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import UNet3DConditionModel
|
||||
|
||||
|
||||
def assign_to_checkpoint(
|
||||
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
|
||||
):
|
||||
"""
|
||||
This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
|
||||
attention layers, and takes into account additional replacements that may arise.
|
||||
|
||||
Assigns the weights to the new checkpoint.
|
||||
"""
|
||||
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
||||
|
||||
# Splits the attention layers into three variables.
|
||||
if attention_paths_to_split is not None:
|
||||
for path, path_map in attention_paths_to_split.items():
|
||||
old_tensor = old_checkpoint[path]
|
||||
channels = old_tensor.shape[0] // 3
|
||||
|
||||
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
||||
|
||||
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
||||
|
||||
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
||||
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
||||
|
||||
checkpoint[path_map["query"]] = query.reshape(target_shape)
|
||||
checkpoint[path_map["key"]] = key.reshape(target_shape)
|
||||
checkpoint[path_map["value"]] = value.reshape(target_shape)
|
||||
|
||||
for path in paths:
|
||||
new_path = path["new"]
|
||||
|
||||
# These have already been assigned
|
||||
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
||||
continue
|
||||
|
||||
if additional_replacements is not None:
|
||||
for replacement in additional_replacements:
|
||||
new_path = new_path.replace(replacement["old"], replacement["new"])
|
||||
|
||||
# proj_attn.weight has to be converted from conv 1D to linear
|
||||
weight = old_checkpoint[path["old"]]
|
||||
names = ["proj_attn.weight"]
|
||||
names_2 = ["proj_out.weight", "proj_in.weight"]
|
||||
if any(k in new_path for k in names):
|
||||
checkpoint[new_path] = weight[:, :, 0]
|
||||
elif any(k in new_path for k in names_2) and len(weight.shape) > 2 and ".attentions." not in new_path:
|
||||
checkpoint[new_path] = weight[:, :, 0]
|
||||
else:
|
||||
checkpoint[new_path] = weight
|
||||
|
||||
|
||||
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
||||
"""
|
||||
Updates paths inside attentions to the new naming scheme (local renaming)
|
||||
"""
|
||||
mapping = []
|
||||
for old_item in old_list:
|
||||
new_item = old_item
|
||||
|
||||
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
|
||||
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
|
||||
|
||||
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
|
||||
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
|
||||
|
||||
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
||||
|
||||
mapping.append({"old": old_item, "new": new_item})
|
||||
|
||||
return mapping
|
||||
|
||||
|
||||
def shave_segments(path, n_shave_prefix_segments=1):
|
||||
"""
|
||||
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
||||
"""
|
||||
if n_shave_prefix_segments >= 0:
|
||||
return ".".join(path.split(".")[n_shave_prefix_segments:])
|
||||
else:
|
||||
return ".".join(path.split(".")[:n_shave_prefix_segments])
|
||||
|
||||
|
||||
def renew_temp_conv_paths(old_list, n_shave_prefix_segments=0):
|
||||
"""
|
||||
Updates paths inside resnets to the new naming scheme (local renaming)
|
||||
"""
|
||||
mapping = []
|
||||
for old_item in old_list:
|
||||
mapping.append({"old": old_item, "new": old_item})
|
||||
|
||||
return mapping
|
||||
|
||||
|
||||
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
||||
"""
|
||||
Updates paths inside resnets to the new naming scheme (local renaming)
|
||||
"""
|
||||
mapping = []
|
||||
for old_item in old_list:
|
||||
new_item = old_item.replace("in_layers.0", "norm1")
|
||||
new_item = new_item.replace("in_layers.2", "conv1")
|
||||
|
||||
new_item = new_item.replace("out_layers.0", "norm2")
|
||||
new_item = new_item.replace("out_layers.3", "conv2")
|
||||
|
||||
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
|
||||
new_item = new_item.replace("skip_connection", "conv_shortcut")
|
||||
|
||||
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
||||
|
||||
if "temopral_conv" not in old_item:
|
||||
mapping.append({"old": old_item, "new": new_item})
|
||||
|
||||
return mapping
|
||||
|
||||
|
||||
def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False):
|
||||
"""
|
||||
Takes a state dict and a config, and returns a converted checkpoint.
|
||||
"""
|
||||
|
||||
# extract state_dict for UNet
|
||||
unet_state_dict = {}
|
||||
keys = list(checkpoint.keys())
|
||||
|
||||
unet_key = "model.diffusion_model."
|
||||
|
||||
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
|
||||
if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
|
||||
print(f"Checkpoint {path} has both EMA and non-EMA weights.")
|
||||
print(
|
||||
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
|
||||
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
|
||||
)
|
||||
for key in keys:
|
||||
if key.startswith("model.diffusion_model"):
|
||||
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
|
||||
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
|
||||
else:
|
||||
if sum(k.startswith("model_ema") for k in keys) > 100:
|
||||
print(
|
||||
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
|
||||
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
|
||||
)
|
||||
|
||||
for key in keys:
|
||||
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
||||
|
||||
new_checkpoint = {}
|
||||
|
||||
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
|
||||
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
|
||||
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
|
||||
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
|
||||
|
||||
if config["class_embed_type"] is None:
|
||||
# No parameters to port
|
||||
...
|
||||
elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
|
||||
new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
|
||||
new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
|
||||
new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
|
||||
new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
|
||||
else:
|
||||
raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
|
||||
|
||||
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
|
||||
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
|
||||
|
||||
first_temp_attention = [v for v in unet_state_dict if v.startswith("input_blocks.0.1")]
|
||||
paths = renew_attention_paths(first_temp_attention)
|
||||
meta_path = {"old": "input_blocks.0.1", "new": "transformer_in"}
|
||||
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
||||
|
||||
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
|
||||
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
|
||||
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
|
||||
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
|
||||
|
||||
# Retrieves the keys for the input blocks only
|
||||
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
|
||||
input_blocks = {
|
||||
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
|
||||
for layer_id in range(num_input_blocks)
|
||||
}
|
||||
|
||||
# Retrieves the keys for the middle blocks only
|
||||
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
|
||||
middle_blocks = {
|
||||
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
|
||||
for layer_id in range(num_middle_blocks)
|
||||
}
|
||||
|
||||
# Retrieves the keys for the output blocks only
|
||||
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
|
||||
output_blocks = {
|
||||
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
|
||||
for layer_id in range(num_output_blocks)
|
||||
}
|
||||
|
||||
for i in range(1, num_input_blocks):
|
||||
block_id = (i - 1) // (config["layers_per_block"] + 1)
|
||||
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
|
||||
|
||||
resnets = [
|
||||
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
|
||||
]
|
||||
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
|
||||
temp_attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.2" in key]
|
||||
|
||||
if f"input_blocks.{i}.op.weight" in unet_state_dict:
|
||||
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
|
||||
f"input_blocks.{i}.op.weight"
|
||||
)
|
||||
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
|
||||
f"input_blocks.{i}.op.bias"
|
||||
)
|
||||
|
||||
paths = renew_resnet_paths(resnets)
|
||||
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
||||
assign_to_checkpoint(
|
||||
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
||||
)
|
||||
|
||||
temporal_convs = [key for key in resnets if "temopral_conv" in key]
|
||||
paths = renew_temp_conv_paths(temporal_convs)
|
||||
meta_path = {
|
||||
"old": f"input_blocks.{i}.0.temopral_conv",
|
||||
"new": f"down_blocks.{block_id}.temp_convs.{layer_in_block_id}",
|
||||
}
|
||||
assign_to_checkpoint(
|
||||
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
||||
)
|
||||
|
||||
if len(attentions):
|
||||
paths = renew_attention_paths(attentions)
|
||||
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
|
||||
assign_to_checkpoint(
|
||||
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
||||
)
|
||||
|
||||
if len(temp_attentions):
|
||||
paths = renew_attention_paths(temp_attentions)
|
||||
meta_path = {
|
||||
"old": f"input_blocks.{i}.2",
|
||||
"new": f"down_blocks.{block_id}.temp_attentions.{layer_in_block_id}",
|
||||
}
|
||||
assign_to_checkpoint(
|
||||
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
||||
)
|
||||
|
||||
resnet_0 = middle_blocks[0]
|
||||
temporal_convs_0 = [key for key in resnet_0 if "temopral_conv" in key]
|
||||
attentions = middle_blocks[1]
|
||||
temp_attentions = middle_blocks[2]
|
||||
resnet_1 = middle_blocks[3]
|
||||
temporal_convs_1 = [key for key in resnet_1 if "temopral_conv" in key]
|
||||
|
||||
resnet_0_paths = renew_resnet_paths(resnet_0)
|
||||
meta_path = {"old": "middle_block.0", "new": "mid_block.resnets.0"}
|
||||
assign_to_checkpoint(
|
||||
resnet_0_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path]
|
||||
)
|
||||
|
||||
temp_conv_0_paths = renew_temp_conv_paths(temporal_convs_0)
|
||||
meta_path = {"old": "middle_block.0.temopral_conv", "new": "mid_block.temp_convs.0"}
|
||||
assign_to_checkpoint(
|
||||
temp_conv_0_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path]
|
||||
)
|
||||
|
||||
resnet_1_paths = renew_resnet_paths(resnet_1)
|
||||
meta_path = {"old": "middle_block.3", "new": "mid_block.resnets.1"}
|
||||
assign_to_checkpoint(
|
||||
resnet_1_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path]
|
||||
)
|
||||
|
||||
temp_conv_1_paths = renew_temp_conv_paths(temporal_convs_1)
|
||||
meta_path = {"old": "middle_block.3.temopral_conv", "new": "mid_block.temp_convs.1"}
|
||||
assign_to_checkpoint(
|
||||
temp_conv_1_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path]
|
||||
)
|
||||
|
||||
attentions_paths = renew_attention_paths(attentions)
|
||||
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
|
||||
assign_to_checkpoint(
|
||||
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
||||
)
|
||||
|
||||
temp_attentions_paths = renew_attention_paths(temp_attentions)
|
||||
meta_path = {"old": "middle_block.2", "new": "mid_block.temp_attentions.0"}
|
||||
assign_to_checkpoint(
|
||||
temp_attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
||||
)
|
||||
|
||||
for i in range(num_output_blocks):
|
||||
block_id = i // (config["layers_per_block"] + 1)
|
||||
layer_in_block_id = i % (config["layers_per_block"] + 1)
|
||||
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
|
||||
output_block_list = {}
|
||||
|
||||
for layer in output_block_layers:
|
||||
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
|
||||
if layer_id in output_block_list:
|
||||
output_block_list[layer_id].append(layer_name)
|
||||
else:
|
||||
output_block_list[layer_id] = [layer_name]
|
||||
|
||||
if len(output_block_list) > 1:
|
||||
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
|
||||
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
|
||||
temp_attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.2" in key]
|
||||
|
||||
resnet_0_paths = renew_resnet_paths(resnets)
|
||||
paths = renew_resnet_paths(resnets)
|
||||
|
||||
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
||||
assign_to_checkpoint(
|
||||
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
||||
)
|
||||
|
||||
temporal_convs = [key for key in resnets if "temopral_conv" in key]
|
||||
paths = renew_temp_conv_paths(temporal_convs)
|
||||
meta_path = {
|
||||
"old": f"output_blocks.{i}.0.temopral_conv",
|
||||
"new": f"up_blocks.{block_id}.temp_convs.{layer_in_block_id}",
|
||||
}
|
||||
assign_to_checkpoint(
|
||||
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
||||
)
|
||||
|
||||
output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
|
||||
if ["conv.bias", "conv.weight"] in output_block_list.values():
|
||||
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
|
||||
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
|
||||
f"output_blocks.{i}.{index}.conv.weight"
|
||||
]
|
||||
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
|
||||
f"output_blocks.{i}.{index}.conv.bias"
|
||||
]
|
||||
|
||||
# Clear attentions as they have been attributed above.
|
||||
if len(attentions) == 2:
|
||||
attentions = []
|
||||
|
||||
if len(attentions):
|
||||
paths = renew_attention_paths(attentions)
|
||||
meta_path = {
|
||||
"old": f"output_blocks.{i}.1",
|
||||
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
|
||||
}
|
||||
assign_to_checkpoint(
|
||||
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
||||
)
|
||||
|
||||
if len(temp_attentions):
|
||||
paths = renew_attention_paths(temp_attentions)
|
||||
meta_path = {
|
||||
"old": f"output_blocks.{i}.2",
|
||||
"new": f"up_blocks.{block_id}.temp_attentions.{layer_in_block_id}",
|
||||
}
|
||||
assign_to_checkpoint(
|
||||
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
||||
)
|
||||
else:
|
||||
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
||||
for path in resnet_0_paths:
|
||||
old_path = ".".join(["output_blocks", str(i), path["old"]])
|
||||
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
|
||||
new_checkpoint[new_path] = unet_state_dict[old_path]
|
||||
|
||||
temopral_conv_paths = [l for l in output_block_layers if "temopral_conv" in l]
|
||||
for path in temopral_conv_paths:
|
||||
pruned_path = path.split("temopral_conv.")[-1]
|
||||
old_path = ".".join(["output_blocks", str(i), str(block_id), "temopral_conv", pruned_path])
|
||||
new_path = ".".join(["up_blocks", str(block_id), "temp_convs", str(layer_in_block_id), pruned_path])
|
||||
new_checkpoint[new_path] = unet_state_dict[old_path]
|
||||
|
||||
return new_checkpoint
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
|
||||
)
|
||||
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
||||
args = parser.parse_args()
|
||||
|
||||
unet_checkpoint = torch.load(args.checkpoint_path, map_location="cpu")
|
||||
unet = UNet3DConditionModel()
|
||||
|
||||
converted_ckpt = convert_ldm_unet_checkpoint(unet_checkpoint, unet.config)
|
||||
|
||||
diff_0 = set(unet.state_dict().keys()) - set(converted_ckpt.keys())
|
||||
diff_1 = set(converted_ckpt.keys()) - set(unet.state_dict().keys())
|
||||
|
||||
assert len(diff_0) == len(diff_1) == 0, "Converted weights don't match"
|
||||
|
||||
# load state_dict
|
||||
unet.load_state_dict(converted_ckpt)
|
||||
|
||||
unet.save_pretrained(args.dump_path)
|
||||
|
||||
# -- finish converting the unet --
|
||||
@@ -41,6 +41,7 @@ else:
|
||||
UNet1DModel,
|
||||
UNet2DConditionModel,
|
||||
UNet2DModel,
|
||||
UNet3DConditionModel,
|
||||
VQModel,
|
||||
)
|
||||
from .optimization import (
|
||||
@@ -130,6 +131,7 @@ else:
|
||||
StableDiffusionUpscalePipeline,
|
||||
StableUnCLIPImg2ImgPipeline,
|
||||
StableUnCLIPPipeline,
|
||||
TextToVideoSDPipeline,
|
||||
UnCLIPImageVariationPipeline,
|
||||
UnCLIPPipeline,
|
||||
VersatileDiffusionDualGuidedPipeline,
|
||||
|
||||
@@ -25,6 +25,7 @@ if is_torch_available():
|
||||
from .unet_1d import UNet1DModel
|
||||
from .unet_2d import UNet2DModel
|
||||
from .unet_2d_condition import UNet2DConditionModel
|
||||
from .unet_3d_condition import UNet3DConditionModel
|
||||
from .vq_model import VQModel
|
||||
|
||||
if is_flax_available():
|
||||
|
||||
@@ -184,6 +184,10 @@ class BasicTransformerBlock(nn.Module):
|
||||
attention_head_dim (`int`): The number of channels in each head.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
||||
only_cross_attention (`bool`, *optional*):
|
||||
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
||||
double_self_attention (`bool`, *optional*):
|
||||
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
num_embeds_ada_norm (:
|
||||
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
||||
@@ -202,6 +206,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
only_cross_attention: bool = False,
|
||||
double_self_attention: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
norm_elementwise_affine: bool = True,
|
||||
norm_type: str = "layer_norm",
|
||||
@@ -233,10 +238,10 @@ class BasicTransformerBlock(nn.Module):
|
||||
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
|
||||
|
||||
# 2. Cross-Attn
|
||||
if cross_attention_dim is not None:
|
||||
if cross_attention_dim is not None or double_self_attention:
|
||||
self.attn2 = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
@@ -253,7 +258,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
else:
|
||||
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
||||
|
||||
if cross_attention_dim is not None:
|
||||
if cross_attention_dim is not None or double_self_attention:
|
||||
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
||||
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
||||
# the second cross attention block.
|
||||
|
||||
@@ -207,6 +207,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
|
||||
|
||||
def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
||||
r"""Encode a batch of images using a tiled encoder.
|
||||
|
||||
Args:
|
||||
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
||||
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is:
|
||||
@@ -253,6 +254,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
|
||||
|
||||
def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
r"""Decode a batch of images using a tiled decoder.
|
||||
|
||||
Args:
|
||||
When this option is enabled, the VAE will split the input tensor into tiles to compute decoding in several
|
||||
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled decoding is:
|
||||
|
||||
@@ -1,3 +1,18 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
# `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
@@ -764,3 +779,61 @@ def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)):
|
||||
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
|
||||
|
||||
return out.view(-1, channel, out_h, out_w)
|
||||
|
||||
|
||||
class TemporalConvLayer(nn.Module):
|
||||
"""
|
||||
Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from:
|
||||
https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
|
||||
"""
|
||||
|
||||
def __init__(self, in_dim, out_dim=None, dropout=0.0):
|
||||
super().__init__()
|
||||
out_dim = out_dim or in_dim
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
|
||||
# conv layers
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.GroupNorm(32, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0))
|
||||
)
|
||||
self.conv2 = nn.Sequential(
|
||||
nn.GroupNorm(32, out_dim),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
|
||||
)
|
||||
self.conv3 = nn.Sequential(
|
||||
nn.GroupNorm(32, out_dim),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
|
||||
)
|
||||
self.conv4 = nn.Sequential(
|
||||
nn.GroupNorm(32, out_dim),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
|
||||
)
|
||||
|
||||
# zero out the last layer params,so the conv block is identity
|
||||
nn.init.zeros_(self.conv4[-1].weight)
|
||||
nn.init.zeros_(self.conv4[-1].bias)
|
||||
|
||||
def forward(self, hidden_states, num_frames=1):
|
||||
hidden_states = (
|
||||
hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4)
|
||||
)
|
||||
|
||||
identity = hidden_states
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
hidden_states = self.conv3(hidden_states)
|
||||
hidden_states = self.conv4(hidden_states)
|
||||
|
||||
hidden_states = identity + hidden_states
|
||||
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(
|
||||
(hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:]
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
176
src/diffusers/models/transformer_temporal.py
Normal file
176
src/diffusers/models/transformer_temporal.py
Normal file
@@ -0,0 +1,176 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput
|
||||
from .attention import BasicTransformerBlock
|
||||
from .modeling_utils import ModelMixin
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransformerTemporalModelOutput(BaseOutput):
|
||||
"""
|
||||
Args:
|
||||
sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`)
|
||||
Hidden states conditioned on `encoder_hidden_states` input.
|
||||
"""
|
||||
|
||||
sample: torch.FloatTensor
|
||||
|
||||
|
||||
class TransformerTemporalModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
Transformer model for video-like data.
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
||||
in_channels (`int`, *optional*):
|
||||
Pass if the input is continuous. The number of channels in the input and output.
|
||||
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
|
||||
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
|
||||
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
|
||||
`ImagePositionalEmbeddings`.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
attention_bias (`bool`, *optional*):
|
||||
Configure if the TransformerBlocks' attention should contain a bias parameter.
|
||||
double_self_attention (`bool`, *optional*):
|
||||
Configure if each TransformerBlock should contain two self-attention layers
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 16,
|
||||
attention_head_dim: int = 88,
|
||||
in_channels: Optional[int] = None,
|
||||
out_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
dropout: float = 0.0,
|
||||
norm_num_groups: int = 32,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
sample_size: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
norm_elementwise_affine: bool = True,
|
||||
double_self_attention: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
self.proj_in = nn.Linear(in_channels, inner_dim)
|
||||
|
||||
# 3. Define transformers blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
dropout=dropout,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
activation_fn=activation_fn,
|
||||
attention_bias=attention_bias,
|
||||
double_self_attention=double_self_attention,
|
||||
norm_elementwise_affine=norm_elementwise_affine,
|
||||
)
|
||||
for d in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.proj_out = nn.Linear(inner_dim, in_channels)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
timestep=None,
|
||||
class_labels=None,
|
||||
num_frames=1,
|
||||
cross_attention_kwargs=None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
|
||||
When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
|
||||
hidden_states
|
||||
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
||||
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
||||
self-attention.
|
||||
timestep ( `torch.long`, *optional*):
|
||||
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
|
||||
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
||||
Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels
|
||||
conditioning.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.transformer_2d.TransformerTemporalModelOutput`] or `tuple`:
|
||||
[`~models.transformer_2d.TransformerTemporalModelOutput`] if `return_dict` is True, otherwise a `tuple`.
|
||||
When returning a tuple, the first element is the sample tensor.
|
||||
"""
|
||||
# 1. Input
|
||||
batch_frames, channel, height, width = hidden_states.shape
|
||||
batch_size = batch_frames // num_frames
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width)
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)
|
||||
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
|
||||
# 2. Blocks
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
timestep=timestep,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
class_labels=class_labels,
|
||||
)
|
||||
|
||||
# 3. Output
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = (
|
||||
hidden_states[None, None, :]
|
||||
.reshape(batch_size, height, width, channel, num_frames)
|
||||
.permute(0, 3, 4, 1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
hidden_states = hidden_states.reshape(batch_frames, channel, height, width)
|
||||
|
||||
output = hidden_states + residual
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return TransformerTemporalModelOutput(sample=output)
|
||||
670
src/diffusers/models/unet_3d_blocks.py
Normal file
670
src/diffusers/models/unet_3d_blocks.py
Normal file
@@ -0,0 +1,670 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .resnet import Downsample2D, ResnetBlock2D, TemporalConvLayer, Upsample2D
|
||||
from .transformer_2d import Transformer2DModel
|
||||
from .transformer_temporal import TransformerTemporalModel
|
||||
|
||||
|
||||
def get_down_block(
|
||||
down_block_type,
|
||||
num_layers,
|
||||
in_channels,
|
||||
out_channels,
|
||||
temb_channels,
|
||||
add_downsample,
|
||||
resnet_eps,
|
||||
resnet_act_fn,
|
||||
attn_num_head_channels,
|
||||
resnet_groups=None,
|
||||
cross_attention_dim=None,
|
||||
downsample_padding=None,
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=True,
|
||||
only_cross_attention=False,
|
||||
upcast_attention=False,
|
||||
resnet_time_scale_shift="default",
|
||||
):
|
||||
if down_block_type == "DownBlock3D":
|
||||
return DownBlock3D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
add_downsample=add_downsample,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
downsample_padding=downsample_padding,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
elif down_block_type == "CrossAttnDownBlock3D":
|
||||
if cross_attention_dim is None:
|
||||
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
|
||||
return CrossAttnDownBlock3D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
add_downsample=add_downsample,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
downsample_padding=downsample_padding,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention,
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
raise ValueError(f"{down_block_type} does not exist.")
|
||||
|
||||
|
||||
def get_up_block(
|
||||
up_block_type,
|
||||
num_layers,
|
||||
in_channels,
|
||||
out_channels,
|
||||
prev_output_channel,
|
||||
temb_channels,
|
||||
add_upsample,
|
||||
resnet_eps,
|
||||
resnet_act_fn,
|
||||
attn_num_head_channels,
|
||||
resnet_groups=None,
|
||||
cross_attention_dim=None,
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=True,
|
||||
only_cross_attention=False,
|
||||
upcast_attention=False,
|
||||
resnet_time_scale_shift="default",
|
||||
):
|
||||
if up_block_type == "UpBlock3D":
|
||||
return UpBlock3D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=temb_channels,
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
elif up_block_type == "CrossAttnUpBlock3D":
|
||||
if cross_attention_dim is None:
|
||||
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
|
||||
return CrossAttnUpBlock3D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=temb_channels,
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention,
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
raise ValueError(f"{up_block_type} does not exist.")
|
||||
|
||||
|
||||
class UNetMidBlock3DCrossAttn(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
temb_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
attn_num_head_channels=1,
|
||||
output_scale_factor=1.0,
|
||||
cross_attention_dim=1280,
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=True,
|
||||
upcast_attention=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.has_cross_attention = True
|
||||
self.attn_num_head_channels = attn_num_head_channels
|
||||
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||
|
||||
# there is always at least one resnet
|
||||
resnets = [
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
]
|
||||
temp_convs = [
|
||||
TemporalConvLayer(
|
||||
in_channels,
|
||||
in_channels,
|
||||
dropout=0.1,
|
||||
)
|
||||
]
|
||||
attentions = []
|
||||
temp_attentions = []
|
||||
|
||||
for _ in range(num_layers):
|
||||
attentions.append(
|
||||
Transformer2DModel(
|
||||
in_channels // attn_num_head_channels,
|
||||
attn_num_head_channels,
|
||||
in_channels=in_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
)
|
||||
temp_attentions.append(
|
||||
TransformerTemporalModel(
|
||||
in_channels // attn_num_head_channels,
|
||||
attn_num_head_channels,
|
||||
in_channels=in_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
)
|
||||
)
|
||||
resnets.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
temp_convs.append(
|
||||
TemporalConvLayer(
|
||||
in_channels,
|
||||
in_channels,
|
||||
dropout=0.1,
|
||||
)
|
||||
)
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
self.temp_convs = nn.ModuleList(temp_convs)
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.temp_attentions = nn.ModuleList(temp_attentions)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
temb=None,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
num_frames=1,
|
||||
cross_attention_kwargs=None,
|
||||
):
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames)
|
||||
for attn, temp_attn, resnet, temp_conv in zip(
|
||||
self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:]
|
||||
):
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
).sample
|
||||
hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CrossAttnDownBlock3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
temb_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
attn_num_head_channels=1,
|
||||
cross_attention_dim=1280,
|
||||
output_scale_factor=1.0,
|
||||
downsample_padding=1,
|
||||
add_downsample=True,
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=False,
|
||||
only_cross_attention=False,
|
||||
upcast_attention=False,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
attentions = []
|
||||
temp_attentions = []
|
||||
temp_convs = []
|
||||
|
||||
self.has_cross_attention = True
|
||||
self.attn_num_head_channels = attn_num_head_channels
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else out_channels
|
||||
resnets.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
temp_convs.append(
|
||||
TemporalConvLayer(
|
||||
out_channels,
|
||||
out_channels,
|
||||
dropout=0.1,
|
||||
)
|
||||
)
|
||||
attentions.append(
|
||||
Transformer2DModel(
|
||||
out_channels // attn_num_head_channels,
|
||||
attn_num_head_channels,
|
||||
in_channels=out_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
)
|
||||
temp_attentions.append(
|
||||
TransformerTemporalModel(
|
||||
out_channels // attn_num_head_channels,
|
||||
attn_num_head_channels,
|
||||
in_channels=out_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
)
|
||||
)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
self.temp_convs = nn.ModuleList(temp_convs)
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.temp_attentions = nn.ModuleList(temp_attentions)
|
||||
|
||||
if add_downsample:
|
||||
self.downsamplers = nn.ModuleList(
|
||||
[
|
||||
Downsample2D(
|
||||
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.downsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
temb=None,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
num_frames=1,
|
||||
cross_attention_kwargs=None,
|
||||
):
|
||||
# TODO(Patrick, William) - attention mask is not used
|
||||
output_states = ()
|
||||
|
||||
for resnet, temp_conv, attn, temp_attn in zip(
|
||||
self.resnets, self.temp_convs, self.attentions, self.temp_attentions
|
||||
):
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
).sample
|
||||
hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample
|
||||
|
||||
output_states += (hidden_states,)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
|
||||
output_states += (hidden_states,)
|
||||
|
||||
return hidden_states, output_states
|
||||
|
||||
|
||||
class DownBlock3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
temb_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
output_scale_factor=1.0,
|
||||
add_downsample=True,
|
||||
downsample_padding=1,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
temp_convs = []
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else out_channels
|
||||
resnets.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
temp_convs.append(
|
||||
TemporalConvLayer(
|
||||
out_channels,
|
||||
out_channels,
|
||||
dropout=0.1,
|
||||
)
|
||||
)
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
self.temp_convs = nn.ModuleList(temp_convs)
|
||||
|
||||
if add_downsample:
|
||||
self.downsamplers = nn.ModuleList(
|
||||
[
|
||||
Downsample2D(
|
||||
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.downsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states, temb=None, num_frames=1):
|
||||
output_states = ()
|
||||
|
||||
for resnet, temp_conv in zip(self.resnets, self.temp_convs):
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
|
||||
|
||||
output_states += (hidden_states,)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
|
||||
output_states += (hidden_states,)
|
||||
|
||||
return hidden_states, output_states
|
||||
|
||||
|
||||
class CrossAttnUpBlock3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
prev_output_channel: int,
|
||||
temb_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
attn_num_head_channels=1,
|
||||
cross_attention_dim=1280,
|
||||
output_scale_factor=1.0,
|
||||
add_upsample=True,
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=False,
|
||||
only_cross_attention=False,
|
||||
upcast_attention=False,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
temp_convs = []
|
||||
attentions = []
|
||||
temp_attentions = []
|
||||
|
||||
self.has_cross_attention = True
|
||||
self.attn_num_head_channels = attn_num_head_channels
|
||||
|
||||
for i in range(num_layers):
|
||||
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
||||
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
||||
|
||||
resnets.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=resnet_in_channels + res_skip_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
temp_convs.append(
|
||||
TemporalConvLayer(
|
||||
out_channels,
|
||||
out_channels,
|
||||
dropout=0.1,
|
||||
)
|
||||
)
|
||||
attentions.append(
|
||||
Transformer2DModel(
|
||||
out_channels // attn_num_head_channels,
|
||||
attn_num_head_channels,
|
||||
in_channels=out_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
)
|
||||
temp_attentions.append(
|
||||
TransformerTemporalModel(
|
||||
out_channels // attn_num_head_channels,
|
||||
attn_num_head_channels,
|
||||
in_channels=out_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
)
|
||||
)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
self.temp_convs = nn.ModuleList(temp_convs)
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.temp_attentions = nn.ModuleList(temp_attentions)
|
||||
|
||||
if add_upsample:
|
||||
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
||||
else:
|
||||
self.upsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
res_hidden_states_tuple,
|
||||
temb=None,
|
||||
encoder_hidden_states=None,
|
||||
upsample_size=None,
|
||||
attention_mask=None,
|
||||
num_frames=1,
|
||||
cross_attention_kwargs=None,
|
||||
):
|
||||
# TODO(Patrick, William) - attention mask is not used
|
||||
for resnet, temp_conv, attn, temp_attn in zip(
|
||||
self.resnets, self.temp_convs, self.attentions, self.temp_attentions
|
||||
):
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
).sample
|
||||
hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states, upsample_size)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class UpBlock3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
prev_output_channel: int,
|
||||
out_channels: int,
|
||||
temb_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
output_scale_factor=1.0,
|
||||
add_upsample=True,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
temp_convs = []
|
||||
|
||||
for i in range(num_layers):
|
||||
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
||||
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
||||
|
||||
resnets.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=resnet_in_channels + res_skip_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
temp_convs.append(
|
||||
TemporalConvLayer(
|
||||
out_channels,
|
||||
out_channels,
|
||||
dropout=0.1,
|
||||
)
|
||||
)
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
self.temp_convs = nn.ModuleList(temp_convs)
|
||||
|
||||
if add_upsample:
|
||||
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
||||
else:
|
||||
self.upsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, num_frames=1):
|
||||
for resnet, temp_conv in zip(self.resnets, self.temp_convs):
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states, upsample_size)
|
||||
|
||||
return hidden_states
|
||||
492
src/diffusers/models/unet_3d_condition.py
Normal file
492
src/diffusers/models/unet_3d_condition.py
Normal file
@@ -0,0 +1,492 @@
|
||||
# Copyright 2023 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2023 The ModelScope Team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput, logging
|
||||
from .embeddings import TimestepEmbedding, Timesteps
|
||||
from .modeling_utils import ModelMixin
|
||||
from .transformer_temporal import TransformerTemporalModel
|
||||
from .unet_3d_blocks import (
|
||||
CrossAttnDownBlock3D,
|
||||
CrossAttnUpBlock3D,
|
||||
DownBlock3D,
|
||||
UNetMidBlock3DCrossAttn,
|
||||
UpBlock3D,
|
||||
get_down_block,
|
||||
get_up_block,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclass
|
||||
class UNet3DConditionOutput(BaseOutput):
|
||||
"""
|
||||
Args:
|
||||
sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
|
||||
Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
|
||||
"""
|
||||
|
||||
sample: torch.FloatTensor
|
||||
|
||||
|
||||
class UNet3DConditionModel(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
UNet3DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
|
||||
and returns sample shaped output.
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
|
||||
implements for all the models (such as downloading or saving, etc.)
|
||||
|
||||
Parameters:
|
||||
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
||||
Height and width of input/output sample.
|
||||
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
|
||||
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
||||
The tuple of downsample blocks to use.
|
||||
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
|
||||
The tuple of upsample blocks to use.
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
||||
The tuple of output channels for each block.
|
||||
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
|
||||
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
|
||||
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
||||
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
|
||||
If `None`, it will skip the normalization and activation layers in post-processing
|
||||
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
|
||||
cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
|
||||
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = False
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
sample_size: Optional[int] = None,
|
||||
in_channels: int = 4,
|
||||
out_channels: int = 4,
|
||||
down_block_types: Tuple[str] = (
|
||||
"CrossAttnDownBlock3D",
|
||||
"CrossAttnDownBlock3D",
|
||||
"CrossAttnDownBlock3D",
|
||||
"DownBlock3D",
|
||||
),
|
||||
up_block_types: Tuple[str] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"),
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
||||
layers_per_block: int = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
act_fn: str = "silu",
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 1024,
|
||||
attention_head_dim: Union[int, Tuple[int]] = 64,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.sample_size = sample_size
|
||||
|
||||
# Check inputs
|
||||
if len(down_block_types) != len(up_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
||||
)
|
||||
|
||||
if len(block_out_channels) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
# input
|
||||
conv_in_kernel = 3
|
||||
conv_out_kernel = 3
|
||||
conv_in_padding = (conv_in_kernel - 1) // 2
|
||||
self.conv_in = nn.Conv2d(
|
||||
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
||||
)
|
||||
|
||||
# time
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
self.time_proj = Timesteps(block_out_channels[0], True, 0)
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
|
||||
self.time_embedding = TimestepEmbedding(
|
||||
timestep_input_dim,
|
||||
time_embed_dim,
|
||||
act_fn=act_fn,
|
||||
)
|
||||
|
||||
self.transformer_in = TransformerTemporalModel(
|
||||
num_attention_heads=8,
|
||||
attention_head_dim=attention_head_dim,
|
||||
in_channels=block_out_channels[0],
|
||||
num_layers=1,
|
||||
)
|
||||
|
||||
# class embedding
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
||||
if isinstance(attention_head_dim, int):
|
||||
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
||||
|
||||
# down
|
||||
output_channel = block_out_channels[0]
|
||||
for i, down_block_type in enumerate(down_block_types):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
down_block = get_down_block(
|
||||
down_block_type,
|
||||
num_layers=layers_per_block,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
add_downsample=not is_final_block,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attention_head_dim[i],
|
||||
downsample_padding=downsample_padding,
|
||||
dual_cross_attention=False,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
# mid
|
||||
self.mid_block = UNetMidBlock3DCrossAttn(
|
||||
in_channels=block_out_channels[-1],
|
||||
temb_channels=time_embed_dim,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attention_head_dim[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
dual_cross_attention=False,
|
||||
)
|
||||
|
||||
# count how many layers upsample the images
|
||||
self.num_upsamplers = 0
|
||||
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
reversed_attention_head_dim = list(reversed(attention_head_dim))
|
||||
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
for i, up_block_type in enumerate(up_block_types):
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
prev_output_channel = output_channel
|
||||
output_channel = reversed_block_out_channels[i]
|
||||
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
||||
|
||||
# add upsample block for all BUT final layer
|
||||
if not is_final_block:
|
||||
add_upsample = True
|
||||
self.num_upsamplers += 1
|
||||
else:
|
||||
add_upsample = False
|
||||
|
||||
up_block = get_up_block(
|
||||
up_block_type,
|
||||
num_layers=layers_per_block + 1,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=reversed_attention_head_dim[i],
|
||||
dual_cross_attention=False,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
# out
|
||||
if norm_num_groups is not None:
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
|
||||
)
|
||||
self.conv_act = nn.SiLU()
|
||||
else:
|
||||
self.conv_norm_out = None
|
||||
self.conv_act = None
|
||||
|
||||
conv_out_padding = (conv_out_kernel - 1) // 2
|
||||
self.conv_out = nn.Conv2d(
|
||||
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
|
||||
)
|
||||
|
||||
def set_attention_slice(self, slice_size):
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
||||
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
||||
|
||||
Args:
|
||||
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
||||
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
||||
`"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
|
||||
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
||||
must be a multiple of `slice_size`.
|
||||
"""
|
||||
sliceable_head_dims = []
|
||||
|
||||
def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
|
||||
if hasattr(module, "set_attention_slice"):
|
||||
sliceable_head_dims.append(module.sliceable_head_dim)
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_retrieve_slicable_dims(child)
|
||||
|
||||
# retrieve number of attention layers
|
||||
for module in self.children():
|
||||
fn_recursive_retrieve_slicable_dims(module)
|
||||
|
||||
num_slicable_layers = len(sliceable_head_dims)
|
||||
|
||||
if slice_size == "auto":
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
||||
elif slice_size == "max":
|
||||
# make smallest slice possible
|
||||
slice_size = num_slicable_layers * [1]
|
||||
|
||||
slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
||||
|
||||
if len(slice_size) != len(sliceable_head_dims):
|
||||
raise ValueError(
|
||||
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
||||
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
||||
)
|
||||
|
||||
for i in range(len(slice_size)):
|
||||
size = slice_size[i]
|
||||
dim = sliceable_head_dims[i]
|
||||
if size is not None and size > dim:
|
||||
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
||||
|
||||
# Recursively walk through all the children.
|
||||
# Any children which exposes the set_attention_slice method
|
||||
# gets the message
|
||||
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
||||
if hasattr(module, "set_attention_slice"):
|
||||
module.set_attention_slice(slice_size.pop())
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_set_attention_slice(child, slice_size)
|
||||
|
||||
reversed_slice_size = list(reversed(slice_size))
|
||||
for module in self.children():
|
||||
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
||||
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[UNet3DConditionOutput, Tuple]:
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.FloatTensor`): (batch, num_frames, channel, height, width) noisy inputs tensor
|
||||
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
||||
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`models.unet_2d_condition.UNet3DConditionOutput`] instead of a plain tuple.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
||||
|
||||
Returns:
|
||||
[`~models.unet_2d_condition.UNet3DConditionOutput`] or `tuple`:
|
||||
[`~models.unet_2d_condition.UNet3DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
||||
returning a tuple, the first element is the sample tensor.
|
||||
"""
|
||||
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
||||
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
||||
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
||||
# on the fly if necessary.
|
||||
default_overall_up_factor = 2**self.num_upsamplers
|
||||
|
||||
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
||||
forward_upsample_size = False
|
||||
upsample_size = None
|
||||
|
||||
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
||||
logger.info("Forward upsample size to force interpolation output size.")
|
||||
forward_upsample_size = True
|
||||
|
||||
# prepare attention_mask
|
||||
if attention_mask is not None:
|
||||
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
||||
# This would be a good case for the `match` statement (Python 3.10+)
|
||||
is_mps = sample.device.type == "mps"
|
||||
if isinstance(timestep, float):
|
||||
dtype = torch.float32 if is_mps else torch.float64
|
||||
else:
|
||||
dtype = torch.int32 if is_mps else torch.int64
|
||||
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
||||
elif len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
num_frames = sample.shape[2]
|
||||
timesteps = timesteps.expand(sample.shape[0])
|
||||
|
||||
t_emb = self.time_proj(timesteps)
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=self.dtype)
|
||||
|
||||
emb = self.time_embedding(t_emb, timestep_cond)
|
||||
emb = emb.repeat_interleave(repeats=num_frames, dim=0)
|
||||
encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
|
||||
|
||||
# 2. pre-process
|
||||
sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:])
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
sample = self.transformer_in(sample, num_frames=num_frames).sample
|
||||
|
||||
# 3. down
|
||||
down_block_res_samples = (sample,)
|
||||
for downsample_block in self.down_blocks:
|
||||
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
num_frames=num_frames,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames)
|
||||
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
if down_block_additional_residuals is not None:
|
||||
new_down_block_res_samples = ()
|
||||
|
||||
for down_block_res_sample, down_block_additional_residual in zip(
|
||||
down_block_res_samples, down_block_additional_residuals
|
||||
):
|
||||
down_block_res_sample = down_block_res_sample + down_block_additional_residual
|
||||
new_down_block_res_samples += (down_block_res_sample,)
|
||||
|
||||
down_block_res_samples = new_down_block_res_samples
|
||||
|
||||
# 4. mid
|
||||
if self.mid_block is not None:
|
||||
sample = self.mid_block(
|
||||
sample,
|
||||
emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
num_frames=num_frames,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
)
|
||||
|
||||
if mid_block_additional_residual is not None:
|
||||
sample = sample + mid_block_additional_residual
|
||||
|
||||
# 5. up
|
||||
for i, upsample_block in enumerate(self.up_blocks):
|
||||
is_final_block = i == len(self.up_blocks) - 1
|
||||
|
||||
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
||||
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
||||
|
||||
# if we have not reached the final block and need to forward the
|
||||
# upsample size, we do it here
|
||||
if not is_final_block and forward_upsample_size:
|
||||
upsample_size = down_block_res_samples[-1].shape[2:]
|
||||
|
||||
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
||||
sample = upsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
res_hidden_states_tuple=res_samples,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
upsample_size=upsample_size,
|
||||
attention_mask=attention_mask,
|
||||
num_frames=num_frames,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
)
|
||||
else:
|
||||
sample = upsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
res_hidden_states_tuple=res_samples,
|
||||
upsample_size=upsample_size,
|
||||
num_frames=num_frames,
|
||||
)
|
||||
|
||||
# 6. post-process
|
||||
if self.conv_norm_out:
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
# reshape to (batch, channel, framerate, width, height)
|
||||
sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4)
|
||||
|
||||
if not return_dict:
|
||||
return (sample,)
|
||||
|
||||
return UNet3DConditionOutput(sample=sample)
|
||||
@@ -65,6 +65,7 @@ else:
|
||||
StableUnCLIPPipeline,
|
||||
)
|
||||
from .stable_diffusion_safe import StableDiffusionPipelineSafe
|
||||
from .text_to_video_synthesis import TextToVideoSDPipeline
|
||||
from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline
|
||||
from .versatile_diffusion import (
|
||||
VersatileDiffusionDualGuidedPipeline,
|
||||
|
||||
@@ -234,7 +234,7 @@ class AltDiffusionPipeline(DiffusionPipeline):
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.")
|
||||
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
|
||||
@@ -244,7 +244,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.")
|
||||
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
|
||||
@@ -258,7 +258,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.")
|
||||
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
|
||||
@@ -237,7 +237,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.")
|
||||
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
|
||||
@@ -274,7 +274,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline):
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.")
|
||||
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
|
||||
@@ -249,7 +249,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.")
|
||||
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
|
||||
@@ -293,7 +293,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.")
|
||||
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
|
||||
@@ -237,7 +237,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.")
|
||||
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
|
||||
@@ -432,7 +432,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.")
|
||||
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
|
||||
@@ -158,7 +158,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.")
|
||||
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
|
||||
@@ -394,7 +394,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.")
|
||||
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
|
||||
31
src/diffusers/pipelines/text_to_video_synthesis/__init__.py
Normal file
31
src/diffusers/pipelines/text_to_video_synthesis/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ...utils import BaseOutput, OptionalDependencyNotAvailable, is_torch_available, is_transformers_available
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextToVideoSDPipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for text to video pipelines.
|
||||
|
||||
Args:
|
||||
frames (`List[np.ndarray]` or `torch.FloatTensor`)
|
||||
List of denoised frames (essentially images) as NumPy arrays of shape `(height, width, num_channels)` or as
|
||||
a `torch` tensor. NumPy array present the denoised images of the diffusion pipeline. The length of the list
|
||||
denotes the video length i.e., the number of frames.
|
||||
"""
|
||||
|
||||
frames: Union[List[np.ndarray], torch.FloatTensor]
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .pipeline_text_to_video_synth import TextToVideoSDPipeline # noqa: F401
|
||||
@@ -0,0 +1,668 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...models import AutoencoderKL, UNet3DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
logging,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import TextToVideoSDPipelineOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import TextToVideoSDPipeline
|
||||
>>> from diffusers.utils import export_to_video
|
||||
|
||||
>>> pipe = TextToVideoSDPipeline.from_pretrained(
|
||||
... "damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16"
|
||||
... )
|
||||
>>> pipe.enable_model_cpu_offload()
|
||||
|
||||
>>> prompt = "Spiderman is surfing"
|
||||
>>> video_frames = pipe(prompt).frames
|
||||
>>> video_path = export_to_video(video_frames)
|
||||
>>> video_path
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]:
|
||||
# This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
|
||||
# reshape to ncfhw
|
||||
mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1)
|
||||
std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1)
|
||||
# unnormalize back to [0,1]
|
||||
video = video.mul_(std).add_(mean)
|
||||
video.clamp_(0, 1)
|
||||
# prepare the final outputs
|
||||
i, c, f, h, w = video.shape
|
||||
images = video.permute(2, 3, 0, 4, 1).reshape(
|
||||
f, h, i * w, c
|
||||
) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c)
|
||||
images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames)
|
||||
images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c
|
||||
return images
|
||||
|
||||
|
||||
class TextToVideoSDPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-video generation.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`CLIPTextModel`]):
|
||||
Frozen text-encoder. Same as Stable Diffusion 2.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
unet ([`UNet3DConditionModel`]): Conditional U-Net architecture to denoise the encoded video latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet3DConditionModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
|
||||
def enable_vae_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding.
|
||||
|
||||
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
|
||||
steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.vae.enable_slicing()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
|
||||
def disable_vae_slicing(self):
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_slicing()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
|
||||
def enable_vae_tiling(self):
|
||||
r"""
|
||||
Enable tiled VAE decoding.
|
||||
|
||||
When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
|
||||
several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
|
||||
"""
|
||||
self.vae.enable_tiling()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
|
||||
def disable_vae_tiling(self):
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_tiling()
|
||||
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
||||
text_encoder, vae have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded
|
||||
to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a
|
||||
submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower.
|
||||
"""
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
|
||||
from accelerate import cpu_offload
|
||||
else:
|
||||
raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
def enable_model_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
||||
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
||||
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
||||
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
||||
"""
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
# We'll offload the last model manually.
|
||||
self.final_offload_hook = hook
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
||||
def _execution_device(self):
|
||||
r"""
|
||||
Returns the device on which the pipeline's models will be executed. After calling
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if not hasattr(self.unet, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.unet.modules():
|
||||
if (
|
||||
hasattr(module, "_hf_hook")
|
||||
and hasattr(module._hf_hook, "execution_device")
|
||||
and module._hf_hook.execution_device is not None
|
||||
):
|
||||
return torch.device(module._hf_hook.execution_device)
|
||||
return self.device
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
|
||||
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
"""
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = self.tokenizer.batch_decode(
|
||||
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
||||
attention_mask = text_inputs.attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
prompt_embeds = self.text_encoder(
|
||||
text_input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
prompt_embeds = prompt_embeds[0]
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
||||
attention_mask = uncond_input.attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
negative_prompt_embeds = self.text_encoder(
|
||||
uncond_input.input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
negative_prompt_embeds = negative_prompt_embeds[0]
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
|
||||
batch_size, channels, num_frames, height, width = latents.shape
|
||||
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
|
||||
|
||||
image = self.vae.decode(latents).sample
|
||||
video = (
|
||||
image[None, :]
|
||||
.reshape(
|
||||
(
|
||||
batch_size,
|
||||
num_frames,
|
||||
-1,
|
||||
)
|
||||
+ image.shape[2:]
|
||||
)
|
||||
.permute(0, 2, 1, 3, 4)
|
||||
)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
video = video.float()
|
||||
return video
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
callback_steps,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
def prepare_latents(
|
||||
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
||||
):
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
num_frames,
|
||||
height // self.vae_scale_factor,
|
||||
width // self.vae_scale_factor,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_frames: int = 16,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 9.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "np",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated video.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated video.
|
||||
num_frames (`int`, *optional*, defaults to 16):
|
||||
The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds
|
||||
amounts to 2 seconds of video.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality videos at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate videos that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower video quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the video generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
|
||||
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`. Latents should be of shape
|
||||
`(batch_size, num_channel, num_frames, height, width)`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
output_type (`str`, *optional*, defaults to `"np"`):
|
||||
The output format of the generate video. Choose between `torch.FloatTensor` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
||||
When returning a tuple, the first element is a list with the generated frames.
|
||||
"""
|
||||
# 0. Default height and width to unet
|
||||
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||
|
||||
num_images_per_prompt = 1
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
|
||||
)
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.unet.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
num_frames,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# reshape latents
|
||||
bsz, channel, frames, width, height = latents.shape
|
||||
latents = latents.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height)
|
||||
noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
# reshape latents back
|
||||
latents = latents[None, :].reshape(bsz, frames, channel, width, height).permute(0, 2, 1, 3, 4)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
video_tensor = self.decode_latents(latents)
|
||||
|
||||
if output_type == "pt":
|
||||
video = video_tensor
|
||||
else:
|
||||
video = tensor2vid(video_tensor)
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.final_offload_hook.offload()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return TextToVideoSDPipelineOutput(frames=video)
|
||||
@@ -92,6 +92,8 @@ if is_torch_available():
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from .testing_utils import export_to_video
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -122,6 +122,21 @@ class UNet2DModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class UNet3DConditionModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class VQModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -347,6 +347,21 @@ class StableUnCLIPPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class TextToVideoSDPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class UnCLIPImageVariationPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -169,6 +169,14 @@ if _onnx_available:
|
||||
if _onnx_available:
|
||||
logger.debug(f"Successfully imported onnxruntime version {_onnxruntime_version}")
|
||||
|
||||
# (sayakpaul): importlib.util.find_spec("opencv-python") returns None even when it's installed.
|
||||
# _opencv_available = importlib.util.find_spec("opencv-python") is not None
|
||||
try:
|
||||
_opencv_version = importlib_metadata.version("opencv-python")
|
||||
_opencv_available = True
|
||||
logger.debug(f"Successfully imported cv2 version {_opencv_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_opencv_available = False
|
||||
|
||||
_scipy_available = importlib.util.find_spec("scipy") is not None
|
||||
try:
|
||||
@@ -272,6 +280,10 @@ def is_onnx_available():
|
||||
return _onnx_available
|
||||
|
||||
|
||||
def is_opencv_available():
|
||||
return _opencv_available
|
||||
|
||||
|
||||
def is_scipy_available():
|
||||
return _scipy_available
|
||||
|
||||
@@ -332,6 +344,12 @@ ONNX_IMPORT_ERROR = """
|
||||
install onnxruntime`
|
||||
"""
|
||||
|
||||
# docstyle-ignore
|
||||
OPENCV_IMPORT_ERROR = """
|
||||
{0} requires the OpenCV library but it was not found in your environment. You can install it with pip: `pip
|
||||
install opencv-python`
|
||||
"""
|
||||
|
||||
# docstyle-ignore
|
||||
SCIPY_IMPORT_ERROR = """
|
||||
{0} requires the scipy library but it was not found in your environment. You can install it with pip: `pip install
|
||||
@@ -391,6 +409,7 @@ BACKENDS_MAPPING = OrderedDict(
|
||||
("flax", (is_flax_available, FLAX_IMPORT_ERROR)),
|
||||
("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)),
|
||||
("onnx", (is_onnx_available, ONNX_IMPORT_ERROR)),
|
||||
("opencv", (is_opencv_available, OPENCV_IMPORT_ERROR)),
|
||||
("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)),
|
||||
("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)),
|
||||
("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)),
|
||||
|
||||
@@ -3,12 +3,13 @@ import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import tempfile
|
||||
import unittest
|
||||
import urllib.parse
|
||||
from distutils.util import strtobool
|
||||
from io import BytesIO, StringIO
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
@@ -16,7 +17,14 @@ import PIL.ImageOps
|
||||
import requests
|
||||
from packaging import version
|
||||
|
||||
from .import_utils import is_compel_available, is_flax_available, is_onnx_available, is_torch_available
|
||||
from .import_utils import (
|
||||
BACKENDS_MAPPING,
|
||||
is_compel_available,
|
||||
is_flax_available,
|
||||
is_onnx_available,
|
||||
is_opencv_available,
|
||||
is_torch_available,
|
||||
)
|
||||
from .logging import get_logger
|
||||
|
||||
|
||||
@@ -253,6 +261,23 @@ def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
|
||||
return image
|
||||
|
||||
|
||||
def export_to_video(video_frames: List[np.ndarray], output_video_path: str = None) -> str:
|
||||
if is_opencv_available():
|
||||
import cv2
|
||||
else:
|
||||
raise ImportError(BACKENDS_MAPPING["opencv"][1].format("export_to_video"))
|
||||
if output_video_path is None:
|
||||
output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name
|
||||
|
||||
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
||||
h, w, c = video_frames[0].shape
|
||||
video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=8, frameSize=(w, h))
|
||||
for i in range(len(video_frames)):
|
||||
img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR)
|
||||
video_writer.write(img)
|
||||
return output_video_path
|
||||
|
||||
|
||||
def load_hf_numpy(path) -> np.ndarray:
|
||||
if not path.startswith("http://") or path.startswith("https://"):
|
||||
path = os.path.join(
|
||||
|
||||
242
tests/models/test_models_unet_3d_condition.py
Normal file
242
tests/models/test_models_unet_3d_condition.py
Normal file
@@ -0,0 +1,242 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers.models import ModelMixin, UNet3DConditionModel
|
||||
from diffusers.models.attention_processor import LoRAAttnProcessor
|
||||
from diffusers.utils import (
|
||||
floats_tensor,
|
||||
logging,
|
||||
torch_device,
|
||||
)
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
def create_lora_layers(model):
|
||||
lora_attn_procs = {}
|
||||
for name in model.attn_processors.keys():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = model.config.block_out_channels[-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
block_id = int(name[len("up_blocks.")])
|
||||
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = model.config.block_out_channels[block_id]
|
||||
|
||||
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
|
||||
lora_attn_procs[name] = lora_attn_procs[name].to(model.device)
|
||||
|
||||
# add 1 to weights to mock trained weights
|
||||
with torch.no_grad():
|
||||
lora_attn_procs[name].to_q_lora.up.weight += 1
|
||||
lora_attn_procs[name].to_k_lora.up.weight += 1
|
||||
lora_attn_procs[name].to_v_lora.up.weight += 1
|
||||
lora_attn_procs[name].to_out_lora.up.weight += 1
|
||||
|
||||
return lora_attn_procs
|
||||
|
||||
|
||||
class UNet3DConditionModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = UNet3DConditionModel
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_channels = 4
|
||||
num_frames = 4
|
||||
sizes = (32, 32)
|
||||
|
||||
noise = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
|
||||
time_step = torch.tensor([10]).to(torch_device)
|
||||
encoder_hidden_states = floats_tensor((batch_size, 4, 32)).to(torch_device)
|
||||
|
||||
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 4, 32, 32)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 4, 32, 32)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"block_out_channels": (32, 64, 64, 64),
|
||||
"down_block_types": (
|
||||
"CrossAttnDownBlock3D",
|
||||
"CrossAttnDownBlock3D",
|
||||
"CrossAttnDownBlock3D",
|
||||
"DownBlock3D",
|
||||
),
|
||||
"up_block_types": ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"),
|
||||
"cross_attention_dim": 32,
|
||||
"attention_head_dim": 4,
|
||||
"out_channels": 4,
|
||||
"in_channels": 4,
|
||||
"layers_per_block": 2,
|
||||
"sample_size": 32,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_enable_works(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
|
||||
assert (
|
||||
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
|
||||
== "XFormersAttnProcessor"
|
||||
), "xformers is not enabled"
|
||||
|
||||
# Overriding because `block_out_channels` needs to be different for this model.
|
||||
def test_forward_with_norm_groups(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["norm_num_groups"] = 32
|
||||
init_dict["block_out_channels"] = (32, 64, 64, 64)
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
# Overriding since the UNet3D outputs a different structure.
|
||||
def test_determinism(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
# Warmup pass when using mps (see #372)
|
||||
if torch_device == "mps" and isinstance(model, ModelMixin):
|
||||
model(**self.dummy_input)
|
||||
|
||||
first = model(**inputs_dict)
|
||||
if isinstance(first, dict):
|
||||
first = first.sample
|
||||
|
||||
second = model(**inputs_dict)
|
||||
if isinstance(second, dict):
|
||||
second = second.sample
|
||||
|
||||
out_1 = first.cpu().numpy()
|
||||
out_2 = second.cpu().numpy()
|
||||
out_1 = out_1[~np.isnan(out_1)]
|
||||
out_2 = out_2[~np.isnan(out_2)]
|
||||
max_diff = np.amax(np.abs(out_1 - out_2))
|
||||
self.assertLessEqual(max_diff, 1e-5)
|
||||
|
||||
def test_model_attention_slicing(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["attention_head_dim"] = 8
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
model.set_attention_slice("auto")
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
assert output is not None
|
||||
|
||||
model.set_attention_slice("max")
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
assert output is not None
|
||||
|
||||
model.set_attention_slice(2)
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
assert output is not None
|
||||
|
||||
# (`attn_processors`) needs to be implemented in this model for this test.
|
||||
# def test_lora_processors(self):
|
||||
|
||||
# (`attn_processors`) needs to be implemented in this model for this test.
|
||||
# def test_lora_save_load(self):
|
||||
|
||||
# (`attn_processors`) needs to be implemented for this test in the model.
|
||||
# def test_lora_save_load_safetensors(self):
|
||||
|
||||
# (`attn_processors`) needs to be implemented for this test in the model.
|
||||
# def test_lora_save_safetensors_load_torch(self):
|
||||
|
||||
# (`attn_processors`) needs to be implemented for this test.
|
||||
# def test_lora_save_torch_force_load_safetensors_error(self):
|
||||
|
||||
# (`attn_processors`) needs to be added for this test.
|
||||
# def test_lora_on_off(self):
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_lora_xformers_on_off(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["attention_head_dim"] = 4
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
lora_attn_procs = create_lora_layers(model)
|
||||
model.set_attn_processor(lora_attn_procs)
|
||||
|
||||
# default
|
||||
with torch.no_grad():
|
||||
sample = model(**inputs_dict).sample
|
||||
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
on_sample = model(**inputs_dict).sample
|
||||
|
||||
model.disable_xformers_memory_efficient_attention()
|
||||
off_sample = model(**inputs_dict).sample
|
||||
|
||||
assert (sample - on_sample).abs().max() < 1e-4
|
||||
assert (sample - off_sample).abs().max() < 1e-4
|
||||
|
||||
|
||||
# (todo: sayakpaul) implement SLOW tests.
|
||||
0
tests/pipelines/text_to_video/__init__.py
Normal file
0
tests/pipelines/text_to_video/__init__.py
Normal file
196
tests/pipelines/text_to_video/test_text_to_video.py
Normal file
196
tests/pipelines/text_to_video/test_text_to_video.py
Normal file
@@ -0,0 +1,196 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
TextToVideoSDPipeline,
|
||||
UNet3DConditionModel,
|
||||
)
|
||||
from diffusers.utils import load_numpy, skip_mps, slow
|
||||
|
||||
from ...pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ...test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
class TextToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = TextToVideoSDPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
# No `output_type`.
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"generator",
|
||||
"latents",
|
||||
"return_dict",
|
||||
"callback",
|
||||
"callback_steps",
|
||||
]
|
||||
)
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet3DConditionModel(
|
||||
block_out_channels=(32, 64, 64, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D"),
|
||||
up_block_types=("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"),
|
||||
cross_attention_dim=32,
|
||||
attention_head_dim=4,
|
||||
)
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
sample_size=128,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
hidden_act="gelu",
|
||||
projection_dim=512,
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
"scheduler": scheduler,
|
||||
"vae": vae,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"output_type": "pt",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_text_to_video_default_case(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = TextToVideoSDPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["output_type"] = "np"
|
||||
frames = sd_pipe(**inputs).frames
|
||||
image_slice = frames[0][-3:, -3:, -1]
|
||||
|
||||
assert frames[0].shape == (64, 64, 3)
|
||||
expected_slice = np.array([166, 184, 167, 118, 102, 123, 108, 93, 114])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
self._test_attention_slicing_forward_pass(test_mean_pixel_difference=False)
|
||||
|
||||
# (todo): sayakpaul
|
||||
@unittest.skip(reason="Batching needs to be properly figured out first for this pipeline.")
|
||||
def test_inference_batch_consistent(self):
|
||||
pass
|
||||
|
||||
# (todo): sayakpaul
|
||||
@unittest.skip(reason="Batching needs to be properly figured out first for this pipeline.")
|
||||
def test_inference_batch_single_identical(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="`num_images_per_prompt` argument is not supported for this pipeline.")
|
||||
def test_num_images_per_prompt(self):
|
||||
pass
|
||||
|
||||
@skip_mps
|
||||
def test_progress_bar(self):
|
||||
return super().test_progress_bar()
|
||||
|
||||
|
||||
@slow
|
||||
class TextToVideoSDPipelineSlowTests(unittest.TestCase):
|
||||
def test_full_model(self):
|
||||
expected_video = load_numpy(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text_to_video/video.npy"
|
||||
)
|
||||
|
||||
pipe = TextToVideoSDPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b")
|
||||
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
prompt = "Spiderman is surfing"
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
|
||||
video_frames = pipe(prompt, generator=generator, num_inference_steps=25, output_type="pt").frames
|
||||
video = video_frames.cpu().numpy()
|
||||
|
||||
assert np.abs(expected_video - video).mean() < 5e-2
|
||||
|
||||
def test_two_step_model(self):
|
||||
expected_video = load_numpy(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text_to_video/video_2step.npy"
|
||||
)
|
||||
|
||||
pipe = TextToVideoSDPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b")
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
prompt = "Spiderman is surfing"
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
|
||||
video_frames = pipe(prompt, generator=generator, num_inference_steps=2, output_type="pt").frames
|
||||
video = video_frames.cpu().numpy()
|
||||
|
||||
assert np.abs(expected_video - video).mean() < 5e-2
|
||||
@@ -20,6 +20,13 @@ from diffusers.utils.testing_utils import require_torch, torch_device
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
def to_np(tensor):
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
tensor = tensor.detach().cpu().numpy()
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
@require_torch
|
||||
class PipelineTesterMixin:
|
||||
"""
|
||||
@@ -130,7 +137,7 @@ class PipelineTesterMixin:
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_loaded = pipe_loaded(**inputs)[0]
|
||||
|
||||
max_diff = np.abs(output - output_loaded).max()
|
||||
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
|
||||
self.assertLess(max_diff, 1e-4)
|
||||
|
||||
def test_pipeline_call_signature(self):
|
||||
@@ -327,7 +334,7 @@ class PipelineTesterMixin:
|
||||
output = pipe(**self.get_dummy_inputs(torch_device))[0]
|
||||
output_tuple = pipe(**self.get_dummy_inputs(torch_device), return_dict=False)[0]
|
||||
|
||||
max_diff = np.abs(output - output_tuple).max()
|
||||
max_diff = np.abs(to_np(output) - to_np(output_tuple)).max()
|
||||
self.assertLess(max_diff, 1e-4)
|
||||
|
||||
def test_components_function(self):
|
||||
@@ -351,7 +358,7 @@ class PipelineTesterMixin:
|
||||
output = pipe(**self.get_dummy_inputs(torch_device))[0]
|
||||
output_fp16 = pipe_fp16(**self.get_dummy_inputs(torch_device))[0]
|
||||
|
||||
max_diff = np.abs(output - output_fp16).max()
|
||||
max_diff = np.abs(to_np(output) - to_np(output_fp16)).max()
|
||||
self.assertLess(max_diff, 1e-2, "The outputs of the fp16 and fp32 pipelines are too different.")
|
||||
|
||||
@unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA")
|
||||
@@ -383,7 +390,7 @@ class PipelineTesterMixin:
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_loaded = pipe_loaded(**inputs)[0]
|
||||
|
||||
max_diff = np.abs(output - output_loaded).max()
|
||||
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
|
||||
self.assertLess(max_diff, 1e-2, "The output of the fp16 pipeline changed after saving and loading.")
|
||||
|
||||
def test_save_load_optional_components(self):
|
||||
@@ -421,7 +428,7 @@ class PipelineTesterMixin:
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_loaded = pipe_loaded(**inputs)[0]
|
||||
|
||||
max_diff = np.abs(output - output_loaded).max()
|
||||
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
|
||||
self.assertLess(max_diff, 1e-4)
|
||||
|
||||
@unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices")
|
||||
@@ -442,7 +449,7 @@ class PipelineTesterMixin:
|
||||
self.assertTrue(all(device == "cuda" for device in model_devices))
|
||||
|
||||
output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0]
|
||||
self.assertTrue(np.isnan(output_cuda).sum() == 0)
|
||||
self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0)
|
||||
|
||||
def test_to_dtype(self):
|
||||
components = self.get_dummy_components()
|
||||
@@ -482,7 +489,7 @@ class PipelineTesterMixin:
|
||||
output_with_slicing = pipe(**inputs)[0]
|
||||
|
||||
if test_max_difference:
|
||||
max_diff = np.abs(output_with_slicing - output_without_slicing).max()
|
||||
max_diff = np.abs(to_np(output_with_slicing) - to_np(output_without_slicing)).max()
|
||||
self.assertLess(max_diff, expected_max_diff, "Attention slicing should not affect the inference results")
|
||||
|
||||
if test_mean_pixel_difference:
|
||||
@@ -508,7 +515,7 @@ class PipelineTesterMixin:
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_with_offload = pipe(**inputs)[0]
|
||||
|
||||
max_diff = np.abs(output_with_offload - output_without_offload).max()
|
||||
max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
|
||||
self.assertLess(max_diff, 1e-4, "CPU offloading should not affect the inference results")
|
||||
|
||||
@unittest.skipIf(
|
||||
|
||||
Reference in New Issue
Block a user