1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
* begin model

* finish blocks

* add_embedding

* addition_time_embed_dim

* use TimestepEmbedding

* fix temporal res block

* fix time_pos_embed

* fix add_embedding

* add conversion script

* fix model

* up

* add new resnet blocks

* make forward work

* return sample in original shape

* fix temb shape in TemporalResnetBlock

* add spatio temporal transformers

* add vae blocks

* fix blocks

* update

* update

* fix shapes in Alphablender and add time activation in res blcok

* use new blocks

* style

* fix temb shape

* fix SpatioTemporalResBlock

* reuse TemporalBasicTransformerBlock

* fix TemporalBasicTransformerBlock

* use TransformerSpatioTemporalModel

* fix TransformerSpatioTemporalModel

* fix time_context dim

* clean up

* make temb optional

* add blocks

* rename model

* update conversion script

* remove UNetMidBlockSpatioTemporal

* add in init

* remove unused arg

* remove unused arg

* remove more unsed args

* up

* up

* check for None

* update vae

* update up/mid blocks for decoder

* begin pipeline

* adapt scheduler

* add guidance scalings

* fix norm eps in temporal transformers

* add temporal autoencoder

* make pipeline run

* fix frame decodig

* decode in float32

* decode n frames at a time

* pass decoding_t to decode_latents

* fix decode_latents

* vae encode/decode in fp32

* fix dtype in TransformerSpatioTemporalModel

* type image_latents same as image_embeddings

* allow using differnt eps in temporal block for video decoder

* fix default values in vae

* pass num frames in decode

* switch spatial to temporal for mixing in VAE

* fix num frames during split decoding

* cast alpha to sample dtype

* fix attention in MidBlockTemporalDecoder

* fix typo

* fix guidance_scales dtype

* fix missing activation in TemporalDecoder

* skip_post_quant_conv

* add vae conversion

* style

* take guidance scale as input

* up

* allow passing PIL to export_video

* accept fps as arg

* add pipeline and vae in init

* remove hack

* use AutoencoderKLTemporalDecoder

* don't scale image latents

* add unet tests

* clean up unet

* clean TransformerSpatioTemporalModel

* add slow svd test

* clean up

* make temb optional in Decoder mid block

* fix norm eps in TransformerSpatioTemporalModel

* clean up temp decoder

* clean up

* clean up

* use c_noise values for timesteps

* use math for log

* update

* fix copies

* doc

* upcast vae

* update forward pass for gradient checkpointing

* make added_time_ids is tensor

* up

* fix upcasting

* remove post quant conv

* add _resize_with_antialiasing

* fix _compute_padding

* cleanup model

* more cleanup

* more cleanup

* more cleanup

* remove freeu

* remove attn slice

* small clean

* up

* up

* remove extra step kwargs

* remove eta

* remove dropout

* remove callback

* remove merge factor args

* clean

* clean up

* move to dedicated folder

* remove attention_head_dim

* docstr and small fix

* update unet doc strings

* rename decoding_t

* correct linting

* store c_skip and c_out

* cleanup

* clean TemporalResnetBlock

* more cleanup

* clean up vae

* clean up

* begin doc

* more cleanup

* up

* up

* doc

* Improve

* better naming

* better naming

* better naming

* better naming

* better naming

* better naming

* better naming

* better naming

* Apply suggestions from code review

* Default chunk size to None

* add example

* Better

* Apply suggestions from code review

* update doc

* Update src/diffusers/pipelines/stable_diffusion_video/pipeline_stable_diffusion_video.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* style

* Get torch compile working

* up

* rename

* fix doc

* add chunking

* torch compile

* torch compile

* add modelling outputs

* torch compile

* Improve chunking

* Apply suggestions from code review

* Update docs/source/en/using-diffusers/svd.md

* Close diff tag

* remove slicing

* resnet docstr

* add docstr in resnet

* rename

* Apply suggestions from code review

* update tests

* Fix output type latents

* fix more

* fix more

* Update docs/source/en/using-diffusers/svd.md

* fix more

* add pipeline tests

* remove unused arg

* clean  up

* make sure get_scaling receives tensors

* fix euler scheduler

* fix get_scalings

* simply euler for now

* remove old test file

* use randn_tensor to create noise

* fix device for rand tensor

* increase expected_max_difference

* fix test_inference_batch_single_identical

* actually fix test_inference_batch_single_identical

* disable test_save_load_float16

* skip test_float16_inference

* skip test_inference_batch_single_identical

* fix test_xformers_attention_forwardGenerator_pass

* Apply suggestions from code review

* update StableVideoDiffusionPipelineSlowTests

* update image

* add diffusers example

* fix more

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: apolinário <joaopaulo.passos@gmail.com>
This commit is contained in:
Suraj Patil
2023-11-29 19:13:36 +01:00
committed by GitHub
parent d1b2a1a957
commit 63f767ef15
38 changed files with 5287 additions and 149 deletions

View File

@@ -82,7 +82,7 @@ Models are designed as configurable toolboxes that are natural extensions of [Py
The following design principles are followed:
- Models correspond to **a type of model architecture**. *E.g.* the [`UNet2DConditionModel`] class is used for all UNet variations that expect 2D image inputs and are conditioned on some context.
- All models can be found in [`src/diffusers/models`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models) and every model architecture shall be defined in its file, e.g. [`unet_2d_condition.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py), [`transformer_2d.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformer_2d.py), etc...
- Models **do not** follow the single-file policy and should make use of smaller model building blocks, such as [`attention.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py), [`resnet.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py), [`embeddings.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py), etc... **Note**: This is in stark contrast to Transformers' modelling files and shows that models do not really follow the single-file policy.
- Models **do not** follow the single-file policy and should make use of smaller model building blocks, such as [`attention.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py), [`resnet.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py), [`embeddings.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py), etc... **Note**: This is in stark contrast to Transformers' modeling files and shows that models do not really follow the single-file policy.
- Models intend to expose complexity, just like PyTorch's `Module` class, and give clear error messages.
- Models all inherit from `ModelMixin` and `ConfigMixin`.
- Models can be optimized for performance when it doesnt demand major code changes, keep backward compatibility, and give significant memory or compute gain.

View File

@@ -94,6 +94,8 @@
title: Latent Consistency Model-LoRA
- local: using-diffusers/inference_with_lcm
title: Latent Consistency Model
- local: using-diffusers/svd
title: Stable Video Diffusion
title: Specific pipeline examples
- sections:
- local: training/overview

View File

@@ -0,0 +1,131 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Stable Video Diffusion
[[open-in-colab]]
[Stable Video Diffusion](https://static1.squarespace.com/static/6213c340453c3f502425776e/t/655ce779b9d47d342a93c890/1700587395994/stable_video_diffusion.pdf) is a powerful image-to-video generation model that can generate high resolution (576x1024) 2-4 second videos conditioned on the input image.
This guide will show you how to use SVD to short generate videos from images.
Before you begin, make sure you have the following libraries installed:
```py
!pip install -q -U diffusers transformers accelerate
```
## Image to Video Generation
The are two variants of SVD. [SVD](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid)
and [SVD-XT](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt). The svd checkpoint is trained to generate 14 frames and the svd-xt checkpoint is further
finetuned to generate 25 frames.
We will use the `svd-xt` checkpoint for this guide.
```python
import torch
from diffusers import StableVideoDiffusionPipeline
from diffusers.utils import load_image, export_to_video
pipe = StableVideoDiffusionPipeline.from_pretrained(
"stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16"
)
pipe.enable_model_cpu_offload()
# Load the conditioning image
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png?download=true")
image = image.resize((1024, 576))
generator = torch.manual_seed(42)
frames = pipe(image, decode_chunk_size=8, generator=generator).frames[0]
export_to_video(frames, "generated.mp4", fps=7)
```
<video width="1024" height="576" controls>
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket_generated.mp4?download=true" type="video/mp4">
</video>
<Tip>
Since generating videos is more memory intensive we can use the `decode_chunk_size` argument to control how many frames are decoded at once. This will reduce the memory usage. It's recommended to tweak this value based on your GPU memory.
Setting `decode_chunk_size=1` will decode one frame at a time and will use the least amount of memory but the video might have some flickering.
Additionally, we also use [model cpu offloading](../../optimization/memory#model-offloading) to reduce the memory usage.
</Tip>
### Torch.compile
You can achieve a 20-25% speed-up at the expense of slightly increased memory by compiling the UNet as follows:
```diff
- pipe.enable_model_cpu_offload()
+ pipe.to("cuda")
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
```
### Low-memory
Video generation is very memory intensive as we have to essentially generate `num_frames` all at once. The mechanism is very comparable to text-to-image generation with a high batch size. To reduce the memory requirement you have multiple options. The following options trade inference speed against lower memory requirement:
- enable model offloading: Each component of the pipeline is offloaded to CPU once it's not needed anymore.
- enable feed-forward chunking: The feed-forward layer runs in a loop instead of running with a single huge feed-forward batch size
- reduce `decode_chunk_size`: This means that the VAE decodes frames in chunks instead of decoding them all together. **Note**: In addition to leading to a small slowdown, this method also slightly leads to video quality deterioration
You can enable them as follows:
```diff
-pipe.enable_model_cpu_offload()
-frames = pipe(image, decode_chunk_size=8, generator=generator).frames[0]
+pipe.enable_model_cpu_offload()
+pipe.unet.enable_forward_chunking()
+frames = pipe(image, decode_chunk_size=2, generator=generator, num_frames=25).frames[0]
```
Including all these tricks should lower the memory requirement to less than 8GB VRAM.
### Micro-conditioning
Along with conditioning image Stable Diffusion Video also allows providing micro-conditioning that allows more control over the generated video.
It accepts the following arguments:
- `fps`: The frames per second of the generated video.
- `motion_bucket_id`: The motion bucket id to use for the generated video. This can be used to control the motion of the generated video. Increasing the motion bucket id will increase the motion of the generated video.
- `noise_aug_strength`: The amount of noise added to the conditioning image. The higher the values the less the video will resemble the conditioning image. Increasing this value will also increase the motion of the generated video.
Here is an example of using micro-conditioning to generate a video with more motion.
```python
import torch
from diffusers import StableVideoDiffusionPipeline
from diffusers.utils import load_image, export_to_video
pipe = StableVideoDiffusionPipeline.from_pretrained(
"stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16"
)
pipe.enable_model_cpu_offload()
# Load the conditioning image
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png?download=true")
image = image.resize((1024, 576))
generator = torch.manual_seed(42)
frames = pipe(image, decode_chunk_size=8, generator=generator, motion_bucket_id=180, noise_aug_strength=0.1).frames[0]
export_to_video(frames, "generated.mp4", fps=7)
```
<video width="1024" height="576" controls>
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket_generated_motion.mp4?download=true" type="video/mp4">
</video>

View File

@@ -0,0 +1,730 @@
from diffusers.utils import is_accelerate_available, logging
if is_accelerate_available():
pass
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def create_unet_diffusers_config(original_config, image_size: int, controlnet=False):
"""
Creates a config for the diffusers based on the config of the LDM model.
"""
if controlnet:
unet_params = original_config.model.params.control_stage_config.params
else:
if "unet_config" in original_config.model.params and original_config.model.params.unet_config is not None:
unet_params = original_config.model.params.unet_config.params
else:
unet_params = original_config.model.params.network_config.params
vae_params = original_config.model.params.first_stage_config.params.encoder_config.params
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
down_block_types = []
resolution = 1
for i in range(len(block_out_channels)):
block_type = (
"CrossAttnDownBlockSpatioTemporal"
if resolution in unet_params.attention_resolutions
else "DownBlockSpatioTemporal"
)
down_block_types.append(block_type)
if i != len(block_out_channels) - 1:
resolution *= 2
up_block_types = []
for i in range(len(block_out_channels)):
block_type = (
"CrossAttnUpBlockSpatioTemporal"
if resolution in unet_params.attention_resolutions
else "UpBlockSpatioTemporal"
)
up_block_types.append(block_type)
resolution //= 2
if unet_params.transformer_depth is not None:
transformer_layers_per_block = (
unet_params.transformer_depth
if isinstance(unet_params.transformer_depth, int)
else list(unet_params.transformer_depth)
)
else:
transformer_layers_per_block = 1
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
head_dim = unet_params.num_heads if "num_heads" in unet_params else None
use_linear_projection = (
unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
)
if use_linear_projection:
# stable diffusion 2-base-512 and 2-768
if head_dim is None:
head_dim_mult = unet_params.model_channels // unet_params.num_head_channels
head_dim = [head_dim_mult * c for c in list(unet_params.channel_mult)]
class_embed_type = None
addition_embed_type = None
addition_time_embed_dim = None
projection_class_embeddings_input_dim = None
context_dim = None
if unet_params.context_dim is not None:
context_dim = (
unet_params.context_dim if isinstance(unet_params.context_dim, int) else unet_params.context_dim[0]
)
if "num_classes" in unet_params:
if unet_params.num_classes == "sequential":
addition_time_embed_dim = 256
assert "adm_in_channels" in unet_params
projection_class_embeddings_input_dim = unet_params.adm_in_channels
config = {
"sample_size": image_size // vae_scale_factor,
"in_channels": unet_params.in_channels,
"down_block_types": tuple(down_block_types),
"block_out_channels": tuple(block_out_channels),
"layers_per_block": unet_params.num_res_blocks,
"cross_attention_dim": context_dim,
"attention_head_dim": head_dim,
"use_linear_projection": use_linear_projection,
"class_embed_type": class_embed_type,
"addition_embed_type": addition_embed_type,
"addition_time_embed_dim": addition_time_embed_dim,
"projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
"transformer_layers_per_block": transformer_layers_per_block,
}
if "disable_self_attentions" in unet_params:
config["only_cross_attention"] = unet_params.disable_self_attentions
if "num_classes" in unet_params and isinstance(unet_params.num_classes, int):
config["num_class_embeds"] = unet_params.num_classes
if controlnet:
config["conditioning_channels"] = unet_params.hint_channels
else:
config["out_channels"] = unet_params.out_channels
config["up_block_types"] = tuple(up_block_types)
return config
def assign_to_checkpoint(
paths,
checkpoint,
old_checkpoint,
attention_paths_to_split=None,
additional_replacements=None,
config=None,
mid_block_suffix="",
):
"""
This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
attention layers, and takes into account additional replacements that may arise.
Assigns the weights to the new checkpoint.
"""
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
# Splits the attention layers into three variables.
if attention_paths_to_split is not None:
for path, path_map in attention_paths_to_split.items():
old_tensor = old_checkpoint[path]
channels = old_tensor.shape[0] // 3
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
query, key, value = old_tensor.split(channels // num_heads, dim=1)
checkpoint[path_map["query"]] = query.reshape(target_shape)
checkpoint[path_map["key"]] = key.reshape(target_shape)
checkpoint[path_map["value"]] = value.reshape(target_shape)
if mid_block_suffix is not None:
mid_block_suffix = f".{mid_block_suffix}"
else:
mid_block_suffix = ""
for path in paths:
new_path = path["new"]
# These have already been assigned
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
continue
# Global renaming happens here
new_path = new_path.replace("middle_block.0", f"mid_block.resnets.0{mid_block_suffix}")
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
new_path = new_path.replace("middle_block.2", f"mid_block.resnets.1{mid_block_suffix}")
if additional_replacements is not None:
for replacement in additional_replacements:
new_path = new_path.replace(replacement["old"], replacement["new"])
if new_path == "mid_block.resnets.0.spatial_res_block.norm1.weight":
print("yeyy")
# proj_attn.weight has to be converted from conv 1D to linear
is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path)
shape = old_checkpoint[path["old"]].shape
if is_attn_weight and len(shape) == 3:
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
elif is_attn_weight and len(shape) == 4:
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
else:
checkpoint[new_path] = old_checkpoint[path["old"]]
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
"""
Updates paths inside attentions to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
new_item = new_item.replace("time_stack", "temporal_transformer_blocks")
new_item = new_item.replace("time_pos_embed.0.bias", "time_pos_embed.linear_1.bias")
new_item = new_item.replace("time_pos_embed.0.weight", "time_pos_embed.linear_1.weight")
new_item = new_item.replace("time_pos_embed.2.bias", "time_pos_embed.linear_2.bias")
new_item = new_item.replace("time_pos_embed.2.weight", "time_pos_embed.linear_2.weight")
mapping.append({"old": old_item, "new": new_item})
return mapping
def shave_segments(path, n_shave_prefix_segments=1):
"""
Removes segments. Positive values shave the first segments, negative shave the last segments.
"""
if n_shave_prefix_segments >= 0:
return ".".join(path.split(".")[n_shave_prefix_segments:])
else:
return ".".join(path.split(".")[:n_shave_prefix_segments])
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
"""
Updates paths inside resnets to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item.replace("in_layers.0", "norm1")
new_item = new_item.replace("in_layers.2", "conv1")
new_item = new_item.replace("out_layers.0", "norm2")
new_item = new_item.replace("out_layers.3", "conv2")
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
new_item = new_item.replace("skip_connection", "conv_shortcut")
new_item = new_item.replace("time_stack.", "")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({"old": old_item, "new": new_item})
return mapping
def convert_ldm_unet_checkpoint(
checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False
):
"""
Takes a state dict and a config, and returns a converted checkpoint.
"""
if skip_extract_state_dict:
unet_state_dict = checkpoint
else:
# extract state_dict for UNet
unet_state_dict = {}
keys = list(checkpoint.keys())
unet_key = "model.diffusion_model."
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
logger.warning(f"Checkpoint {path} has both EMA and non-EMA weights.")
logger.warning(
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
)
for key in keys:
if key.startswith("model.diffusion_model"):
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
else:
if sum(k.startswith("model_ema") for k in keys) > 100:
logger.warning(
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
)
for key in keys:
if key.startswith(unet_key):
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
new_checkpoint = {}
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
if config["class_embed_type"] is None:
# No parameters to port
...
elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
else:
raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
# if config["addition_embed_type"] == "text_time":
new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
# Retrieves the keys for the input blocks only
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
input_blocks = {
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
for layer_id in range(num_input_blocks)
}
# Retrieves the keys for the middle blocks only
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
middle_blocks = {
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
for layer_id in range(num_middle_blocks)
}
# Retrieves the keys for the output blocks only
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
output_blocks = {
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
for layer_id in range(num_output_blocks)
}
for i in range(1, num_input_blocks):
block_id = (i - 1) // (config["layers_per_block"] + 1)
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
spatial_resnets = [
key
for key in input_blocks[i]
if f"input_blocks.{i}.0" in key
and (
f"input_blocks.{i}.0.op" not in key
and f"input_blocks.{i}.0.time_stack" not in key
and f"input_blocks.{i}.0.time_mixer" not in key
)
]
temporal_resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0.time_stack" in key]
# import ipdb; ipdb.set_trace()
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
f"input_blocks.{i}.0.op.weight"
)
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
f"input_blocks.{i}.0.op.bias"
)
paths = renew_resnet_paths(spatial_resnets)
meta_path = {
"old": f"input_blocks.{i}.0",
"new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}.spatial_res_block",
}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
paths = renew_resnet_paths(temporal_resnets)
meta_path = {
"old": f"input_blocks.{i}.0",
"new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}.temporal_res_block",
}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
# TODO resnet time_mixer.mix_factor
if f"input_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict:
new_checkpoint[
f"down_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"
] = unet_state_dict[f"input_blocks.{i}.0.time_mixer.mix_factor"]
if len(attentions):
paths = renew_attention_paths(attentions)
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
# import ipdb; ipdb.set_trace()
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
resnet_0 = middle_blocks[0]
attentions = middle_blocks[1]
resnet_1 = middle_blocks[2]
resnet_0_spatial = [key for key in resnet_0 if "time_stack" not in key and "time_mixer" not in key]
resnet_0_paths = renew_resnet_paths(resnet_0_spatial)
# import ipdb; ipdb.set_trace()
assign_to_checkpoint(
resnet_0_paths, new_checkpoint, unet_state_dict, config=config, mid_block_suffix="spatial_res_block"
)
resnet_0_temporal = [key for key in resnet_0 if "time_stack" in key and "time_mixer" not in key]
resnet_0_paths = renew_resnet_paths(resnet_0_temporal)
assign_to_checkpoint(
resnet_0_paths, new_checkpoint, unet_state_dict, config=config, mid_block_suffix="temporal_res_block"
)
resnet_1_spatial = [key for key in resnet_1 if "time_stack" not in key and "time_mixer" not in key]
resnet_1_paths = renew_resnet_paths(resnet_1_spatial)
assign_to_checkpoint(
resnet_1_paths, new_checkpoint, unet_state_dict, config=config, mid_block_suffix="spatial_res_block"
)
resnet_1_temporal = [key for key in resnet_1 if "time_stack" in key and "time_mixer" not in key]
resnet_1_paths = renew_resnet_paths(resnet_1_temporal)
assign_to_checkpoint(
resnet_1_paths, new_checkpoint, unet_state_dict, config=config, mid_block_suffix="temporal_res_block"
)
new_checkpoint["mid_block.resnets.0.time_mixer.mix_factor"] = unet_state_dict[
"middle_block.0.time_mixer.mix_factor"
]
new_checkpoint["mid_block.resnets.1.time_mixer.mix_factor"] = unet_state_dict[
"middle_block.2.time_mixer.mix_factor"
]
attentions_paths = renew_attention_paths(attentions)
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
assign_to_checkpoint(
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
for i in range(num_output_blocks):
block_id = i // (config["layers_per_block"] + 1)
layer_in_block_id = i % (config["layers_per_block"] + 1)
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
output_block_list = {}
for layer in output_block_layers:
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
if layer_id in output_block_list:
output_block_list[layer_id].append(layer_name)
else:
output_block_list[layer_id] = [layer_name]
if len(output_block_list) > 1:
spatial_resnets = [
key
for key in output_blocks[i]
if f"output_blocks.{i}.0" in key
and (f"output_blocks.{i}.0.time_stack" not in key and "time_mixer" not in key)
]
# import ipdb; ipdb.set_trace()
temporal_resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0.time_stack" in key]
paths = renew_resnet_paths(spatial_resnets)
meta_path = {
"old": f"output_blocks.{i}.0",
"new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}.spatial_res_block",
}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
paths = renew_resnet_paths(temporal_resnets)
meta_path = {
"old": f"output_blocks.{i}.0",
"new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}.temporal_res_block",
}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
if f"output_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict:
new_checkpoint[
f"up_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"
] = unet_state_dict[f"output_blocks.{i}.0.time_mixer.mix_factor"]
output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
if ["conv.bias", "conv.weight"] in output_block_list.values():
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
f"output_blocks.{i}.{index}.conv.weight"
]
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
f"output_blocks.{i}.{index}.conv.bias"
]
# Clear attentions as they have been attributed above.
if len(attentions) == 2:
attentions = []
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key and "conv" not in key]
if len(attentions):
paths = renew_attention_paths(attentions)
# import ipdb; ipdb.set_trace()
meta_path = {
"old": f"output_blocks.{i}.1",
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
else:
spatial_layers = [
layer for layer in output_block_layers if "time_stack" not in layer and "time_mixer" not in layer
]
resnet_0_paths = renew_resnet_paths(spatial_layers, n_shave_prefix_segments=1)
# import ipdb; ipdb.set_trace()
for path in resnet_0_paths:
old_path = ".".join(["output_blocks", str(i), path["old"]])
new_path = ".".join(
["up_blocks", str(block_id), "resnets", str(layer_in_block_id), "spatial_res_block", path["new"]]
)
new_checkpoint[new_path] = unet_state_dict[old_path]
temporal_layers = [
layer for layer in output_block_layers if "time_stack" in layer and "time_mixer" not in key
]
resnet_0_paths = renew_resnet_paths(temporal_layers, n_shave_prefix_segments=1)
# import ipdb; ipdb.set_trace()
for path in resnet_0_paths:
old_path = ".".join(["output_blocks", str(i), path["old"]])
new_path = ".".join(
["up_blocks", str(block_id), "resnets", str(layer_in_block_id), "temporal_res_block", path["new"]]
)
new_checkpoint[new_path] = unet_state_dict[old_path]
new_checkpoint["up_blocks.0.resnets.0.time_mixer.mix_factor"] = unet_state_dict[
f"output_blocks.{str(i)}.0.time_mixer.mix_factor"
]
return new_checkpoint
def conv_attn_to_linear(checkpoint):
keys = list(checkpoint.keys())
attn_keys = ["to_q.weight", "to_k.weight", "to_v.weight"]
for key in keys:
if ".".join(key.split(".")[-2:]) in attn_keys:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0, 0]
elif "proj_attn.weight" in key:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0]
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0, is_temporal=False):
"""
Updates paths inside resnets to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item
# Temporal resnet
new_item = old_item.replace("in_layers.0", "norm1")
new_item = new_item.replace("in_layers.2", "conv1")
new_item = new_item.replace("out_layers.0", "norm2")
new_item = new_item.replace("out_layers.3", "conv2")
new_item = new_item.replace("skip_connection", "conv_shortcut")
new_item = new_item.replace("time_stack.", "temporal_res_block.")
# Spatial resnet
new_item = new_item.replace("conv1", "spatial_res_block.conv1")
new_item = new_item.replace("norm1", "spatial_res_block.norm1")
new_item = new_item.replace("conv2", "spatial_res_block.conv2")
new_item = new_item.replace("norm2", "spatial_res_block.norm2")
new_item = new_item.replace("nin_shortcut", "spatial_res_block.conv_shortcut")
new_item = new_item.replace("mix_factor", "spatial_res_block.time_mixer.mix_factor")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({"old": old_item, "new": new_item})
return mapping
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
"""
Updates paths inside attentions to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item
new_item = new_item.replace("norm.weight", "group_norm.weight")
new_item = new_item.replace("norm.bias", "group_norm.bias")
new_item = new_item.replace("q.weight", "to_q.weight")
new_item = new_item.replace("q.bias", "to_q.bias")
new_item = new_item.replace("k.weight", "to_k.weight")
new_item = new_item.replace("k.bias", "to_k.bias")
new_item = new_item.replace("v.weight", "to_v.weight")
new_item = new_item.replace("v.bias", "to_v.bias")
new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({"old": old_item, "new": new_item})
return mapping
def convert_ldm_vae_checkpoint(checkpoint, config):
# extract state dict for VAE
vae_state_dict = {}
keys = list(checkpoint.keys())
vae_key = "first_stage_model." if any(k.startswith("first_stage_model.") for k in keys) else ""
for key in keys:
if key.startswith(vae_key):
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
new_checkpoint = {}
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
new_checkpoint["decoder.time_conv_out.weight"] = vae_state_dict["decoder.time_mix_conv.weight"]
new_checkpoint["decoder.time_conv_out.bias"] = vae_state_dict["decoder.time_mix_conv.bias"]
# new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
# new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
# new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
# new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
# Retrieves the keys for the encoder down blocks only
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
down_blocks = {
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
}
# Retrieves the keys for the decoder up blocks only
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
up_blocks = {
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
}
for i in range(num_down_blocks):
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
f"encoder.down.{i}.downsample.conv.weight"
)
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
f"encoder.down.{i}.downsample.conv.bias"
)
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
num_mid_res_blocks = 2
for i in range(1, num_mid_res_blocks + 1):
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
paths = renew_vae_attention_paths(mid_attentions)
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
conv_attn_to_linear(new_checkpoint)
for i in range(num_up_blocks):
block_id = num_up_blocks - 1 - i
resnets = [
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
]
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
f"decoder.up.{block_id}.upsample.conv.weight"
]
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
f"decoder.up.{block_id}.upsample.conv.bias"
]
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
num_mid_res_blocks = 2
for i in range(1, num_mid_res_blocks + 1):
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
paths = renew_vae_attention_paths(mid_attentions)
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
conv_attn_to_linear(new_checkpoint)
return new_checkpoint

