1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[core] FreeNoise (#8948)

* initial work draft for freenoise; needs massive cleanup

* fix freeinit bug

* add animatediff controlnet implementation

* revert attention changes

* add freenoise

* remove old helper functions

* add decode batch size param to all pipelines

* make style

* fix copied from comments

* make fix-copies

* make style

* copy animatediff controlnet implementation from #8972

* add experimental support for num_frames not perfectly fitting context length, ocntext stride

* make unet motion model lora work again based on #8995

* copy load video utils from #8972

* copied from AnimateDiff::prepare_latents

* address the case where last batch of frames does not match length of indices in prepare latents

* decode_batch_size->vae_batch_size; batch vae encode support in animatediff vid2vid

* revert sparsectrl and sdxl freenoise changes

* revert pia

* add freenoise tests

* make fix-copies

* improve docstrings

* add freenoise tests to animatediff controlnet

* update tests

* Update src/diffusers/models/unets/unet_motion_model.py

* add freenoise to animatediff pag

* address review comments

* make style

* update tests

* make fix-copies

* fix error message

* remove copied from comment

* fix imports in tests

* update

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
Aryan
2024-08-07 10:35:18 +05:30
committed by GitHub
parent 2d753b6fb5
commit 16a93f1a25
11 changed files with 911 additions and 50 deletions

View File

@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.nn.functional as F
@@ -272,6 +272,17 @@ class BasicTransformerBlock(nn.Module):
attention_out_bias: bool = True,
):
super().__init__()
self.dim = dim
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
self.dropout = dropout
self.cross_attention_dim = cross_attention_dim
self.activation_fn = activation_fn
self.attention_bias = attention_bias
self.double_self_attention = double_self_attention
self.norm_elementwise_affine = norm_elementwise_affine
self.positional_embeddings = positional_embeddings
self.num_positional_embeddings = num_positional_embeddings
self.only_cross_attention = only_cross_attention
# We keep these boolean flags for backward-compatibility.
@@ -782,6 +793,319 @@ class SkipFFTransformerBlock(nn.Module):
return hidden_states
@maybe_allow_in_graph
class FreeNoiseTransformerBlock(nn.Module):
r"""
A FreeNoise Transformer block.
Parameters:
dim (`int`):
The number of channels in the input and output.
num_attention_heads (`int`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`):
The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0):
The dropout probability to use.
cross_attention_dim (`int`, *optional*):
The size of the encoder_hidden_states vector for cross attention.
activation_fn (`str`, *optional*, defaults to `"geglu"`):
Activation function to be used in feed-forward.
num_embeds_ada_norm (`int`, *optional*):
The number of diffusion steps used during training. See `Transformer2DModel`.
attention_bias (`bool`, defaults to `False`):
Configure if the attentions should contain a bias parameter.
only_cross_attention (`bool`, defaults to `False`):
Whether to use only cross-attention layers. In this case two cross attention layers are used.
double_self_attention (`bool`, defaults to `False`):
Whether to use two self-attention layers. In this case no cross attention layers are used.
upcast_attention (`bool`, defaults to `False`):
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
norm_elementwise_affine (`bool`, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
norm_type (`str`, defaults to `"layer_norm"`):
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
final_dropout (`bool` defaults to `False`):
Whether to apply a final dropout after the last feed-forward layer.
attention_type (`str`, defaults to `"default"`):
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
positional_embeddings (`str`, *optional*):
The type of positional embeddings to apply to.
num_positional_embeddings (`int`, *optional*, defaults to `None`):
The maximum number of positional embeddings to apply.
ff_inner_dim (`int`, *optional*):
Hidden dimension of feed-forward MLP.
ff_bias (`bool`, defaults to `True`):
Whether or not to use bias in feed-forward MLP.
attention_out_bias (`bool`, defaults to `True`):
Whether or not to use bias in attention output project layer.
context_length (`int`, defaults to `16`):
The maximum number of frames that the FreeNoise block processes at once.
context_stride (`int`, defaults to `4`):
The number of frames to be skipped before starting to process a new batch of `context_length` frames.
weighting_scheme (`str`, defaults to `"pyramid"`):
The weighting scheme to use for weighting averaging of processed latent frames. As described in the
Equation 9. of the [FreeNoise](https://arxiv.org/abs/2310.15169) paper, "pyramid" is the default setting
used.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout: float = 0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_elementwise_affine: bool = True,
norm_type: str = "layer_norm",
norm_eps: float = 1e-5,
final_dropout: bool = False,
positional_embeddings: Optional[str] = None,
num_positional_embeddings: Optional[int] = None,
ff_inner_dim: Optional[int] = None,
ff_bias: bool = True,
attention_out_bias: bool = True,
context_length: int = 16,
context_stride: int = 4,
weighting_scheme: str = "pyramid",
):
super().__init__()
self.dim = dim
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
self.dropout = dropout
self.cross_attention_dim = cross_attention_dim
self.activation_fn = activation_fn
self.attention_bias = attention_bias
self.double_self_attention = double_self_attention
self.norm_elementwise_affine = norm_elementwise_affine
self.positional_embeddings = positional_embeddings
self.num_positional_embeddings = num_positional_embeddings
self.only_cross_attention = only_cross_attention
self.set_free_noise_properties(context_length, context_stride, weighting_scheme)
# We keep these boolean flags for backward-compatibility.
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
self.use_layer_norm = norm_type == "layer_norm"
self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
raise ValueError(
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
)
self.norm_type = norm_type
self.num_embeds_ada_norm = num_embeds_ada_norm
if positional_embeddings and (num_positional_embeddings is None):
raise ValueError(
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
)
if positional_embeddings == "sinusoidal":
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
else:
self.pos_embed = None
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention,
out_bias=attention_out_bias,
)
# 2. Cross-Attn
if cross_attention_dim is not None or double_self_attention:
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
out_bias=attention_out_bias,
) # is self-attn if encoder_hidden_states is none
# 3. Feed-forward
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
inner_dim=ff_inner_dim,
bias=ff_bias,
)
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
def _get_frame_indices(self, num_frames: int) -> List[Tuple[int, int]]:
frame_indices = []
for i in range(0, num_frames - self.context_length + 1, self.context_stride):
window_start = i
window_end = min(num_frames, i + self.context_length)
frame_indices.append((window_start, window_end))
return frame_indices
def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]:
if weighting_scheme == "pyramid":
if num_frames % 2 == 0:
# num_frames = 4 => [1, 2, 2, 1]
weights = list(range(1, num_frames // 2 + 1))
weights = weights + weights[::-1]
else:
# num_frames = 5 => [1, 2, 3, 2, 1]
weights = list(range(1, num_frames // 2 + 1))
weights = weights + [num_frames // 2 + 1] + weights[::-1]
else:
raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}")
return weights
def set_free_noise_properties(
self, context_length: int, context_stride: int, weighting_scheme: str = "pyramid"
) -> None:
self.context_length = context_length
self.context_stride = context_stride
self.weighting_scheme = weighting_scheme
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0) -> None:
# Sets chunk feed-forward
self._chunk_size = chunk_size
self._chunk_dim = dim
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
*args,
**kwargs,
) -> torch.Tensor:
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
# hidden_states: [B x H x W, F, C]
device = hidden_states.device
dtype = hidden_states.dtype
num_frames = hidden_states.size(1)
frame_indices = self._get_frame_indices(num_frames)
frame_weights = self._get_frame_weights(self.context_length, self.weighting_scheme)
frame_weights = torch.tensor(frame_weights, device=device, dtype=dtype).unsqueeze(0).unsqueeze(-1)
is_last_frame_batch_complete = frame_indices[-1][1] == num_frames
# Handle out-of-bounds case if num_frames isn't perfectly divisible by context_length
# For example, num_frames=25, context_length=16, context_stride=4, then we expect the ranges:
# [(0, 16), (4, 20), (8, 24), (10, 26)]
if not is_last_frame_batch_complete:
if num_frames < self.context_length:
raise ValueError(f"Expected {num_frames=} to be greater or equal than {self.context_length=}")
last_frame_batch_length = num_frames - frame_indices[-1][1]
frame_indices.append((num_frames - self.context_length, num_frames))
num_times_accumulated = torch.zeros((1, num_frames, 1), device=device)
accumulated_values = torch.zeros_like(hidden_states)
for i, (frame_start, frame_end) in enumerate(frame_indices):
# The reason for slicing here is to ensure that if (frame_end - frame_start) is to handle
# cases like frame_indices=[(0, 16), (16, 20)], if the user provided a video with 19 frames, or
# essentially a non-multiple of `context_length`.
weights = torch.ones_like(num_times_accumulated[:, frame_start:frame_end])
weights *= frame_weights
hidden_states_chunk = hidden_states[:, frame_start:frame_end]
# Notice that normalization is always applied before the real computation in the following blocks.
# 1. Self-Attention
norm_hidden_states = self.norm1(hidden_states_chunk)
if self.pos_embed is not None:
norm_hidden_states = self.pos_embed(norm_hidden_states)
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
hidden_states_chunk = attn_output + hidden_states_chunk
if hidden_states_chunk.ndim == 4:
hidden_states_chunk = hidden_states_chunk.squeeze(1)
# 2. Cross-Attention
if self.attn2 is not None:
norm_hidden_states = self.norm2(hidden_states_chunk)
if self.pos_embed is not None and self.norm_type != "ada_norm_single":
norm_hidden_states = self.pos_embed(norm_hidden_states)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
hidden_states_chunk = attn_output + hidden_states_chunk
if i == len(frame_indices) - 1 and not is_last_frame_batch_complete:
accumulated_values[:, -last_frame_batch_length:] += (
hidden_states_chunk[:, -last_frame_batch_length:] * weights[:, -last_frame_batch_length:]
)
num_times_accumulated[:, -last_frame_batch_length:] += weights[:, -last_frame_batch_length]
else:
accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights
num_times_accumulated[:, frame_start:frame_end] += weights
hidden_states = torch.where(
num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
).to(dtype)
# 3. Feed-forward
norm_hidden_states = self.norm3(hidden_states)
if self._chunk_size is not None:
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
else:
ff_output = self.ff(norm_hidden_states)
hidden_states = ff_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
return hidden_states
class FeedForward(nn.Module):
r"""
A feed-forward layer.

