mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
[Refactor] FreeInit for AnimateDiff based pipelines (#6874)
* update * update * update * update * update * update * update * update * update * update
This commit is contained in:
@@ -13,12 +13,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import math
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.fft as fft
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
@@ -43,6 +41,7 @@ from ...utils import (
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..free_init_utils import FreeInitMixin
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import AnimateDiffPipelineOutput
|
||||
|
||||
@@ -87,72 +86,9 @@ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type:
|
||||
return outputs
|
||||
|
||||
|
||||
def _get_freeinit_freq_filter(
|
||||
shape: Tuple[int, ...],
|
||||
device: Union[str, torch.dtype],
|
||||
filter_type: str,
|
||||
order: float,
|
||||
spatial_stop_frequency: float,
|
||||
temporal_stop_frequency: float,
|
||||
) -> torch.Tensor:
|
||||
r"""Returns the FreeInit filter based on filter type and other input conditions."""
|
||||
|
||||
T, H, W = shape[-3], shape[-2], shape[-1]
|
||||
mask = torch.zeros(shape)
|
||||
|
||||
if spatial_stop_frequency == 0 or temporal_stop_frequency == 0:
|
||||
return mask
|
||||
|
||||
if filter_type == "butterworth":
|
||||
|
||||
def retrieve_mask(x):
|
||||
return 1 / (1 + (x / spatial_stop_frequency**2) ** order)
|
||||
elif filter_type == "gaussian":
|
||||
|
||||
def retrieve_mask(x):
|
||||
return math.exp(-1 / (2 * spatial_stop_frequency**2) * x)
|
||||
elif filter_type == "ideal":
|
||||
|
||||
def retrieve_mask(x):
|
||||
return 1 if x <= spatial_stop_frequency * 2 else 0
|
||||
else:
|
||||
raise NotImplementedError("`filter_type` must be one of gaussian, butterworth or ideal")
|
||||
|
||||
for t in range(T):
|
||||
for h in range(H):
|
||||
for w in range(W):
|
||||
d_square = (
|
||||
((spatial_stop_frequency / temporal_stop_frequency) * (2 * t / T - 1)) ** 2
|
||||
+ (2 * h / H - 1) ** 2
|
||||
+ (2 * w / W - 1) ** 2
|
||||
)
|
||||
mask[..., t, h, w] = retrieve_mask(d_square)
|
||||
|
||||
return mask.to(device)
|
||||
|
||||
|
||||
def _freq_mix_3d(x: torch.Tensor, noise: torch.Tensor, LPF: torch.Tensor) -> torch.Tensor:
|
||||
r"""Noise reinitialization."""
|
||||
# FFT
|
||||
x_freq = fft.fftn(x, dim=(-3, -2, -1))
|
||||
x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1))
|
||||
noise_freq = fft.fftn(noise, dim=(-3, -2, -1))
|
||||
noise_freq = fft.fftshift(noise_freq, dim=(-3, -2, -1))
|
||||
|
||||
# frequency mix
|
||||
HPF = 1 - LPF
|
||||
x_freq_low = x_freq * LPF
|
||||
noise_freq_high = noise_freq * HPF
|
||||
x_freq_mixed = x_freq_low + noise_freq_high # mix in freq domain
|
||||
|
||||
# IFFT
|
||||
x_freq_mixed = fft.ifftshift(x_freq_mixed, dim=(-3, -2, -1))
|
||||
x_mixed = fft.ifftn(x_freq_mixed, dim=(-3, -2, -1)).real
|
||||
|
||||
return x_mixed
|
||||
|
||||
|
||||
class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin):
|
||||
class AnimateDiffPipeline(
|
||||
DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FreeInitMixin
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-to-video generation.
|
||||
|
||||
@@ -182,7 +118,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
|
||||
_optional_components = ["feature_extractor", "image_encoder"]
|
||||
_optional_components = ["feature_extractor", "image_encoder", "motion_adapter"]
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
@@ -204,7 +140,8 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
||||
image_encoder: CLIPVisionModelWithProjection = None,
|
||||
):
|
||||
super().__init__()
|
||||
unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
|
||||
if isinstance(unet, UNet2DConditionModel):
|
||||
unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
@@ -530,63 +467,10 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
||||
raise ValueError("The pipeline must have `unet` for using FreeU.")
|
||||
self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu
|
||||
def disable_freeu(self):
|
||||
"""Disables the FreeU mechanism if enabled."""
|
||||
self.unet.disable_freeu()
|
||||
|
||||
@property
|
||||
def free_init_enabled(self):
|
||||
return hasattr(self, "_free_init_num_iters") and self._free_init_num_iters is not None
|
||||
|
||||
def enable_free_init(
|
||||
self,
|
||||
num_iters: int = 3,
|
||||
use_fast_sampling: bool = False,
|
||||
method: str = "butterworth",
|
||||
order: int = 4,
|
||||
spatial_stop_frequency: float = 0.25,
|
||||
temporal_stop_frequency: float = 0.25,
|
||||
generator: torch.Generator = None,
|
||||
):
|
||||
"""Enables the FreeInit mechanism as in https://arxiv.org/abs/2312.07537.
|
||||
|
||||
This implementation has been adapted from the [official repository](https://github.com/TianxingWu/FreeInit).
|
||||
|
||||
Args:
|
||||
num_iters (`int`, *optional*, defaults to `3`):
|
||||
Number of FreeInit noise re-initialization iterations.
|
||||
use_fast_sampling (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to speedup sampling procedure at the cost of probably lower quality results. Enables
|
||||
the "Coarse-to-Fine Sampling" strategy, as mentioned in the paper, if set to `True`.
|
||||
method (`str`, *optional*, defaults to `butterworth`):
|
||||
Must be one of `butterworth`, `ideal` or `gaussian` to use as the filtering method for the
|
||||
FreeInit low pass filter.
|
||||
order (`int`, *optional*, defaults to `4`):
|
||||
Order of the filter used in `butterworth` method. Larger values lead to `ideal` method behaviour
|
||||
whereas lower values lead to `gaussian` method behaviour.
|
||||
spatial_stop_frequency (`float`, *optional*, defaults to `0.25`):
|
||||
Normalized stop frequency for spatial dimensions. Must be between 0 to 1. Referred to as `d_s` in
|
||||
the original implementation.
|
||||
temporal_stop_frequency (`float`, *optional*, defaults to `0.25`):
|
||||
Normalized stop frequency for temporal dimensions. Must be between 0 to 1. Referred to as `d_t` in
|
||||
the original implementation.
|
||||
generator (`torch.Generator`, *optional*, defaults to `0.25`):
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||
FreeInit generation deterministic.
|
||||
"""
|
||||
self._free_init_num_iters = num_iters
|
||||
self._free_init_use_fast_sampling = use_fast_sampling
|
||||
self._free_init_method = method
|
||||
self._free_init_order = order
|
||||
self._free_init_spatial_stop_frequency = spatial_stop_frequency
|
||||
self._free_init_temporal_stop_frequency = temporal_stop_frequency
|
||||
self._free_init_generator = generator
|
||||
|
||||
def disable_free_init(self):
|
||||
"""Disables the FreeInit mechanism if enabled."""
|
||||
self._free_init_num_iters = None
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
@@ -691,158 +575,6 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def _denoise_loop(
|
||||
self,
|
||||
timesteps,
|
||||
num_inference_steps,
|
||||
do_classifier_free_guidance,
|
||||
guidance_scale,
|
||||
num_warmup_steps,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
latents,
|
||||
cross_attention_kwargs,
|
||||
added_cond_kwargs,
|
||||
extra_step_kwargs,
|
||||
callback,
|
||||
callback_steps,
|
||||
callback_on_step_end,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
):
|
||||
"""Denoising loop for AnimateDiff."""
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
return latents
|
||||
|
||||
def _free_init_loop(
|
||||
self,
|
||||
height,
|
||||
width,
|
||||
num_frames,
|
||||
num_channels_latents,
|
||||
batch_size,
|
||||
num_videos_per_prompt,
|
||||
denoise_args,
|
||||
device,
|
||||
):
|
||||
"""Denoising loop for AnimateDiff using FreeInit noise reinitialization technique."""
|
||||
|
||||
latents = denoise_args.get("latents")
|
||||
prompt_embeds = denoise_args.get("prompt_embeds")
|
||||
timesteps = denoise_args.get("timesteps")
|
||||
num_inference_steps = denoise_args.get("num_inference_steps")
|
||||
|
||||
latent_shape = (
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents,
|
||||
num_frames,
|
||||
height // self.vae_scale_factor,
|
||||
width // self.vae_scale_factor,
|
||||
)
|
||||
free_init_filter_shape = (
|
||||
1,
|
||||
num_channels_latents,
|
||||
num_frames,
|
||||
height // self.vae_scale_factor,
|
||||
width // self.vae_scale_factor,
|
||||
)
|
||||
free_init_freq_filter = _get_freeinit_freq_filter(
|
||||
shape=free_init_filter_shape,
|
||||
device=device,
|
||||
filter_type=self._free_init_method,
|
||||
order=self._free_init_order,
|
||||
spatial_stop_frequency=self._free_init_spatial_stop_frequency,
|
||||
temporal_stop_frequency=self._free_init_temporal_stop_frequency,
|
||||
)
|
||||
|
||||
with self.progress_bar(total=self._free_init_num_iters) as free_init_progress_bar:
|
||||
for i in range(self._free_init_num_iters):
|
||||
# For the first FreeInit iteration, the original latent is used without modification.
|
||||
# Subsequent iterations apply the noise reinitialization technique.
|
||||
if i == 0:
|
||||
initial_noise = latents.detach().clone()
|
||||
else:
|
||||
current_diffuse_timestep = (
|
||||
self.scheduler.config.num_train_timesteps - 1
|
||||
) # diffuse to t=999 noise level
|
||||
diffuse_timesteps = torch.full((batch_size,), current_diffuse_timestep).long()
|
||||
z_T = self.scheduler.add_noise(
|
||||
original_samples=latents, noise=initial_noise, timesteps=diffuse_timesteps.to(device)
|
||||
).to(dtype=torch.float32)
|
||||
z_rand = randn_tensor(
|
||||
shape=latent_shape,
|
||||
generator=self._free_init_generator,
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
latents = _freq_mix_3d(z_T, z_rand, LPF=free_init_freq_filter)
|
||||
latents = latents.to(prompt_embeds.dtype)
|
||||
|
||||
# Coarse-to-Fine Sampling for faster inference (can lead to lower quality)
|
||||
if self._free_init_use_fast_sampling:
|
||||
current_num_inference_steps = int(num_inference_steps / self._free_init_num_iters * (i + 1))
|
||||
self.scheduler.set_timesteps(current_num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
denoise_args.update({"timesteps": timesteps, "num_inference_steps": current_num_inference_steps})
|
||||
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
denoise_args.update({"latents": latents, "num_warmup_steps": num_warmup_steps})
|
||||
latents = self._denoise_loop(**denoise_args)
|
||||
|
||||
free_init_progress_bar.update()
|
||||
|
||||
return latents
|
||||
|
||||
def _retrieve_video_frames(self, latents, output_type, return_dict):
|
||||
"""Helper function to handle latents to output conversion."""
|
||||
if output_type == "latent":
|
||||
return AnimateDiffPipelineOutput(frames=latents)
|
||||
|
||||
video_tensor = self.decode_latents(latents)
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return AnimateDiffPipelineOutput(frames=video)
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
@@ -1046,7 +778,6 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
@@ -1068,43 +799,64 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
||||
# 7. Add image embeds for IP-Adapter
|
||||
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
denoise_args = {
|
||||
"timesteps": timesteps,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
"do_classifier_free_guidance": self.do_classifier_free_guidance,
|
||||
"guidance_scale": guidance_scale,
|
||||
"num_warmup_steps": num_warmup_steps,
|
||||
"prompt_embeds": prompt_embeds,
|
||||
"negative_prompt_embeds": negative_prompt_embeds,
|
||||
"latents": latents,
|
||||
"cross_attention_kwargs": self.cross_attention_kwargs,
|
||||
"added_cond_kwargs": added_cond_kwargs,
|
||||
"extra_step_kwargs": extra_step_kwargs,
|
||||
"callback": callback,
|
||||
"callback_steps": callback_steps,
|
||||
"callback_on_step_end": callback_on_step_end,
|
||||
"callback_on_step_end_tensor_inputs": callback_on_step_end_tensor_inputs,
|
||||
}
|
||||
num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1
|
||||
for free_init_iter in range(num_free_init_iters):
|
||||
if self.free_init_enabled:
|
||||
latents, timesteps = self._apply_free_init(
|
||||
latents, free_init_iter, num_inference_steps, device, latents.dtype, generator
|
||||
)
|
||||
|
||||
if self.free_init_enabled:
|
||||
latents = self._free_init_loop(
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
num_channels_latents=num_channels_latents,
|
||||
batch_size=batch_size,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
denoise_args=denoise_args,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
latents = self._denoise_loop(**denoise_args)
|
||||
self._num_timesteps = len(timesteps)
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
video = self._retrieve_video_frames(latents, output_type, return_dict)
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
).sample
|
||||
|
||||
# perform guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
if output_type == "latent":
|
||||
return AnimateDiffPipelineOutput(frames=latents)
|
||||
|
||||
video_tensor = self.decode_latents(latents)
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
|
||||
|
||||
# 9. Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
return video
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return AnimateDiffPipelineOutput(frames=video)
|
||||
|
||||
@@ -34,6 +34,7 @@ from ...schedulers import (
|
||||
)
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..free_init_utils import FreeInitMixin
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import AnimateDiffPipelineOutput
|
||||
|
||||
@@ -163,7 +164,9 @@ def retrieve_timesteps(
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class AnimateDiffVideoToVideoPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin):
|
||||
class AnimateDiffVideoToVideoPipeline(
|
||||
DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FreeInitMixin
|
||||
):
|
||||
r"""
|
||||
Pipeline for video-to-video generation.
|
||||
|
||||
@@ -193,7 +196,7 @@ class AnimateDiffVideoToVideoPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
|
||||
_optional_components = ["feature_extractor", "image_encoder"]
|
||||
_optional_components = ["feature_extractor", "image_encoder", "motion_adapter"]
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
@@ -215,7 +218,8 @@ class AnimateDiffVideoToVideoPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
image_encoder: CLIPVisionModelWithProjection = None,
|
||||
):
|
||||
super().__init__()
|
||||
unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
|
||||
if isinstance(unet, UNet2DConditionModel):
|
||||
unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
@@ -584,12 +588,12 @@ class AnimateDiffVideoToVideoPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
if video is not None and latents is not None:
|
||||
raise ValueError("Only one of `video` or `latents` should be provided")
|
||||
|
||||
def get_timesteps(self, num_inference_steps, strength, device):
|
||||
def get_timesteps(self, num_inference_steps, timesteps, strength, device):
|
||||
# get the original timestep using init_timestep
|
||||
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
timesteps = timesteps[t_start * self.scheduler.order :]
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
@@ -876,9 +880,8 @@ class AnimateDiffVideoToVideoPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
|
||||
# 4. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
@@ -901,42 +904,55 @@ class AnimateDiffVideoToVideoPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
# 7. Add image embeds for IP-Adapter
|
||||
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1
|
||||
for free_init_iter in range(num_free_init_iters):
|
||||
if self.free_init_enabled:
|
||||
latents, timesteps = self._apply_free_init(
|
||||
latents, free_init_iter, num_inference_steps, device, latents.dtype, generator
|
||||
)
|
||||
num_inference_steps = len(timesteps)
|
||||
# make sure to readjust timesteps based on strength
|
||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=self.cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
).sample
|
||||
self._num_timesteps = len(timesteps)
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
# 8. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# perform guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=self.cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
).sample
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
# perform guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
progress_bar.update()
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if output_type == "latent":
|
||||
return AnimateDiffPipelineOutput(frames=latents)
|
||||
|
||||
184
src/diffusers/pipelines/free_init_utils.py
Normal file
184
src/diffusers/pipelines/free_init_utils.py
Normal file
@@ -0,0 +1,184 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.fft as fft
|
||||
|
||||
from ..utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
class FreeInitMixin:
|
||||
r"""Mixin class for FreeInit."""
|
||||
|
||||
def enable_free_init(
|
||||
self,
|
||||
num_iters: int = 3,
|
||||
use_fast_sampling: bool = False,
|
||||
method: str = "butterworth",
|
||||
order: int = 4,
|
||||
spatial_stop_frequency: float = 0.25,
|
||||
temporal_stop_frequency: float = 0.25,
|
||||
):
|
||||
"""Enables the FreeInit mechanism as in https://arxiv.org/abs/2312.07537.
|
||||
|
||||
This implementation has been adapted from the [official repository](https://github.com/TianxingWu/FreeInit).
|
||||
|
||||
Args:
|
||||
num_iters (`int`, *optional*, defaults to `3`):
|
||||
Number of FreeInit noise re-initialization iterations.
|
||||
use_fast_sampling (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to speedup sampling procedure at the cost of probably lower quality results. Enables
|
||||
the "Coarse-to-Fine Sampling" strategy, as mentioned in the paper, if set to `True`.
|
||||
method (`str`, *optional*, defaults to `butterworth`):
|
||||
Must be one of `butterworth`, `ideal` or `gaussian` to use as the filtering method for the
|
||||
FreeInit low pass filter.
|
||||
order (`int`, *optional*, defaults to `4`):
|
||||
Order of the filter used in `butterworth` method. Larger values lead to `ideal` method behaviour
|
||||
whereas lower values lead to `gaussian` method behaviour.
|
||||
spatial_stop_frequency (`float`, *optional*, defaults to `0.25`):
|
||||
Normalized stop frequency for spatial dimensions. Must be between 0 to 1. Referred to as `d_s` in
|
||||
the original implementation.
|
||||
temporal_stop_frequency (`float`, *optional*, defaults to `0.25`):
|
||||
Normalized stop frequency for temporal dimensions. Must be between 0 to 1. Referred to as `d_t` in
|
||||
the original implementation.
|
||||
"""
|
||||
self._free_init_num_iters = num_iters
|
||||
self._free_init_use_fast_sampling = use_fast_sampling
|
||||
self._free_init_method = method
|
||||
self._free_init_order = order
|
||||
self._free_init_spatial_stop_frequency = spatial_stop_frequency
|
||||
self._free_init_temporal_stop_frequency = temporal_stop_frequency
|
||||
|
||||
def disable_free_init(self):
|
||||
"""Disables the FreeInit mechanism if enabled."""
|
||||
self._free_init_num_iters = None
|
||||
|
||||
@property
|
||||
def free_init_enabled(self):
|
||||
return hasattr(self, "_free_init_num_iters") and self._free_init_num_iters is not None
|
||||
|
||||
def _get_free_init_freq_filter(
|
||||
self,
|
||||
shape: Tuple[int, ...],
|
||||
device: Union[str, torch.dtype],
|
||||
filter_type: str,
|
||||
order: float,
|
||||
spatial_stop_frequency: float,
|
||||
temporal_stop_frequency: float,
|
||||
) -> torch.Tensor:
|
||||
r"""Returns the FreeInit filter based on filter type and other input conditions."""
|
||||
|
||||
time, height, width = shape[-3], shape[-2], shape[-1]
|
||||
mask = torch.zeros(shape)
|
||||
|
||||
if spatial_stop_frequency == 0 or temporal_stop_frequency == 0:
|
||||
return mask
|
||||
|
||||
if filter_type == "butterworth":
|
||||
|
||||
def retrieve_mask(x):
|
||||
return 1 / (1 + (x / spatial_stop_frequency**2) ** order)
|
||||
elif filter_type == "gaussian":
|
||||
|
||||
def retrieve_mask(x):
|
||||
return math.exp(-1 / (2 * spatial_stop_frequency**2) * x)
|
||||
elif filter_type == "ideal":
|
||||
|
||||
def retrieve_mask(x):
|
||||
return 1 if x <= spatial_stop_frequency * 2 else 0
|
||||
else:
|
||||
raise NotImplementedError("`filter_type` must be one of gaussian, butterworth or ideal")
|
||||
|
||||
for t in range(time):
|
||||
for h in range(height):
|
||||
for w in range(width):
|
||||
d_square = (
|
||||
((spatial_stop_frequency / temporal_stop_frequency) * (2 * t / time - 1)) ** 2
|
||||
+ (2 * h / height - 1) ** 2
|
||||
+ (2 * w / width - 1) ** 2
|
||||
)
|
||||
mask[..., t, h, w] = retrieve_mask(d_square)
|
||||
|
||||
return mask.to(device)
|
||||
|
||||
def _apply_freq_filter(self, x: torch.Tensor, noise: torch.Tensor, low_pass_filter: torch.Tensor) -> torch.Tensor:
|
||||
r"""Noise reinitialization."""
|
||||
# FFT
|
||||
x_freq = fft.fftn(x, dim=(-3, -2, -1))
|
||||
x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1))
|
||||
noise_freq = fft.fftn(noise, dim=(-3, -2, -1))
|
||||
noise_freq = fft.fftshift(noise_freq, dim=(-3, -2, -1))
|
||||
|
||||
# frequency mix
|
||||
high_pass_filter = 1 - low_pass_filter
|
||||
x_freq_low = x_freq * low_pass_filter
|
||||
noise_freq_high = noise_freq * high_pass_filter
|
||||
x_freq_mixed = x_freq_low + noise_freq_high # mix in freq domain
|
||||
|
||||
# IFFT
|
||||
x_freq_mixed = fft.ifftshift(x_freq_mixed, dim=(-3, -2, -1))
|
||||
x_mixed = fft.ifftn(x_freq_mixed, dim=(-3, -2, -1)).real
|
||||
|
||||
return x_mixed
|
||||
|
||||
def _apply_free_init(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
free_init_iteration: int,
|
||||
num_inference_steps: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
generator: torch.Generator,
|
||||
):
|
||||
if free_init_iteration == 0:
|
||||
self._free_init_initial_noise = latents.detach().clone()
|
||||
return latents, self.scheduler.timesteps
|
||||
|
||||
latent_shape = latents.shape
|
||||
|
||||
free_init_filter_shape = (1, *latent_shape[1:])
|
||||
free_init_freq_filter = self._get_free_init_freq_filter(
|
||||
shape=free_init_filter_shape,
|
||||
device=device,
|
||||
filter_type=self._free_init_method,
|
||||
order=self._free_init_order,
|
||||
spatial_stop_frequency=self._free_init_spatial_stop_frequency,
|
||||
temporal_stop_frequency=self._free_init_temporal_stop_frequency,
|
||||
)
|
||||
|
||||
current_diffuse_timestep = self.scheduler.config.num_train_timesteps - 1
|
||||
diffuse_timesteps = torch.full((latent_shape[0],), current_diffuse_timestep).long()
|
||||
|
||||
z_t = self.scheduler.add_noise(
|
||||
original_samples=latents, noise=self._free_init_initial_noise, timesteps=diffuse_timesteps.to(device)
|
||||
).to(dtype=torch.float32)
|
||||
|
||||
z_rand = randn_tensor(
|
||||
shape=latent_shape,
|
||||
generator=generator,
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
latents = self._apply_freq_filter(z_t, z_rand, low_pass_filter=free_init_freq_filter)
|
||||
latents = latents.to(dtype)
|
||||
|
||||
# Coarse-to-Fine Sampling for faster inference (can lead to lower quality)
|
||||
if self._free_init_use_fast_sampling:
|
||||
num_inference_steps = int(num_inference_steps / self._free_init_num_iters * (free_init_iteration + 1))
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
|
||||
return latents, self.scheduler.timesteps
|
||||
@@ -45,6 +45,7 @@ from ...utils import (
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..free_init_utils import FreeInitMixin
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
|
||||
|
||||
@@ -210,7 +211,7 @@ class PIAPipelineOutput(BaseOutput):
|
||||
|
||||
|
||||
class PIAPipeline(
|
||||
DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FromSingleFileMixin
|
||||
DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FromSingleFileMixin, FreeInitMixin
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-to-video generation.
|
||||
@@ -560,58 +561,6 @@ class PIAPipeline(
|
||||
"""Disables the FreeU mechanism if enabled."""
|
||||
self.unet.disable_freeu()
|
||||
|
||||
@property
|
||||
def free_init_enabled(self):
|
||||
return hasattr(self, "_free_init_num_iters") and self._free_init_num_iters is not None
|
||||
|
||||
def enable_free_init(
|
||||
self,
|
||||
num_iters: int = 3,
|
||||
use_fast_sampling: bool = False,
|
||||
method: str = "butterworth",
|
||||
order: int = 4,
|
||||
spatial_stop_frequency: float = 0.25,
|
||||
temporal_stop_frequency: float = 0.25,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
):
|
||||
"""Enables the FreeInit mechanism as in https://arxiv.org/abs/2312.07537.
|
||||
|
||||
This implementation has been adapted from the [official repository](https://github.com/TianxingWu/FreeInit).
|
||||
|
||||
Args:
|
||||
num_iters (`int`, *optional*, defaults to `3`):
|
||||
Number of FreeInit noise re-initialization iterations.
|
||||
use_fast_sampling (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to speedup sampling procedure at the cost of probably lower quality results. Enables
|
||||
the "Coarse-to-Fine Sampling" strategy, as mentioned in the paper, if set to `True`.
|
||||
method (`str`, *optional*, defaults to `butterworth`):
|
||||
Must be one of `butterworth`, `ideal` or `gaussian` to use as the filtering method for the
|
||||
FreeInit low pass filter.
|
||||
order (`int`, *optional*, defaults to `4`):
|
||||
Order of the filter used in `butterworth` method. Larger values lead to `ideal` method behaviour
|
||||
whereas lower values lead to `gaussian` method behaviour.
|
||||
spatial_stop_frequency (`float`, *optional*, defaults to `0.25`):
|
||||
Normalized stop frequency for spatial dimensions. Must be between 0 to 1. Referred to as `d_s` in
|
||||
the original implementation.
|
||||
temporal_stop_frequency (`float`, *optional*, defaults to `0.25`):
|
||||
Normalized stop frequency for temporal dimensions. Must be between 0 to 1. Referred to as `d_t` in
|
||||
the original implementation.
|
||||
generator (`torch.Generator`, *optional*, defaults to `0.25`):
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||
FreeInit generation deterministic.
|
||||
"""
|
||||
self._free_init_num_iters = num_iters
|
||||
self._free_init_use_fast_sampling = use_fast_sampling
|
||||
self._free_init_method = method
|
||||
self._free_init_order = order
|
||||
self._free_init_spatial_stop_frequency = spatial_stop_frequency
|
||||
self._free_init_temporal_stop_frequency = temporal_stop_frequency
|
||||
self._free_init_generator = generator
|
||||
|
||||
def disable_free_init(self):
|
||||
"""Disables the FreeInit mechanism if enabled."""
|
||||
self._free_init_num_iters = None
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
@@ -795,143 +744,6 @@ class PIAPipeline(
|
||||
|
||||
return mask, masked_image
|
||||
|
||||
def _denoise_loop(
|
||||
self,
|
||||
timesteps,
|
||||
num_inference_steps,
|
||||
do_classifier_free_guidance,
|
||||
guidance_scale,
|
||||
num_warmup_steps,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
latents,
|
||||
mask,
|
||||
masked_image,
|
||||
cross_attention_kwargs,
|
||||
added_cond_kwargs,
|
||||
extra_step_kwargs,
|
||||
callback_on_step_end,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
):
|
||||
"""Denoising loop for PIA."""
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
latent_model_input = torch.cat([latent_model_input, mask, masked_image], dim=1)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
return latents
|
||||
|
||||
def _free_init_loop(
|
||||
self,
|
||||
height,
|
||||
width,
|
||||
num_frames,
|
||||
batch_size,
|
||||
num_videos_per_prompt,
|
||||
denoise_args,
|
||||
device,
|
||||
):
|
||||
"""Denoising loop for PIA using FreeInit noise reinitialization technique."""
|
||||
|
||||
latents = denoise_args.get("latents")
|
||||
prompt_embeds = denoise_args.get("prompt_embeds")
|
||||
timesteps = denoise_args.get("timesteps")
|
||||
num_inference_steps = denoise_args.get("num_inference_steps")
|
||||
|
||||
latent_shape = (
|
||||
batch_size * num_videos_per_prompt,
|
||||
4,
|
||||
num_frames,
|
||||
height // self.vae_scale_factor,
|
||||
width // self.vae_scale_factor,
|
||||
)
|
||||
free_init_filter_shape = (
|
||||
1,
|
||||
4,
|
||||
num_frames,
|
||||
height // self.vae_scale_factor,
|
||||
width // self.vae_scale_factor,
|
||||
)
|
||||
free_init_freq_filter = _get_freeinit_freq_filter(
|
||||
shape=free_init_filter_shape,
|
||||
device=device,
|
||||
filter_type=self._free_init_method,
|
||||
order=self._free_init_order,
|
||||
spatial_stop_frequency=self._free_init_spatial_stop_frequency,
|
||||
temporal_stop_frequency=self._free_init_temporal_stop_frequency,
|
||||
)
|
||||
|
||||
with self.progress_bar(total=self._free_init_num_iters) as free_init_progress_bar:
|
||||
for i in range(self._free_init_num_iters):
|
||||
# For the first FreeInit iteration, the original latent is used without modification.
|
||||
# Subsequent iterations apply the noise reinitialization technique.
|
||||
if i == 0:
|
||||
initial_noise = latents.detach().clone()
|
||||
else:
|
||||
current_diffuse_timestep = (
|
||||
self.scheduler.config.num_train_timesteps - 1
|
||||
) # diffuse to t=999 noise level
|
||||
diffuse_timesteps = torch.full((batch_size,), current_diffuse_timestep).long()
|
||||
z_T = self.scheduler.add_noise(
|
||||
original_samples=latents, noise=initial_noise, timesteps=diffuse_timesteps.to(device)
|
||||
).to(dtype=torch.float32)
|
||||
z_rand = randn_tensor(
|
||||
shape=latent_shape,
|
||||
generator=self._free_init_generator,
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
latents = _freq_mix_3d(z_T, z_rand, LPF=free_init_freq_filter)
|
||||
latents = latents.to(prompt_embeds.dtype)
|
||||
|
||||
# Coarse-to-Fine Sampling for faster inference (can lead to lower quality)
|
||||
if self._free_init_use_fast_sampling:
|
||||
current_num_inference_steps = int(num_inference_steps / self._free_init_num_iters * (i + 1))
|
||||
self.scheduler.set_timesteps(current_num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
denoise_args.update({"timesteps": timesteps, "num_inference_steps": current_num_inference_steps})
|
||||
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
denoise_args.update({"latents": latents, "num_warmup_steps": num_warmup_steps})
|
||||
latents = self._denoise_loop(**denoise_args)
|
||||
|
||||
free_init_progress_bar.update()
|
||||
|
||||
return latents
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
|
||||
def get_timesteps(self, num_inference_steps, strength, device):
|
||||
# get the original timestep using init_timestep
|
||||
@@ -944,19 +756,6 @@ class PIAPipeline(
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
def _retrieve_video_frames(self, latents, output_type, return_dict):
|
||||
"""Helper function to handle latents to output conversion."""
|
||||
if output_type == "latent":
|
||||
return PIAPipelineOutput(frames=latents)
|
||||
|
||||
video_tensor = self.decode_latents(latents)
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return PIAPipelineOutput(frames=video)
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
@@ -1191,41 +990,62 @@ class PIAPipeline(
|
||||
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
denoise_args = {
|
||||
"timesteps": timesteps,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
"do_classifier_free_guidance": self.do_classifier_free_guidance,
|
||||
"guidance_scale": guidance_scale,
|
||||
"num_warmup_steps": num_warmup_steps,
|
||||
"prompt_embeds": prompt_embeds,
|
||||
"negative_prompt_embeds": negative_prompt_embeds,
|
||||
"latents": latents,
|
||||
"mask": mask,
|
||||
"masked_image": masked_image,
|
||||
"cross_attention_kwargs": self.cross_attention_kwargs,
|
||||
"added_cond_kwargs": added_cond_kwargs,
|
||||
"extra_step_kwargs": extra_step_kwargs,
|
||||
"callback_on_step_end": callback_on_step_end,
|
||||
"callback_on_step_end_tensor_inputs": callback_on_step_end_tensor_inputs,
|
||||
}
|
||||
num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1
|
||||
for free_init_iter in range(num_free_init_iters):
|
||||
if self.free_init_enabled:
|
||||
latents, timesteps = self._apply_free_init(
|
||||
latents, free_init_iter, num_inference_steps, device, latents.dtype, generator
|
||||
)
|
||||
|
||||
if self.free_init_enabled:
|
||||
latents = self._free_init_loop(
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
batch_size=batch_size,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
denoise_args=denoise_args,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
latents = self._denoise_loop(**denoise_args)
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
latent_model_input = torch.cat([latent_model_input, mask, masked_image], dim=1)
|
||||
|
||||
video = self._retrieve_video_frames(latents, output_type, return_dict)
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
).sample
|
||||
|
||||
# perform guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if output_type == "latent":
|
||||
return PIAPipelineOutput(frames=latents)
|
||||
|
||||
video_tensor = self.decode_latents(latents)
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
|
||||
|
||||
# 9. Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
return video
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return PIAPipelineOutput(frames=video)
|
||||
|
||||
@@ -242,7 +242,6 @@ class AnimateDiffPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
inputs_normal = self.get_dummy_inputs(torch_device)
|
||||
frames_normal = pipe(**inputs_normal).frames[0]
|
||||
|
||||
free_init_generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
pipe.enable_free_init(
|
||||
num_iters=2,
|
||||
use_fast_sampling=True,
|
||||
@@ -250,7 +249,6 @@ class AnimateDiffPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
order=4,
|
||||
spatial_stop_frequency=0.25,
|
||||
temporal_stop_frequency=0.25,
|
||||
generator=free_init_generator,
|
||||
)
|
||||
inputs_enable_free_init = self.get_dummy_inputs(torch_device)
|
||||
frames_enable_free_init = pipe(**inputs_enable_free_init).frames[0]
|
||||
|
||||
@@ -267,3 +267,38 @@ class AnimateDiffVideoToVideoPipelineFastTests(PipelineTesterMixin, unittest.Tes
|
||||
|
||||
max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
|
||||
self.assertLess(max_diff, 1e-4, "XFormers attention should not affect the inference results")
|
||||
|
||||
def test_free_init(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.to(torch_device)
|
||||
|
||||
inputs_normal = self.get_dummy_inputs(torch_device)
|
||||
frames_normal = pipe(**inputs_normal).frames[0]
|
||||
|
||||
pipe.enable_free_init(
|
||||
num_iters=2,
|
||||
use_fast_sampling=True,
|
||||
method="butterworth",
|
||||
order=4,
|
||||
spatial_stop_frequency=0.25,
|
||||
temporal_stop_frequency=0.25,
|
||||
)
|
||||
inputs_enable_free_init = self.get_dummy_inputs(torch_device)
|
||||
frames_enable_free_init = pipe(**inputs_enable_free_init).frames[0]
|
||||
|
||||
pipe.disable_free_init()
|
||||
inputs_disable_free_init = self.get_dummy_inputs(torch_device)
|
||||
frames_disable_free_init = pipe(**inputs_disable_free_init).frames[0]
|
||||
|
||||
sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_init)).sum()
|
||||
max_diff_disabled = np.abs(to_np(frames_normal) - to_np(frames_disable_free_init)).max()
|
||||
self.assertGreater(
|
||||
sum_enabled, 1e1, "Enabling of FreeInit should lead to results different from the default pipeline results"
|
||||
)
|
||||
self.assertLess(
|
||||
max_diff_disabled,
|
||||
1e-4,
|
||||
"Disabling of FreeInit should lead to results similar to the default pipeline results",
|
||||
)
|
||||
|
||||
@@ -255,7 +255,6 @@ class PIAPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
inputs_normal = self.get_dummy_inputs(torch_device)
|
||||
frames_normal = pipe(**inputs_normal).frames[0]
|
||||
|
||||
free_init_generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
pipe.enable_free_init(
|
||||
num_iters=2,
|
||||
use_fast_sampling=True,
|
||||
@@ -263,7 +262,6 @@ class PIAPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
order=4,
|
||||
spatial_stop_frequency=0.25,
|
||||
temporal_stop_frequency=0.25,
|
||||
generator=free_init_generator,
|
||||
)
|
||||
inputs_enable_free_init = self.get_dummy_inputs(torch_device)
|
||||
frames_enable_free_init = pipe(**inputs_enable_free_init).frames[0]
|
||||
|
||||
Reference in New Issue
Block a user