View File

@@ -76,6 +76,7 @@ else:
[
"AsymmetricAutoencoderKL",
"AutoencoderKL",
"AutoencoderKLTemporalDecoder",
"AutoencoderTiny",
"ConsistencyDecoderVAE",
"ControlNetModel",
@@ -92,6 +93,7 @@ else:
"UNet2DModel",
"UNet3DConditionModel",
"UNetMotionModel",
"UNetSpatioTemporalConditionModel",
"VQModel",
]
)
@@ -277,6 +279,7 @@ else:
"StableDiffusionXLPipeline",
"StableUnCLIPImg2ImgPipeline",
"StableUnCLIPPipeline",
"StableVideoDiffusionPipeline",
"TextToVideoSDPipeline",
"TextToVideoZeroPipeline",
"TextToVideoZeroSDXLPipeline",
@@ -447,6 +450,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .models import (
AsymmetricAutoencoderKL,
AutoencoderKL,
AutoencoderKLTemporalDecoder,
AutoencoderTiny,
ConsistencyDecoderVAE,
ControlNetModel,
@@ -463,6 +467,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
UNet2DModel,
UNet3DConditionModel,
UNetMotionModel,
UNetSpatioTemporalConditionModel,
VQModel,
)
from .optimization import (
@@ -627,6 +632,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionXLPipeline,
StableUnCLIPImg2ImgPipeline,
StableUnCLIPPipeline,
StableVideoDiffusionPipeline,
TextToVideoSDPipeline,
TextToVideoZeroPipeline,
TextToVideoZeroSDXLPipeline,

View File

@@ -14,7 +14,12 @@
from typing import TYPE_CHECKING
from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available
from ..utils import (
DIFFUSERS_SLOW_IMPORT,
_LazyModule,
is_flax_available,
is_torch_available,
)
_import_structure = {}
@@ -23,6 +28,7 @@ if is_torch_available():
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
_import_structure["autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
_import_structure["autoencoder_kl"] = ["AutoencoderKL"]
_import_structure["autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
_import_structure["autoencoder_tiny"] = ["AutoencoderTiny"]
_import_structure["consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["controlnet"] = ["ControlNetModel"]
@@ -38,6 +44,7 @@ if is_torch_available():
_import_structure["unet_3d_condition"] = ["UNet3DConditionModel"]
_import_structure["unet_kandi3"] = ["Kandinsky3UNet"]
_import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
_import_structure["unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
_import_structure["vq_model"] = ["VQModel"]
if is_flax_available():
@@ -51,6 +58,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .adapter import MultiAdapter, T2IAdapter
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
from .autoencoder_kl import AutoencoderKL
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
from .autoencoder_tiny import AutoencoderTiny
from .consistency_decoder_vae import ConsistencyDecoderVAE
from .controlnet import ControlNetModel
@@ -66,6 +74,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .unet_3d_condition import UNet3DConditionModel
from .unet_kandi3 import Kandinsky3UNet
from .unet_motion_model import MotionAdapter, UNetMotionModel
from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
from .vq_model import VQModel
if is_flax_available():

View File

@@ -25,6 +25,31 @@ from .lora import LoRACompatibleLinear
from .normalization import AdaLayerNorm, AdaLayerNormZero
def _chunked_feed_forward(
ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int, lora_scale: Optional[float] = None
):
# "feed_forward_chunk_size" can be used to save memory
if hidden_states.shape[chunk_dim] % chunk_size != 0:
raise ValueError(
f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
)
num_chunks = hidden_states.shape[chunk_dim] // chunk_size
if lora_scale is None:
ff_output = torch.cat(
[ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
dim=chunk_dim,
)
else:
# TOOD(Patrick): LoRA scale can be removed once PEFT refactor is complete
ff_output = torch.cat(
[ff(hid_slice, scale=lora_scale) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
dim=chunk_dim,
)
return ff_output
@maybe_allow_in_graph
class GatedSelfAttentionDense(nn.Module):
r"""
@@ -194,7 +219,12 @@ class BasicTransformerBlock(nn.Module):
if not self.use_ada_layer_norm_single:
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
)
# 4. Fuser
if attention_type == "gated" or attention_type == "gated-text-image":
@@ -208,7 +238,7 @@ class BasicTransformerBlock(nn.Module):
self._chunk_size = None
self._chunk_dim = 0
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
# Sets chunk feed-forward
self._chunk_size = chunk_size
self._chunk_dim = dim
@@ -311,18 +341,8 @@ class BasicTransformerBlock(nn.Module):
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
raise ValueError(
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
)
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
ff_output = torch.cat(
[
self.ff(hid_slice, scale=lora_scale)
for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
],
dim=self._chunk_dim,
ff_output = _chunked_feed_forward(
self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale
)
else:
ff_output = self.ff(norm_hidden_states, scale=lora_scale)
@@ -339,6 +359,137 @@ class BasicTransformerBlock(nn.Module):
return hidden_states
@maybe_allow_in_graph
class TemporalBasicTransformerBlock(nn.Module):
r"""
A basic Transformer block for video like data.
Parameters:
dim (`int`): The number of channels in the input and output.
time_mix_inner_dim (`int`): The number of channels for temporal attention.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
"""
def __init__(
self,
dim: int,
time_mix_inner_dim: int,
num_attention_heads: int,
attention_head_dim: int,
cross_attention_dim: Optional[int] = None,
):
super().__init__()
self.is_res = dim == time_mix_inner_dim
self.norm_in = nn.LayerNorm(dim)
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
self.norm_in = nn.LayerNorm(dim)
self.ff_in = FeedForward(
dim,
dim_out=time_mix_inner_dim,
activation_fn="geglu",
)
self.norm1 = nn.LayerNorm(time_mix_inner_dim)
self.attn1 = Attention(
query_dim=time_mix_inner_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
cross_attention_dim=None,
)
# 2. Cross-Attn
if cross_attention_dim is not None:
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
# the second cross attention block.
self.norm2 = nn.LayerNorm(time_mix_inner_dim)
self.attn2 = Attention(
query_dim=time_mix_inner_dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
) # is self-attn if encoder_hidden_states is none
else:
self.norm2 = None
self.attn2 = None
# 3. Feed-forward
self.norm3 = nn.LayerNorm(time_mix_inner_dim)
self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu")
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = None
def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs):
# Sets chunk feed-forward
self._chunk_size = chunk_size
# chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off
self._chunk_dim = 1
def forward(
self,
hidden_states: torch.FloatTensor,
num_frames: int,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
# Notice that normalization is always applied before the real computation in the following blocks.
# 0. Self-Attention
batch_size = hidden_states.shape[0]
batch_frames, seq_length, channels = hidden_states.shape
batch_size = batch_frames // num_frames
hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)
hidden_states = hidden_states.permute(0, 2, 1, 3)
hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels)
residual = hidden_states
hidden_states = self.norm_in(hidden_states)
if self._chunk_size is not None:
hidden_states = _chunked_feed_forward(self.ff, hidden_states, self._chunk_dim, self._chunk_size)
else:
hidden_states = self.ff_in(hidden_states)
if self.is_res:
hidden_states = hidden_states + residual
norm_hidden_states = self.norm1(hidden_states)
attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
hidden_states = attn_output + hidden_states
# 3. Cross-Attention
if self.attn2 is not None:
norm_hidden_states = self.norm2(hidden_states)
attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
hidden_states = attn_output + hidden_states
# 4. Feed-forward
norm_hidden_states = self.norm3(hidden_states)
if self._chunk_size is not None:
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
else:
ff_output = self.ff(norm_hidden_states)
if self.is_res:
hidden_states = ff_output + hidden_states
else:
hidden_states = ff_output
hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)
hidden_states = hidden_states.permute(0, 2, 1, 3)
hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)
return hidden_states
class FeedForward(nn.Module):
r"""
A feed-forward layer.

View File

@@ -18,7 +18,7 @@ import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils.accelerate_utils import apply_forward_hook
from .autoencoder_kl import AutoencoderKLOutput
from .modeling_outputs import AutoencoderKLOutput
from .modeling_utils import ModelMixin
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder

View File

@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
import torch
@@ -19,7 +18,6 @@ import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import FromOriginalVAEMixin
from ..utils import BaseOutput
from ..utils.accelerate_utils import apply_forward_hook
from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
@@ -28,24 +26,11 @@ from .attention_processor import (
AttnAddedKVProcessor,
AttnProcessor,
)
from .modeling_outputs import AutoencoderKLOutput
from .modeling_utils import ModelMixin
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
@dataclass
class AutoencoderKLOutput(BaseOutput):
"""
Output of AutoencoderKL encoding method.
Args:
latent_dist (`DiagonalGaussianDistribution`):
Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
`DiagonalGaussianDistribution` allows for sampling latents from the distribution.
"""
latent_dist: "DiagonalGaussianDistribution"
class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.

View File

@@ -0,0 +1,402 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import FromOriginalVAEMixin
from ..utils import is_torch_version
from ..utils.accelerate_utils import apply_forward_hook
from .attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
from .modeling_outputs import AutoencoderKLOutput
from .modeling_utils import ModelMixin
from .unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
class TemporalDecoder(nn.Module):
def __init__(
self,
in_channels: int = 4,
out_channels: int = 3,
block_out_channels: Tuple[int] = (128, 256, 512, 512),
layers_per_block: int = 2,
):
super().__init__()
self.layers_per_block = layers_per_block
self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
self.mid_block = MidBlockTemporalDecoder(
num_layers=self.layers_per_block,
in_channels=block_out_channels[-1],
out_channels=block_out_channels[-1],
attention_head_dim=block_out_channels[-1],
)
# up
self.up_blocks = nn.ModuleList([])
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i in range(len(block_out_channels)):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
up_block = UpBlockTemporalDecoder(
num_layers=self.layers_per_block + 1,
in_channels=prev_output_channel,
out_channels=output_channel,
add_upsample=not is_final_block,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-6)
self.conv_act = nn.SiLU()
self.conv_out = torch.nn.Conv2d(
in_channels=block_out_channels[0],
out_channels=out_channels,
kernel_size=3,
padding=1,
)
conv_out_kernel_size = (3, 1, 1)
padding = [int(k // 2) for k in conv_out_kernel_size]
self.time_conv_out = torch.nn.Conv3d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=conv_out_kernel_size,
padding=padding,
)
self.gradient_checkpointing = False
def forward(
self,
sample: torch.FloatTensor,
image_only_indicator: torch.FloatTensor,
num_frames: int = 1,
) -> torch.FloatTensor:
r"""The forward method of the `Decoder` class."""
sample = self.conv_in(sample)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block),
sample,
image_only_indicator,
use_reentrant=False,
)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block),
sample,
image_only_indicator,
use_reentrant=False,
)
else:
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block),
sample,
image_only_indicator,
)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block),
sample,
image_only_indicator,
)
else:
# middle
sample = self.mid_block(sample, image_only_indicator=image_only_indicator)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = up_block(sample, image_only_indicator=image_only_indicator)
# post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
batch_frames, channels, height, width = sample.shape
batch_size = batch_frames // num_frames
sample = sample[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
sample = self.time_conv_out(sample)
sample = sample.permute(0, 2, 1, 3, 4).reshape(batch_frames, channels, height, width)
return sample
class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
Parameters:
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
Tuple of downsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
Tuple of block output channels.
layers_per_block: (`int`, *optional*, defaults to 1): Number of layers per block.
latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
scaling_factor (`float`, *optional*, defaults to 0.18215):
The component-wise standard deviation of the trained latent space computed using the first batch of the
training set. This is used to scale the latent space to have unit variance when training the diffusion
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
force_upcast (`bool`, *optional*, default to `True`):
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
can be fine-tuned / trained to a lower range without loosing too much precision in which case
`force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
block_out_channels: Tuple[int] = (64,),
layers_per_block: int = 1,
latent_channels: int = 4,
sample_size: int = 32,
scaling_factor: float = 0.18215,
force_upcast: float = True,
):
super().__init__()
# pass init params to Encoder
self.encoder = Encoder(
in_channels=in_channels,
out_channels=latent_channels,
down_block_types=down_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
double_z=True,
)
# pass init params to Decoder
self.decoder = TemporalDecoder(
in_channels=latent_channels,
out_channels=out_channels,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
)
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
sample_size = (
self.config.sample_size[0]
if isinstance(self.config.sample_size, (list, tuple))
else self.config.sample_size
)
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (Encoder, TemporalDecoder)):
module.gradient_checkpointing = value
@property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor, _remove_lora=_remove_lora)
else:
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnProcessor()
else:
raise ValueError(
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor, _remove_lora=True)
@apply_forward_hook
def encode(
self, x: torch.FloatTensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
"""
Encode a batch of images into latents.
Args:
x (`torch.FloatTensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded images. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
@apply_forward_hook
def decode(
self,
z: torch.FloatTensor,
num_frames: int,
return_dict: bool = True,
) -> Union[DecoderOutput, torch.FloatTensor]:
"""
Decode a batch of images.
Args:
z (`torch.FloatTensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
batch_size = z.shape[0] // num_frames
image_only_indicator = torch.zeros(batch_size, num_frames, dtype=z.dtype, device=z.device)
decoded = self.decoder(z, num_frames=num_frames, image_only_indicator=image_only_indicator)
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def forward(
self,
sample: torch.FloatTensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
num_frames: int = 1,
) -> Union[DecoderOutput, torch.FloatTensor]:
r"""
Args:
sample (`torch.FloatTensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z, num_frames=num_frames).sample
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)

View File

@@ -0,0 +1,17 @@
from dataclasses import dataclass
from ..utils import BaseOutput
@dataclass
class AutoencoderKLOutput(BaseOutput):
"""
Output of AutoencoderKL encoding method.
Args:
latent_dist (`DiagonalGaussianDistribution`):
Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
`DiagonalGaussianDistribution` allows for sampling latents from the distribution.
"""
latent_dist: "DiagonalGaussianDistribution" # noqa: F821

View File

@@ -165,7 +165,10 @@ class Upsample2D(nn.Module):
self.Conv2d_0 = conv
def forward(
self, hidden_states: torch.FloatTensor, output_size: Optional[int] = None, scale: float = 1.0
self,
hidden_states: torch.FloatTensor,
output_size: Optional[int] = None,
scale: float = 1.0,
) -> torch.FloatTensor:
assert hidden_states.shape[1] == self.channels
@@ -379,7 +382,11 @@ class FirUpsample2D(nn.Module):
weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
inverse_conv = F.conv_transpose2d(
hidden_states, weight, stride=stride, output_padding=output_padding, padding=0
hidden_states,
weight,
stride=stride,
output_padding=output_padding,
padding=0,
)
output = upfirdn2d_native(
@@ -530,7 +537,14 @@ class KDownsample2D(nn.Module):
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode)
weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
weight = inputs.new_zeros(
[
inputs.shape[1],
inputs.shape[1],
self.kernel.shape[0],
self.kernel.shape[1],
]
)
indices = torch.arange(inputs.shape[1], device=inputs.device)
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
weight[indices, indices] = kernel
@@ -553,7 +567,14 @@ class KUpsample2D(nn.Module):
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode)
weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
weight = inputs.new_zeros(
[
inputs.shape[1],
inputs.shape[1],
self.kernel.shape[0],
self.kernel.shape[1],
]
)
indices = torch.arange(inputs.shape[1], device=inputs.device)
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
weight[indices, indices] = kernel
@@ -690,11 +711,19 @@ class ResnetBlock2D(nn.Module):
self.conv_shortcut = None
if self.use_in_shortcut:
self.conv_shortcut = conv_cls(
in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
in_channels,
conv_2d_out_channels,
kernel_size=1,
stride=1,
padding=0,
bias=conv_shortcut_bias,
)
def forward(
self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor, scale: float = 1.0
self,
input_tensor: torch.FloatTensor,
temb: torch.FloatTensor,
scale: float = 1.0,
) -> torch.FloatTensor:
hidden_states = input_tensor
@@ -866,7 +895,10 @@ class ResidualTemporalBlock1D(nn.Module):
def upsample_2d(
hidden_states: torch.FloatTensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1
hidden_states: torch.FloatTensor,
kernel: Optional[torch.FloatTensor] = None,
factor: int = 2,
gain: float = 1,
) -> torch.FloatTensor:
r"""Upsample2D a batch of 2D images with the given filter.
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
@@ -910,7 +942,10 @@ def upsample_2d(
def downsample_2d(
hidden_states: torch.FloatTensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1
hidden_states: torch.FloatTensor,
kernel: Optional[torch.FloatTensor] = None,
factor: int = 2,
gain: float = 1,
) -> torch.FloatTensor:
r"""Downsample2D a batch of 2D images with the given filter.
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
@@ -946,13 +981,20 @@ def downsample_2d(
kernel = kernel * gain
pad_value = kernel.shape[0] - factor
output = upfirdn2d_native(
hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)
hidden_states,
kernel.to(device=hidden_states.device),
down=factor,
pad=((pad_value + 1) // 2, pad_value // 2),
)
return output
def upfirdn2d_native(
tensor: torch.Tensor, kernel: torch.Tensor, up: int = 1, down: int = 1, pad: Tuple[int, int] = (0, 0)
tensor: torch.Tensor,
kernel: torch.Tensor,
up: int = 1,
down: int = 1,
pad: Tuple[int, int] = (0, 0),
) -> torch.Tensor:
up_x = up_y = up
down_x = down_y = down
@@ -1008,7 +1050,13 @@ class TemporalConvLayer(nn.Module):
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
"""
def __init__(self, in_dim: int, out_dim: Optional[int] = None, dropout: float = 0.0, norm_num_groups: int = 32):
def __init__(
self,
in_dim: int,
out_dim: Optional[int] = None,
dropout: float = 0.0,
norm_num_groups: int = 32,
):
super().__init__()
out_dim = out_dim or in_dim
self.in_dim = in_dim
@@ -1016,7 +1064,9 @@ class TemporalConvLayer(nn.Module):
# conv layers
self.conv1 = nn.Sequential(
nn.GroupNorm(norm_num_groups, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0))
nn.GroupNorm(norm_num_groups, in_dim),
nn.SiLU(),
nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)),
)
self.conv2 = nn.Sequential(
nn.GroupNorm(norm_num_groups, out_dim),
@@ -1058,3 +1108,261 @@ class TemporalConvLayer(nn.Module):
(hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:]
)
return hidden_states
class TemporalResnetBlock(nn.Module):
r"""
A Resnet block.
Parameters:
in_channels (`int`): The number of channels in the input.
out_channels (`int`, *optional*, default to be `None`):
The number of output channels for the first conv2d layer. If None, same as `in_channels`.
temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
"""
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
temb_channels: int = 512,
eps: float = 1e-6,
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
kernel_size = (3, 1, 1)
padding = [k // 2 for k in kernel_size]
self.norm1 = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=eps, affine=True)
self.conv1 = nn.Conv3d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=1,
padding=padding,
)
if temb_channels is not None:
self.time_emb_proj = nn.Linear(temb_channels, out_channels)
else:
self.time_emb_proj = None
self.norm2 = torch.nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=eps, affine=True)
self.dropout = torch.nn.Dropout(0.0)
self.conv2 = nn.Conv3d(
out_channels,
out_channels,
kernel_size=kernel_size,
stride=1,
padding=padding,
)
self.nonlinearity = get_activation("silu")
self.use_in_shortcut = self.in_channels != out_channels
self.conv_shortcut = None
if self.use_in_shortcut:
self.conv_shortcut = nn.Conv3d(
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
)
def forward(self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
if self.time_emb_proj is not None:
temb = self.nonlinearity(temb)
temb = self.time_emb_proj(temb)[:, :, :, None, None]
temb = temb.permute(0, 2, 1, 3, 4)
hidden_states = hidden_states + temb
hidden_states = self.norm2(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = input_tensor + hidden_states
return output_tensor
# VideoResBlock
class SpatioTemporalResBlock(nn.Module):
r"""
A SpatioTemporal Resnet block.
Parameters:
in_channels (`int`): The number of channels in the input.
out_channels (`int`, *optional*, default to be `None`):
The number of output channels for the first conv2d layer. If None, same as `in_channels`.
temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the spatial resenet.
temporal_eps (`float`, *optional*, defaults to `eps`): The epsilon to use for the temporal resnet.
merge_factor (`float`, *optional*, defaults to `0.5`): The merge factor to use for the temporal mixing.
merge_strategy (`str`, *optional*, defaults to `learned_with_images`):
The merge strategy to use for the temporal mixing.
switch_spatial_to_temporal_mix (`bool`, *optional*, defaults to `False`):
If `True`, switch the spatial and temporal mixing.
"""
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
temb_channels: int = 512,
eps: float = 1e-6,
temporal_eps: Optional[float] = None,
merge_factor: float = 0.5,
merge_strategy="learned_with_images",
switch_spatial_to_temporal_mix: bool = False,
):
super().__init__()
self.spatial_res_block = ResnetBlock2D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=eps,
)
self.temporal_res_block = TemporalResnetBlock(
in_channels=out_channels if out_channels is not None else in_channels,
out_channels=out_channels if out_channels is not None else in_channels,
temb_channels=temb_channels,
eps=temporal_eps if temporal_eps is not None else eps,
)
self.time_mixer = AlphaBlender(
alpha=merge_factor,
merge_strategy=merge_strategy,
switch_spatial_to_temporal_mix=switch_spatial_to_temporal_mix,
)
def forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
image_only_indicator: Optional[torch.Tensor] = None,
):
num_frames = image_only_indicator.shape[-1]
hidden_states = self.spatial_res_block(hidden_states, temb)
batch_frames, channels, height, width = hidden_states.shape
batch_size = batch_frames // num_frames
hidden_states_mix = (
hidden_states[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
)
hidden_states = (
hidden_states[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
)
if temb is not None:
temb = temb.reshape(batch_size, num_frames, -1)
hidden_states = self.temporal_res_block(hidden_states, temb)
hidden_states = self.time_mixer(
x_spatial=hidden_states_mix,
x_temporal=hidden_states,
image_only_indicator=image_only_indicator,
)
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(batch_frames, channels, height, width)
return hidden_states
class AlphaBlender(nn.Module):
r"""
A module to blend spatial and temporal features.
Parameters:
alpha (`float`): The initial value of the blending factor.
merge_strategy (`str`, *optional*, defaults to `learned_with_images`):
The merge strategy to use for the temporal mixing.
switch_spatial_to_temporal_mix (`bool`, *optional*, defaults to `False`):
If `True`, switch the spatial and temporal mixing.
"""
strategies = ["learned", "fixed", "learned_with_images"]
def __init__(
self,
alpha: float,
merge_strategy: str = "learned_with_images",
switch_spatial_to_temporal_mix: bool = False,
):
super().__init__()
self.merge_strategy = merge_strategy
self.switch_spatial_to_temporal_mix = switch_spatial_to_temporal_mix # For TemporalVAE
if merge_strategy not in self.strategies:
raise ValueError(f"merge_strategy needs to be in {self.strategies}")
if self.merge_strategy == "fixed":
self.register_buffer("mix_factor", torch.Tensor([alpha]))
elif self.merge_strategy == "learned" or self.merge_strategy == "learned_with_images":
self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha])))
else:
raise ValueError(f"Unknown merge strategy {self.merge_strategy}")
def get_alpha(self, image_only_indicator: torch.Tensor, ndims: int) -> torch.Tensor:
if self.merge_strategy == "fixed":
alpha = self.mix_factor
elif self.merge_strategy == "learned":
alpha = torch.sigmoid(self.mix_factor)
elif self.merge_strategy == "learned_with_images":
if image_only_indicator is None:
raise ValueError("Please provide image_only_indicator to use learned_with_images merge strategy")
alpha = torch.where(
image_only_indicator.bool(),
torch.ones(1, 1, device=image_only_indicator.device),
torch.sigmoid(self.mix_factor)[..., None],
)
# (batch, channel, frames, height, width)
if ndims == 5:
alpha = alpha[:, None, :, None, None]
# (batch*frames, height*width, channels)
elif ndims == 3:
alpha = alpha.reshape(-1)[:, None, None]
else:
raise ValueError(f"Unexpected ndims {ndims}. Dimensions should be 3 or 5")
else:
raise NotImplementedError
return alpha
def forward(
self,
x_spatial: torch.Tensor,
x_temporal: torch.Tensor,
image_only_indicator: Optional[torch.Tensor] = None,
) -> torch.Tensor:
alpha = self.get_alpha(image_only_indicator, x_spatial.ndim)
alpha = alpha.to(x_spatial.dtype)
if self.switch_spatial_to_temporal_mix:
alpha = 1.0 - alpha
x = alpha * x_spatial + (1.0 - alpha) * x_temporal
return x

View File

@@ -19,8 +19,10 @@ from torch import nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from .attention import BasicTransformerBlock
from .attention import BasicTransformerBlock, TemporalBasicTransformerBlock
from .embeddings import TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
from .resnet import AlphaBlender
@dataclass
@@ -195,3 +197,183 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
return (output,)
return TransformerTemporalModelOutput(sample=output)
class TransformerSpatioTemporalModel(nn.Module):
"""
A Transformer model for video-like data.
Parameters:
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
in_channels (`int`, *optional*):
The number of channels in the input and output (specify if the input is **continuous**).
out_channels (`int`, *optional*):
The number of channels in the output (specify if the input is **continuous**).
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
"""
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: int = 320,
out_channels: Optional[int] = None,
num_layers: int = 1,
cross_attention_dim: Optional[int] = None,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
self.inner_dim = inner_dim
# 2. Define input layers
self.in_channels = in_channels
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6)
self.proj_in = nn.Linear(in_channels, inner_dim)
# 3. Define transformers blocks
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
cross_attention_dim=cross_attention_dim,
)
for d in range(num_layers)
]
)
time_mix_inner_dim = inner_dim
self.temporal_transformer_blocks = nn.ModuleList(
[
TemporalBasicTransformerBlock(
inner_dim,
time_mix_inner_dim,
num_attention_heads,
attention_head_dim,
cross_attention_dim=cross_attention_dim,
)
for _ in range(num_layers)
]
)
time_embed_dim = in_channels * 4
self.time_pos_embed = TimestepEmbedding(in_channels, time_embed_dim, out_dim=in_channels)
self.time_proj = Timesteps(in_channels, True, 0)
self.time_mixer = AlphaBlender(alpha=0.5, merge_strategy="learned_with_images")
# 4. Define output layers
self.out_channels = in_channels if out_channels is None else out_channels
# TODO: should use out_channels for continuous projections
self.proj_out = nn.Linear(inner_dim, in_channels)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
image_only_indicator: Optional[torch.Tensor] = None,
return_dict: bool = True,
):
"""
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
Input hidden_states.
num_frames (`int`):
The number of frames to be processed per batch. This is used to reshape the hidden states.
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
image_only_indicator (`torch.LongTensor` of shape `(batch size, num_frames)`, *optional*):
A tensor indicating whether the input contains only images. 1 indicates that the input contains only
images, 0 indicates that the input contains video frames.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`] instead of a plain
tuple.
Returns:
[`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
returned, otherwise a `tuple` where the first element is the sample tensor.
"""
# 1. Input
batch_frames, _, height, width = hidden_states.shape
num_frames = image_only_indicator.shape[-1]
batch_size = batch_frames // num_frames
time_context = encoder_hidden_states
time_context_first_timestep = time_context[None, :].reshape(
batch_size, num_frames, -1, time_context.shape[-1]
)[:, 0]
time_context = time_context_first_timestep[None, :].broadcast_to(
height * width, batch_size, 1, time_context.shape[-1]
)
time_context = time_context.reshape(height * width * batch_size, 1, time_context.shape[-1])
residual = hidden_states
hidden_states = self.norm(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_frames, height * width, inner_dim)
hidden_states = self.proj_in(hidden_states)
num_frames_emb = torch.arange(num_frames, device=hidden_states.device)
num_frames_emb = num_frames_emb.repeat(batch_size, 1)
num_frames_emb = num_frames_emb.reshape(-1)
t_emb = self.time_proj(num_frames_emb)
# `Timesteps` does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=hidden_states.dtype)
emb = self.time_pos_embed(t_emb)
emb = emb[:, None, :]
# 2. Blocks
for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks):
if self.training and self.gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint(
block,
hidden_states,
None,
encoder_hidden_states,
None,
use_reentrant=False,
)
else:
hidden_states = block(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
)
hidden_states_mix = hidden_states
hidden_states_mix = hidden_states_mix + emb
hidden_states_mix = temporal_block(
hidden_states_mix,
num_frames=num_frames,
encoder_hidden_states=time_context,
)
hidden_states = self.time_mixer(
x_spatial=hidden_states,
x_temporal=hidden_states_mix,
image_only_indicator=image_only_indicator,
)
# 3. Output
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch_frames, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
output = hidden_states + residual
if not return_dict:
return (output,)
return TransformerTemporalModelOutput(sample=output)

View File

@@ -19,10 +19,20 @@ from torch import nn
from ..utils import is_torch_version
from ..utils.torch_utils import apply_freeu
from .attention import Attention
from .dual_transformer_2d import DualTransformer2DModel
from .resnet import Downsample2D, ResnetBlock2D, TemporalConvLayer, Upsample2D
from .resnet import (
Downsample2D,
ResnetBlock2D,
SpatioTemporalResBlock,
TemporalConvLayer,
Upsample2D,
)
from .transformer_2d import Transformer2DModel
from .transformer_temporal import TransformerTemporalModel
from .transformer_temporal import (
TransformerSpatioTemporalModel,
TransformerTemporalModel,
)
def get_down_block(
@@ -45,7 +55,15 @@ def get_down_block(
resnet_time_scale_shift: str = "default",
temporal_num_attention_heads: int = 8,
temporal_max_seq_length: int = 32,
) -> Union["DownBlock3D", "CrossAttnDownBlock3D", "DownBlockMotion", "CrossAttnDownBlockMotion"]:
transformer_layers_per_block: int = 1,
) -> Union[
"DownBlock3D",
"CrossAttnDownBlock3D",
"DownBlockMotion",
"CrossAttnDownBlockMotion",
"DownBlockSpatioTemporal",
"CrossAttnDownBlockSpatioTemporal",
]:
if down_block_type == "DownBlock3D":
return DownBlock3D(
num_layers=num_layers,
@@ -118,6 +136,29 @@ def get_down_block(
temporal_num_attention_heads=temporal_num_attention_heads,
temporal_max_seq_length=temporal_max_seq_length,
)
elif down_block_type == "DownBlockSpatioTemporal":
# added for SDV
return DownBlockSpatioTemporal(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
add_downsample=add_downsample,
)
elif down_block_type == "CrossAttnDownBlockSpatioTemporal":
# added for SDV
if cross_attention_dim is None:
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockSpatioTemporal")
return CrossAttnDownBlockSpatioTemporal(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
num_layers=num_layers,
transformer_layers_per_block=transformer_layers_per_block,
add_downsample=add_downsample,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads,
)
raise ValueError(f"{down_block_type} does not exist.")
@@ -144,7 +185,16 @@ def get_up_block(
temporal_num_attention_heads: int = 8,
temporal_cross_attention_dim: Optional[int] = None,
temporal_max_seq_length: int = 32,
) -> Union["UpBlock3D", "CrossAttnUpBlock3D", "UpBlockMotion", "CrossAttnUpBlockMotion"]:
transformer_layers_per_block: int = 1,
dropout: float = 0.0,
) -> Union[
"UpBlock3D",
"CrossAttnUpBlock3D",
"UpBlockMotion",
"CrossAttnUpBlockMotion",
"UpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal",
]:
if up_block_type == "UpBlock3D":
return UpBlock3D(
num_layers=num_layers,
@@ -221,6 +271,34 @@ def get_up_block(
temporal_num_attention_heads=temporal_num_attention_heads,
temporal_max_seq_length=temporal_max_seq_length,
)
elif up_block_type == "UpBlockSpatioTemporal":
# added for SDV
return UpBlockSpatioTemporal(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
resolution_idx=resolution_idx,
add_upsample=add_upsample,
)
elif up_block_type == "CrossAttnUpBlockSpatioTemporal":
# added for SDV
if cross_attention_dim is None:
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockSpatioTemporal")
return CrossAttnUpBlockSpatioTemporal(
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
num_layers=num_layers,
transformer_layers_per_block=transformer_layers_per_block,
add_upsample=add_upsample,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads,
resolution_idx=resolution_idx,
)
raise ValueError(f"{up_block_type} does not exist.")
@@ -347,7 +425,10 @@ class UNetMidBlock3DCrossAttn(nn.Module):
return_dict=False,
)[0]
hidden_states = temp_attn(
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
hidden_states,
num_frames=num_frames,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,
)[0]
hidden_states = resnet(hidden_states, temb)
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
@@ -443,7 +524,11 @@ class CrossAttnDownBlock3D(nn.Module):
self.downsamplers = nn.ModuleList(
[
Downsample2D(
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
out_channels,
use_conv=True,
out_channels=out_channels,
padding=downsample_padding,
name="op",
)
]
)
@@ -476,7 +561,10 @@ class CrossAttnDownBlock3D(nn.Module):
return_dict=False,
)[0]
hidden_states = temp_attn(
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
hidden_states,
num_frames=num_frames,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,
)[0]
output_states += (hidden_states,)
@@ -543,7 +631,11 @@ class DownBlock3D(nn.Module):
self.downsamplers = nn.ModuleList(
[
Downsample2D(
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
out_channels,
use_conv=True,
out_channels=out_channels,
padding=downsample_padding,
name="op",
)
]
)
@@ -553,7 +645,10 @@ class DownBlock3D(nn.Module):
self.gradient_checkpointing = False
def forward(
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, num_frames: int = 1
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
num_frames: int = 1,
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
output_states = ()
@@ -716,7 +811,10 @@ class CrossAttnUpBlock3D(nn.Module):
return_dict=False,
)[0]
hidden_states = temp_attn(
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
hidden_states,
num_frames=num_frames,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,
)[0]
if self.upsamplers is not None:
@@ -890,7 +988,11 @@ class DownBlockMotion(nn.Module):
self.downsamplers = nn.ModuleList(
[
Downsample2D(
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
out_channels,
use_conv=True,
out_channels=out_channels,
padding=downsample_padding,
name="op",
)
]
)
@@ -920,14 +1022,20 @@ class DownBlockMotion(nn.Module):
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
create_custom_forward(resnet),
hidden_states,
temb,
use_reentrant=False,
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, scale
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, num_frames
create_custom_forward(motion_module),
hidden_states.requires_grad_(),
temb,
num_frames,
)
else:
@@ -1047,7 +1155,11 @@ class CrossAttnDownBlockMotion(nn.Module):
self.downsamplers = nn.ModuleList(
[
Downsample2D(
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
out_channels,
use_conv=True,
out_channels=out_channels,
padding=downsample_padding,
name="op",
)
]
)
@@ -1442,7 +1554,10 @@ class UpBlockMotion(nn.Module):
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
create_custom_forward(resnet),
hidden_states,
temb,
use_reentrant=False,
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
@@ -1636,3 +1751,645 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
return hidden_states
class MidBlockTemporalDecoder(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
attention_head_dim: int = 512,
num_layers: int = 1,
upcast_attention: bool = False,
):
super().__init__()
resnets = []
attentions = []
for i in range(num_layers):
input_channels = in_channels if i == 0 else out_channels
resnets.append(
SpatioTemporalResBlock(
in_channels=input_channels,
out_channels=out_channels,
temb_channels=None,
eps=1e-6,
temporal_eps=1e-5,
merge_factor=0.0,
merge_strategy="learned",
switch_spatial_to_temporal_mix=True,
)
)
attentions.append(
Attention(
query_dim=in_channels,
heads=in_channels // attention_head_dim,
dim_head=attention_head_dim,
eps=1e-6,
upcast_attention=upcast_attention,
norm_num_groups=32,
bias=True,
residual_connection=True,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(
self,
hidden_states: torch.FloatTensor,
image_only_indicator: torch.FloatTensor,
):
hidden_states = self.resnets[0](
hidden_states,
image_only_indicator=image_only_indicator,
)
for resnet, attn in zip(self.resnets[1:], self.attentions):
hidden_states = attn(hidden_states)
hidden_states = resnet(
hidden_states,
image_only_indicator=image_only_indicator,
)
return hidden_states
class UpBlockTemporalDecoder(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
add_upsample: bool = True,
):
super().__init__()
resnets = []
for i in range(num_layers):
input_channels = in_channels if i == 0 else out_channels
resnets.append(
SpatioTemporalResBlock(
in_channels=input_channels,
out_channels=out_channels,
temb_channels=None,
eps=1e-6,
temporal_eps=1e-5,
merge_factor=0.0,
merge_strategy="learned",
switch_spatial_to_temporal_mix=True,
)
)
self.resnets = nn.ModuleList(resnets)
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
else:
self.upsamplers = None
def forward(
self,
hidden_states: torch.FloatTensor,
image_only_indicator: torch.FloatTensor,
) -> torch.FloatTensor:
for resnet in self.resnets:
hidden_states = resnet(
hidden_states,
image_only_indicator=image_only_indicator,
)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
return hidden_states
class UNetMidBlockSpatioTemporal(nn.Module):
def __init__(
self,
in_channels: int,
temb_channels: int,
num_layers: int = 1,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
num_attention_heads: int = 1,
cross_attention_dim: int = 1280,
):
super().__init__()
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
# support for variable transformer layers per block
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
# there is always at least one resnet
resnets = [
SpatioTemporalResBlock(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=1e-5,
)
]
attentions = []
for i in range(num_layers):
attentions.append(
TransformerSpatioTemporalModel(
num_attention_heads,
in_channels // num_attention_heads,
in_channels=in_channels,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
)
)
resnets.append(
SpatioTemporalResBlock(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=1e-5,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
image_only_indicator: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
hidden_states = self.resnets[0](
hidden_states,
temb,
image_only_indicator=image_only_indicator,
)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if self.training and self.gradient_checkpointing: # TODO
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
image_only_indicator=image_only_indicator,
return_dict=False,
)[0]
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
image_only_indicator,
**ckpt_kwargs,
)
else:
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
image_only_indicator=image_only_indicator,
return_dict=False,
)[0]
hidden_states = resnet(
hidden_states,
temb,
image_only_indicator=image_only_indicator,
)
return hidden_states
class DownBlockSpatioTemporal(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
temb_channels: int,
num_layers: int = 1,
add_downsample: bool = True,
):
super().__init__()
resnets = []
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
SpatioTemporalResBlock(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=1e-5,
)
)
self.resnets = nn.ModuleList(resnets)
if add_downsample:
self.downsamplers = nn.ModuleList(
[
Downsample2D(
out_channels,
use_conv=True,
out_channels=out_channels,
name="op",
)
]
)
else:
self.downsamplers = None
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
image_only_indicator: Optional[torch.Tensor] = None,
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
output_states = ()
for resnet in self.resnets:
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
image_only_indicator,
use_reentrant=False,
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
image_only_indicator,
)
else:
hidden_states = resnet(
hidden_states,
temb,
image_only_indicator=image_only_indicator,
)
output_states = output_states + (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
output_states = output_states + (hidden_states,)
return hidden_states, output_states
class CrossAttnDownBlockSpatioTemporal(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
temb_channels: int,
num_layers: int = 1,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
num_attention_heads: int = 1,
cross_attention_dim: int = 1280,
add_downsample: bool = True,
):
super().__init__()
resnets = []
attentions = []
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
SpatioTemporalResBlock(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=1e-6,
)
)
attentions.append(
TransformerSpatioTemporalModel(
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
if add_downsample:
self.downsamplers = nn.ModuleList(
[
Downsample2D(
out_channels,
use_conv=True,
out_channels=out_channels,
padding=1,
name="op",
)
]
)
else:
self.downsamplers = None
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
image_only_indicator: Optional[torch.Tensor] = None,
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
output_states = ()
blocks = list(zip(self.resnets, self.attentions))
for resnet, attn in blocks:
if self.training and self.gradient_checkpointing: # TODO
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
image_only_indicator,
**ckpt_kwargs,
)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
image_only_indicator=image_only_indicator,
return_dict=False,
)[0]
else:
hidden_states = resnet(
hidden_states,
temb,
image_only_indicator=image_only_indicator,
)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
image_only_indicator=image_only_indicator,
return_dict=False,
)[0]
output_states = output_states + (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
output_states = output_states + (hidden_states,)
return hidden_states, output_states
class UpBlockSpatioTemporal(nn.Module):
def __init__(
self,
in_channels: int,
prev_output_channel: int,
out_channels: int,
temb_channels: int,
resolution_idx: Optional[int] = None,
num_layers: int = 1,
resnet_eps: float = 1e-6,
add_upsample: bool = True,
):
super().__init__()
resnets = []
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append(
SpatioTemporalResBlock(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
)
)
self.resnets = nn.ModuleList(resnets)
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
else:
self.upsamplers = None
self.gradient_checkpointing = False
self.resolution_idx = resolution_idx
def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
image_only_indicator: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
for resnet in self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
image_only_indicator,
use_reentrant=False,
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
image_only_indicator,
)
else:
hidden_states = resnet(
hidden_states,
temb,
image_only_indicator=image_only_indicator,
)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
return hidden_states
class CrossAttnUpBlockSpatioTemporal(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
prev_output_channel: int,
temb_channels: int,
resolution_idx: Optional[int] = None,
num_layers: int = 1,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
resnet_eps: float = 1e-6,
num_attention_heads: int = 1,
cross_attention_dim: int = 1280,
add_upsample: bool = True,
):
super().__init__()
resnets = []
attentions = []
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append(
SpatioTemporalResBlock(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
)
)
attentions.append(
TransformerSpatioTemporalModel(
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
else:
self.upsamplers = None
self.gradient_checkpointing = False
self.resolution_idx = resolution_idx
def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
image_only_indicator: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing: # TODO
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
image_only_indicator,
**ckpt_kwargs,
)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
image_only_indicator=image_only_indicator,
return_dict=False,
)[0]
else:
hidden_states = resnet(
hidden_states,
temb,
image_only_indicator=image_only_indicator,
)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
image_only_indicator=image_only_indicator,
return_dict=False,
)[0]
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
return hidden_states

View File

@@ -0,0 +1,489 @@
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import UNet2DConditionLoadersMixin
from ..utils import BaseOutput, logging
from .attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
from .embeddings import TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
from .unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class UNetSpatioTemporalConditionOutput(BaseOutput):
"""
The output of [`UNetSpatioTemporalConditionModel`].
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
"""
sample: torch.FloatTensor = None
class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
r"""
A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and returns a sample
shaped output.
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
Parameters:
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
Height and width of input/output sample.
in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`):
The tuple of downsample blocks to use.
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal")`):
The tuple of upsample blocks to use.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block.
addition_time_embed_dim: (`int`, defaults to 256):
Dimension to to encode the additional time ids.
projection_class_embeddings_input_dim (`int`, defaults to 768):
The dimension of the projection of encoded `added_time_ids`.
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
The dimension of the cross attention features.
transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
[`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`], [`~models.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`],
[`~models.unet_3d_blocks.UNetMidBlockSpatioTemporal`].
num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`):
The number of attention heads.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
sample_size: Optional[int] = None,
in_channels: int = 8,
out_channels: int = 4,
down_block_types: Tuple[str] = (
"CrossAttnDownBlockSpatioTemporal",
"CrossAttnDownBlockSpatioTemporal",
"CrossAttnDownBlockSpatioTemporal",
"DownBlockSpatioTemporal",
),
up_block_types: Tuple[str] = (
"UpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal",
),
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
addition_time_embed_dim: int = 256,
projection_class_embeddings_input_dim: int = 768,
layers_per_block: Union[int, Tuple[int]] = 2,
cross_attention_dim: Union[int, Tuple[int]] = 1024,
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
num_attention_heads: Union[int, Tuple[int]] = (5, 10, 10, 20),
num_frames: int = 25,
):
super().__init__()
self.sample_size = sample_size
# Check inputs
if len(down_block_types) != len(up_block_types):
raise ValueError(
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
)
if len(block_out_channels) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
)
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
)
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
)
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
)
# input
self.conv_in = nn.Conv2d(
in_channels,
block_out_channels[0],
kernel_size=3,
padding=1,
)
# time
time_embed_dim = block_out_channels[0] * 4
self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0)
timestep_input_dim = block_out_channels[0]
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0)
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
self.down_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([])
if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(down_block_types)
if isinstance(cross_attention_dim, int):
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
if isinstance(layers_per_block, int):
layers_per_block = [layers_per_block] * len(down_block_types)
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
blocks_time_embed_dim = time_embed_dim
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=layers_per_block[i],
transformer_layers_per_block=transformer_layers_per_block[i],
in_channels=input_channel,
out_channels=output_channel,
temb_channels=blocks_time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=1e-5,
cross_attention_dim=cross_attention_dim[i],
num_attention_heads=num_attention_heads[i],
resnet_act_fn="silu",
)
self.down_blocks.append(down_block)
# mid
self.mid_block = UNetMidBlockSpatioTemporal(
block_out_channels[-1],
temb_channels=blocks_time_embed_dim,
transformer_layers_per_block=transformer_layers_per_block[-1],
cross_attention_dim=cross_attention_dim[-1],
num_attention_heads=num_attention_heads[-1],
)
# count how many layers upsample the images
self.num_upsamplers = 0
# up
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_num_attention_heads = list(reversed(num_attention_heads))
reversed_layers_per_block = list(reversed(layers_per_block))
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
is_final_block = i == len(block_out_channels) - 1
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
# add upsample block for all BUT final layer
if not is_final_block:
add_upsample = True
self.num_upsamplers += 1
else:
add_upsample = False
up_block = get_up_block(
up_block_type,
num_layers=reversed_layers_per_block[i] + 1,
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=blocks_time_embed_dim,
add_upsample=add_upsample,
resnet_eps=1e-5,
resolution_idx=i,
cross_attention_dim=reversed_cross_attention_dim[i],
num_attention_heads=reversed_num_attention_heads[i],
resnet_act_fn="silu",
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-5)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(
block_out_channels[0],
out_channels,
kernel_size=3,
padding=1,
)
@property
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(
name: str,
module: torch.nn.Module,
processors: Dict[str, AttentionProcessor],
):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnProcessor()
else:
raise ValueError(
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
# Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
"""
Sets the attention processor to use [feed forward
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
Parameters:
chunk_size (`int`, *optional*):
The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
over each tensor of dim=`dim`.
dim (`int`, *optional*, defaults to `0`):
The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
or dim=1 (sequence length).
"""
if dim not in [0, 1]:
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
# By default chunk size is 1
chunk_size = chunk_size or 1
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
if hasattr(module, "set_chunk_feed_forward"):
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
for child in module.children():
fn_recursive_feed_forward(child, chunk_size, dim)
for module in self.children():
fn_recursive_feed_forward(module, chunk_size, dim)
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
added_time_ids: torch.Tensor,
return_dict: bool = True,
) -> Union[UNetSpatioTemporalConditionOutput, Tuple]:
r"""
The [`UNetSpatioTemporalConditionModel`] forward method.
Args:
sample (`torch.FloatTensor`):
The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`.
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
encoder_hidden_states (`torch.FloatTensor`):
The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`.
added_time_ids: (`torch.FloatTensor`):
The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal
embeddings and added to the time embeddings.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead of a plain
tuple.
Returns:
[`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`:
If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise
a `tuple` is returned where the first element is the sample tensor.
"""
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
batch_size, num_frames = sample.shape[:2]
timesteps = timesteps.expand(batch_size)
t_emb = self.time_proj(timesteps)
# `Timesteps` does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=sample.dtype)
emb = self.time_embedding(t_emb)
time_embeds = self.add_time_proj(added_time_ids.flatten())
time_embeds = time_embeds.reshape((batch_size, -1))
time_embeds = time_embeds.to(emb.dtype)
aug_emb = self.add_embedding(time_embeds)
emb = emb + aug_emb
# Flatten the batch and frames dimensions
# sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
sample = sample.flatten(0, 1)
# Repeat the embeddings num_video_frames times
# emb: [batch, channels] -> [batch * frames, channels]
emb = emb.repeat_interleave(num_frames, dim=0)
# encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0)
# 2. pre-process
sample = self.conv_in(sample)
image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device)
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
image_only_indicator=image_only_indicator,
)
else:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
image_only_indicator=image_only_indicator,
)
down_block_res_samples += res_samples
# 4. mid
sample = self.mid_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
image_only_indicator=image_only_indicator,
)
# 5. up
for i, upsample_block in enumerate(self.up_blocks):
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
image_only_indicator=image_only_indicator,
)
else:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
image_only_indicator=image_only_indicator,
)
# 6. post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
# 7. Reshape back to original shape
sample = sample.reshape(batch_size, num_frames, *sample.shape[1:])
if not return_dict:
return (sample,)
return UNetSpatioTemporalConditionOutput(sample=sample)

View File

@@ -22,7 +22,12 @@ from ..utils import BaseOutput, is_torch_version
from ..utils.torch_utils import randn_tensor
from .activations import get_activation
from .attention_processor import SpatialNorm
from .unet_2d_blocks import AutoencoderTinyBlock, UNetMidBlock2D, get_down_block, get_up_block
from .unet_2d_blocks import (
AutoencoderTinyBlock,
UNetMidBlock2D,
get_down_block,
get_up_block,
)
@dataclass
@@ -274,7 +279,9 @@ class Decoder(nn.Module):
self.gradient_checkpointing = False
def forward(
self, sample: torch.FloatTensor, latent_embeds: Optional[torch.FloatTensor] = None
self,
sample: torch.FloatTensor,
latent_embeds: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
r"""The forward method of the `Decoder` class."""
@@ -292,14 +299,20 @@ class Decoder(nn.Module):
if is_torch_version(">=", "1.11.0"):
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), sample, latent_embeds, use_reentrant=False
create_custom_forward(self.mid_block),
sample,
latent_embeds,
use_reentrant=False,
)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block), sample, latent_embeds, use_reentrant=False
create_custom_forward(up_block),
sample,
latent_embeds,
use_reentrant=False,
)
else:
# middle
@@ -540,7 +553,10 @@ class MaskConditionDecoder(nn.Module):
if is_torch_version(">=", "1.11.0"):
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), sample, latent_embeds, use_reentrant=False
create_custom_forward(self.mid_block),
sample,
latent_embeds,
use_reentrant=False,
)
sample = sample.to(upscale_dtype)
@@ -548,7 +564,10 @@ class MaskConditionDecoder(nn.Module):
if image is not None and mask is not None:
masked_image = (1 - mask) * image
im_x = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.condition_encoder), masked_image, mask, use_reentrant=False
create_custom_forward(self.condition_encoder),
masked_image,
mask,
use_reentrant=False,
)
# up
@@ -558,7 +577,10 @@ class MaskConditionDecoder(nn.Module):
mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
sample = sample * mask_ + sample_ * (1 - mask_)
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block), sample, latent_embeds, use_reentrant=False
create_custom_forward(up_block),
sample,
latent_embeds,
use_reentrant=False,
)
if image is not None and mask is not None:
sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
@@ -573,7 +595,9 @@ class MaskConditionDecoder(nn.Module):
if image is not None and mask is not None:
masked_image = (1 - mask) * image
im_x = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.condition_encoder), masked_image, mask
create_custom_forward(self.condition_encoder),
masked_image,
mask,
)
# up
@@ -754,7 +778,10 @@ class DiagonalGaussianDistribution(object):
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
# make sure sample is on the same device as the parameters and has same dtype
sample = randn_tensor(
self.mean.shape, generator=generator, device=self.parameters.device, dtype=self.parameters.dtype
self.mean.shape,
generator=generator,
device=self.parameters.device,
dtype=self.parameters.dtype,
)
x = self.mean + self.std * sample
return x
@@ -764,7 +791,10 @@ class DiagonalGaussianDistribution(object):
return torch.Tensor([0.0])
else:
if other is None:
return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3])
return 0.5 * torch.sum(
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=[1, 2, 3],
)
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
@@ -779,7 +809,10 @@ class DiagonalGaussianDistribution(object):
if self.deterministic:
return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims,
)
def mode(self) -> torch.Tensor:
return self.mean
@@ -820,7 +853,16 @@ class EncoderTiny(nn.Module):
if i == 0:
layers.append(nn.Conv2d(in_channels, num_channels, kernel_size=3, padding=1))
else:
layers.append(nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, stride=2, bias=False))
layers.append(
nn.Conv2d(
num_channels,
num_channels,
kernel_size=3,
padding=1,
stride=2,
bias=False,
)
)
for _ in range(num_block):
layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn))
@@ -899,7 +941,15 @@ class DecoderTiny(nn.Module):
layers.append(nn.Upsample(scale_factor=upsampling_scaling_factor))
conv_out_channel = num_channels if not is_final_block else out_channels
layers.append(nn.Conv2d(num_channels, conv_out_channel, kernel_size=3, padding=1, bias=is_final_block))
layers.append(
nn.Conv2d(
num_channels,
conv_out_channel,
kernel_size=3,
padding=1,
bias=is_final_block,
)
)
self.layers = nn.Sequential(*layers)
self.gradient_checkpointing = False

View File

@@ -17,7 +17,12 @@ from ..utils import (
# These modules contain pipelines from multiple libraries/frameworks
_dummy_objects = {}
_import_structure = {"stable_diffusion": [], "stable_diffusion_xl": [], "latent_diffusion": [], "controlnet": []}
_import_structure = {
"controlnet": [],
"latent_diffusion": [],
"stable_diffusion": [],
"stable_diffusion_xl": [],
}
try:
if not is_torch_available():
@@ -39,7 +44,11 @@ else:
_import_structure["dit"] = ["DiTPipeline"]
_import_structure["latent_diffusion"].extend(["LDMSuperResolutionPipeline"])
_import_structure["latent_diffusion_uncond"] = ["LDMPipeline"]
_import_structure["pipeline_utils"] = ["AudioPipelineOutput", "DiffusionPipeline", "ImagePipelineOutput"]
_import_structure["pipeline_utils"] = [
"AudioPipelineOutput",
"DiffusionPipeline",
"ImagePipelineOutput",
]
_import_structure["pndm"] = ["PNDMPipeline"]
_import_structure["repaint"] = ["RePaintPipeline"]
_import_structure["score_sde_ve"] = ["ScoreSdeVePipeline"]
@@ -61,7 +70,10 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["alt_diffusion"] = ["AltDiffusionImg2ImgPipeline", "AltDiffusionPipeline"]
_import_structure["alt_diffusion"] = [
"AltDiffusionImg2ImgPipeline",
"AltDiffusionPipeline",
]
_import_structure["animatediff"] = ["AnimateDiffPipeline"]
_import_structure["audioldm"] = ["AudioLDMPipeline"]
_import_structure["audioldm2"] = [
@@ -110,7 +122,10 @@ else:
"KandinskyV22PriorEmb2EmbPipeline",
"KandinskyV22PriorPipeline",
]
_import_structure["kandinsky3"] = ["Kandinsky3Img2ImgPipeline", "Kandinsky3Pipeline"]
_import_structure["kandinsky3"] = [
"Kandinsky3Img2ImgPipeline",
"Kandinsky3Pipeline",
]
_import_structure["latent_consistency_models"] = [
"LatentConsistencyModelImg2ImgPipeline",
"LatentConsistencyModelPipeline",
@@ -150,6 +165,7 @@ else:
]
)
_import_structure["stable_diffusion_safe"] = ["StableDiffusionPipelineSafe"]
_import_structure["stable_video_diffusion"] = ["StableVideoDiffusionPipeline"]
_import_structure["stable_diffusion_xl"].extend(
[
"StableDiffusionXLImg2ImgPipeline",
@@ -158,7 +174,10 @@ else:
"StableDiffusionXLPipeline",
]
)
_import_structure["t2i_adapter"] = ["StableDiffusionAdapterPipeline", "StableDiffusionXLAdapterPipeline"]
_import_structure["t2i_adapter"] = [
"StableDiffusionAdapterPipeline",
"StableDiffusionXLAdapterPipeline",
]
_import_structure["text_to_video_synthesis"] = [
"TextToVideoSDPipeline",
"TextToVideoZeroPipeline",
@@ -216,7 +235,9 @@ try:
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403
from ..utils import (
dummy_torch_and_transformers_and_k_diffusion_objects,
)
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects))
else:
@@ -259,7 +280,10 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_transformers_and_torch_and_note_seq_objects))
else:
_import_structure["spectrogram_diffusion"] = ["MidiProcessor", "SpectrogramDiffusionPipeline"]
_import_structure["spectrogram_diffusion"] = [
"MidiProcessor",
"SpectrogramDiffusionPipeline",
]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -269,7 +293,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from ..utils.dummy_pt_objects import * # noqa F403
else:
from .auto_pipeline import AutoPipelineForImage2Image, AutoPipelineForInpainting, AutoPipelineForText2Image
from .auto_pipeline import (
AutoPipelineForImage2Image,
AutoPipelineForInpainting,
AutoPipelineForText2Image,
)
from .consistency_models import ConsistencyModelPipeline
from .dance_diffusion import DanceDiffusionPipeline
from .ddim import DDIMPipeline
@@ -277,7 +305,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .dit import DiTPipeline
from .latent_diffusion import LDMSuperResolutionPipeline
from .latent_diffusion_uncond import LDMPipeline
from .pipeline_utils import AudioPipelineOutput, DiffusionPipeline, ImagePipelineOutput
from .pipeline_utils import (
AudioPipelineOutput,
DiffusionPipeline,
ImagePipelineOutput,
)
from .pndm import PNDMPipeline
from .repaint import RePaintPipeline
from .score_sde_ve import ScoreSdeVePipeline
@@ -300,7 +332,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
from .animatediff import AnimateDiffPipeline
from .audioldm import AudioLDMPipeline
from .audioldm2 import AudioLDM2Pipeline, AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
from .audioldm2 import (
AudioLDM2Pipeline,
AudioLDM2ProjectionModel,
AudioLDM2UNet2DConditionModel,
)
from .blip_diffusion import BlipDiffusionPipeline
from .controlnet import (
BlipDiffusionControlNetPipeline,
@@ -344,7 +380,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Kandinsky3Img2ImgPipeline,
Kandinsky3Pipeline,
)
from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline
from .latent_consistency_models import (
LatentConsistencyModelImg2ImgPipeline,
LatentConsistencyModelPipeline,
)
from .latent_diffusion import LDMTextToImagePipeline
from .musicldm import MusicLDMPipeline
from .paint_by_example import PaintByExamplePipeline
@@ -383,7 +422,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionXLInstructPix2PixPipeline,
StableDiffusionXLPipeline,
)
from .t2i_adapter import StableDiffusionAdapterPipeline, StableDiffusionXLAdapterPipeline
from .stable_video_diffusion import StableVideoDiffusionPipeline
from .t2i_adapter import (
StableDiffusionAdapterPipeline,
StableDiffusionXLAdapterPipeline,
)
from .text_to_video_synthesis import (
TextToVideoSDPipeline,
TextToVideoZeroPipeline,
@@ -473,7 +516,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from ..utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403
else:
from .spectrogram_diffusion import MidiProcessor, SpectrogramDiffusionPipeline
from .spectrogram_diffusion import (
MidiProcessor,
SpectrogramDiffusionPipeline,
)
else:
import sys

View File

@@ -55,7 +55,9 @@ try:
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import StableDiffusionImageVariationPipeline
from ...utils.dummy_torch_and_transformers_objects import (
StableDiffusionImageVariationPipeline,
)
_dummy_objects.update({"StableDiffusionImageVariationPipeline": StableDiffusionImageVariationPipeline})
else:
@@ -90,7 +92,9 @@ try:
):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403
from ...utils import (
dummy_torch_and_transformers_and_k_diffusion_objects,
)
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects))
else:
@@ -137,18 +141,32 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionPipelineOutput,
StableDiffusionSafetyChecker,
)
from .pipeline_stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline
from .pipeline_stable_diffusion_attend_and_excite import (
StableDiffusionAttendAndExcitePipeline,
)
from .pipeline_stable_diffusion_gligen import StableDiffusionGLIGENPipeline
from .pipeline_stable_diffusion_gligen_text_image import StableDiffusionGLIGENTextImagePipeline
from .pipeline_stable_diffusion_gligen_text_image import (
StableDiffusionGLIGENTextImagePipeline,
)
from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy
from .pipeline_stable_diffusion_instruct_pix2pix import StableDiffusionInstructPix2PixPipeline
from .pipeline_stable_diffusion_latent_upscale import StableDiffusionLatentUpscalePipeline
from .pipeline_stable_diffusion_inpaint_legacy import (
StableDiffusionInpaintPipelineLegacy,
)
from .pipeline_stable_diffusion_instruct_pix2pix import (
StableDiffusionInstructPix2PixPipeline,
)
from .pipeline_stable_diffusion_latent_upscale import (
StableDiffusionLatentUpscalePipeline,
)
from .pipeline_stable_diffusion_ldm3d import StableDiffusionLDM3DPipeline
from .pipeline_stable_diffusion_model_editing import StableDiffusionModelEditingPipeline
from .pipeline_stable_diffusion_model_editing import (
StableDiffusionModelEditingPipeline,
)
from .pipeline_stable_diffusion_panorama import StableDiffusionPanoramaPipeline
from .pipeline_stable_diffusion_paradigms import StableDiffusionParadigmsPipeline
from .pipeline_stable_diffusion_paradigms import (
StableDiffusionParadigmsPipeline,
)
from .pipeline_stable_diffusion_sag import StableDiffusionSAGPipeline
from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline
from .pipeline_stable_unclip import StableUnCLIPPipeline
@@ -160,9 +178,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import StableDiffusionImageVariationPipeline
from ...utils.dummy_torch_and_transformers_objects import (
StableDiffusionImageVariationPipeline,
)
else:
from .pipeline_stable_diffusion_image_variation import StableDiffusionImageVariationPipeline
from .pipeline_stable_diffusion_image_variation import (
StableDiffusionImageVariationPipeline,
)
try:
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.26.0")):
@@ -174,9 +196,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionPix2PixZeroPipeline,
)
else:
from .pipeline_stable_diffusion_depth2img import StableDiffusionDepth2ImgPipeline
from .pipeline_stable_diffusion_depth2img import (
StableDiffusionDepth2ImgPipeline,
)
from .pipeline_stable_diffusion_diffedit import StableDiffusionDiffEditPipeline
from .pipeline_stable_diffusion_pix2pix_zero import StableDiffusionPix2PixZeroPipeline
from .pipeline_stable_diffusion_pix2pix_zero import (
StableDiffusionPix2PixZeroPipeline,
)
try:
if not (
@@ -189,7 +215,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_and_k_diffusion_objects import *
else:
from .pipeline_stable_diffusion_k_diffusion import StableDiffusionKDiffusionPipeline
from .pipeline_stable_diffusion_k_diffusion import (
StableDiffusionKDiffusionPipeline,
)
try:
if not (is_transformers_available() and is_onnx_available()):
@@ -197,11 +225,22 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from ...utils.dummy_onnx_objects import *
else:
from .pipeline_onnx_stable_diffusion import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline
from .pipeline_onnx_stable_diffusion_img2img import OnnxStableDiffusionImg2ImgPipeline
from .pipeline_onnx_stable_diffusion_inpaint import OnnxStableDiffusionInpaintPipeline
from .pipeline_onnx_stable_diffusion_inpaint_legacy import OnnxStableDiffusionInpaintPipelineLegacy
from .pipeline_onnx_stable_diffusion_upscale import OnnxStableDiffusionUpscalePipeline
from .pipeline_onnx_stable_diffusion import (
OnnxStableDiffusionPipeline,
StableDiffusionOnnxPipeline,
)
from .pipeline_onnx_stable_diffusion_img2img import (
OnnxStableDiffusionImg2ImgPipeline,
)
from .pipeline_onnx_stable_diffusion_inpaint import (
OnnxStableDiffusionInpaintPipeline,
)
from .pipeline_onnx_stable_diffusion_inpaint_legacy import (
OnnxStableDiffusionInpaintPipelineLegacy,
)
from .pipeline_onnx_stable_diffusion_upscale import (
OnnxStableDiffusionUpscalePipeline,
)
try:
if not (is_transformers_available() and is_flax_available()):
@@ -210,8 +249,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from ...utils.dummy_flax_objects import *
else:
from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline
from .pipeline_flax_stable_diffusion_img2img import FlaxStableDiffusionImg2ImgPipeline
from .pipeline_flax_stable_diffusion_inpaint import FlaxStableDiffusionInpaintPipeline
from .pipeline_flax_stable_diffusion_img2img import (
FlaxStableDiffusionImg2ImgPipeline,
)
from .pipeline_flax_stable_diffusion_inpaint import (
FlaxStableDiffusionInpaintPipeline,
)
from .pipeline_output import FlaxStableDiffusionPipelineOutput
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker

View File

@@ -0,0 +1,58 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
BaseOutput,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure.update(
{
"pipeline_stable_video_diffusion": [
"StableVideoDiffusionPipeline",
"StableVideoDiffusionPipelineOutput",
],
}
)
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_stable_video_diffusion import (
StableVideoDiffusionPipeline,
StableVideoDiffusionPipelineOutput,
)
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -0,0 +1,649 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Union
import numpy as np
import PIL.Image
import torch
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
from ...schedulers import EulerDiscreteScheduler
from ...utils import BaseOutput, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def _append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
return x[(...,) + (None,) * dims_to_append]
def tensor2vid(video: torch.Tensor, processor, output_type="np"):
# Based on:
# https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
batch_size, channels, num_frames, height, width = video.shape
outputs = []
for batch_idx in range(batch_size):
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
batch_output = processor.postprocess(batch_vid, output_type)
outputs.append(batch_output)
return outputs
@dataclass
class StableVideoDiffusionPipelineOutput(BaseOutput):
r"""
Output class for zero-shot text-to-video pipeline.
Args:
frames (`[List[PIL.Image.Image]`, `np.ndarray`]):
List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
num_channels)`.
"""
frames: Union[List[PIL.Image.Image], np.ndarray]
class StableVideoDiffusionPipeline(DiffusionPipeline):
r"""
Pipeline to generate video from an input image using Stable Video Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
image_encoder ([`~transformers.CLIPVisionModelWithProjection`]):
Frozen CLIP image-encoder ([laion/CLIP-ViT-H-14-laion2B-s32B-b79K](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K)).
unet ([`UNetSpatioTemporalConditionModel`]):
A `UNetSpatioTemporalConditionModel` to denoise the encoded image latents.
scheduler ([`EulerDiscreteScheduler`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images.
"""
model_cpu_offload_seq = "image_encoder->unet->vae"
_callback_tensor_inputs = ["latents"]
def __init__(
self,
vae: AutoencoderKLTemporalDecoder,
image_encoder: CLIPVisionModelWithProjection,
unet: UNetSpatioTemporalConditionModel,
scheduler: EulerDiscreteScheduler,
feature_extractor: CLIPImageProcessor,
):
super().__init__()
self.register_modules(
vae=vae,
image_encoder=image_encoder,
unet=unet,
scheduler=scheduler,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
def _encode_image(self, image, device, num_videos_per_prompt, do_classifier_free_guidance):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.image_processor.pil_to_numpy(image)
image = self.image_processor.numpy_to_pt(image)
# We normalize the image before resizing to match with the original implementation.
# Then we unnormalize it after resizing.
image = image * 2.0 - 1.0
image = _resize_with_antialiasing(image, (224, 224))
image = (image + 1.0) / 2.0
# Normalize the image with for CLIP input
image = self.feature_extractor(
images=image,
do_normalize=True,
do_center_crop=False,
do_resize=False,
do_rescale=False,
return_tensors="pt",
).pixel_values
image = image.to(device=device, dtype=dtype)
image_embeddings = self.image_encoder(image).image_embeds
image_embeddings = image_embeddings.unsqueeze(1)
# duplicate image embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = image_embeddings.shape
image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1)
image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
if do_classifier_free_guidance:
negative_image_embeddings = torch.zeros_like(image_embeddings)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
image_embeddings = torch.cat([negative_image_embeddings, image_embeddings])
return image_embeddings
def _encode_vae_image(
self,
image: torch.Tensor,
device,
num_videos_per_prompt,
do_classifier_free_guidance,
):
image = image.to(device=device)
image_latents = self.vae.encode(image).latent_dist.mode()
if do_classifier_free_guidance:
negative_image_latents = torch.zeros_like(image_latents)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
image_latents = torch.cat([negative_image_latents, image_latents])
# duplicate image_latents for each generation per prompt, using mps friendly method
image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1)
return image_latents
def _get_add_time_ids(
self,
fps,
motion_bucket_id,
noise_aug_strength,
dtype,
batch_size,
num_videos_per_prompt,
do_classifier_free_guidance,
):
add_time_ids = [fps, motion_bucket_id, noise_aug_strength]
passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids)
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
if expected_add_embed_dim != passed_add_embed_dim:
raise ValueError(
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
)
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1)
if do_classifier_free_guidance:
add_time_ids = torch.cat([add_time_ids, add_time_ids])
return add_time_ids
def decode_latents(self, latents, num_frames, decode_chunk_size=14):
# [batch, frames, channels, height, width] -> [batch*frames, channels, height, width]
latents = latents.flatten(0, 1)
latents = 1 / self.vae.config.scaling_factor * latents
accepts_num_frames = "num_frames" in set(inspect.signature(self.vae.forward).parameters.keys())
# decode decode_chunk_size frames at a time to avoid OOM
frames = []
for i in range(0, latents.shape[0], decode_chunk_size):
num_frames_in = latents[i : i + decode_chunk_size].shape[0]
decode_kwargs = {}
if accepts_num_frames:
# we only pass num_frames_in if it's expected
decode_kwargs["num_frames"] = num_frames_in
frame = self.vae.decode(latents[i : i + decode_chunk_size], **decode_kwargs).sample
frames.append(frame)
frames = torch.cat(frames, dim=0)
# [batch*frames, channels, height, width] -> [batch, channels, frames, height, width]
frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
frames = frames.float()
return frames
def check_inputs(self, image, height, width):
if (
not isinstance(image, torch.Tensor)
and not isinstance(image, PIL.Image.Image)
and not isinstance(image, list)
):
raise ValueError(
"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
f" {type(image)}"
)
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
def prepare_latents(
self,
batch_size,
num_frames,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
shape = (
batch_size,
num_frames,
num_channels_latents // 2,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
@property
def guidance_scale(self):
return self._guidance_scale
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
@property
def num_timesteps(self):
return self._num_timesteps
@torch.no_grad()
def __call__(
self,
image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
height: int = 576,
width: int = 1024,
num_frames: Optional[int] = None,
num_inference_steps: int = 25,
min_guidance_scale: float = 1.0,
max_guidance_scale: float = 3.0,
fps: int = 7,
motion_bucket_id: int = 127,
noise_aug_strength: int = 0.02,
decode_chunk_size: Optional[int] = None,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
return_dict: bool = True,
):
r"""
The call function to the pipeline for generation.
Args:
image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
Image or images to guide image generation. If you provide a tensor, it needs to be compatible with
[`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json).
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The width in pixels of the generated image.
num_frames (`int`, *optional*):
The number of video frames to generate. Defaults to 14 for `stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt`
num_inference_steps (`int`, *optional*, defaults to 25):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. This parameter is modulated by `strength`.
min_guidance_scale (`float`, *optional*, defaults to 1.0):
The minimum guidance scale. Used for the classifier free guidance with first frame.
max_guidance_scale (`float`, *optional*, defaults to 3.0):
The maximum guidance scale. Used for the classifier free guidance with last frame.
fps (`int`, *optional*, defaults to 7):
Frames per second. The rate at which the generated images shall be exported to a video after generation.
Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training.
motion_bucket_id (`int`, *optional*, defaults to 127):
The motion bucket ID. Used as conditioning for the generation. The higher the number the more motion will be in the video.
noise_aug_strength (`int`, *optional*, defaults to 0.02):
The amount of noise added to the init image, the higher it is the less the video will look like the init image. Increase it for more motion.
decode_chunk_size (`int`, *optional*):
The number of frames to decode at a time. The higher the chunk size, the higher the temporal consistency
between frames, but also the higher the memory consumption. By default, the decoder will decode all frames at once
for maximal quality. Reduce `decode_chunk_size` to reduce memory usage.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
Returns:
[`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is returned,
otherwise a `tuple` is returned where the first element is a list of list with the generated frames.
Examples:
```py
from diffusers import StableVideoDiffusionPipeline
from diffusers.utils import load_image, export_to_video
pipe = StableVideoDiffusionPipeline.from_pretrained("stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16")
pipe.to("cuda")
image = load_image("https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200")
image = image.resize((1024, 576))
frames = pipe(image, num_frames=25, decode_chunk_size=8).frames[0]
export_to_video(frames, "generated.mp4", fps=7)
```
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
num_frames = num_frames if num_frames is not None else self.unet.config.num_frames
decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames
# 1. Check inputs. Raise error if not correct
self.check_inputs(image, height, width)
# 2. Define call parameters
if isinstance(image, PIL.Image.Image):
batch_size = 1
elif isinstance(image, list):
batch_size = len(image)
else:
batch_size = image.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = max_guidance_scale > 1.0
# 3. Encode input image
image_embeddings = self._encode_image(image, device, num_videos_per_prompt, do_classifier_free_guidance)
# NOTE: Stable Diffusion Video was conditioned on fps - 1, which
# is why it is reduced here.
# See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188
fps = fps - 1
# 4. Encode input image using VAE
image = self.image_processor.preprocess(image, height=height, width=width)
noise = randn_tensor(image.shape, generator=generator, device=image.device, dtype=image.dtype)
image = image + noise_aug_strength * noise
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
if needs_upcasting:
self.vae.to(dtype=torch.float32)
image_latents = self._encode_vae_image(image, device, num_videos_per_prompt, do_classifier_free_guidance)
image_latents = image_latents.to(image_embeddings.dtype)
# cast back to fp16 if needed
if needs_upcasting:
self.vae.to(dtype=torch.float16)
# Repeat the image latents for each frame so we can concatenate them with the noise
# image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width]
image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
# 5. Get Added Time IDs
added_time_ids = self._get_add_time_ids(
fps,
motion_bucket_id,
noise_aug_strength,
image_embeddings.dtype,
batch_size,
num_videos_per_prompt,
do_classifier_free_guidance,
)
added_time_ids = added_time_ids.to(device)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
num_frames,
num_channels_latents,
height,
width,
image_embeddings.dtype,
device,
generator,
latents,
)
# 7. Prepare guidance scale
guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0)
guidance_scale = guidance_scale.to(device, latents.dtype)
guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1)
guidance_scale = _append_dims(guidance_scale, latents.ndim)
self._guidance_scale = guidance_scale
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# Concatenate image_latents over channels dimention
latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=image_embeddings,
added_time_ids=added_time_ids,
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if not output_type == "latent":
# cast back to fp16 if needed
if needs_upcasting:
self.vae.to(dtype=torch.float16)
frames = self.decode_latents(latents, num_frames, decode_chunk_size)
frames = tensor2vid(frames, self.image_processor, output_type=output_type)
else:
frames = latents
self.maybe_free_model_hooks()
if not return_dict:
return frames
return StableVideoDiffusionPipelineOutput(frames=frames)
# resizing utils
# TODO: clean up later
def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True):
h, w = input.shape[-2:]
factors = (h / size[0], w / size[1])
# First, we have to determine sigma
# Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171
sigmas = (
max((factors[0] - 1.0) / 2.0, 0.001),
max((factors[1] - 1.0) / 2.0, 0.001),
)
# Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma
# https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206
# But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now
ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3))
# Make sure it is odd
if (ks[0] % 2) == 0:
ks = ks[0] + 1, ks[1]
if (ks[1] % 2) == 0:
ks = ks[0], ks[1] + 1
input = _gaussian_blur2d(input, ks, sigmas)
output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners)
return output
def _compute_padding(kernel_size):
"""Compute padding tuple."""
# 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom)
# https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad
if len(kernel_size) < 2:
raise AssertionError(kernel_size)
computed = [k - 1 for k in kernel_size]
# for even kernels we need to do asymmetric padding :(
out_padding = 2 * len(kernel_size) * [0]
for i in range(len(kernel_size)):
computed_tmp = computed[-(i + 1)]
pad_front = computed_tmp // 2
pad_rear = computed_tmp - pad_front
out_padding[2 * i + 0] = pad_front
out_padding[2 * i + 1] = pad_rear
return out_padding
def _filter2d(input, kernel):
# prepare kernel
b, c, h, w = input.shape
tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype)
tmp_kernel = tmp_kernel.expand(-1, c, -1, -1)
height, width = tmp_kernel.shape[-2:]
padding_shape: list[int] = _compute_padding([height, width])
input = torch.nn.functional.pad(input, padding_shape, mode="reflect")
# kernel and input tensor reshape to align element-wise or batch-wise params
tmp_kernel = tmp_kernel.reshape(-1, 1, height, width)
input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1))
# convolve the tensor with the kernel.
output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
out = output.view(b, c, h, w)
return out
def _gaussian(window_size: int, sigma):
if isinstance(sigma, float):
sigma = torch.tensor([[sigma]])
batch_size = sigma.shape[0]
x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1)
if window_size % 2 == 0:
x = x + 0.5
gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0)))
return gauss / gauss.sum(-1, keepdim=True)
def _gaussian_blur2d(input, kernel_size, sigma):
if isinstance(sigma, tuple):
sigma = torch.tensor([sigma], dtype=input.dtype)
else:
sigma = sigma.to(dtype=input.dtype)
ky, kx = int(kernel_size[0]), int(kernel_size[1])
bs = sigma.shape[0]
kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1))
kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1))
out_x = _filter2d(input, kernel_x[..., None, :])
out = _filter2d(out_x, kernel_y[..., None])
return out