View File

@@ -343,6 +343,7 @@ class DownBlockMotion(nn.Module):
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = motion_module(hidden_states, num_frames=num_frames)
output_states = output_states + (hidden_states,)
@@ -536,6 +537,7 @@ class CrossAttnDownBlockMotion(nn.Module):
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -761,6 +763,7 @@ class CrossAttnUpBlockMotion(nn.Module):
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -921,9 +924,9 @@ class UpBlockMotion(nn.Module):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = motion_module(hidden_states, num_frames=num_frames)
if self.upsamplers is not None:
@@ -1923,7 +1926,6 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
"""
Sets the attention processor to use [feed forward
@@ -1953,7 +1955,6 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
for module in self.children():
fn_recursive_feed_forward(module, chunk_size, dim)
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
def disable_forward_chunking(self) -> None:
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
if hasattr(module, "set_chunk_feed_forward"):

View File

@@ -42,6 +42,7 @@ from ...utils import (
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..free_init_utils import FreeInitMixin
from ..free_noise_utils import AnimateDiffFreeNoiseMixin
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import AnimateDiffPipelineOutput
@@ -72,6 +73,7 @@ class AnimateDiffPipeline(
IPAdapterMixin,
StableDiffusionLoraLoaderMixin,
FreeInitMixin,
AnimateDiffFreeNoiseMixin,
):
r"""
Pipeline for text-to-video generation.
@@ -394,15 +396,20 @@ class AnimateDiffPipeline(
return ip_adapter_image_embeds
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
def decode_latents(self, latents):
def decode_latents(self, latents, vae_batch_size: int = 16):
latents = 1 / self.vae.config.scaling_factor * latents
batch_size, channels, num_frames, height, width = latents.shape
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
image = self.vae.decode(latents).sample
video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4)
video = []
for i in range(0, latents.shape[0], vae_batch_size):
batch_latents = latents[i : i + vae_batch_size]
batch_latents = self.vae.decode(batch_latents).sample
video.append(batch_latents)
video = torch.cat(video)
video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
video = video.float()
return video
@@ -495,10 +502,21 @@ class AnimateDiffPipeline(
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
)
# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
def prepare_latents(
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
):
# If FreeNoise is enabled, generate latents as described in Equation (7) of [FreeNoise](https://arxiv.org/abs/2310.15169)
if self.free_noise_enabled:
latents = self._prepare_latents_free_noise(
batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
shape = (
batch_size,
num_channels_latents,
@@ -506,11 +524,6 @@ class AnimateDiffPipeline(
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
@@ -569,6 +582,7 @@ class AnimateDiffPipeline(
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
vae_batch_size: int = 16,
**kwargs,
):
r"""
@@ -637,6 +651,8 @@ class AnimateDiffPipeline(
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
vae_batch_size (`int`, defaults to `16`):
The number of frames to decode at a time when calling `decode_latents` method.
Examples:
@@ -808,7 +824,7 @@ class AnimateDiffPipeline(
if output_type == "latent":
video = latents
else:
video_tensor = self.decode_latents(latents)
video_tensor = self.decode_latents(latents, vae_batch_size)
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
# 10. Offload all models

View File

@@ -30,6 +30,7 @@ from ...utils.torch_utils import is_compiled_module, randn_tensor
from ...video_processor import VideoProcessor
from ..controlnet.multicontrolnet import MultiControlNetModel
from ..free_init_utils import FreeInitMixin
from ..free_noise_utils import AnimateDiffFreeNoiseMixin
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import AnimateDiffPipelineOutput
@@ -109,6 +110,7 @@ class AnimateDiffControlNetPipeline(
IPAdapterMixin,
StableDiffusionLoraLoaderMixin,
FreeInitMixin,
AnimateDiffFreeNoiseMixin,
):
r"""
Pipeline for text-to-video generation with ControlNet guidance.
@@ -432,15 +434,16 @@ class AnimateDiffControlNetPipeline(
return ip_adapter_image_embeds
def decode_latents(self, latents, decode_batch_size: int = 16):
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents
def decode_latents(self, latents, vae_batch_size: int = 16):
latents = 1 / self.vae.config.scaling_factor * latents
batch_size, channels, num_frames, height, width = latents.shape
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
video = []
for i in range(0, latents.shape[0], decode_batch_size):
batch_latents = latents[i : i + decode_batch_size]
for i in range(0, latents.shape[0], vae_batch_size):
batch_latents = latents[i : i + vae_batch_size]
batch_latents = self.vae.decode(batch_latents).sample
video.append(batch_latents)
@@ -608,10 +611,22 @@ class AnimateDiffControlNetPipeline(
if end > 1.0:
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.prepare_latents
def prepare_latents(
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
):
# If FreeNoise is enabled, generate latents as described in Equation (7) of [FreeNoise](https://arxiv.org/abs/2310.15169)
if self.free_noise_enabled:
latents = self._prepare_latents_free_noise(
batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
shape = (
batch_size,
num_channels_latents,
@@ -619,11 +634,6 @@ class AnimateDiffControlNetPipeline(
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
@@ -718,7 +728,7 @@ class AnimateDiffControlNetPipeline(
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
decode_batch_size: int = 16,
vae_batch_size: int = 16,
):
r"""
The call function to the pipeline for generation.
@@ -1054,7 +1064,7 @@ class AnimateDiffControlNetPipeline(
if output_type == "latent":
video = latents
else:
video_tensor = self.decode_latents(latents, decode_batch_size)
video_tensor = self.decode_latents(latents, vae_batch_size)
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
# 10. Offload all models

View File

@@ -35,6 +35,7 @@ from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..free_init_utils import FreeInitMixin
from ..free_noise_utils import AnimateDiffFreeNoiseMixin
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import AnimateDiffPipelineOutput
@@ -176,6 +177,7 @@ class AnimateDiffVideoToVideoPipeline(
IPAdapterMixin,
StableDiffusionLoraLoaderMixin,
FreeInitMixin,
AnimateDiffFreeNoiseMixin,
):
r"""
Pipeline for video-to-video generation.
@@ -498,15 +500,29 @@ class AnimateDiffVideoToVideoPipeline(
return ip_adapter_image_embeds
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
def decode_latents(self, latents):
def encode_video(self, video, generator, vae_batch_size: int = 16) -> torch.Tensor:
latents = []
for i in range(0, len(video), vae_batch_size):
batch_video = video[i : i + vae_batch_size]
batch_video = retrieve_latents(self.vae.encode(batch_video), generator=generator)
latents.append(batch_video)
return torch.cat(latents)
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents
def decode_latents(self, latents, vae_batch_size: int = 16):
latents = 1 / self.vae.config.scaling_factor * latents
batch_size, channels, num_frames, height, width = latents.shape
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
image = self.vae.decode(latents).sample
video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4)
video = []
for i in range(0, latents.shape[0], vae_batch_size):
batch_latents = latents[i : i + vae_batch_size]
batch_latents = self.vae.decode(batch_latents).sample
video.append(batch_latents)
video = torch.cat(video)
video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
video = video.float()
return video
@@ -622,6 +638,7 @@ class AnimateDiffVideoToVideoPipeline(
device,
generator,
latents=None,
vae_batch_size: int = 16,
):
if latents is None:
num_frames = video.shape[1]
@@ -656,13 +673,10 @@ class AnimateDiffVideoToVideoPipeline(
)
init_latents = [
retrieve_latents(self.vae.encode(video[i]), generator=generator[i]).unsqueeze(0)
for i in range(batch_size)
self.encode_video(video[i], generator[i], vae_batch_size).unsqueeze(0) for i in range(batch_size)
]
else:
init_latents = [
retrieve_latents(self.vae.encode(vid), generator=generator).unsqueeze(0) for vid in video
]
init_latents = [self.encode_video(vid, generator, vae_batch_size).unsqueeze(0) for vid in video]
init_latents = torch.cat(init_latents, dim=0)
@@ -747,6 +761,7 @@ class AnimateDiffVideoToVideoPipeline(
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
vae_batch_size: int = 16,
):
r"""
The call function to the pipeline for generation.
@@ -822,6 +837,8 @@ class AnimateDiffVideoToVideoPipeline(
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
vae_batch_size (`int`, defaults to `16`):
The number of frames to decode at a time when calling `decode_latents` method.
Examples:
@@ -923,6 +940,7 @@ class AnimateDiffVideoToVideoPipeline(
device=device,
generator=generator,
latents=latents,
vae_batch_size=vae_batch_size,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
@@ -990,7 +1008,7 @@ class AnimateDiffVideoToVideoPipeline(
if output_type == "latent":
video = latents
else:
video_tensor = self.decode_latents(latents)
video_tensor = self.decode_latents(latents, vae_batch_size)
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
# 10. Offload all models

View File

@@ -0,0 +1,236 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union
import torch
from ..models.attention import BasicTransformerBlock, FreeNoiseTransformerBlock
from ..models.unets.unet_motion_model import (
CrossAttnDownBlockMotion,
DownBlockMotion,
UpBlockMotion,
)
from ..utils import logging
from ..utils.torch_utils import randn_tensor
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class AnimateDiffFreeNoiseMixin:
r"""Mixin class for [FreeNoise](https://arxiv.org/abs/2310.15169)."""
def _enable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion]):
r"""Helper function to enable FreeNoise in transformer blocks."""
for motion_module in block.motion_modules:
num_transformer_blocks = len(motion_module.transformer_blocks)
for i in range(num_transformer_blocks):
if isinstance(motion_module.transformer_blocks[i], FreeNoiseTransformerBlock):
motion_module.transformer_blocks[i].set_free_noise_properties(
self._free_noise_context_length,
self._free_noise_context_stride,
self._free_noise_weighting_scheme,
)
else:
assert isinstance(motion_module.transformer_blocks[i], BasicTransformerBlock)
basic_transfomer_block = motion_module.transformer_blocks[i]
motion_module.transformer_blocks[i] = FreeNoiseTransformerBlock(
dim=basic_transfomer_block.dim,
num_attention_heads=basic_transfomer_block.num_attention_heads,
attention_head_dim=basic_transfomer_block.attention_head_dim,
dropout=basic_transfomer_block.dropout,
cross_attention_dim=basic_transfomer_block.cross_attention_dim,
activation_fn=basic_transfomer_block.activation_fn,
attention_bias=basic_transfomer_block.attention_bias,
only_cross_attention=basic_transfomer_block.only_cross_attention,
double_self_attention=basic_transfomer_block.double_self_attention,
positional_embeddings=basic_transfomer_block.positional_embeddings,
num_positional_embeddings=basic_transfomer_block.num_positional_embeddings,
context_length=self._free_noise_context_length,
context_stride=self._free_noise_context_stride,
weighting_scheme=self._free_noise_weighting_scheme,
).to(device=self.device, dtype=self.dtype)
motion_module.transformer_blocks[i].load_state_dict(
basic_transfomer_block.state_dict(), strict=True
)
def _disable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion]):
r"""Helper function to disable FreeNoise in transformer blocks."""
for motion_module in block.motion_modules:
num_transformer_blocks = len(motion_module.transformer_blocks)
for i in range(num_transformer_blocks):
if isinstance(motion_module.transformer_blocks[i], FreeNoiseTransformerBlock):
free_noise_transfomer_block = motion_module.transformer_blocks[i]
motion_module.transformer_blocks[i] = BasicTransformerBlock(
dim=free_noise_transfomer_block.dim,
num_attention_heads=free_noise_transfomer_block.num_attention_heads,
attention_head_dim=free_noise_transfomer_block.attention_head_dim,
dropout=free_noise_transfomer_block.dropout,
cross_attention_dim=free_noise_transfomer_block.cross_attention_dim,
activation_fn=free_noise_transfomer_block.activation_fn,
attention_bias=free_noise_transfomer_block.attention_bias,
only_cross_attention=free_noise_transfomer_block.only_cross_attention,
double_self_attention=free_noise_transfomer_block.double_self_attention,
positional_embeddings=free_noise_transfomer_block.positional_embeddings,
num_positional_embeddings=free_noise_transfomer_block.num_positional_embeddings,
).to(device=self.device, dtype=self.dtype)
motion_module.transformer_blocks[i].load_state_dict(
free_noise_transfomer_block.state_dict(), strict=True
)
def _prepare_latents_free_noise(
self,
batch_size: int,
num_channels_latents: int,
num_frames: int,
height: int,
width: int,
dtype: torch.dtype,
device: torch.device,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.Tensor] = None,
):
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
context_num_frames = (
self._free_noise_context_length if self._free_noise_context_length == "repeat_context" else num_frames
)
shape = (
batch_size,
num_channels_latents,
context_num_frames,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
if self._free_noise_noise_type == "random":
return latents
else:
if latents.size(2) == num_frames:
return latents
elif latents.size(2) != self._free_noise_context_length:
raise ValueError(
f"You have passed `latents` as a parameter to FreeNoise. The expected number of frames is either {num_frames} or {self._free_noise_context_length}, but found {latents.size(2)}"
)
latents = latents.to(device)
if self._free_noise_noise_type == "shuffle_context":
for i in range(self._free_noise_context_length, num_frames, self._free_noise_context_stride):
# ensure window is within bounds
window_start = max(0, i - self._free_noise_context_length)
window_end = min(num_frames, window_start + self._free_noise_context_stride)
window_length = window_end - window_start
if window_length == 0:
break
indices = torch.LongTensor(list(range(window_start, window_end)))
shuffled_indices = indices[torch.randperm(window_length, generator=generator)]
current_start = i
current_end = min(num_frames, current_start + window_length)
if current_end == current_start + window_length:
# batch of frames perfectly fits the window
latents[:, :, current_start:current_end] = latents[:, :, shuffled_indices]
else:
# handle the case where the last batch of frames does not fit perfectly with the window
prefix_length = current_end - current_start
shuffled_indices = shuffled_indices[:prefix_length]
latents[:, :, current_start:current_end] = latents[:, :, shuffled_indices]
elif self._free_noise_noise_type == "repeat_context":
num_repeats = (num_frames + self._free_noise_context_length - 1) // self._free_noise_context_length
latents = torch.cat([latents] * num_repeats, dim=2)
latents = latents[:, :, :num_frames]
return latents
def enable_free_noise(
self,
context_length: Optional[int] = 16,
context_stride: int = 4,
weighting_scheme: str = "pyramid",
noise_type: str = "shuffle_context",
) -> None:
r"""
Enable long video generation using FreeNoise.
Args:
context_length (`int`, defaults to `16`, *optional*):
The number of video frames to process at once. It's recommended to set this to the maximum frames the
Motion Adapter was trained with (usually 16/24/32). If `None`, the default value from the motion
adapter config is used.
context_stride (`int`, *optional*):
Long videos are generated by processing many frames. FreeNoise processes these frames in sliding
windows of size `context_length`. Context stride allows you to specify how many frames to skip between
each window. For example, a context length of 16 and context stride of 4 would process 24 frames as:
[0, 15], [4, 19], [8, 23] (0-based indexing)
weighting_scheme (`str`, defaults to `pyramid`):
Weighting scheme for averaging latents after accumulation in FreeNoise blocks. The following weighting
schemes are supported currently:
- "pyramid"
Peforms weighted averaging with a pyramid like weight pattern: [1, 2, 3, 2, 1].
noise_type (`str`, defaults to "shuffle_context"):
TODO
"""
allowed_weighting_scheme = ["pyramid"]
allowed_noise_type = ["shuffle_context", "repeat_context", "random"]
if context_length > self.motion_adapter.config.motion_max_seq_length:
logger.warning(
f"You have set {context_length=} which is greater than {self.motion_adapter.config.motion_max_seq_length=}. This can lead to bad generation results."
)
if weighting_scheme not in allowed_weighting_scheme:
raise ValueError(
f"The parameter `weighting_scheme` must be one of {allowed_weighting_scheme}, but got {weighting_scheme=}"
)
if noise_type not in allowed_noise_type:
raise ValueError(f"The parameter `noise_type` must be one of {allowed_noise_type}, but got {noise_type=}")
self._free_noise_context_length = context_length or self.motion_adapter.config.motion_max_seq_length
self._free_noise_context_stride = context_stride
self._free_noise_weighting_scheme = weighting_scheme
self._free_noise_noise_type = noise_type
blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
for block in blocks:
self._enable_free_noise_in_block(block)
def disable_free_noise(self) -> None:
self._free_noise_context_length = None
blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
for block in blocks:
self._disable_free_noise_in_block(block)
@property
def free_noise_enabled(self):
return hasattr(self, "_free_noise_context_length") and self._free_noise_context_length is not None

View File

@@ -35,6 +35,7 @@ from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..animatediff.pipeline_output import AnimateDiffPipelineOutput
from ..free_init_utils import FreeInitMixin
from ..free_noise_utils import AnimateDiffFreeNoiseMixin
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pag_utils import PAGMixin
@@ -83,6 +84,7 @@ class AnimateDiffPAGPipeline(
IPAdapterMixin,
StableDiffusionLoraLoaderMixin,
FreeInitMixin,
AnimateDiffFreeNoiseMixin,
PAGMixin,
):
r"""
@@ -404,15 +406,21 @@ class AnimateDiffPAGPipeline(
return ip_adapter_image_embeds
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
def decode_latents(self, latents):
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents
def decode_latents(self, latents, vae_batch_size: int = 16):
latents = 1 / self.vae.config.scaling_factor * latents
batch_size, channels, num_frames, height, width = latents.shape
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
image = self.vae.decode(latents).sample
video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4)
video = []
for i in range(0, latents.shape[0], vae_batch_size):
batch_latents = latents[i : i + vae_batch_size]
batch_latents = self.vae.decode(batch_latents).sample
video.append(batch_latents)
video = torch.cat(video)
video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
video = video.float()
return video
@@ -499,10 +507,22 @@ class AnimateDiffPAGPipeline(
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
)
# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.prepare_latents
def prepare_latents(
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
):
# If FreeNoise is enabled, generate latents as described in Equation (7) of [FreeNoise](https://arxiv.org/abs/2310.15169)
if self.free_noise_enabled:
latents = self._prepare_latents_free_noise(
batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
shape = (
batch_size,
num_channels_latents,
@@ -510,11 +530,6 @@ class AnimateDiffPAGPipeline(
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
@@ -573,6 +588,7 @@ class AnimateDiffPAGPipeline(
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
vae_batch_size: int = 16,
pag_scale: float = 3.0,
pag_adaptive_scale: float = 0.0,
):
@@ -831,7 +847,7 @@ class AnimateDiffPAGPipeline(
if output_type == "latent":
video = latents
else:
video_tensor = self.decode_latents(latents)
video_tensor = self.decode_latents(latents, vae_batch_size)
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
# 10. Offload all models

View File

@@ -17,6 +17,7 @@ from diffusers import (
UNet2DConditionModel,
UNetMotionModel,
)
from diffusers.models.attention import FreeNoiseTransformerBlock
from diffusers.utils import is_xformers_available, logging
from diffusers.utils.testing_utils import numpy_cosine_similarity_distance, require_torch_gpu, slow, torch_device
@@ -401,6 +402,64 @@ class AnimateDiffPipelineFastTests(
"Enabling of FreeInit should lead to results different from the default pipeline results",
)
def test_free_noise_blocks(self):
components = self.get_dummy_components()
pipe: AnimateDiffPipeline = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
pipe.to(torch_device)
pipe.enable_free_noise()
for block in pipe.unet.down_blocks:
for motion_module in block.motion_modules:
for transformer_block in motion_module.transformer_blocks:
self.assertTrue(
isinstance(transformer_block, FreeNoiseTransformerBlock),
"Motion module transformer blocks must be an instance of `FreeNoiseTransformerBlock` after enabling FreeNoise.",
)
pipe.disable_free_noise()
for block in pipe.unet.down_blocks:
for motion_module in block.motion_modules:
for transformer_block in motion_module.transformer_blocks:
self.assertFalse(
isinstance(transformer_block, FreeNoiseTransformerBlock),
"Motion module transformer blocks must not be an instance of `FreeNoiseTransformerBlock` after disabling FreeNoise.",
)
def test_free_noise(self):
components = self.get_dummy_components()
pipe: AnimateDiffPipeline = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
pipe.to(torch_device)
inputs_normal = self.get_dummy_inputs(torch_device)
frames_normal = pipe(**inputs_normal).frames[0]
for context_length in [8, 9]:
for context_stride in [4, 6]:
pipe.enable_free_noise(context_length, context_stride)
inputs_enable_free_noise = self.get_dummy_inputs(torch_device)
frames_enable_free_noise = pipe(**inputs_enable_free_noise).frames[0]
pipe.disable_free_noise()
inputs_disable_free_noise = self.get_dummy_inputs(torch_device)
frames_disable_free_noise = pipe(**inputs_disable_free_noise).frames[0]
sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_noise)).sum()
max_diff_disabled = np.abs(to_np(frames_normal) - to_np(frames_disable_free_noise)).max()
self.assertGreater(
sum_enabled,
1e1,
"Enabling of FreeNoise should lead to results different from the default pipeline results",
)
self.assertLess(
max_diff_disabled,
1e-4,
"Disabling of FreeNoise should lead to results similar to the default pipeline results",
)
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",

View File

@@ -18,6 +18,7 @@ from diffusers import (
UNet2DConditionModel,
UNetMotionModel,
)
from diffusers.models.attention import FreeNoiseTransformerBlock
from diffusers.utils import logging
from diffusers.utils.testing_utils import torch_device
@@ -409,6 +410,64 @@ class AnimateDiffControlNetPipelineFastTests(
"Enabling of FreeInit should lead to results different from the default pipeline results",
)
def test_free_noise_blocks(self):
components = self.get_dummy_components()
pipe: AnimateDiffControlNetPipeline = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
pipe.to(torch_device)
pipe.enable_free_noise()
for block in pipe.unet.down_blocks:
for motion_module in block.motion_modules:
for transformer_block in motion_module.transformer_blocks:
self.assertTrue(
isinstance(transformer_block, FreeNoiseTransformerBlock),
"Motion module transformer blocks must be an instance of `FreeNoiseTransformerBlock` after enabling FreeNoise.",
)
pipe.disable_free_noise()
for block in pipe.unet.down_blocks:
for motion_module in block.motion_modules:
for transformer_block in motion_module.transformer_blocks:
self.assertFalse(
isinstance(transformer_block, FreeNoiseTransformerBlock),
"Motion module transformer blocks must not be an instance of `FreeNoiseTransformerBlock` after disabling FreeNoise.",
)
def test_free_noise(self):
components = self.get_dummy_components()
pipe: AnimateDiffControlNetPipeline = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
pipe.to(torch_device)
inputs_normal = self.get_dummy_inputs(torch_device, num_frames=16)
frames_normal = pipe(**inputs_normal).frames[0]
for context_length in [8, 9]:
for context_stride in [4, 6]:
pipe.enable_free_noise(context_length, context_stride)
inputs_enable_free_noise = self.get_dummy_inputs(torch_device, num_frames=16)
frames_enable_free_noise = pipe(**inputs_enable_free_noise).frames[0]
pipe.disable_free_noise()
inputs_disable_free_noise = self.get_dummy_inputs(torch_device, num_frames=16)
frames_disable_free_noise = pipe(**inputs_disable_free_noise).frames[0]
sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_noise)).sum()
max_diff_disabled = np.abs(to_np(frames_normal) - to_np(frames_disable_free_noise)).max()
self.assertGreater(
sum_enabled,
1e1,
"Enabling of FreeNoise should lead to results different from the default pipeline results",
)
self.assertLess(
max_diff_disabled,
1e-4,
"Disabling of FreeNoise should lead to results similar to the default pipeline results",
)
def test_vae_slicing(self, video_count=2):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()

View File

@@ -17,6 +17,7 @@ from diffusers import (
UNet2DConditionModel,
UNetMotionModel,
)
from diffusers.models.attention import FreeNoiseTransformerBlock
from diffusers.utils import is_xformers_available, logging
from diffusers.utils.testing_utils import torch_device
@@ -114,7 +115,7 @@ class AnimateDiffVideoToVideoPipelineFastTests(
}
return components
def get_dummy_inputs(self, device, seed=0):
def get_dummy_inputs(self, device, seed=0, num_frames: int = 2):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
@@ -122,8 +123,7 @@ class AnimateDiffVideoToVideoPipelineFastTests(
video_height = 32
video_width = 32
video_num_frames = 2
video = [Image.new("RGB", (video_width, video_height))] * video_num_frames
video = [Image.new("RGB", (video_width, video_height))] * num_frames
inputs = {
"video": video,
@@ -428,3 +428,66 @@ class AnimateDiffVideoToVideoPipelineFastTests(
1e1,
"Enabling of FreeInit should lead to results different from the default pipeline results",
)
def test_free_noise_blocks(self):
components = self.get_dummy_components()
pipe: AnimateDiffVideoToVideoPipeline = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
pipe.to(torch_device)
pipe.enable_free_noise()
for block in pipe.unet.down_blocks:
for motion_module in block.motion_modules:
for transformer_block in motion_module.transformer_blocks:
self.assertTrue(
isinstance(transformer_block, FreeNoiseTransformerBlock),
"Motion module transformer blocks must be an instance of `FreeNoiseTransformerBlock` after enabling FreeNoise.",
)
pipe.disable_free_noise()
for block in pipe.unet.down_blocks:
for motion_module in block.motion_modules:
for transformer_block in motion_module.transformer_blocks:
self.assertFalse(
isinstance(transformer_block, FreeNoiseTransformerBlock),
"Motion module transformer blocks must not be an instance of `FreeNoiseTransformerBlock` after disabling FreeNoise.",
)
def test_free_noise(self):
components = self.get_dummy_components()
pipe: AnimateDiffVideoToVideoPipeline = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
pipe.to(torch_device)
inputs_normal = self.get_dummy_inputs(torch_device, num_frames=16)
inputs_normal["num_inference_steps"] = 2
inputs_normal["strength"] = 0.5
frames_normal = pipe(**inputs_normal).frames[0]
for context_length in [8, 9]:
for context_stride in [4, 6]:
pipe.enable_free_noise(context_length, context_stride)
inputs_enable_free_noise = self.get_dummy_inputs(torch_device, num_frames=16)
inputs_enable_free_noise["num_inference_steps"] = 2
inputs_enable_free_noise["strength"] = 0.5
frames_enable_free_noise = pipe(**inputs_enable_free_noise).frames[0]
pipe.disable_free_noise()
inputs_disable_free_noise = self.get_dummy_inputs(torch_device, num_frames=16)
inputs_disable_free_noise["num_inference_steps"] = 2
inputs_disable_free_noise["strength"] = 0.5
frames_disable_free_noise = pipe(**inputs_disable_free_noise).frames[0]
sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_noise)).sum()
max_diff_disabled = np.abs(to_np(frames_normal) - to_np(frames_disable_free_noise)).max()
self.assertGreater(
sum_enabled,
1e1,
"Enabling of FreeNoise should lead to results different from the default pipeline results",
)
self.assertLess(
max_diff_disabled,
1e-4,
"Disabling of FreeNoise should lead to results similar to the default pipeline results",
)

View File

@@ -17,6 +17,7 @@ from diffusers import (
UNet2DConditionModel,
UNetMotionModel,
)
from diffusers.models.attention import FreeNoiseTransformerBlock
from diffusers.utils import is_xformers_available
from diffusers.utils.testing_utils import torch_device
@@ -347,6 +348,64 @@ class AnimateDiffPAGPipelineFastTests(
"Enabling of FreeInit should lead to results different from the default pipeline results",
)
def test_free_noise_blocks(self):
components = self.get_dummy_components()
pipe: AnimateDiffPAGPipeline = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
pipe.to(torch_device)
pipe.enable_free_noise()
for block in pipe.unet.down_blocks:
for motion_module in block.motion_modules:
for transformer_block in motion_module.transformer_blocks:
self.assertTrue(
isinstance(transformer_block, FreeNoiseTransformerBlock),
"Motion module transformer blocks must be an instance of `FreeNoiseTransformerBlock` after enabling FreeNoise.",
)
pipe.disable_free_noise()
for block in pipe.unet.down_blocks:
for motion_module in block.motion_modules:
for transformer_block in motion_module.transformer_blocks:
self.assertFalse(
isinstance(transformer_block, FreeNoiseTransformerBlock),
"Motion module transformer blocks must not be an instance of `FreeNoiseTransformerBlock` after disabling FreeNoise.",
)
def test_free_noise(self):
components = self.get_dummy_components()
pipe: AnimateDiffPAGPipeline = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
pipe.to(torch_device)
inputs_normal = self.get_dummy_inputs(torch_device)
frames_normal = pipe(**inputs_normal).frames[0]
for context_length in [8, 9]:
for context_stride in [4, 6]:
pipe.enable_free_noise(context_length, context_stride)
inputs_enable_free_noise = self.get_dummy_inputs(torch_device)
frames_enable_free_noise = pipe(**inputs_enable_free_noise).frames[0]
pipe.disable_free_noise()
inputs_disable_free_noise = self.get_dummy_inputs(torch_device)
frames_disable_free_noise = pipe(**inputs_disable_free_noise).frames[0]
sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_noise)).sum()
max_diff_disabled = np.abs(to_np(frames_normal) - to_np(frames_disable_free_noise)).max()
self.assertGreater(
sum_enabled,
1e1,
"Enabling of FreeNoise should lead to results different from the default pipeline results",
)
self.assertLess(
max_diff_disabled,
1e-4,
"Disabling of FreeNoise should lead to results similar to the default pipeline results",
)
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",