mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[core] FreeNoise (#8948)
* initial work draft for freenoise; needs massive cleanup * fix freeinit bug * add animatediff controlnet implementation * revert attention changes * add freenoise * remove old helper functions * add decode batch size param to all pipelines * make style * fix copied from comments * make fix-copies * make style * copy animatediff controlnet implementation from #8972 * add experimental support for num_frames not perfectly fitting context length, ocntext stride * make unet motion model lora work again based on #8995 * copy load video utils from #8972 * copied from AnimateDiff::prepare_latents * address the case where last batch of frames does not match length of indices in prepare latents * decode_batch_size->vae_batch_size; batch vae encode support in animatediff vid2vid * revert sparsectrl and sdxl freenoise changes * revert pia * add freenoise tests * make fix-copies * improve docstrings * add freenoise tests to animatediff controlnet * update tests * Update src/diffusers/models/unets/unet_motion_model.py * add freenoise to animatediff pag * address review comments * make style * update tests * make fix-copies * fix error message * remove copied from comment * fix imports in tests * update --------- Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
@@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -272,6 +272,17 @@ class BasicTransformerBlock(nn.Module):
|
||||
attention_out_bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.dropout = dropout
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.activation_fn = activation_fn
|
||||
self.attention_bias = attention_bias
|
||||
self.double_self_attention = double_self_attention
|
||||
self.norm_elementwise_affine = norm_elementwise_affine
|
||||
self.positional_embeddings = positional_embeddings
|
||||
self.num_positional_embeddings = num_positional_embeddings
|
||||
self.only_cross_attention = only_cross_attention
|
||||
|
||||
# We keep these boolean flags for backward-compatibility.
|
||||
@@ -782,6 +793,319 @@ class SkipFFTransformerBlock(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class FreeNoiseTransformerBlock(nn.Module):
|
||||
r"""
|
||||
A FreeNoise Transformer block.
|
||||
|
||||
Parameters:
|
||||
dim (`int`):
|
||||
The number of channels in the input and output.
|
||||
num_attention_heads (`int`):
|
||||
The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`):
|
||||
The number of channels in each head.
|
||||
dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*):
|
||||
The size of the encoder_hidden_states vector for cross attention.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`):
|
||||
Activation function to be used in feed-forward.
|
||||
num_embeds_ada_norm (`int`, *optional*):
|
||||
The number of diffusion steps used during training. See `Transformer2DModel`.
|
||||
attention_bias (`bool`, defaults to `False`):
|
||||
Configure if the attentions should contain a bias parameter.
|
||||
only_cross_attention (`bool`, defaults to `False`):
|
||||
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
||||
double_self_attention (`bool`, defaults to `False`):
|
||||
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
||||
upcast_attention (`bool`, defaults to `False`):
|
||||
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
|
||||
norm_elementwise_affine (`bool`, defaults to `True`):
|
||||
Whether to use learnable elementwise affine parameters for normalization.
|
||||
norm_type (`str`, defaults to `"layer_norm"`):
|
||||
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
|
||||
final_dropout (`bool` defaults to `False`):
|
||||
Whether to apply a final dropout after the last feed-forward layer.
|
||||
attention_type (`str`, defaults to `"default"`):
|
||||
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
|
||||
positional_embeddings (`str`, *optional*):
|
||||
The type of positional embeddings to apply to.
|
||||
num_positional_embeddings (`int`, *optional*, defaults to `None`):
|
||||
The maximum number of positional embeddings to apply.
|
||||
ff_inner_dim (`int`, *optional*):
|
||||
Hidden dimension of feed-forward MLP.
|
||||
ff_bias (`bool`, defaults to `True`):
|
||||
Whether or not to use bias in feed-forward MLP.
|
||||
attention_out_bias (`bool`, defaults to `True`):
|
||||
Whether or not to use bias in attention output project layer.
|
||||
context_length (`int`, defaults to `16`):
|
||||
The maximum number of frames that the FreeNoise block processes at once.
|
||||
context_stride (`int`, defaults to `4`):
|
||||
The number of frames to be skipped before starting to process a new batch of `context_length` frames.
|
||||
weighting_scheme (`str`, defaults to `"pyramid"`):
|
||||
The weighting scheme to use for weighting averaging of processed latent frames. As described in the
|
||||
Equation 9. of the [FreeNoise](https://arxiv.org/abs/2310.15169) paper, "pyramid" is the default setting
|
||||
used.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
dropout: float = 0.0,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
only_cross_attention: bool = False,
|
||||
double_self_attention: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
norm_elementwise_affine: bool = True,
|
||||
norm_type: str = "layer_norm",
|
||||
norm_eps: float = 1e-5,
|
||||
final_dropout: bool = False,
|
||||
positional_embeddings: Optional[str] = None,
|
||||
num_positional_embeddings: Optional[int] = None,
|
||||
ff_inner_dim: Optional[int] = None,
|
||||
ff_bias: bool = True,
|
||||
attention_out_bias: bool = True,
|
||||
context_length: int = 16,
|
||||
context_stride: int = 4,
|
||||
weighting_scheme: str = "pyramid",
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.dropout = dropout
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.activation_fn = activation_fn
|
||||
self.attention_bias = attention_bias
|
||||
self.double_self_attention = double_self_attention
|
||||
self.norm_elementwise_affine = norm_elementwise_affine
|
||||
self.positional_embeddings = positional_embeddings
|
||||
self.num_positional_embeddings = num_positional_embeddings
|
||||
self.only_cross_attention = only_cross_attention
|
||||
|
||||
self.set_free_noise_properties(context_length, context_stride, weighting_scheme)
|
||||
|
||||
# We keep these boolean flags for backward-compatibility.
|
||||
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
|
||||
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
|
||||
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
|
||||
self.use_layer_norm = norm_type == "layer_norm"
|
||||
self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
|
||||
|
||||
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
||||
raise ValueError(
|
||||
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
||||
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
|
||||
)
|
||||
|
||||
self.norm_type = norm_type
|
||||
self.num_embeds_ada_norm = num_embeds_ada_norm
|
||||
|
||||
if positional_embeddings and (num_positional_embeddings is None):
|
||||
raise ValueError(
|
||||
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
|
||||
)
|
||||
|
||||
if positional_embeddings == "sinusoidal":
|
||||
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
|
||||
else:
|
||||
self.pos_embed = None
|
||||
|
||||
# Define 3 blocks. Each block has its own normalization layer.
|
||||
# 1. Self-Attn
|
||||
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
||||
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
||||
upcast_attention=upcast_attention,
|
||||
out_bias=attention_out_bias,
|
||||
)
|
||||
|
||||
# 2. Cross-Attn
|
||||
if cross_attention_dim is not None or double_self_attention:
|
||||
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
||||
|
||||
self.attn2 = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
upcast_attention=upcast_attention,
|
||||
out_bias=attention_out_bias,
|
||||
) # is self-attn if encoder_hidden_states is none
|
||||
|
||||
# 3. Feed-forward
|
||||
self.ff = FeedForward(
|
||||
dim,
|
||||
dropout=dropout,
|
||||
activation_fn=activation_fn,
|
||||
final_dropout=final_dropout,
|
||||
inner_dim=ff_inner_dim,
|
||||
bias=ff_bias,
|
||||
)
|
||||
|
||||
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
||||
|
||||
# let chunk size default to None
|
||||
self._chunk_size = None
|
||||
self._chunk_dim = 0
|
||||
|
||||
def _get_frame_indices(self, num_frames: int) -> List[Tuple[int, int]]:
|
||||
frame_indices = []
|
||||
for i in range(0, num_frames - self.context_length + 1, self.context_stride):
|
||||
window_start = i
|
||||
window_end = min(num_frames, i + self.context_length)
|
||||
frame_indices.append((window_start, window_end))
|
||||
return frame_indices
|
||||
|
||||
def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]:
|
||||
if weighting_scheme == "pyramid":
|
||||
if num_frames % 2 == 0:
|
||||
# num_frames = 4 => [1, 2, 2, 1]
|
||||
weights = list(range(1, num_frames // 2 + 1))
|
||||
weights = weights + weights[::-1]
|
||||
else:
|
||||
# num_frames = 5 => [1, 2, 3, 2, 1]
|
||||
weights = list(range(1, num_frames // 2 + 1))
|
||||
weights = weights + [num_frames // 2 + 1] + weights[::-1]
|
||||
else:
|
||||
raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}")
|
||||
|
||||
return weights
|
||||
|
||||
def set_free_noise_properties(
|
||||
self, context_length: int, context_stride: int, weighting_scheme: str = "pyramid"
|
||||
) -> None:
|
||||
self.context_length = context_length
|
||||
self.context_stride = context_stride
|
||||
self.weighting_scheme = weighting_scheme
|
||||
|
||||
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0) -> None:
|
||||
# Sets chunk feed-forward
|
||||
self._chunk_size = chunk_size
|
||||
self._chunk_dim = dim
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
if cross_attention_kwargs is not None:
|
||||
if cross_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
||||
|
||||
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
|
||||
|
||||
# hidden_states: [B x H x W, F, C]
|
||||
device = hidden_states.device
|
||||
dtype = hidden_states.dtype
|
||||
|
||||
num_frames = hidden_states.size(1)
|
||||
frame_indices = self._get_frame_indices(num_frames)
|
||||
frame_weights = self._get_frame_weights(self.context_length, self.weighting_scheme)
|
||||
frame_weights = torch.tensor(frame_weights, device=device, dtype=dtype).unsqueeze(0).unsqueeze(-1)
|
||||
is_last_frame_batch_complete = frame_indices[-1][1] == num_frames
|
||||
|
||||
# Handle out-of-bounds case if num_frames isn't perfectly divisible by context_length
|
||||
# For example, num_frames=25, context_length=16, context_stride=4, then we expect the ranges:
|
||||
# [(0, 16), (4, 20), (8, 24), (10, 26)]
|
||||
if not is_last_frame_batch_complete:
|
||||
if num_frames < self.context_length:
|
||||
raise ValueError(f"Expected {num_frames=} to be greater or equal than {self.context_length=}")
|
||||
last_frame_batch_length = num_frames - frame_indices[-1][1]
|
||||
frame_indices.append((num_frames - self.context_length, num_frames))
|
||||
|
||||
num_times_accumulated = torch.zeros((1, num_frames, 1), device=device)
|
||||
accumulated_values = torch.zeros_like(hidden_states)
|
||||
|
||||
for i, (frame_start, frame_end) in enumerate(frame_indices):
|
||||
# The reason for slicing here is to ensure that if (frame_end - frame_start) is to handle
|
||||
# cases like frame_indices=[(0, 16), (16, 20)], if the user provided a video with 19 frames, or
|
||||
# essentially a non-multiple of `context_length`.
|
||||
weights = torch.ones_like(num_times_accumulated[:, frame_start:frame_end])
|
||||
weights *= frame_weights
|
||||
|
||||
hidden_states_chunk = hidden_states[:, frame_start:frame_end]
|
||||
|
||||
# Notice that normalization is always applied before the real computation in the following blocks.
|
||||
# 1. Self-Attention
|
||||
norm_hidden_states = self.norm1(hidden_states_chunk)
|
||||
|
||||
if self.pos_embed is not None:
|
||||
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
||||
|
||||
attn_output = self.attn1(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
||||
attention_mask=attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
|
||||
hidden_states_chunk = attn_output + hidden_states_chunk
|
||||
if hidden_states_chunk.ndim == 4:
|
||||
hidden_states_chunk = hidden_states_chunk.squeeze(1)
|
||||
|
||||
# 2. Cross-Attention
|
||||
if self.attn2 is not None:
|
||||
norm_hidden_states = self.norm2(hidden_states_chunk)
|
||||
|
||||
if self.pos_embed is not None and self.norm_type != "ada_norm_single":
|
||||
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
||||
|
||||
attn_output = self.attn2(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
hidden_states_chunk = attn_output + hidden_states_chunk
|
||||
|
||||
if i == len(frame_indices) - 1 and not is_last_frame_batch_complete:
|
||||
accumulated_values[:, -last_frame_batch_length:] += (
|
||||
hidden_states_chunk[:, -last_frame_batch_length:] * weights[:, -last_frame_batch_length:]
|
||||
)
|
||||
num_times_accumulated[:, -last_frame_batch_length:] += weights[:, -last_frame_batch_length]
|
||||
else:
|
||||
accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights
|
||||
num_times_accumulated[:, frame_start:frame_end] += weights
|
||||
|
||||
hidden_states = torch.where(
|
||||
num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
|
||||
).to(dtype)
|
||||
|
||||
# 3. Feed-forward
|
||||
norm_hidden_states = self.norm3(hidden_states)
|
||||
|
||||
if self._chunk_size is not None:
|
||||
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
|
||||
else:
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
hidden_states = ff_output + hidden_states
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
r"""
|
||||
A feed-forward layer.
|
||||
|
||||
@@ -343,6 +343,7 @@ class DownBlockMotion(nn.Module):
|
||||
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
hidden_states = motion_module(hidden_states, num_frames=num_frames)
|
||||
|
||||
output_states = output_states + (hidden_states,)
|
||||
@@ -536,6 +537,7 @@ class CrossAttnDownBlockMotion(nn.Module):
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
@@ -761,6 +763,7 @@ class CrossAttnUpBlockMotion(nn.Module):
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
@@ -921,9 +924,9 @@ class UpBlockMotion(nn.Module):
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb
|
||||
)
|
||||
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
hidden_states = motion_module(hidden_states, num_frames=num_frames)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
@@ -1923,7 +1926,6 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
|
||||
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
|
||||
"""
|
||||
Sets the attention processor to use [feed forward
|
||||
@@ -1953,7 +1955,6 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
|
||||
for module in self.children():
|
||||
fn_recursive_feed_forward(module, chunk_size, dim)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
|
||||
def disable_forward_chunking(self) -> None:
|
||||
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
|
||||
if hasattr(module, "set_chunk_feed_forward"):
|
||||
|
||||
@@ -42,6 +42,7 @@ from ...utils import (
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..free_init_utils import FreeInitMixin
|
||||
from ..free_noise_utils import AnimateDiffFreeNoiseMixin
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from .pipeline_output import AnimateDiffPipelineOutput
|
||||
|
||||
@@ -72,6 +73,7 @@ class AnimateDiffPipeline(
|
||||
IPAdapterMixin,
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
FreeInitMixin,
|
||||
AnimateDiffFreeNoiseMixin,
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-to-video generation.
|
||||
@@ -394,15 +396,20 @@ class AnimateDiffPipeline(
|
||||
|
||||
return ip_adapter_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
def decode_latents(self, latents, vae_batch_size: int = 16):
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
|
||||
batch_size, channels, num_frames, height, width = latents.shape
|
||||
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
|
||||
|
||||
image = self.vae.decode(latents).sample
|
||||
video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4)
|
||||
video = []
|
||||
for i in range(0, latents.shape[0], vae_batch_size):
|
||||
batch_latents = latents[i : i + vae_batch_size]
|
||||
batch_latents = self.vae.decode(batch_latents).sample
|
||||
video.append(batch_latents)
|
||||
|
||||
video = torch.cat(video)
|
||||
video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
video = video.float()
|
||||
return video
|
||||
@@ -495,10 +502,21 @@ class AnimateDiffPipeline(
|
||||
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
|
||||
def prepare_latents(
|
||||
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
||||
):
|
||||
# If FreeNoise is enabled, generate latents as described in Equation (7) of [FreeNoise](https://arxiv.org/abs/2310.15169)
|
||||
if self.free_noise_enabled:
|
||||
latents = self._prepare_latents_free_noise(
|
||||
batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents
|
||||
)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
@@ -506,11 +524,6 @@ class AnimateDiffPipeline(
|
||||
height // self.vae_scale_factor,
|
||||
width // self.vae_scale_factor,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
@@ -569,6 +582,7 @@ class AnimateDiffPipeline(
|
||||
clip_skip: Optional[int] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
vae_batch_size: int = 16,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@@ -637,6 +651,8 @@ class AnimateDiffPipeline(
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
vae_batch_size (`int`, defaults to `16`):
|
||||
The number of frames to decode at a time when calling `decode_latents` method.
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -808,7 +824,7 @@ class AnimateDiffPipeline(
|
||||
if output_type == "latent":
|
||||
video = latents
|
||||
else:
|
||||
video_tensor = self.decode_latents(latents)
|
||||
video_tensor = self.decode_latents(latents, vae_batch_size)
|
||||
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
|
||||
|
||||
# 10. Offload all models
|
||||
|
||||
@@ -30,6 +30,7 @@ from ...utils.torch_utils import is_compiled_module, randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..controlnet.multicontrolnet import MultiControlNetModel
|
||||
from ..free_init_utils import FreeInitMixin
|
||||
from ..free_noise_utils import AnimateDiffFreeNoiseMixin
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from .pipeline_output import AnimateDiffPipelineOutput
|
||||
|
||||
@@ -109,6 +110,7 @@ class AnimateDiffControlNetPipeline(
|
||||
IPAdapterMixin,
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
FreeInitMixin,
|
||||
AnimateDiffFreeNoiseMixin,
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-to-video generation with ControlNet guidance.
|
||||
@@ -432,15 +434,16 @@ class AnimateDiffControlNetPipeline(
|
||||
|
||||
return ip_adapter_image_embeds
|
||||
|
||||
def decode_latents(self, latents, decode_batch_size: int = 16):
|
||||
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents
|
||||
def decode_latents(self, latents, vae_batch_size: int = 16):
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
|
||||
batch_size, channels, num_frames, height, width = latents.shape
|
||||
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
|
||||
|
||||
video = []
|
||||
for i in range(0, latents.shape[0], decode_batch_size):
|
||||
batch_latents = latents[i : i + decode_batch_size]
|
||||
for i in range(0, latents.shape[0], vae_batch_size):
|
||||
batch_latents = latents[i : i + vae_batch_size]
|
||||
batch_latents = self.vae.decode(batch_latents).sample
|
||||
video.append(batch_latents)
|
||||
|
||||
@@ -608,10 +611,22 @@ class AnimateDiffControlNetPipeline(
|
||||
if end > 1.0:
|
||||
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
|
||||
|
||||
# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
|
||||
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.prepare_latents
|
||||
def prepare_latents(
|
||||
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
||||
):
|
||||
# If FreeNoise is enabled, generate latents as described in Equation (7) of [FreeNoise](https://arxiv.org/abs/2310.15169)
|
||||
if self.free_noise_enabled:
|
||||
latents = self._prepare_latents_free_noise(
|
||||
batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents
|
||||
)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
@@ -619,11 +634,6 @@ class AnimateDiffControlNetPipeline(
|
||||
height // self.vae_scale_factor,
|
||||
width // self.vae_scale_factor,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
@@ -718,7 +728,7 @@ class AnimateDiffControlNetPipeline(
|
||||
clip_skip: Optional[int] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
decode_batch_size: int = 16,
|
||||
vae_batch_size: int = 16,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
@@ -1054,7 +1064,7 @@ class AnimateDiffControlNetPipeline(
|
||||
if output_type == "latent":
|
||||
video = latents
|
||||
else:
|
||||
video_tensor = self.decode_latents(latents, decode_batch_size)
|
||||
video_tensor = self.decode_latents(latents, vae_batch_size)
|
||||
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
|
||||
|
||||
# 10. Offload all models
|
||||
|
||||
@@ -35,6 +35,7 @@ from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..free_init_utils import FreeInitMixin
|
||||
from ..free_noise_utils import AnimateDiffFreeNoiseMixin
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from .pipeline_output import AnimateDiffPipelineOutput
|
||||
|
||||
@@ -176,6 +177,7 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
IPAdapterMixin,
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
FreeInitMixin,
|
||||
AnimateDiffFreeNoiseMixin,
|
||||
):
|
||||
r"""
|
||||
Pipeline for video-to-video generation.
|
||||
@@ -498,15 +500,29 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
|
||||
return ip_adapter_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
def encode_video(self, video, generator, vae_batch_size: int = 16) -> torch.Tensor:
|
||||
latents = []
|
||||
for i in range(0, len(video), vae_batch_size):
|
||||
batch_video = video[i : i + vae_batch_size]
|
||||
batch_video = retrieve_latents(self.vae.encode(batch_video), generator=generator)
|
||||
latents.append(batch_video)
|
||||
return torch.cat(latents)
|
||||
|
||||
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents
|
||||
def decode_latents(self, latents, vae_batch_size: int = 16):
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
|
||||
batch_size, channels, num_frames, height, width = latents.shape
|
||||
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
|
||||
|
||||
image = self.vae.decode(latents).sample
|
||||
video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4)
|
||||
video = []
|
||||
for i in range(0, latents.shape[0], vae_batch_size):
|
||||
batch_latents = latents[i : i + vae_batch_size]
|
||||
batch_latents = self.vae.decode(batch_latents).sample
|
||||
video.append(batch_latents)
|
||||
|
||||
video = torch.cat(video)
|
||||
video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
video = video.float()
|
||||
return video
|
||||
@@ -622,6 +638,7 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
vae_batch_size: int = 16,
|
||||
):
|
||||
if latents is None:
|
||||
num_frames = video.shape[1]
|
||||
@@ -656,13 +673,10 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
)
|
||||
|
||||
init_latents = [
|
||||
retrieve_latents(self.vae.encode(video[i]), generator=generator[i]).unsqueeze(0)
|
||||
for i in range(batch_size)
|
||||
self.encode_video(video[i], generator[i], vae_batch_size).unsqueeze(0) for i in range(batch_size)
|
||||
]
|
||||
else:
|
||||
init_latents = [
|
||||
retrieve_latents(self.vae.encode(vid), generator=generator).unsqueeze(0) for vid in video
|
||||
]
|
||||
init_latents = [self.encode_video(vid, generator, vae_batch_size).unsqueeze(0) for vid in video]
|
||||
|
||||
init_latents = torch.cat(init_latents, dim=0)
|
||||
|
||||
@@ -747,6 +761,7 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
clip_skip: Optional[int] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
vae_batch_size: int = 16,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
@@ -822,6 +837,8 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
vae_batch_size (`int`, defaults to `16`):
|
||||
The number of frames to decode at a time when calling `decode_latents` method.
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -923,6 +940,7 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
device=device,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
vae_batch_size=vae_batch_size,
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
@@ -990,7 +1008,7 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
if output_type == "latent":
|
||||
video = latents
|
||||
else:
|
||||
video_tensor = self.decode_latents(latents)
|
||||
video_tensor = self.decode_latents(latents, vae_batch_size)
|
||||
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
|
||||
|
||||
# 10. Offload all models
|
||||
|
||||
236
src/diffusers/pipelines/free_noise_utils.py
Normal file
236
src/diffusers/pipelines/free_noise_utils.py
Normal file
@@ -0,0 +1,236 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..models.attention import BasicTransformerBlock, FreeNoiseTransformerBlock
|
||||
from ..models.unets.unet_motion_model import (
|
||||
CrossAttnDownBlockMotion,
|
||||
DownBlockMotion,
|
||||
UpBlockMotion,
|
||||
)
|
||||
from ..utils import logging
|
||||
from ..utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class AnimateDiffFreeNoiseMixin:
|
||||
r"""Mixin class for [FreeNoise](https://arxiv.org/abs/2310.15169)."""
|
||||
|
||||
def _enable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion]):
|
||||
r"""Helper function to enable FreeNoise in transformer blocks."""
|
||||
|
||||
for motion_module in block.motion_modules:
|
||||
num_transformer_blocks = len(motion_module.transformer_blocks)
|
||||
|
||||
for i in range(num_transformer_blocks):
|
||||
if isinstance(motion_module.transformer_blocks[i], FreeNoiseTransformerBlock):
|
||||
motion_module.transformer_blocks[i].set_free_noise_properties(
|
||||
self._free_noise_context_length,
|
||||
self._free_noise_context_stride,
|
||||
self._free_noise_weighting_scheme,
|
||||
)
|
||||
else:
|
||||
assert isinstance(motion_module.transformer_blocks[i], BasicTransformerBlock)
|
||||
basic_transfomer_block = motion_module.transformer_blocks[i]
|
||||
|
||||
motion_module.transformer_blocks[i] = FreeNoiseTransformerBlock(
|
||||
dim=basic_transfomer_block.dim,
|
||||
num_attention_heads=basic_transfomer_block.num_attention_heads,
|
||||
attention_head_dim=basic_transfomer_block.attention_head_dim,
|
||||
dropout=basic_transfomer_block.dropout,
|
||||
cross_attention_dim=basic_transfomer_block.cross_attention_dim,
|
||||
activation_fn=basic_transfomer_block.activation_fn,
|
||||
attention_bias=basic_transfomer_block.attention_bias,
|
||||
only_cross_attention=basic_transfomer_block.only_cross_attention,
|
||||
double_self_attention=basic_transfomer_block.double_self_attention,
|
||||
positional_embeddings=basic_transfomer_block.positional_embeddings,
|
||||
num_positional_embeddings=basic_transfomer_block.num_positional_embeddings,
|
||||
context_length=self._free_noise_context_length,
|
||||
context_stride=self._free_noise_context_stride,
|
||||
weighting_scheme=self._free_noise_weighting_scheme,
|
||||
).to(device=self.device, dtype=self.dtype)
|
||||
|
||||
motion_module.transformer_blocks[i].load_state_dict(
|
||||
basic_transfomer_block.state_dict(), strict=True
|
||||
)
|
||||
|
||||
def _disable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion]):
|
||||
r"""Helper function to disable FreeNoise in transformer blocks."""
|
||||
|
||||
for motion_module in block.motion_modules:
|
||||
num_transformer_blocks = len(motion_module.transformer_blocks)
|
||||
|
||||
for i in range(num_transformer_blocks):
|
||||
if isinstance(motion_module.transformer_blocks[i], FreeNoiseTransformerBlock):
|
||||
free_noise_transfomer_block = motion_module.transformer_blocks[i]
|
||||
|
||||
motion_module.transformer_blocks[i] = BasicTransformerBlock(
|
||||
dim=free_noise_transfomer_block.dim,
|
||||
num_attention_heads=free_noise_transfomer_block.num_attention_heads,
|
||||
attention_head_dim=free_noise_transfomer_block.attention_head_dim,
|
||||
dropout=free_noise_transfomer_block.dropout,
|
||||
cross_attention_dim=free_noise_transfomer_block.cross_attention_dim,
|
||||
activation_fn=free_noise_transfomer_block.activation_fn,
|
||||
attention_bias=free_noise_transfomer_block.attention_bias,
|
||||
only_cross_attention=free_noise_transfomer_block.only_cross_attention,
|
||||
double_self_attention=free_noise_transfomer_block.double_self_attention,
|
||||
positional_embeddings=free_noise_transfomer_block.positional_embeddings,
|
||||
num_positional_embeddings=free_noise_transfomer_block.num_positional_embeddings,
|
||||
).to(device=self.device, dtype=self.dtype)
|
||||
|
||||
motion_module.transformer_blocks[i].load_state_dict(
|
||||
free_noise_transfomer_block.state_dict(), strict=True
|
||||
)
|
||||
|
||||
def _prepare_latents_free_noise(
|
||||
self,
|
||||
batch_size: int,
|
||||
num_channels_latents: int,
|
||||
num_frames: int,
|
||||
height: int,
|
||||
width: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
context_num_frames = (
|
||||
self._free_noise_context_length if self._free_noise_context_length == "repeat_context" else num_frames
|
||||
)
|
||||
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
context_num_frames,
|
||||
height // self.vae_scale_factor,
|
||||
width // self.vae_scale_factor,
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
if self._free_noise_noise_type == "random":
|
||||
return latents
|
||||
else:
|
||||
if latents.size(2) == num_frames:
|
||||
return latents
|
||||
elif latents.size(2) != self._free_noise_context_length:
|
||||
raise ValueError(
|
||||
f"You have passed `latents` as a parameter to FreeNoise. The expected number of frames is either {num_frames} or {self._free_noise_context_length}, but found {latents.size(2)}"
|
||||
)
|
||||
latents = latents.to(device)
|
||||
|
||||
if self._free_noise_noise_type == "shuffle_context":
|
||||
for i in range(self._free_noise_context_length, num_frames, self._free_noise_context_stride):
|
||||
# ensure window is within bounds
|
||||
window_start = max(0, i - self._free_noise_context_length)
|
||||
window_end = min(num_frames, window_start + self._free_noise_context_stride)
|
||||
window_length = window_end - window_start
|
||||
|
||||
if window_length == 0:
|
||||
break
|
||||
|
||||
indices = torch.LongTensor(list(range(window_start, window_end)))
|
||||
shuffled_indices = indices[torch.randperm(window_length, generator=generator)]
|
||||
|
||||
current_start = i
|
||||
current_end = min(num_frames, current_start + window_length)
|
||||
if current_end == current_start + window_length:
|
||||
# batch of frames perfectly fits the window
|
||||
latents[:, :, current_start:current_end] = latents[:, :, shuffled_indices]
|
||||
else:
|
||||
# handle the case where the last batch of frames does not fit perfectly with the window
|
||||
prefix_length = current_end - current_start
|
||||
shuffled_indices = shuffled_indices[:prefix_length]
|
||||
latents[:, :, current_start:current_end] = latents[:, :, shuffled_indices]
|
||||
|
||||
elif self._free_noise_noise_type == "repeat_context":
|
||||
num_repeats = (num_frames + self._free_noise_context_length - 1) // self._free_noise_context_length
|
||||
latents = torch.cat([latents] * num_repeats, dim=2)
|
||||
|
||||
latents = latents[:, :, :num_frames]
|
||||
return latents
|
||||
|
||||
def enable_free_noise(
|
||||
self,
|
||||
context_length: Optional[int] = 16,
|
||||
context_stride: int = 4,
|
||||
weighting_scheme: str = "pyramid",
|
||||
noise_type: str = "shuffle_context",
|
||||
) -> None:
|
||||
r"""
|
||||
Enable long video generation using FreeNoise.
|
||||
|
||||
Args:
|
||||
context_length (`int`, defaults to `16`, *optional*):
|
||||
The number of video frames to process at once. It's recommended to set this to the maximum frames the
|
||||
Motion Adapter was trained with (usually 16/24/32). If `None`, the default value from the motion
|
||||
adapter config is used.
|
||||
context_stride (`int`, *optional*):
|
||||
Long videos are generated by processing many frames. FreeNoise processes these frames in sliding
|
||||
windows of size `context_length`. Context stride allows you to specify how many frames to skip between
|
||||
each window. For example, a context length of 16 and context stride of 4 would process 24 frames as:
|
||||
[0, 15], [4, 19], [8, 23] (0-based indexing)
|
||||
weighting_scheme (`str`, defaults to `pyramid`):
|
||||
Weighting scheme for averaging latents after accumulation in FreeNoise blocks. The following weighting
|
||||
schemes are supported currently:
|
||||
- "pyramid"
|
||||
Peforms weighted averaging with a pyramid like weight pattern: [1, 2, 3, 2, 1].
|
||||
noise_type (`str`, defaults to "shuffle_context"):
|
||||
TODO
|
||||
"""
|
||||
|
||||
allowed_weighting_scheme = ["pyramid"]
|
||||
allowed_noise_type = ["shuffle_context", "repeat_context", "random"]
|
||||
|
||||
if context_length > self.motion_adapter.config.motion_max_seq_length:
|
||||
logger.warning(
|
||||
f"You have set {context_length=} which is greater than {self.motion_adapter.config.motion_max_seq_length=}. This can lead to bad generation results."
|
||||
)
|
||||
if weighting_scheme not in allowed_weighting_scheme:
|
||||
raise ValueError(
|
||||
f"The parameter `weighting_scheme` must be one of {allowed_weighting_scheme}, but got {weighting_scheme=}"
|
||||
)
|
||||
if noise_type not in allowed_noise_type:
|
||||
raise ValueError(f"The parameter `noise_type` must be one of {allowed_noise_type}, but got {noise_type=}")
|
||||
|
||||
self._free_noise_context_length = context_length or self.motion_adapter.config.motion_max_seq_length
|
||||
self._free_noise_context_stride = context_stride
|
||||
self._free_noise_weighting_scheme = weighting_scheme
|
||||
self._free_noise_noise_type = noise_type
|
||||
|
||||
blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
|
||||
for block in blocks:
|
||||
self._enable_free_noise_in_block(block)
|
||||
|
||||
def disable_free_noise(self) -> None:
|
||||
self._free_noise_context_length = None
|
||||
|
||||
blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
|
||||
for block in blocks:
|
||||
self._disable_free_noise_in_block(block)
|
||||
|
||||
@property
|
||||
def free_noise_enabled(self):
|
||||
return hasattr(self, "_free_noise_context_length") and self._free_noise_context_length is not None
|
||||
@@ -35,6 +35,7 @@ from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..animatediff.pipeline_output import AnimateDiffPipelineOutput
|
||||
from ..free_init_utils import FreeInitMixin
|
||||
from ..free_noise_utils import AnimateDiffFreeNoiseMixin
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from .pag_utils import PAGMixin
|
||||
|
||||
@@ -83,6 +84,7 @@ class AnimateDiffPAGPipeline(
|
||||
IPAdapterMixin,
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
FreeInitMixin,
|
||||
AnimateDiffFreeNoiseMixin,
|
||||
PAGMixin,
|
||||
):
|
||||
r"""
|
||||
@@ -404,15 +406,21 @@ class AnimateDiffPAGPipeline(
|
||||
|
||||
return ip_adapter_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents
|
||||
def decode_latents(self, latents, vae_batch_size: int = 16):
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
|
||||
batch_size, channels, num_frames, height, width = latents.shape
|
||||
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
|
||||
|
||||
image = self.vae.decode(latents).sample
|
||||
video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4)
|
||||
video = []
|
||||
for i in range(0, latents.shape[0], vae_batch_size):
|
||||
batch_latents = latents[i : i + vae_batch_size]
|
||||
batch_latents = self.vae.decode(batch_latents).sample
|
||||
video.append(batch_latents)
|
||||
|
||||
video = torch.cat(video)
|
||||
video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
video = video.float()
|
||||
return video
|
||||
@@ -499,10 +507,22 @@ class AnimateDiffPAGPipeline(
|
||||
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
|
||||
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.prepare_latents
|
||||
def prepare_latents(
|
||||
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
||||
):
|
||||
# If FreeNoise is enabled, generate latents as described in Equation (7) of [FreeNoise](https://arxiv.org/abs/2310.15169)
|
||||
if self.free_noise_enabled:
|
||||
latents = self._prepare_latents_free_noise(
|
||||
batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents
|
||||
)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
@@ -510,11 +530,6 @@ class AnimateDiffPAGPipeline(
|
||||
height // self.vae_scale_factor,
|
||||
width // self.vae_scale_factor,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
@@ -573,6 +588,7 @@ class AnimateDiffPAGPipeline(
|
||||
clip_skip: Optional[int] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
vae_batch_size: int = 16,
|
||||
pag_scale: float = 3.0,
|
||||
pag_adaptive_scale: float = 0.0,
|
||||
):
|
||||
@@ -831,7 +847,7 @@ class AnimateDiffPAGPipeline(
|
||||
if output_type == "latent":
|
||||
video = latents
|
||||
else:
|
||||
video_tensor = self.decode_latents(latents)
|
||||
video_tensor = self.decode_latents(latents, vae_batch_size)
|
||||
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
|
||||
|
||||
# 10. Offload all models
|
||||
|
||||
@@ -17,6 +17,7 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
UNetMotionModel,
|
||||
)
|
||||
from diffusers.models.attention import FreeNoiseTransformerBlock
|
||||
from diffusers.utils import is_xformers_available, logging
|
||||
from diffusers.utils.testing_utils import numpy_cosine_similarity_distance, require_torch_gpu, slow, torch_device
|
||||
|
||||
@@ -401,6 +402,64 @@ class AnimateDiffPipelineFastTests(
|
||||
"Enabling of FreeInit should lead to results different from the default pipeline results",
|
||||
)
|
||||
|
||||
def test_free_noise_blocks(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe: AnimateDiffPipeline = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.to(torch_device)
|
||||
|
||||
pipe.enable_free_noise()
|
||||
for block in pipe.unet.down_blocks:
|
||||
for motion_module in block.motion_modules:
|
||||
for transformer_block in motion_module.transformer_blocks:
|
||||
self.assertTrue(
|
||||
isinstance(transformer_block, FreeNoiseTransformerBlock),
|
||||
"Motion module transformer blocks must be an instance of `FreeNoiseTransformerBlock` after enabling FreeNoise.",
|
||||
)
|
||||
|
||||
pipe.disable_free_noise()
|
||||
for block in pipe.unet.down_blocks:
|
||||
for motion_module in block.motion_modules:
|
||||
for transformer_block in motion_module.transformer_blocks:
|
||||
self.assertFalse(
|
||||
isinstance(transformer_block, FreeNoiseTransformerBlock),
|
||||
"Motion module transformer blocks must not be an instance of `FreeNoiseTransformerBlock` after disabling FreeNoise.",
|
||||
)
|
||||
|
||||
def test_free_noise(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe: AnimateDiffPipeline = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.to(torch_device)
|
||||
|
||||
inputs_normal = self.get_dummy_inputs(torch_device)
|
||||
frames_normal = pipe(**inputs_normal).frames[0]
|
||||
|
||||
for context_length in [8, 9]:
|
||||
for context_stride in [4, 6]:
|
||||
pipe.enable_free_noise(context_length, context_stride)
|
||||
|
||||
inputs_enable_free_noise = self.get_dummy_inputs(torch_device)
|
||||
frames_enable_free_noise = pipe(**inputs_enable_free_noise).frames[0]
|
||||
|
||||
pipe.disable_free_noise()
|
||||
|
||||
inputs_disable_free_noise = self.get_dummy_inputs(torch_device)
|
||||
frames_disable_free_noise = pipe(**inputs_disable_free_noise).frames[0]
|
||||
|
||||
sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_noise)).sum()
|
||||
max_diff_disabled = np.abs(to_np(frames_normal) - to_np(frames_disable_free_noise)).max()
|
||||
self.assertGreater(
|
||||
sum_enabled,
|
||||
1e1,
|
||||
"Enabling of FreeNoise should lead to results different from the default pipeline results",
|
||||
)
|
||||
self.assertLess(
|
||||
max_diff_disabled,
|
||||
1e-4,
|
||||
"Disabling of FreeNoise should lead to results similar to the default pipeline results",
|
||||
)
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
|
||||
@@ -18,6 +18,7 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
UNetMotionModel,
|
||||
)
|
||||
from diffusers.models.attention import FreeNoiseTransformerBlock
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils.testing_utils import torch_device
|
||||
|
||||
@@ -409,6 +410,64 @@ class AnimateDiffControlNetPipelineFastTests(
|
||||
"Enabling of FreeInit should lead to results different from the default pipeline results",
|
||||
)
|
||||
|
||||
def test_free_noise_blocks(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe: AnimateDiffControlNetPipeline = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.to(torch_device)
|
||||
|
||||
pipe.enable_free_noise()
|
||||
for block in pipe.unet.down_blocks:
|
||||
for motion_module in block.motion_modules:
|
||||
for transformer_block in motion_module.transformer_blocks:
|
||||
self.assertTrue(
|
||||
isinstance(transformer_block, FreeNoiseTransformerBlock),
|
||||
"Motion module transformer blocks must be an instance of `FreeNoiseTransformerBlock` after enabling FreeNoise.",
|
||||
)
|
||||
|
||||
pipe.disable_free_noise()
|
||||
for block in pipe.unet.down_blocks:
|
||||
for motion_module in block.motion_modules:
|
||||
for transformer_block in motion_module.transformer_blocks:
|
||||
self.assertFalse(
|
||||
isinstance(transformer_block, FreeNoiseTransformerBlock),
|
||||
"Motion module transformer blocks must not be an instance of `FreeNoiseTransformerBlock` after disabling FreeNoise.",
|
||||
)
|
||||
|
||||
def test_free_noise(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe: AnimateDiffControlNetPipeline = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.to(torch_device)
|
||||
|
||||
inputs_normal = self.get_dummy_inputs(torch_device, num_frames=16)
|
||||
frames_normal = pipe(**inputs_normal).frames[0]
|
||||
|
||||
for context_length in [8, 9]:
|
||||
for context_stride in [4, 6]:
|
||||
pipe.enable_free_noise(context_length, context_stride)
|
||||
|
||||
inputs_enable_free_noise = self.get_dummy_inputs(torch_device, num_frames=16)
|
||||
frames_enable_free_noise = pipe(**inputs_enable_free_noise).frames[0]
|
||||
|
||||
pipe.disable_free_noise()
|
||||
|
||||
inputs_disable_free_noise = self.get_dummy_inputs(torch_device, num_frames=16)
|
||||
frames_disable_free_noise = pipe(**inputs_disable_free_noise).frames[0]
|
||||
|
||||
sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_noise)).sum()
|
||||
max_diff_disabled = np.abs(to_np(frames_normal) - to_np(frames_disable_free_noise)).max()
|
||||
self.assertGreater(
|
||||
sum_enabled,
|
||||
1e1,
|
||||
"Enabling of FreeNoise should lead to results different from the default pipeline results",
|
||||
)
|
||||
self.assertLess(
|
||||
max_diff_disabled,
|
||||
1e-4,
|
||||
"Disabling of FreeNoise should lead to results similar to the default pipeline results",
|
||||
)
|
||||
|
||||
def test_vae_slicing(self, video_count=2):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
|
||||
@@ -17,6 +17,7 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
UNetMotionModel,
|
||||
)
|
||||
from diffusers.models.attention import FreeNoiseTransformerBlock
|
||||
from diffusers.utils import is_xformers_available, logging
|
||||
from diffusers.utils.testing_utils import torch_device
|
||||
|
||||
@@ -114,7 +115,7 @@ class AnimateDiffVideoToVideoPipelineFastTests(
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
def get_dummy_inputs(self, device, seed=0, num_frames: int = 2):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
@@ -122,8 +123,7 @@ class AnimateDiffVideoToVideoPipelineFastTests(
|
||||
|
||||
video_height = 32
|
||||
video_width = 32
|
||||
video_num_frames = 2
|
||||
video = [Image.new("RGB", (video_width, video_height))] * video_num_frames
|
||||
video = [Image.new("RGB", (video_width, video_height))] * num_frames
|
||||
|
||||
inputs = {
|
||||
"video": video,
|
||||
@@ -428,3 +428,66 @@ class AnimateDiffVideoToVideoPipelineFastTests(
|
||||
1e1,
|
||||
"Enabling of FreeInit should lead to results different from the default pipeline results",
|
||||
)
|
||||
|
||||
def test_free_noise_blocks(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe: AnimateDiffVideoToVideoPipeline = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.to(torch_device)
|
||||
|
||||
pipe.enable_free_noise()
|
||||
for block in pipe.unet.down_blocks:
|
||||
for motion_module in block.motion_modules:
|
||||
for transformer_block in motion_module.transformer_blocks:
|
||||
self.assertTrue(
|
||||
isinstance(transformer_block, FreeNoiseTransformerBlock),
|
||||
"Motion module transformer blocks must be an instance of `FreeNoiseTransformerBlock` after enabling FreeNoise.",
|
||||
)
|
||||
|
||||
pipe.disable_free_noise()
|
||||
for block in pipe.unet.down_blocks:
|
||||
for motion_module in block.motion_modules:
|
||||
for transformer_block in motion_module.transformer_blocks:
|
||||
self.assertFalse(
|
||||
isinstance(transformer_block, FreeNoiseTransformerBlock),
|
||||
"Motion module transformer blocks must not be an instance of `FreeNoiseTransformerBlock` after disabling FreeNoise.",
|
||||
)
|
||||
|
||||
def test_free_noise(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe: AnimateDiffVideoToVideoPipeline = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.to(torch_device)
|
||||
|
||||
inputs_normal = self.get_dummy_inputs(torch_device, num_frames=16)
|
||||
inputs_normal["num_inference_steps"] = 2
|
||||
inputs_normal["strength"] = 0.5
|
||||
frames_normal = pipe(**inputs_normal).frames[0]
|
||||
|
||||
for context_length in [8, 9]:
|
||||
for context_stride in [4, 6]:
|
||||
pipe.enable_free_noise(context_length, context_stride)
|
||||
|
||||
inputs_enable_free_noise = self.get_dummy_inputs(torch_device, num_frames=16)
|
||||
inputs_enable_free_noise["num_inference_steps"] = 2
|
||||
inputs_enable_free_noise["strength"] = 0.5
|
||||
frames_enable_free_noise = pipe(**inputs_enable_free_noise).frames[0]
|
||||
|
||||
pipe.disable_free_noise()
|
||||
inputs_disable_free_noise = self.get_dummy_inputs(torch_device, num_frames=16)
|
||||
inputs_disable_free_noise["num_inference_steps"] = 2
|
||||
inputs_disable_free_noise["strength"] = 0.5
|
||||
frames_disable_free_noise = pipe(**inputs_disable_free_noise).frames[0]
|
||||
|
||||
sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_noise)).sum()
|
||||
max_diff_disabled = np.abs(to_np(frames_normal) - to_np(frames_disable_free_noise)).max()
|
||||
self.assertGreater(
|
||||
sum_enabled,
|
||||
1e1,
|
||||
"Enabling of FreeNoise should lead to results different from the default pipeline results",
|
||||
)
|
||||
self.assertLess(
|
||||
max_diff_disabled,
|
||||
1e-4,
|
||||
"Disabling of FreeNoise should lead to results similar to the default pipeline results",
|
||||
)
|
||||
|
||||
@@ -17,6 +17,7 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
UNetMotionModel,
|
||||
)
|
||||
from diffusers.models.attention import FreeNoiseTransformerBlock
|
||||
from diffusers.utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import torch_device
|
||||
|
||||
@@ -347,6 +348,64 @@ class AnimateDiffPAGPipelineFastTests(
|
||||
"Enabling of FreeInit should lead to results different from the default pipeline results",
|
||||
)
|
||||
|
||||
def test_free_noise_blocks(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe: AnimateDiffPAGPipeline = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.to(torch_device)
|
||||
|
||||
pipe.enable_free_noise()
|
||||
for block in pipe.unet.down_blocks:
|
||||
for motion_module in block.motion_modules:
|
||||
for transformer_block in motion_module.transformer_blocks:
|
||||
self.assertTrue(
|
||||
isinstance(transformer_block, FreeNoiseTransformerBlock),
|
||||
"Motion module transformer blocks must be an instance of `FreeNoiseTransformerBlock` after enabling FreeNoise.",
|
||||
)
|
||||
|
||||
pipe.disable_free_noise()
|
||||
for block in pipe.unet.down_blocks:
|
||||
for motion_module in block.motion_modules:
|
||||
for transformer_block in motion_module.transformer_blocks:
|
||||
self.assertFalse(
|
||||
isinstance(transformer_block, FreeNoiseTransformerBlock),
|
||||
"Motion module transformer blocks must not be an instance of `FreeNoiseTransformerBlock` after disabling FreeNoise.",
|
||||
)
|
||||
|
||||
def test_free_noise(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe: AnimateDiffPAGPipeline = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.to(torch_device)
|
||||
|
||||
inputs_normal = self.get_dummy_inputs(torch_device)
|
||||
frames_normal = pipe(**inputs_normal).frames[0]
|
||||
|
||||
for context_length in [8, 9]:
|
||||
for context_stride in [4, 6]:
|
||||
pipe.enable_free_noise(context_length, context_stride)
|
||||
|
||||
inputs_enable_free_noise = self.get_dummy_inputs(torch_device)
|
||||
frames_enable_free_noise = pipe(**inputs_enable_free_noise).frames[0]
|
||||
|
||||
pipe.disable_free_noise()
|
||||
|
||||
inputs_disable_free_noise = self.get_dummy_inputs(torch_device)
|
||||
frames_disable_free_noise = pipe(**inputs_disable_free_noise).frames[0]
|
||||
|
||||
sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_noise)).sum()
|
||||
max_diff_disabled = np.abs(to_np(frames_normal) - to_np(frames_disable_free_noise)).max()
|
||||
self.assertGreater(
|
||||
sum_enabled,
|
||||
1e1,
|
||||
"Enabling of FreeNoise should lead to results different from the default pipeline results",
|
||||
)
|
||||
self.assertLess(
|
||||
max_diff_disabled,
|
||||
1e-4,
|
||||
"Disabling of FreeNoise should lead to results similar to the default pipeline results",
|
||||
)
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
|
||||
Reference in New Issue
Block a user