View File

@@ -323,8 +323,20 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
sigma_min: float = in_sigmas[-1].item()
sigma_max: float = in_sigmas[0].item()
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)

View File

@@ -358,8 +358,20 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
sigma_min: float = in_sigmas[-1].item()
sigma_max: float = in_sigmas[0].item()
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)

View File

@@ -358,8 +358,20 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
sigma_min: float = in_sigmas[-1].item()
sigma_max: float = in_sigmas[0].item()
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)

View File

@@ -357,8 +357,20 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
sigma_min: float = in_sigmas[-1].item()
sigma_max: float = in_sigmas[0].item()
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)

View File

@@ -144,7 +144,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
prediction_type: str = "epsilon",
interpolation_type: str = "linear",
use_karras_sigmas: Optional[bool] = False,
sigma_min: Optional[float] = None,
sigma_max: Optional[float] = None,
timestep_spacing: str = "linspace",
timestep_type: str = "discrete", # can be "discrete" or "continuous"
steps_offset: int = 0,
):
if trained_betas is not None:
@@ -164,13 +167,22 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas)
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
sigmas = torch.from_numpy(sigmas[::-1].copy()).to(dtype=torch.float32)
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
# setable values
self.num_inference_steps = None
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
self.timesteps = torch.from_numpy(timesteps)
# TODO: Support the full EDM scalings for all prediction types and timestep types
if timestep_type == "continuous" and prediction_type == "v_prediction":
self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas])
else:
self.timesteps = timesteps
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
self.is_scale_input_called = False
self.use_karras_sigmas = use_karras_sigmas
@@ -268,10 +280,15 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas).to(device=device)
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
self.timesteps = torch.from_numpy(timesteps).to(device=device)
# TODO: Support the full EDM scalings for all prediction types and timestep types
if self.config.timestep_type == "continuous" and self.config.prediction_type == "v_prediction":
self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas]).to(device=device)
else:
self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device)
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
self._step_index = None
def _sigma_to_t(self, sigma, log_sigmas):
@@ -301,8 +318,20 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
sigma_min: float = in_sigmas[-1].item()
sigma_max: float = in_sigmas[0].item()
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)
@@ -412,7 +441,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
elif self.config.prediction_type == "epsilon":
pred_original_sample = sample - sigma_hat * model_output
elif self.config.prediction_type == "v_prediction":
# * c_out + input * c_skip
# denoised = model_output * c_out + input * c_skip
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
else:
raise ValueError(

View File

@@ -303,8 +303,20 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
sigma_min: float = in_sigmas[-1].item()
sigma_max: float = in_sigmas[0].item()
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)

