1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
* freeinit

* update freeinit implementation based on review

Co-Authored-By: Dhruv Nair <dhruv.nair@gmail.com>

* fix

* another fix

* refactor

* fix timesteps missing bug

* apply suggestions from review

Co-Authored-By: Dhruv Nair <dhruv.nair@gmail.com>

* add test for freeinit

* apply suggestions from review

Co-Authored-By: Dhruv Nair <dhruv.nair@gmail.com>

* refactor

* fix test

* fix tensor not on same device

* update

* remove return_intermediate_results

* fix broken freeinit test

* update animatediff docs

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
Aryan V S
2024-01-17 17:17:07 +05:30
committed by GitHub
parent dce06680d2
commit 9112028ed8
3 changed files with 484 additions and 66 deletions

View File

@@ -235,6 +235,62 @@ export_to_gif(frames, "animation.gif")
</tr>
</table>
## Using FreeInit
[FreeInit: Bridging Initialization Gap in Video Diffusion Models](https://arxiv.org/abs/2312.07537) by Tianxing Wu, Chenyang Si, Yuming Jiang, Ziqi Huang, Ziwei Liu.
FreeInit is an effective method that improves temporal consistency and overall quality of videos generated using video-diffusion-models without any addition training. It can be applied to AnimateDiff, ModelScope, VideoCrafter and various other video generation models seamlessly at inference time, and works by iteratively refining the latent-initialization noise. More details can be found it the paper.
The following example demonstrates the usage of FreeInit.
```python
import torch
from diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler
from diffusers.utils import export_to_gif
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2")
model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
pipe = AnimateDiffPipeline.from_pretrained(model_id, motion_adapter=adapter, torch_dtype=torch.float16).to("cuda")
pipe.scheduler = DDIMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
beta_schedule="linear",
clip_sample=False,
timestep_spacing="linspace",
steps_offset=1
)
# enable memory savings
pipe.enable_vae_slicing()
pipe.enable_vae_tiling()
# enable FreeInit
# Refer to the enable_free_init documentation for a full list of configurable parameters
pipe.enable_free_init(method="butterworth", use_fast_sampling=True)
# run inference
output = pipe(
prompt="a panda playing a guitar, on a boat, in the ocean, high quality",
negative_prompt="bad quality, worse quality",
num_frames=16,
guidance_scale=7.5,
num_inference_steps=20,
generator=torch.Generator("cpu").manual_seed(666),
)
# disable FreeInit
pipe.disable_free_init()
frames = output.frames[0]
export_to_gif(frames, "animation.gif")
```
<Tip warning={true}>
FreeInit is not really free - the improved quality comes at the cost of extra computation. It requires sampling a few extra times depending on the `num_iters` parameter that is set when enabling it. Setting the `use_fast_sampling` parameter to `True` can improve the overall performance (at the cost of lower quality compared to when `use_fast_sampling=False` but still better results than vanilla video generation models).
</Tip>
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
@@ -248,6 +304,8 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
- __call__
- enable_freeu
- disable_freeu
- enable_free_init
- disable_free_init
- enable_vae_slicing
- disable_vae_slicing
- enable_vae_tiling

View File

@@ -13,11 +13,13 @@
# limitations under the License.
import inspect
import math
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.fft as fft
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from ...image_processor import PipelineImageInput, VaeImageProcessor
@@ -36,6 +38,7 @@ from ...schedulers import (
from ...utils import (
USE_PEFT_BACKEND,
BaseOutput,
deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -79,6 +82,71 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"):
return outputs
def _get_freeinit_freq_filter(
shape: Tuple[int, ...],
device: Union[str, torch.dtype],
filter_type: str,
order: float,
spatial_stop_frequency: float,
temporal_stop_frequency: float,
) -> torch.Tensor:
r"""Returns the FreeInit filter based on filter type and other input conditions."""
T, H, W = shape[-3], shape[-2], shape[-1]
mask = torch.zeros(shape)
if spatial_stop_frequency == 0 or temporal_stop_frequency == 0:
return mask
if filter_type == "butterworth":
def retrieve_mask(x):
return 1 / (1 + (x / spatial_stop_frequency**2) ** order)
elif filter_type == "gaussian":
def retrieve_mask(x):
return math.exp(-1 / (2 * spatial_stop_frequency**2) * x)
elif filter_type == "ideal":
def retrieve_mask(x):
return 1 if x <= spatial_stop_frequency * 2 else 0
else:
raise NotImplementedError("`filter_type` must be one of gaussian, butterworth or ideal")
for t in range(T):
for h in range(H):
for w in range(W):
d_square = (
((spatial_stop_frequency / temporal_stop_frequency) * (2 * t / T - 1)) ** 2
+ (2 * h / H - 1) ** 2
+ (2 * w / W - 1) ** 2
)
mask[..., t, h, w] = retrieve_mask(d_square)
return mask.to(device)
def _freq_mix_3d(x: torch.Tensor, noise: torch.Tensor, LPF: torch.Tensor) -> torch.Tensor:
r"""Noise reinitialization."""
# FFT
x_freq = fft.fftn(x, dim=(-3, -2, -1))
x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1))
noise_freq = fft.fftn(noise, dim=(-3, -2, -1))
noise_freq = fft.fftshift(noise_freq, dim=(-3, -2, -1))
# frequency mix
HPF = 1 - LPF
x_freq_low = x_freq * LPF
noise_freq_high = noise_freq * HPF
x_freq_mixed = x_freq_low + noise_freq_high # mix in freq domain
# IFFT
x_freq_mixed = fft.ifftshift(x_freq_mixed, dim=(-3, -2, -1))
x_mixed = fft.ifftn(x_freq_mixed, dim=(-3, -2, -1)).real
return x_mixed
@dataclass
class AnimateDiffPipelineOutput(BaseOutput):
frames: Union[torch.Tensor, np.ndarray]
@@ -115,6 +183,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
_optional_components = ["feature_extractor", "image_encoder"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
self,
@@ -442,6 +511,58 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
@property
def free_init_enabled(self):
return hasattr(self, "_free_init_num_iters") and self._free_init_num_iters is not None
def enable_free_init(
self,
num_iters: int = 3,
use_fast_sampling: bool = False,
method: str = "butterworth",
order: int = 4,
spatial_stop_frequency: float = 0.25,
temporal_stop_frequency: float = 0.25,
generator: torch.Generator = None,
):
"""Enables the FreeInit mechanism as in https://arxiv.org/abs/2312.07537.
This implementation has been adapted from the [official repository](https://github.com/TianxingWu/FreeInit).
Args:
num_iters (`int`, *optional*, defaults to `3`):
Number of FreeInit noise re-initialization iterations.
use_fast_sampling (`bool`, *optional*, defaults to `False`):
Whether or not to speedup sampling procedure at the cost of probably lower quality results. Enables
the "Coarse-to-Fine Sampling" strategy, as mentioned in the paper, if set to `True`.
method (`str`, *optional*, defaults to `butterworth`):
Must be one of `butterworth`, `ideal` or `gaussian` to use as the filtering method for the
FreeInit low pass filter.
order (`int`, *optional*, defaults to `4`):
Order of the filter used in `butterworth` method. Larger values lead to `ideal` method behaviour
whereas lower values lead to `gaussian` method behaviour.
spatial_stop_frequency (`float`, *optional*, defaults to `0.25`):
Normalized stop frequency for spatial dimensions. Must be between 0 to 1. Referred to as `d_s` in
the original implementation.
temporal_stop_frequency (`float`, *optional*, defaults to `0.25`):
Normalized stop frequency for temporal dimensions. Must be between 0 to 1. Referred to as `d_t` in
the original implementation.
generator (`torch.Generator`, *optional*, defaults to `0.25`):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
FreeInit generation deterministic.
"""
self._free_init_num_iters = num_iters
self._free_init_use_fast_sampling = use_fast_sampling
self._free_init_method = method
self._free_init_order = order
self._free_init_spatial_stop_frequency = spatial_stop_frequency
self._free_init_temporal_stop_frequency = temporal_stop_frequency
self._free_init_generator = generator
def disable_free_init(self):
"""Disables the FreeInit mechanism if enabled."""
self._free_init_num_iters = None
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
@@ -539,6 +660,185 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
latents = latents * self.scheduler.init_noise_sigma
return latents
def _denoise_loop(
self,
timesteps,
num_inference_steps,
do_classifier_free_guidance,
guidance_scale,
num_warmup_steps,
prompt_embeds,
negative_prompt_embeds,
latents,
cross_attention_kwargs,
added_cond_kwargs,
extra_step_kwargs,
callback,
callback_steps,
callback_on_step_end,
callback_on_step_end_tensor_inputs,
):
"""Denoising loop for AnimateDiff."""
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
return latents
def _free_init_loop(
self,
height,
width,
num_frames,
num_channels_latents,
batch_size,
num_videos_per_prompt,
denoise_args,
device,
):
"""Denoising loop for AnimateDiff using FreeInit noise reinitialization technique."""
latents = denoise_args.get("latents")
prompt_embeds = denoise_args.get("prompt_embeds")
timesteps = denoise_args.get("timesteps")
num_inference_steps = denoise_args.get("num_inference_steps")
latent_shape = (
batch_size * num_videos_per_prompt,
num_channels_latents,
num_frames,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
free_init_filter_shape = (
1,
num_channels_latents,
num_frames,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
free_init_freq_filter = _get_freeinit_freq_filter(
shape=free_init_filter_shape,
device=device,
filter_type=self._free_init_method,
order=self._free_init_order,
spatial_stop_frequency=self._free_init_spatial_stop_frequency,
temporal_stop_frequency=self._free_init_temporal_stop_frequency,
)
with self.progress_bar(total=self._free_init_num_iters) as free_init_progress_bar:
for i in range(self._free_init_num_iters):
# For the first FreeInit iteration, the original latent is used without modification.
# Subsequent iterations apply the noise reinitialization technique.
if i == 0:
initial_noise = latents.detach().clone()
else:
current_diffuse_timestep = (
self.scheduler.config.num_train_timesteps - 1
) # diffuse to t=999 noise level
diffuse_timesteps = torch.full((batch_size,), current_diffuse_timestep).long()
z_T = self.scheduler.add_noise(
original_samples=latents, noise=initial_noise, timesteps=diffuse_timesteps.to(device)
).to(dtype=torch.float32)
z_rand = randn_tensor(
shape=latent_shape,
generator=self._free_init_generator,
device=device,
dtype=torch.float32,
)
latents = _freq_mix_3d(z_T, z_rand, LPF=free_init_freq_filter)
latents = latents.to(prompt_embeds.dtype)
# Coarse-to-Fine Sampling for faster inference (can lead to lower quality)
if self._free_init_use_fast_sampling:
current_num_inference_steps = int(num_inference_steps / self._free_init_num_iters * (i + 1))
self.scheduler.set_timesteps(current_num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
denoise_args.update({"timesteps": timesteps, "num_inference_steps": current_num_inference_steps})
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
denoise_args.update({"latents": latents, "num_warmup_steps": num_warmup_steps})
latents = self._denoise_loop(**denoise_args)
free_init_progress_bar.update()
return latents
def _retrieve_video_frames(self, latents, output_type, return_dict):
"""Helper function to handle latents to output conversion."""
if output_type == "latent":
return AnimateDiffPipelineOutput(frames=latents)
video_tensor = self.decode_latents(latents)
if output_type == "pt":
video = video_tensor
else:
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
if not return_dict:
return (video,)
return AnimateDiffPipelineOutput(frames=video)
@property
def guidance_scale(self):
return self._guidance_scale
@property
def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def cross_attention_kwargs(self):
return self._cross_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -559,10 +859,11 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
):
r"""
The call function to the pipeline for generation.
@@ -603,25 +904,30 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
ip_adapter_image: (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or
`np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead
of a plain tuple.
callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at
every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeine class.
Examples:
Returns:
@@ -629,6 +935,23 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
"""
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
if callback is not None:
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
if callback_steps is not None:
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
@@ -637,9 +960,20 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
prompt,
height,
width,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
callback_on_step_end_tensor_inputs,
)
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -649,30 +983,26 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
)
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
num_videos_per_prompt,
do_classifier_free_guidance,
self.do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
clip_skip=clip_skip,
clip_skip=self.clip_skip,
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if do_classifier_free_guidance:
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None:
@@ -680,12 +1010,13 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
image_embeds, negative_image_embeds = self.encode_image(
ip_adapter_image, device, num_videos_per_prompt, output_hidden_state
)
if do_classifier_free_guidance:
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
self._num_timesteps = len(timesteps)
# 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
@@ -703,55 +1034,47 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7 Add image embeds for IP-Adapter
# 7. Add image embeds for IP-Adapter
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
# Denoising loop
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
denoise_args = {
"timesteps": timesteps,
"num_inference_steps": num_inference_steps,
"do_classifier_free_guidance": self.do_classifier_free_guidance,
"guidance_scale": guidance_scale,
"num_warmup_steps": num_warmup_steps,
"prompt_embeds": prompt_embeds,
"negative_prompt_embeds": negative_prompt_embeds,
"latents": latents,
"cross_attention_kwargs": self.cross_attention_kwargs,
"added_cond_kwargs": added_cond_kwargs,
"extra_step_kwargs": extra_step_kwargs,
"callback": callback,
"callback_steps": callback_steps,
"callback_on_step_end": callback_on_step_end,
"callback_on_step_end_tensor_inputs": callback_on_step_end_tensor_inputs,
}
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
if output_type == "latent":
return AnimateDiffPipelineOutput(frames=latents)
# Post-processing
video_tensor = self.decode_latents(latents)
if output_type == "pt":
video = video_tensor
if self.free_init_enabled:
latents = self._free_init_loop(
height=height,
width=width,
num_frames=num_frames,
num_channels_latents=num_channels_latents,
batch_size=batch_size,
num_videos_per_prompt=num_videos_per_prompt,
denoise_args=denoise_args,
device=device,
)
else:
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
latents = self._denoise_loop(**denoise_args)
# Offload all models
video = self._retrieve_video_frames(latents, output_type, return_dict)
# 9. Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return AnimateDiffPipelineOutput(frames=video)
return video

View File

@@ -38,8 +38,8 @@ class AnimateDiffPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"generator",
"latents",
"return_dict",
"callback",
"callback_steps",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
]
)
@@ -233,6 +233,43 @@ class AnimateDiffPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
inputs["prompt_embeds"] = torch.randn((1, 4, 32), device=torch_device)
pipe(**inputs)
def test_free_init(self):
components = self.get_dummy_components()
pipe: AnimateDiffPipeline = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
pipe.to(torch_device)
inputs_normal = self.get_dummy_inputs(torch_device)
frames_normal = pipe(**inputs_normal).frames[0]
free_init_generator = torch.Generator(device=torch_device).manual_seed(0)
pipe.enable_free_init(
num_iters=2,
use_fast_sampling=True,
method="butterworth",
order=4,
spatial_stop_frequency=0.25,
temporal_stop_frequency=0.25,
generator=free_init_generator,
)
inputs_enable_free_init = self.get_dummy_inputs(torch_device)
frames_enable_free_init = pipe(**inputs_enable_free_init).frames[0]
pipe.disable_free_init()
inputs_disable_free_init = self.get_dummy_inputs(torch_device)
frames_disable_free_init = pipe(**inputs_disable_free_init).frames[0]
sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_init)).sum()
max_diff_disabled = np.abs(to_np(frames_normal) - to_np(frames_disable_free_init)).max()
self.assertGreater(
sum_enabled, 1e2, "Enabling of FreeInit should lead to results different from the default pipeline results"
)
self.assertLess(
max_diff_disabled,
1e-4,
"Disabling of FreeInit should lead to results similar to the default pipeline results",
)
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",