View File

@@ -324,8 +324,20 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
sigma_min: float = in_sigmas[-1].item()
sigma_max: float = in_sigmas[0].item()
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)

View File

@@ -335,8 +335,20 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
sigma_min: float = in_sigmas[-1].item()
sigma_max: float = in_sigmas[0].item()
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)

View File

@@ -337,8 +337,20 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
sigma_min: float = in_sigmas[-1].item()
sigma_max: float = in_sigmas[0].item()
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)

View File

@@ -32,6 +32,21 @@ class AutoencoderKL(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class AutoencoderKLTemporalDecoder(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AutoencoderTiny(metaclass=DummyObject):
_backends = ["torch"]
@@ -272,6 +287,21 @@ class UNetMotionModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class UNetSpatioTemporalConditionModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class VQModel(metaclass=DummyObject):
_backends = ["torch"]

View File

@@ -1172,6 +1172,21 @@ class StableUnCLIPPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class StableVideoDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class TextToVideoSDPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -3,7 +3,7 @@ import random
import struct
import tempfile
from contextlib import contextmanager
from typing import List
from typing import List, Union
import numpy as np
import PIL.Image
@@ -115,7 +115,9 @@ def export_to_obj(mesh, output_obj_path: str = None):
f.writelines("\n".join(combined_data))
def export_to_video(video_frames: List[np.ndarray], output_video_path: str = None) -> str:
def export_to_video(
video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 8
) -> str:
if is_opencv_available():
import cv2
else:
@@ -123,9 +125,12 @@ def export_to_video(video_frames: List[np.ndarray], output_video_path: str = Non
if output_video_path is None:
output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name
if isinstance(video_frames[0], PIL.Image.Image):
video_frames = [np.array(frame) for frame in video_frames]
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
h, w, c = video_frames[0].shape
video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=8, frameSize=(w, h))
video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=fps, frameSize=(w, h))
for i in range(len(video_frames)):
img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR)
video_writer.write(img)

View File

@@ -0,0 +1,289 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import unittest
import torch
from diffusers import UNetSpatioTemporalConditionModel
from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
torch_all_close,
torch_device,
)
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
logger = logging.get_logger(__name__)
enable_full_determinism()
class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNetSpatioTemporalConditionModel
main_input_name = "sample"
@property
def dummy_input(self):
batch_size = 2
num_frames = 2
num_channels = 4
sizes = (32, 32)
noise = floats_tensor((batch_size, num_frames, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size, 1, 32)).to(torch_device)
return {
"sample": noise,
"timestep": time_step,
"encoder_hidden_states": encoder_hidden_states,
"added_time_ids": self._get_add_time_ids(),
}
@property
def input_shape(self):
return (2, 2, 4, 32, 32)
@property
def output_shape(self):
return (4, 32, 32)
@property
def fps(self):
return 6
@property
def motion_bucket_id(self):
return 127
@property
def noise_aug_strength(self):
return 0.02
@property
def addition_time_embed_dim(self):
return 32
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"block_out_channels": (32, 64),
"down_block_types": (
"CrossAttnDownBlockSpatioTemporal",
"DownBlockSpatioTemporal",
),
"up_block_types": (
"UpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal",
),
"cross_attention_dim": 32,
"num_attention_heads": 8,
"out_channels": 4,
"in_channels": 4,
"layers_per_block": 2,
"sample_size": 32,
"projection_class_embeddings_input_dim": self.addition_time_embed_dim * 3,
"addition_time_embed_dim": self.addition_time_embed_dim,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def _get_add_time_ids(self, do_classifier_free_guidance=True):
add_time_ids = [self.fps, self.motion_bucket_id, self.noise_aug_strength]
passed_add_embed_dim = self.addition_time_embed_dim * len(add_time_ids)
expected_add_embed_dim = self.addition_time_embed_dim * 3
if expected_add_embed_dim != passed_add_embed_dim:
raise ValueError(
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
)
add_time_ids = torch.tensor([add_time_ids], device=torch_device)
add_time_ids = add_time_ids.repeat(1, 1)
if do_classifier_free_guidance:
add_time_ids = torch.cat([add_time_ids, add_time_ids])
return add_time_ids
@unittest.skip("Number of Norm Groups is not configurable")
def test_forward_with_norm_groups(self):
pass
@unittest.skip("Deprecated functionality")
def test_model_attention_slicing(self):
pass
@unittest.skip("Not supported")
def test_model_with_use_linear_projection(self):
pass
@unittest.skip("Not supported")
def test_model_with_simple_projection(self):
pass
@unittest.skip("Not supported")
def test_model_with_class_embeddings_concat(self):
pass
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_enable_works(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.enable_xformers_memory_efficient_attention()
assert (
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
== "XFormersAttnProcessor"
), "xformers is not enabled"
@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
def test_gradient_checkpointing(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
assert not model.is_gradient_checkpointing and model.training
out = model(**inputs_dict).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model.zero_grad()
labels = torch.randn_like(out)
loss = (out - labels).mean()
loss.backward()
# re-instantiate the model now enabling gradient checkpointing
model_2 = self.model_class(**init_dict)
# clone model
model_2.load_state_dict(model.state_dict())
model_2.to(torch_device)
model_2.enable_gradient_checkpointing()
assert model_2.is_gradient_checkpointing and model_2.training
out_2 = model_2(**inputs_dict).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model_2.zero_grad()
loss_2 = (out_2 - labels).mean()
loss_2.backward()
# compare the output and parameters gradients
self.assertTrue((loss - loss_2).abs() < 1e-5)
named_params = dict(model.named_parameters())
named_params_2 = dict(model_2.named_parameters())
for name, param in named_params.items():
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
def test_model_with_num_attention_heads_tuple(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["num_attention_heads"] = (8, 16)
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_model_with_cross_attention_dim_tuple(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["cross_attention_dim"] = (32, 32)
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_gradient_checkpointing_is_applied(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["num_attention_heads"] = (8, 16)
model_class_copy = copy.copy(self.model_class)
modules_with_gc_enabled = {}
# now monkey patch the following function:
# def _set_gradient_checkpointing(self, module, value=False):
# if hasattr(module, "gradient_checkpointing"):
# module.gradient_checkpointing = value
def _set_gradient_checkpointing_new(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
modules_with_gc_enabled[module.__class__.__name__] = True
model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new
model = model_class_copy(**init_dict)
model.enable_gradient_checkpointing()
EXPECTED_SET = {
"TransformerSpatioTemporalModel",
"CrossAttnDownBlockSpatioTemporal",
"DownBlockSpatioTemporal",
"UpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal",
"UNetMidBlockSpatioTemporal",
}
assert set(modules_with_gc_enabled.keys()) == EXPECTED_SET
assert all(modules_with_gc_enabled.values()), "All modules should be enabled"
def test_pickle(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["num_attention_heads"] = (8, 16)
model = self.model_class(**init_dict)
model.to(torch_device)
with torch.no_grad():
sample = model(**inputs_dict).sample
sample_copy = copy.copy(sample)
assert (sample - sample_copy).abs().max() < 1e-4

View File

@@ -23,6 +23,7 @@ from parameterized import parameterized
from diffusers import (
AsymmetricAutoencoderKL,
AutoencoderKL,
AutoencoderKLTemporalDecoder,
AutoencoderTiny,
ConsistencyDecoderVAE,
StableDiffusionPipeline,
@@ -248,11 +249,31 @@ class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
)
elif torch_device == "cpu":
expected_output_slice = torch.tensor(
[-0.1352, 0.0878, 0.0419, -0.0818, -0.1069, 0.0688, -0.1458, -0.4446, -0.0026]
[
-0.1352,
0.0878,
0.0419,
-0.0818,
-0.1069,
0.0688,
-0.1458,
-0.4446,
-0.0026,
]
)
else:
expected_output_slice = torch.tensor(
[-0.2421, 0.4642, 0.2507, -0.0438, 0.0682, 0.3160, -0.2018, -0.0727, 0.2485]
[
-0.2421,
0.4642,
0.2507,
-0.0438,
0.0682,
0.3160,
-0.2018,
-0.0727,
0.2485,
]
)
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
@@ -364,6 +385,93 @@ class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase):
...
class AutoncoderKLTemporalDecoderFastTests(ModelTesterMixin, unittest.TestCase):
model_class = AutoencoderKLTemporalDecoder
main_input_name = "sample"
base_precision = 1e-2
@property
def dummy_input(self):
batch_size = 3
num_channels = 3
sizes = (32, 32)
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
num_frames = 3
return {"sample": image, "num_frames": num_frames}
@property
def input_shape(self):
return (3, 32, 32)
@property
def output_shape(self):
return (3, 32, 32)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"block_out_channels": [32, 64],
"in_channels": 3,
"out_channels": 3,
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
"latent_channels": 4,
"layers_per_block": 2,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_forward_signature(self):
pass
def test_training(self):
pass
@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
def test_gradient_checkpointing(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
assert not model.is_gradient_checkpointing and model.training
out = model(**inputs_dict).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model.zero_grad()
labels = torch.randn_like(out)
loss = (out - labels).mean()
loss.backward()
# re-instantiate the model now enabling gradient checkpointing
model_2 = self.model_class(**init_dict)
# clone model
model_2.load_state_dict(model.state_dict())
model_2.to(torch_device)
model_2.enable_gradient_checkpointing()
assert model_2.is_gradient_checkpointing and model_2.training
out_2 = model_2(**inputs_dict).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model_2.zero_grad()
loss_2 = (out_2 - labels).mean()
loss_2.backward()
# compare the output and parameters gradients
self.assertTrue((loss - loss_2).abs() < 1e-5)
named_params = dict(model.named_parameters())
named_params_2 = dict(model_2.named_parameters())
for name, param in named_params.items():
if "post_quant_conv" in name:
continue
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
@slow
class AutoencoderTinyIntegrationTests(unittest.TestCase):
def tearDown(self):
@@ -609,7 +717,10 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
@parameterized.expand([(13,), (16,), (27,)])
@require_torch_gpu
@unittest.skipIf(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.")
@unittest.skipIf(
not is_xformers_available(),
reason="xformers is not required when using PyTorch 2.0.",
)
def test_stable_diffusion_decode_xformers_vs_2_0_fp16(self, seed):
model = self.get_sd_vae_model(fp16=True)
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64), fp16=True)
@@ -627,7 +738,10 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
@parameterized.expand([(13,), (16,), (37,)])
@require_torch_gpu
@unittest.skipIf(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.")
@unittest.skipIf(
not is_xformers_available(),
reason="xformers is not required when using PyTorch 2.0.",
)
def test_stable_diffusion_decode_xformers_vs_2_0(self, seed):
model = self.get_sd_vae_model()
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))
@@ -808,7 +922,10 @@ class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
@parameterized.expand([(13,), (16,), (37,)])
@require_torch_gpu
@unittest.skipIf(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.")
@unittest.skipIf(
not is_xformers_available(),
reason="xformers is not required when using PyTorch 2.0.",
)
def test_stable_diffusion_decode_xformers_vs_2_0(self, seed):
model = self.get_sd_vae_model()
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))
@@ -886,7 +1003,10 @@ class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
pipe.to(torch_device)
out = pipe(
"horse", num_inference_steps=2, output_type="pt", generator=torch.Generator("cpu").manual_seed(0)
"horse",
num_inference_steps=2,
output_type="pt",
generator=torch.Generator("cpu").manual_seed(0),
).images[0]
actual_output = out[:2, :2, :2].flatten().cpu()
@@ -916,7 +1036,8 @@ class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
actual_output = sample[0, :2, :2, :2].flatten().cpu()
expected_output = torch.tensor(
[-0.0111, -0.0125, -0.0017, -0.0007, 0.1257, 0.1465, 0.1450, 0.1471], dtype=torch.float16
[-0.0111, -0.0125, -0.0017, -0.0007, 0.1257, 0.1465, 0.1450, 0.1471],
dtype=torch.float16,
)
assert torch_all_close(actual_output, expected_output, atol=5e-3)
@@ -926,17 +1047,24 @@ class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
"openai/consistency-decoder", torch_dtype=torch.float16
) # TODO - update
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, vae=vae, safety_checker=None
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16,
vae=vae,
safety_checker=None,
)
pipe.to(torch_device)
out = pipe(
"horse", num_inference_steps=2, output_type="pt", generator=torch.Generator("cpu").manual_seed(0)
"horse",
num_inference_steps=2,
output_type="pt",
generator=torch.Generator("cpu").manual_seed(0),
).images[0]
actual_output = out[:2, :2, :2].flatten().cpu()
expected_output = torch.tensor(
[0.0000, 0.0249, 0.0000, 0.0000, 0.1709, 0.2773, 0.0471, 0.1035], dtype=torch.float16
[0.0000, 0.0249, 0.0000, 0.0000, 0.1709, 0.2773, 0.0471, 0.1035],
dtype=torch.float16,
)
assert torch_all_close(actual_output, expected_output, atol=5e-3)

View File

@@ -0,0 +1,523 @@
import gc
import random
import tempfile
import unittest
import numpy as np
import torch
from transformers import (
CLIPImageProcessor,
CLIPVisionConfig,
CLIPVisionModelWithProjection,
)
import diffusers
from diffusers import (
AutoencoderKLTemporalDecoder,
EulerDiscreteScheduler,
StableVideoDiffusionPipeline,
UNetSpatioTemporalConditionModel,
)
from diffusers.utils import is_accelerate_available, is_accelerate_version, load_image, logging
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import (
CaptureLogger,
disable_full_determinism,
enable_full_determinism,
floats_tensor,
numpy_cosine_similarity_distance,
require_torch_gpu,
slow,
torch_device,
)
from ..test_pipelines_common import PipelineTesterMixin
def to_np(tensor):
if isinstance(tensor, torch.Tensor):
tensor = tensor.detach().cpu().numpy()
return tensor
class StableVideoDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = StableVideoDiffusionPipeline
params = frozenset(["image"])
batch_params = frozenset(["image", "generator"])
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"return_dict",
]
)
def get_dummy_components(self):
torch.manual_seed(0)
unet = UNetSpatioTemporalConditionModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=8,
out_channels=4,
down_block_types=(
"CrossAttnDownBlockSpatioTemporal",
"DownBlockSpatioTemporal",
),
up_block_types=("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal"),
cross_attention_dim=32,
num_attention_heads=8,
projection_class_embeddings_input_dim=96,
addition_time_embed_dim=32,
)
scheduler = EulerDiscreteScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
interpolation_type="linear",
num_train_timesteps=1000,
prediction_type="v_prediction",
sigma_max=700.0,
sigma_min=0.002,
steps_offset=1,
timestep_spacing="leading",
timestep_type="continuous",
trained_betas=None,
use_karras_sigmas=True,
)
torch.manual_seed(0)
vae = AutoencoderKLTemporalDecoder(
block_out_channels=[32, 64],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
latent_channels=4,
)
torch.manual_seed(0)
config = CLIPVisionConfig(
hidden_size=32,
projection_dim=32,
num_hidden_layers=5,
num_attention_heads=4,
image_size=32,
intermediate_size=37,
patch_size=1,
)
image_encoder = CLIPVisionModelWithProjection(config)
torch.manual_seed(0)
feature_extractor = CLIPImageProcessor(crop_size=32, size=32)
components = {
"unet": unet,
"image_encoder": image_encoder,
"scheduler": scheduler,
"vae": vae,
"feature_extractor": feature_extractor,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
image = floats_tensor((1, 3, 32, 32), rng=random.Random(0)).to(device)
inputs = {
"generator": generator,
"image": image,
"num_inference_steps": 2,
"output_type": "pt",
"min_guidance_scale": 1.0,
"max_guidance_scale": 2.5,
"num_frames": 2,
"height": 32,
"width": 32,
}
return inputs
@unittest.skip("Deprecated functionality")
def test_attention_slicing_forward_pass(self):
pass
@unittest.skip("Batched inference works and outputs look correct, but the test is failing")
def test_inference_batch_single_identical(
self,
batch_size=2,
expected_max_diff=1e-4,
):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for components in pipe.components.values():
if hasattr(components, "set_default_attn_processor"):
components.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
# Reset generator in case it is has been used in self.get_dummy_inputs
inputs["generator"] = torch.Generator("cpu").manual_seed(0)
logger = logging.get_logger(pipe.__module__)
logger.setLevel(level=diffusers.logging.FATAL)
# batchify inputs
batched_inputs = {}
batched_inputs.update(inputs)
batched_inputs["generator"] = [torch.Generator("cpu").manual_seed(0) for i in range(batch_size)]
batched_inputs["image"] = torch.cat([inputs["image"]] * batch_size, dim=0)
output = pipe(**inputs).frames
output_batch = pipe(**batched_inputs).frames
assert len(output_batch) == batch_size
max_diff = np.abs(to_np(output_batch[0]) - to_np(output[0])).max()
assert max_diff < expected_max_diff
@unittest.skip("Test is similar to test_inference_batch_single_identical")
def test_inference_batch_consistent(self):
pass
def test_dict_tuple_outputs_equivalent(self, expected_max_difference=1e-4):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
generator_device = "cpu"
output = pipe(**self.get_dummy_inputs(generator_device)).frames[0]
output_tuple = pipe(**self.get_dummy_inputs(generator_device), return_dict=False)[0]
max_diff = np.abs(to_np(output) - to_np(output_tuple)).max()
self.assertLess(max_diff, expected_max_difference)
@unittest.skip("Test is currently failing")
def test_float16_inference(self, expected_max_diff=5e-2):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
components = self.get_dummy_components()
pipe_fp16 = self.pipeline_class(**components)
for component in pipe_fp16.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe_fp16.to(torch_device, torch.float16)
pipe_fp16.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
output = pipe(**inputs).frames[0]
fp16_inputs = self.get_dummy_inputs(torch_device)
output_fp16 = pipe_fp16(**fp16_inputs).frames[0]
max_diff = np.abs(to_np(output) - to_np(output_fp16)).max()
self.assertLess(max_diff, expected_max_diff, "The outputs of the fp16 and fp32 pipelines are too different.")
@unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA")
def test_save_load_float16(self, expected_max_diff=1e-2):
components = self.get_dummy_components()
for name, module in components.items():
if hasattr(module, "half"):
components[name] = module.to(torch_device).half()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
output = pipe(**inputs).frames[0]
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir)
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, torch_dtype=torch.float16)
for component in pipe_loaded.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe_loaded.to(torch_device)
pipe_loaded.set_progress_bar_config(disable=None)
for name, component in pipe_loaded.components.items():
if hasattr(component, "dtype"):
self.assertTrue(
component.dtype == torch.float16,
f"`{name}.dtype` switched from `float16` to {component.dtype} after loading.",
)
inputs = self.get_dummy_inputs(torch_device)
output_loaded = pipe_loaded(**inputs).frames[0]
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
self.assertLess(
max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading."
)
def test_save_load_optional_components(self, expected_max_difference=1e-4):
if not hasattr(self.pipeline_class, "_optional_components"):
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
# set all optional components to None
for optional_component in pipe._optional_components:
setattr(pipe, optional_component, None)
generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
output = pipe(**inputs).frames[0]
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir, safe_serialization=False)
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
for component in pipe_loaded.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe_loaded.to(torch_device)
pipe_loaded.set_progress_bar_config(disable=None)
for optional_component in pipe._optional_components:
self.assertTrue(
getattr(pipe_loaded, optional_component) is None,
f"`{optional_component}` did not stay set to None after loading.",
)
inputs = self.get_dummy_inputs(generator_device)
output_loaded = pipe_loaded(**inputs).frames[0]
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
self.assertLess(max_diff, expected_max_difference)
def test_save_load_local(self, expected_max_difference=9e-4):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
output = pipe(**inputs).frames[0]
logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
logger.setLevel(diffusers.logging.INFO)
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir, safe_serialization=False)
with CaptureLogger(logger) as cap_logger:
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
for name in pipe_loaded.components.keys():
if name not in pipe_loaded._optional_components:
assert name in str(cap_logger)
pipe_loaded.to(torch_device)
pipe_loaded.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
output_loaded = pipe_loaded(**inputs).frames[0]
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
self.assertLess(max_diff, expected_max_difference)
@unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices")
def test_to_device(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
pipe.to("cpu")
model_devices = [
component.device.type for component in pipe.components.values() if hasattr(component, "device")
]
self.assertTrue(all(device == "cpu" for device in model_devices))
output_cpu = pipe(**self.get_dummy_inputs("cpu")).frames[0]
self.assertTrue(np.isnan(output_cpu).sum() == 0)
pipe.to("cuda")
model_devices = [
component.device.type for component in pipe.components.values() if hasattr(component, "device")
]
self.assertTrue(all(device == "cuda" for device in model_devices))
output_cuda = pipe(**self.get_dummy_inputs("cuda")).frames[0]
self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0)
def test_to_dtype(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
pipe.to(torch_dtype=torch.float16)
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
@unittest.skipIf(
torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.14.0"),
reason="CPU offload is only available with CUDA and `accelerate v0.14.0` or higher",
)
def test_sequential_cpu_offload_forward_pass(self, expected_max_diff=1e-4):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
output_without_offload = pipe(**inputs).frames[0]
pipe.enable_sequential_cpu_offload()
inputs = self.get_dummy_inputs(generator_device)
output_with_offload = pipe(**inputs).frames[0]
max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
self.assertLess(max_diff, expected_max_diff, "CPU offloading should not affect the inference results")
@unittest.skipIf(
torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.17.0"),
reason="CPU offload is only available with CUDA and `accelerate v0.17.0` or higher",
)
def test_model_cpu_offload_forward_pass(self, expected_max_diff=2e-4):
generator_device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(generator_device)
output_without_offload = pipe(**inputs).frames[0]
pipe.enable_model_cpu_offload()
inputs = self.get_dummy_inputs(generator_device)
output_with_offload = pipe(**inputs).frames[0]
max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
self.assertLess(max_diff, expected_max_diff, "CPU offloading should not affect the inference results")
offloaded_modules = [
v
for k, v in pipe.components.items()
if isinstance(v, torch.nn.Module) and k not in pipe._exclude_from_cpu_offload
]
(
self.assertTrue(all(v.device.type == "cpu" for v in offloaded_modules)),
f"Not offloaded: {[v for v in offloaded_modules if v.device.type != 'cpu']}",
)
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_attention_forwardGenerator_pass(self):
disable_full_determinism()
expected_max_diff = 9e-4
if not self.test_xformers_attention:
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
output_without_offload = pipe(**inputs).frames[0]
output_without_offload = (
output_without_offload.cpu() if torch.is_tensor(output_without_offload) else output_without_offload
)
pipe.enable_xformers_memory_efficient_attention()
inputs = self.get_dummy_inputs(torch_device)
output_with_offload = pipe(**inputs).frames[0]
output_with_offload = (
output_with_offload.cpu() if torch.is_tensor(output_with_offload) else output_without_offload
)
max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
self.assertLess(max_diff, expected_max_diff, "XFormers attention should not affect the inference results")
enable_full_determinism()
@slow
@require_torch_gpu
class StableVideoDiffusionPipelineSlowTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_sd_video(self):
pipe = StableVideoDiffusionPipeline.from_pretrained(
"stabilityai/stable-video-diffusion-img2vid",
variant="fp16",
torch_dtype=torch.float16,
)
pipe = pipe.to(torch_device)
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat_6.png?download=true"
)
generator = torch.Generator("cpu").manual_seed(0)
num_frames = 3
output = pipe(
image=image,
num_frames=num_frames,
generator=generator,
num_inference_steps=3,
output_type="np",
)
image = output.frames[0]
assert image.shape == (num_frames, 576, 1024, 3)
image_slice = image[0, -3:, -3:, -1]
expected_slice = np.array([0.8592, 0.8645, 0.8499, 0.8722, 0.8769, 0.8421, 0.8557, 0.8528, 0.8285])
assert numpy_cosine_similarity_distance(image_slice.flatten(), expected_slice.flatten()) < 1e-3

View File

@@ -37,6 +37,14 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest):
for prediction_type in ["epsilon", "v_prediction"]:
self.check_over_configs(prediction_type=prediction_type)
def test_timestep_type(self):
timestep_types = ["discrete", "continuous"]
for timestep_type in timestep_types:
self.check_over_configs(timestep_type=timestep_type)
def test_karras_sigmas(self):
self.check_over_configs(use_karras_sigmas=True, sigma_min=0.02, sigma_max=700.0)
def test_full_loop_no_noise(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()

View File

@@ -352,8 +352,8 @@ class SchedulerCommonTest(unittest.TestCase):
_ = scheduler.scale_model_input(sample, scaled_sigma_max)
_ = new_scheduler.scale_model_input(sample, scaled_sigma_max)
elif scheduler_class != VQDiffusionScheduler:
_ = scheduler.scale_model_input(sample, 0)
_ = new_scheduler.scale_model_input(sample, 0)
_ = scheduler.scale_model_input(sample, scheduler.timesteps[-1])
_ = new_scheduler.scale_model_input(sample, scheduler.timesteps[-1])
# Set the seed before step() as some schedulers are stochastic like EulerAncestralDiscreteScheduler, EulerDiscreteScheduler
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):