mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
✨ [Core] Add FreeU mechanism (#5164)
* ✨ Added Fourier filter function to upsample blocks * 🔧 Update Fourier_filter for float16 support * ✨ Added UNetFreeUConfig to UNet model for FreeU adaptation 🛠️ * move unet to its original form and add fourier_filter to torch_utils. * implement freeU enable mechanism * implement disable mechanism * resolution index. * correct resolution idx condition. * fix copies. * no need to use resolution_idx in vae. * spell out the kwargs * proper config property * fix attribution setting * place unet hasattr properly. * fix: attribute access. * proper disable * remove validation method. * debug * debug * debug * debug * debug * debug * potential fix. * add: doc. * fix copies * add: tests. * add: support freeU in SDXL. * set default value of resolution idx. * set default values for resolution_idx. * fix copies * fix rest. * fix copies * address PR comments. * run fix-copies * move apply_free_u to utils and other minors. * introduce support for video (unet3D) * minor ups * consistent fix-copies. * consistent stuff * fix-copies * add: rest * add: docs. * fix: tests * fix: doc path * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * style up * move to techniques. * add: slow test for sd freeu. * add: slow test for sd freeu. * add: slow test for sd freeu. * add: slow test for sd freeu. * add: slow test for sd freeu. * add: slow test for sd freeu. * add: slow test for video with freeu * add: slow test for video with freeu * add: slow test for video with freeu * style --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
This commit is contained in:
@@ -58,6 +58,8 @@
|
||||
title: Control image brightness
|
||||
- local: using-diffusers/weighted_prompts
|
||||
title: Prompt weighting
|
||||
- local: using-diffusers/freeu
|
||||
title: Improve generation quality with FreeU
|
||||
title: Techniques
|
||||
- sections:
|
||||
- local: using-diffusers/pipeline_overview
|
||||
|
||||
123
docs/source/en/using-diffusers/freeu.md
Normal file
123
docs/source/en/using-diffusers/freeu.md
Normal file
@@ -0,0 +1,123 @@
|
||||
# Improve generation quality with FreeU
|
||||
|
||||
[[open-in-colab]]
|
||||
|
||||
The UNet is responsible for denoising during the reverse diffusion process, and there are two distinct features in its architecture:
|
||||
|
||||
1. Backbone features primarily contribute to the denoising process
|
||||
2. Skip features mainly introduce high-frequency features into the decoder module and can make the network overlook the semantics in the backbone features
|
||||
|
||||
However, the skip connection can sometimes introduce unnatural image details. [FreeU](https://hf.co/papers/2309.11497) is a technique for improving image quality by rebalancing the contributions from the UNet’s skip connections and backbone feature maps.
|
||||
|
||||
FreeU is applied during inference and it does not require any additional training. The technique works for different tasks such as text-to-image, image-to-image, and text-to-video.
|
||||
|
||||
In this guide, you will apply FreeU to the [`StableDiffusionPipeline`], [`StableDiffusionXLPipeline`], and [`TextToVideoSDPipeline`].
|
||||
|
||||
## StableDiffusionPipeline
|
||||
|
||||
Load the pipeline:
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, safety_checker=None
|
||||
).to("cuda")
|
||||
```
|
||||
|
||||
Then enable the FreeU mechanism with the FreeU-specific hyperparameters. These values are scaling factors for the backbone and skip features.
|
||||
|
||||
```py
|
||||
pipeline.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
|
||||
```
|
||||
|
||||
The values above are from the official FreeU [code repository](https://github.com/ChenyangSi/FreeU) where you can also find [reference hyperparameters](https://github.com/ChenyangSi/FreeU#range-for-more-parameters) for different models.
|
||||
|
||||
<Tip>
|
||||
|
||||
Disable the FreeU mechanism by calling `disable_freeu()` on a pipeline.
|
||||
|
||||
</Tip>
|
||||
|
||||
And then run inference:
|
||||
|
||||
```py
|
||||
prompt = "A squirrel eating a burger"
|
||||
seed = 2023
|
||||
image = pipeline(prompt, generator=torch.manual_seed(seed)).images[0]
|
||||
```
|
||||
|
||||
The figure below compares non-FreeU and FreeU results respectively for the same hyperparameters used above (`prompt` and `seed`):
|
||||
|
||||

|
||||
|
||||
|
||||
Let's see how Stable Diffusion 2 results are impacted:
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16, safety_checker=None
|
||||
).to("cuda")
|
||||
|
||||
prompt = "A squirrel eating a burger"
|
||||
seed = 2023
|
||||
|
||||
pipeline.enable_freeu(s1=0.9, s2=0.2, b1=1.1, b2=1.2)
|
||||
image = pipeline(prompt, generator=torch.manual_seed(seed)).images[0]
|
||||
```
|
||||
|
||||
|
||||

|
||||
|
||||
## Stable Diffusion XL
|
||||
|
||||
Finally, let's take a look at how FreeU affects Stable Diffusion XL results:
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16,
|
||||
).to("cuda")
|
||||
|
||||
prompt = "A squirrel eating a burger"
|
||||
seed = 2023
|
||||
|
||||
# Comes from
|
||||
# https://wandb.ai/nasirk24/UNET-FreeU-SDXL/reports/FreeU-SDXL-Optimal-Parameters--Vmlldzo1NDg4NTUw
|
||||
pipeline.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2)
|
||||
image = pipeline(prompt, generator=torch.manual_seed(seed)).images[0]
|
||||
```
|
||||
|
||||
|
||||

|
||||
|
||||
## Text-to-video generation
|
||||
|
||||
FreeU can also be used to improve video quality:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
import torch
|
||||
|
||||
model_id = "cerspense/zeroscope_v2_576w"
|
||||
pipe = DiffusionPipeline.from_pretrained("cerspense/zeroscope_v2_576w", torch_dtype=torch.float16).to("cuda")
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
prompt = "an astronaut riding a horse on mars"
|
||||
seed = 2023
|
||||
|
||||
# The values come from
|
||||
# https://github.com/lyn-rgb/FreeU_Diffusers#video-pipelines
|
||||
pipe.enable_freeu(b1=1.2, b2=1.4, s1=0.9, s2=0.2)
|
||||
video_frames = pipe(prompt, height=320, width=576, num_frames=30, generator=torch.manual_seed(seed)).frames
|
||||
export_to_video(video_frames, "astronaut_rides_horse.mp4")
|
||||
```
|
||||
|
||||
Thanks to [kadirnar](https://github.com/kadirnar/) for helping to integrate the feature, and to [justindujardin](https://github.com/justindujardin) for the helpful discussions.
|
||||
@@ -19,6 +19,7 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..utils import is_torch_version, logging
|
||||
from ..utils.torch_utils import apply_freeu
|
||||
from .activations import get_activation
|
||||
from .attention import AdaGroupNorm
|
||||
from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
|
||||
@@ -249,6 +250,7 @@ def get_up_block(
|
||||
add_upsample,
|
||||
resnet_eps,
|
||||
resnet_act_fn,
|
||||
resolution_idx=None,
|
||||
transformer_layers_per_block=1,
|
||||
num_attention_heads=None,
|
||||
resnet_groups=None,
|
||||
@@ -281,6 +283,7 @@ def get_up_block(
|
||||
out_channels=out_channels,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=temb_channels,
|
||||
resolution_idx=resolution_idx,
|
||||
dropout=dropout,
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=resnet_eps,
|
||||
@@ -295,6 +298,7 @@ def get_up_block(
|
||||
out_channels=out_channels,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=temb_channels,
|
||||
resolution_idx=resolution_idx,
|
||||
dropout=dropout,
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=resnet_eps,
|
||||
@@ -314,6 +318,7 @@ def get_up_block(
|
||||
out_channels=out_channels,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=temb_channels,
|
||||
resolution_idx=resolution_idx,
|
||||
dropout=dropout,
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=resnet_eps,
|
||||
@@ -337,6 +342,7 @@ def get_up_block(
|
||||
out_channels=out_channels,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=temb_channels,
|
||||
resolution_idx=resolution_idx,
|
||||
dropout=dropout,
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=resnet_eps,
|
||||
@@ -362,6 +368,7 @@ def get_up_block(
|
||||
out_channels=out_channels,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=temb_channels,
|
||||
resolution_idx=resolution_idx,
|
||||
dropout=dropout,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
@@ -377,6 +384,7 @@ def get_up_block(
|
||||
out_channels=out_channels,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=temb_channels,
|
||||
resolution_idx=resolution_idx,
|
||||
dropout=dropout,
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=resnet_eps,
|
||||
@@ -390,6 +398,7 @@ def get_up_block(
|
||||
out_channels=out_channels,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=temb_channels,
|
||||
resolution_idx=resolution_idx,
|
||||
dropout=dropout,
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=resnet_eps,
|
||||
@@ -402,6 +411,7 @@ def get_up_block(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
resolution_idx=resolution_idx,
|
||||
dropout=dropout,
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=resnet_eps,
|
||||
@@ -415,6 +425,7 @@ def get_up_block(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
resolution_idx=resolution_idx,
|
||||
dropout=dropout,
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=resnet_eps,
|
||||
@@ -430,6 +441,7 @@ def get_up_block(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
resolution_idx=resolution_idx,
|
||||
dropout=dropout,
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=resnet_eps,
|
||||
@@ -441,6 +453,7 @@ def get_up_block(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
resolution_idx=resolution_idx,
|
||||
dropout=dropout,
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=resnet_eps,
|
||||
@@ -1993,6 +2006,7 @@ class AttnUpBlock2D(nn.Module):
|
||||
prev_output_channel: int,
|
||||
out_channels: int,
|
||||
temb_channels: int,
|
||||
resolution_idx: int = None,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
@@ -2075,6 +2089,8 @@ class AttnUpBlock2D(nn.Module):
|
||||
else:
|
||||
self.upsamplers = None
|
||||
|
||||
self.resolution_idx = resolution_idx
|
||||
|
||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0):
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
# pop res hidden states
|
||||
@@ -2103,6 +2119,7 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
out_channels: int,
|
||||
prev_output_channel: int,
|
||||
temb_channels: int,
|
||||
resolution_idx: int = None,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
transformer_layers_per_block: int = 1,
|
||||
@@ -2181,6 +2198,7 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
self.upsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
self.resolution_idx = resolution_idx
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -2194,11 +2212,30 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
||||
is_freeu_enabled = (
|
||||
getattr(self, "s1", None)
|
||||
and getattr(self, "s2", None)
|
||||
and getattr(self, "b1", None)
|
||||
and getattr(self, "b2", None)
|
||||
)
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
|
||||
# FreeU: Only operate on the first two stages
|
||||
if is_freeu_enabled:
|
||||
hidden_states, res_hidden_states = apply_freeu(
|
||||
self.resolution_idx,
|
||||
hidden_states,
|
||||
res_hidden_states,
|
||||
s1=self.s1,
|
||||
s2=self.s2,
|
||||
b1=self.b1,
|
||||
b2=self.b2,
|
||||
)
|
||||
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
@@ -2252,6 +2289,7 @@ class UpBlock2D(nn.Module):
|
||||
prev_output_channel: int,
|
||||
out_channels: int,
|
||||
temb_channels: int,
|
||||
resolution_idx: int = None,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
@@ -2292,12 +2330,33 @@ class UpBlock2D(nn.Module):
|
||||
self.upsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
self.resolution_idx = resolution_idx
|
||||
|
||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0):
|
||||
is_freeu_enabled = (
|
||||
getattr(self, "s1", None)
|
||||
and getattr(self, "s2", None)
|
||||
and getattr(self, "b1", None)
|
||||
and getattr(self, "b2", None)
|
||||
)
|
||||
|
||||
for resnet in self.resnets:
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
|
||||
# FreeU: Only operate on the first two stages
|
||||
if is_freeu_enabled:
|
||||
hidden_states, res_hidden_states = apply_freeu(
|
||||
self.resolution_idx,
|
||||
hidden_states,
|
||||
res_hidden_states,
|
||||
s1=self.s1,
|
||||
s2=self.s2,
|
||||
b1=self.b1,
|
||||
b2=self.b2,
|
||||
)
|
||||
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
@@ -2331,6 +2390,7 @@ class UpDecoderBlock2D(nn.Module):
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
resolution_idx: int = None,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
@@ -2370,6 +2430,8 @@ class UpDecoderBlock2D(nn.Module):
|
||||
else:
|
||||
self.upsamplers = None
|
||||
|
||||
self.resolution_idx = resolution_idx
|
||||
|
||||
def forward(self, hidden_states, temb=None, scale: float = 1.0):
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states, temb=temb, scale=scale)
|
||||
@@ -2386,6 +2448,7 @@ class AttnUpDecoderBlock2D(nn.Module):
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
resolution_idx: int = None,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
@@ -2449,6 +2512,8 @@ class AttnUpDecoderBlock2D(nn.Module):
|
||||
else:
|
||||
self.upsamplers = None
|
||||
|
||||
self.resolution_idx = resolution_idx
|
||||
|
||||
def forward(self, hidden_states, temb=None, scale: float = 1.0):
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
hidden_states = resnet(hidden_states, temb=temb, scale=scale)
|
||||
@@ -2469,6 +2534,7 @@ class AttnSkipUpBlock2D(nn.Module):
|
||||
prev_output_channel: int,
|
||||
out_channels: int,
|
||||
temb_channels: int,
|
||||
resolution_idx: int = None,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
@@ -2553,6 +2619,8 @@ class AttnSkipUpBlock2D(nn.Module):
|
||||
self.skip_norm = None
|
||||
self.act = None
|
||||
|
||||
self.resolution_idx = resolution_idx
|
||||
|
||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None, scale: float = 1.0):
|
||||
for resnet in self.resnets:
|
||||
# pop res hidden states
|
||||
@@ -2589,6 +2657,7 @@ class SkipUpBlock2D(nn.Module):
|
||||
prev_output_channel: int,
|
||||
out_channels: int,
|
||||
temb_channels: int,
|
||||
resolution_idx: int = None,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
@@ -2651,6 +2720,8 @@ class SkipUpBlock2D(nn.Module):
|
||||
self.skip_norm = None
|
||||
self.act = None
|
||||
|
||||
self.resolution_idx = resolution_idx
|
||||
|
||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None, scale: float = 1.0):
|
||||
for resnet in self.resnets:
|
||||
# pop res hidden states
|
||||
@@ -2684,6 +2755,7 @@ class ResnetUpsampleBlock2D(nn.Module):
|
||||
prev_output_channel: int,
|
||||
out_channels: int,
|
||||
temb_channels: int,
|
||||
resolution_idx: int = None,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
@@ -2743,6 +2815,7 @@ class ResnetUpsampleBlock2D(nn.Module):
|
||||
self.upsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
self.resolution_idx = resolution_idx
|
||||
|
||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0):
|
||||
for resnet in self.resnets:
|
||||
@@ -2784,6 +2857,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
|
||||
out_channels: int,
|
||||
prev_output_channel: int,
|
||||
temb_channels: int,
|
||||
resolution_idx: int = None,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
@@ -2873,6 +2947,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
|
||||
self.upsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
self.resolution_idx = resolution_idx
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -2947,6 +3022,7 @@ class KUpBlock2D(nn.Module):
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
temb_channels: int,
|
||||
resolution_idx: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 5,
|
||||
resnet_eps: float = 1e-5,
|
||||
@@ -2988,6 +3064,7 @@ class KUpBlock2D(nn.Module):
|
||||
self.upsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
self.resolution_idx = resolution_idx
|
||||
|
||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0):
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[-1]
|
||||
@@ -3027,6 +3104,7 @@ class KCrossAttnUpBlock2D(nn.Module):
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
temb_channels: int,
|
||||
resolution_idx: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 4,
|
||||
resnet_eps: float = 1e-5,
|
||||
@@ -3104,6 +3182,7 @@ class KCrossAttnUpBlock2D(nn.Module):
|
||||
self.upsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
self.resolution_idx = resolution_idx
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -542,6 +542,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
resolution_idx=i,
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=reversed_cross_attention_dim[i],
|
||||
num_attention_heads=reversed_num_attention_heads[i],
|
||||
@@ -733,6 +734,38 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def enable_freeu(self, s1, s2, b1, b2):
|
||||
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
|
||||
|
||||
The suffixes after the scaling factors represent the stage blocks where they are being applied.
|
||||
|
||||
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
|
||||
are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
|
||||
|
||||
Args:
|
||||
s1 (`float`):
|
||||
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
|
||||
mitigate the "oversmoothing effect" in the enhanced denoising process.
|
||||
s2 (`float`):
|
||||
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
|
||||
mitigate the "oversmoothing effect" in the enhanced denoising process.
|
||||
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
|
||||
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
|
||||
"""
|
||||
for i, upsample_block in enumerate(self.up_blocks):
|
||||
setattr(upsample_block, "s1", s1)
|
||||
setattr(upsample_block, "s2", s2)
|
||||
setattr(upsample_block, "b1", b1)
|
||||
setattr(upsample_block, "b2", b2)
|
||||
|
||||
def disable_freeu(self):
|
||||
"""Disables the FreeU mechanism."""
|
||||
freeu_keys = {"s1", "s2", "b1", "b2"}
|
||||
for i, upsample_block in enumerate(self.up_blocks):
|
||||
for k in freeu_keys:
|
||||
if hasattr(upsample_block, k) or getattr(upsample_block, k) is not None:
|
||||
setattr(upsample_block, k, None)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ..utils.torch_utils import apply_freeu
|
||||
from .resnet import Downsample2D, ResnetBlock2D, TemporalConvLayer, Upsample2D
|
||||
from .transformer_2d import Transformer2DModel
|
||||
from .transformer_temporal import TransformerTemporalModel
|
||||
@@ -87,6 +88,7 @@ def get_up_block(
|
||||
resnet_eps,
|
||||
resnet_act_fn,
|
||||
num_attention_heads,
|
||||
resolution_idx=None,
|
||||
resnet_groups=None,
|
||||
cross_attention_dim=None,
|
||||
dual_cross_attention=False,
|
||||
@@ -107,6 +109,7 @@ def get_up_block(
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
resolution_idx=resolution_idx,
|
||||
)
|
||||
elif up_block_type == "CrossAttnUpBlock3D":
|
||||
if cross_attention_dim is None:
|
||||
@@ -128,6 +131,7 @@ def get_up_block(
|
||||
only_cross_attention=only_cross_attention,
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
resolution_idx=resolution_idx,
|
||||
)
|
||||
raise ValueError(f"{up_block_type} does not exist.")
|
||||
|
||||
@@ -496,6 +500,7 @@ class CrossAttnUpBlock3D(nn.Module):
|
||||
use_linear_projection=False,
|
||||
only_cross_attention=False,
|
||||
upcast_attention=False,
|
||||
resolution_idx=None,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
@@ -565,6 +570,7 @@ class CrossAttnUpBlock3D(nn.Module):
|
||||
self.upsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
self.resolution_idx = resolution_idx
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -577,6 +583,13 @@ class CrossAttnUpBlock3D(nn.Module):
|
||||
num_frames=1,
|
||||
cross_attention_kwargs=None,
|
||||
):
|
||||
is_freeu_enabled = (
|
||||
getattr(self, "s1", None)
|
||||
and getattr(self, "s2", None)
|
||||
and getattr(self, "b1", None)
|
||||
and getattr(self, "b2", None)
|
||||
)
|
||||
|
||||
# TODO(Patrick, William) - attention mask is not used
|
||||
for resnet, temp_conv, attn, temp_attn in zip(
|
||||
self.resnets, self.temp_convs, self.attentions, self.temp_attentions
|
||||
@@ -584,6 +597,19 @@ class CrossAttnUpBlock3D(nn.Module):
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
|
||||
# FreeU: Only operate on the first two stages
|
||||
if is_freeu_enabled:
|
||||
hidden_states, res_hidden_states = apply_freeu(
|
||||
self.resolution_idx,
|
||||
hidden_states,
|
||||
res_hidden_states,
|
||||
s1=self.s1,
|
||||
s2=self.s2,
|
||||
b1=self.b1,
|
||||
b2=self.b2,
|
||||
)
|
||||
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
@@ -621,6 +647,7 @@ class UpBlock3D(nn.Module):
|
||||
resnet_pre_norm: bool = True,
|
||||
output_scale_factor=1.0,
|
||||
add_upsample=True,
|
||||
resolution_idx=None,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
@@ -661,12 +688,32 @@ class UpBlock3D(nn.Module):
|
||||
self.upsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
self.resolution_idx = resolution_idx
|
||||
|
||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, num_frames=1):
|
||||
is_freeu_enabled = (
|
||||
getattr(self, "s1", None)
|
||||
and getattr(self, "s2", None)
|
||||
and getattr(self, "b1", None)
|
||||
and getattr(self, "b2", None)
|
||||
)
|
||||
for resnet, temp_conv in zip(self.resnets, self.temp_convs):
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
|
||||
# FreeU: Only operate on the first two stages
|
||||
if is_freeu_enabled:
|
||||
hidden_states, res_hidden_states = apply_freeu(
|
||||
self.resolution_idx,
|
||||
hidden_states,
|
||||
res_hidden_states,
|
||||
s1=self.s1,
|
||||
s2=self.s2,
|
||||
b1=self.b1,
|
||||
b2=self.b2,
|
||||
)
|
||||
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
@@ -255,6 +255,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=reversed_num_attention_heads[i],
|
||||
dual_cross_attention=False,
|
||||
resolution_idx=i,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
@@ -462,6 +463,40 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.enable_freeu
|
||||
def enable_freeu(self, s1, s2, b1, b2):
|
||||
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
|
||||
|
||||
The suffixes after the scaling factors represent the stage blocks where they are being applied.
|
||||
|
||||
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
|
||||
are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
|
||||
|
||||
Args:
|
||||
s1 (`float`):
|
||||
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
|
||||
mitigate the "oversmoothing effect" in the enhanced denoising process.
|
||||
s2 (`float`):
|
||||
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
|
||||
mitigate the "oversmoothing effect" in the enhanced denoising process.
|
||||
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
|
||||
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
|
||||
"""
|
||||
for i, upsample_block in enumerate(self.up_blocks):
|
||||
setattr(upsample_block, "s1", s1)
|
||||
setattr(upsample_block, "s2", s2)
|
||||
setattr(upsample_block, "b1", b1)
|
||||
setattr(upsample_block, "b2", b2)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.disable_freeu
|
||||
def disable_freeu(self):
|
||||
"""Disables the FreeU mechanism."""
|
||||
freeu_keys = {"s1", "s2", "b1", "b2"}
|
||||
for i, upsample_block in enumerate(self.up_blocks):
|
||||
for k in freeu_keys:
|
||||
if hasattr(upsample_block, k) or getattr(upsample_block, k) is not None:
|
||||
setattr(upsample_block, k, None)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
|
||||
@@ -547,6 +547,32 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
|
||||
r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
|
||||
|
||||
The suffixes after the scaling factors represent the stages where they are being applied.
|
||||
|
||||
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
|
||||
that are known to work well for different pipelines such as Alt Diffusion v1, v2, and Alt Diffusion XL.
|
||||
|
||||
Args:
|
||||
s1 (`float`):
|
||||
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
|
||||
mitigate "oversmoothing effect" in the enhanced denoising process.
|
||||
s2 (`float`):
|
||||
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
|
||||
mitigate "oversmoothing effect" in the enhanced denoising process.
|
||||
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
|
||||
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
|
||||
"""
|
||||
if not hasattr(self, "unet"):
|
||||
raise ValueError("The pipeline must have `unet` for using FreeU.")
|
||||
self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
|
||||
|
||||
def disable_freeu(self):
|
||||
"""Disables the FreeU mechanism if enabled."""
|
||||
self.unet.disable_freeu()
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
|
||||
@@ -537,6 +537,32 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
|
||||
r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
|
||||
|
||||
The suffixes after the scaling factors represent the stages where they are being applied.
|
||||
|
||||
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
|
||||
that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
|
||||
|
||||
Args:
|
||||
s1 (`float`):
|
||||
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
|
||||
mitigate "oversmoothing effect" in the enhanced denoising process.
|
||||
s2 (`float`):
|
||||
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
|
||||
mitigate "oversmoothing effect" in the enhanced denoising process.
|
||||
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
|
||||
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
|
||||
"""
|
||||
if not hasattr(self, "unet"):
|
||||
raise ValueError("The pipeline must have `unet` for using FreeU.")
|
||||
self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
|
||||
|
||||
def disable_freeu(self):
|
||||
"""Disables the FreeU mechanism if enabled."""
|
||||
self.unet.disable_freeu()
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
|
||||
@@ -560,6 +560,34 @@ class StableDiffusionXLPipeline(
|
||||
self.vae.decoder.conv_in.to(dtype)
|
||||
self.vae.decoder.mid_block.to(dtype)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
|
||||
def enable_freeu(self, s1=0.9, s2=0.2, b1=1.2, b2=1.4):
|
||||
r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
|
||||
|
||||
The suffixes after the scaling factors represent the stages where they are being applied.
|
||||
|
||||
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
|
||||
that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
|
||||
|
||||
Args:
|
||||
s1 (`float`):
|
||||
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
|
||||
mitigate "oversmoothing effect" in the enhanced denoising process.
|
||||
s2 (`float`):
|
||||
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
|
||||
mitigate "oversmoothing effect" in the enhanced denoising process.
|
||||
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
|
||||
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
|
||||
"""
|
||||
if not hasattr(self, "unet"):
|
||||
raise ValueError("The pipeline must have `unet` for using FreeU.")
|
||||
self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu
|
||||
def disable_freeu(self):
|
||||
"""Disables the FreeU mechanism if enabled."""
|
||||
self.unet.disable_freeu()
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
|
||||
@@ -472,6 +472,34 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
|
||||
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
|
||||
r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
|
||||
|
||||
The suffixes after the scaling factors represent the stages where they are being applied.
|
||||
|
||||
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
|
||||
that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
|
||||
|
||||
Args:
|
||||
s1 (`float`):
|
||||
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
|
||||
mitigate "oversmoothing effect" in the enhanced denoising process.
|
||||
s2 (`float`):
|
||||
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
|
||||
mitigate "oversmoothing effect" in the enhanced denoising process.
|
||||
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
|
||||
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
|
||||
"""
|
||||
if not hasattr(self, "unet"):
|
||||
raise ValueError("The pipeline must have `unet` for using FreeU.")
|
||||
self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu
|
||||
def disable_freeu(self):
|
||||
"""Disables the FreeU mechanism if enabled."""
|
||||
self.unet.disable_freeu()
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
|
||||
@@ -32,6 +32,7 @@ from ...models.embeddings import (
|
||||
from ...models.transformer_2d import Transformer2DModel
|
||||
from ...models.unet_2d_condition import UNet2DConditionOutput
|
||||
from ...utils import is_torch_version, logging
|
||||
from ...utils.torch_utils import apply_freeu
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -749,6 +750,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
resolution_idx=i,
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=reversed_cross_attention_dim[i],
|
||||
num_attention_heads=reversed_num_attention_heads[i],
|
||||
@@ -941,6 +943,38 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def enable_freeu(self, s1, s2, b1, b2):
|
||||
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
|
||||
|
||||
The suffixes after the scaling factors represent the stage blocks where they are being applied.
|
||||
|
||||
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
|
||||
are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
|
||||
|
||||
Args:
|
||||
s1 (`float`):
|
||||
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
|
||||
mitigate the "oversmoothing effect" in the enhanced denoising process.
|
||||
s2 (`float`):
|
||||
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
|
||||
mitigate the "oversmoothing effect" in the enhanced denoising process.
|
||||
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
|
||||
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
|
||||
"""
|
||||
for i, upsample_block in enumerate(self.up_blocks):
|
||||
setattr(upsample_block, "s1", s1)
|
||||
setattr(upsample_block, "s2", s2)
|
||||
setattr(upsample_block, "b1", b1)
|
||||
setattr(upsample_block, "b2", b2)
|
||||
|
||||
def disable_freeu(self):
|
||||
"""Disables the FreeU mechanism."""
|
||||
freeu_keys = {"s1", "s2", "b1", "b2"}
|
||||
for i, upsample_block in enumerate(self.up_blocks):
|
||||
for k in freeu_keys:
|
||||
if hasattr(upsample_block, k) or getattr(upsample_block, k) is not None:
|
||||
setattr(upsample_block, k, None)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
@@ -1630,6 +1664,7 @@ class UpBlockFlat(nn.Module):
|
||||
prev_output_channel: int,
|
||||
out_channels: int,
|
||||
temb_channels: int,
|
||||
resolution_idx: int = None,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
@@ -1670,12 +1705,33 @@ class UpBlockFlat(nn.Module):
|
||||
self.upsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
self.resolution_idx = resolution_idx
|
||||
|
||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0):
|
||||
is_freeu_enabled = (
|
||||
getattr(self, "s1", None)
|
||||
and getattr(self, "s2", None)
|
||||
and getattr(self, "b1", None)
|
||||
and getattr(self, "b2", None)
|
||||
)
|
||||
|
||||
for resnet in self.resnets:
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
|
||||
# FreeU: Only operate on the first two stages
|
||||
if is_freeu_enabled:
|
||||
hidden_states, res_hidden_states = apply_freeu(
|
||||
self.resolution_idx,
|
||||
hidden_states,
|
||||
res_hidden_states,
|
||||
s1=self.s1,
|
||||
s2=self.s2,
|
||||
b1=self.b1,
|
||||
b2=self.b2,
|
||||
)
|
||||
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
@@ -1712,6 +1768,7 @@ class CrossAttnUpBlockFlat(nn.Module):
|
||||
out_channels: int,
|
||||
prev_output_channel: int,
|
||||
temb_channels: int,
|
||||
resolution_idx: int = None,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
transformer_layers_per_block: int = 1,
|
||||
@@ -1790,6 +1847,7 @@ class CrossAttnUpBlockFlat(nn.Module):
|
||||
self.upsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
self.resolution_idx = resolution_idx
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -1803,11 +1861,30 @@ class CrossAttnUpBlockFlat(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
||||
is_freeu_enabled = (
|
||||
getattr(self, "s1", None)
|
||||
and getattr(self, "s2", None)
|
||||
and getattr(self, "b1", None)
|
||||
and getattr(self, "b2", None)
|
||||
)
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
|
||||
# FreeU: Only operate on the first two stages
|
||||
if is_freeu_enabled:
|
||||
hidden_states, res_hidden_states = apply_freeu(
|
||||
self.resolution_idx,
|
||||
hidden_states,
|
||||
res_hidden_states,
|
||||
s1=self.s1,
|
||||
s2=self.s2,
|
||||
b1=self.b1,
|
||||
b2=self.b2,
|
||||
)
|
||||
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
@@ -22,6 +22,7 @@ from .import_utils import is_torch_available, is_torch_version
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch.fft import fftn, fftshift, ifftn, ifftshift
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -86,3 +87,61 @@ def is_compiled_module(module):
|
||||
if is_torch_version("<", "2.0.0") or not hasattr(torch, "_dynamo"):
|
||||
return False
|
||||
return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)
|
||||
|
||||
|
||||
def fourier_filter(x_in, threshold, scale):
|
||||
"""Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497).
|
||||
|
||||
This version of the method comes from here:
|
||||
https://github.com/huggingface/diffusers/pull/5164#issuecomment-1732638706
|
||||
"""
|
||||
x = x_in
|
||||
B, C, H, W = x.shape
|
||||
|
||||
# Non-power of 2 images must be float32
|
||||
if (W & (W - 1)) != 0 or (H & (H - 1)) != 0:
|
||||
x = x.to(dtype=torch.float32)
|
||||
|
||||
# FFT
|
||||
x_freq = fftn(x, dim=(-2, -1))
|
||||
x_freq = fftshift(x_freq, dim=(-2, -1))
|
||||
|
||||
B, C, H, W = x_freq.shape
|
||||
mask = torch.ones((B, C, H, W), device=x.device)
|
||||
|
||||
crow, ccol = H // 2, W // 2
|
||||
mask[..., crow - threshold : crow + threshold, ccol - threshold : ccol + threshold] = scale
|
||||
x_freq = x_freq * mask
|
||||
|
||||
# IFFT
|
||||
x_freq = ifftshift(x_freq, dim=(-2, -1))
|
||||
x_filtered = ifftn(x_freq, dim=(-2, -1)).real
|
||||
|
||||
return x_filtered.to(dtype=x_in.dtype)
|
||||
|
||||
|
||||
def apply_freeu(
|
||||
resolution_idx: int, hidden_states: torch.Tensor, res_hidden_states: torch.Tensor, **freeu_kwargs
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Applies the FreeU mechanism as introduced in https:
|
||||
//arxiv.org/abs/2309.11497. Adapted from the official code repository: https://github.com/ChenyangSi/FreeU.
|
||||
|
||||
Args:
|
||||
resolution_idx (`int`): Integer denoting the UNet block where FreeU is being applied.
|
||||
hidden_states (`torch.Tensor`): Inputs to the underlying block.
|
||||
res_hidden_states (`torch.Tensor`): Features from the skip block corresponding to the underlying block.
|
||||
s1 (`float`): Scaling factor for stage 1 to attenuate the contributions of the skip features.
|
||||
s2 (`float`): Scaling factor for stage 2 to attenuate the contributions of the skip features.
|
||||
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
|
||||
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
|
||||
"""
|
||||
if resolution_idx == 0:
|
||||
num_half_channels = hidden_states.shape[1] // 2
|
||||
hidden_states[:, :num_half_channels] = hidden_states[:, :num_half_channels] * freeu_kwargs["b1"]
|
||||
res_hidden_states = fourier_filter(res_hidden_states, threshold=1, scale=freeu_kwargs["s1"])
|
||||
if resolution_idx == 1:
|
||||
num_half_channels = hidden_states.shape[1] // 2
|
||||
hidden_states[:, :num_half_channels] = hidden_states[:, :num_half_channels] * freeu_kwargs["b2"]
|
||||
res_hidden_states = fourier_filter(res_hidden_states, threshold=1, scale=freeu_kwargs["s2"])
|
||||
|
||||
return hidden_states, res_hidden_states
|
||||
|
||||
@@ -565,6 +565,47 @@ class StableDiffusionPipelineFastTests(
|
||||
def test_inference_batch_single_identical(self):
|
||||
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
|
||||
|
||||
def test_freeu_enabled(self):
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "hey"
|
||||
output = sd_pipe(prompt, num_inference_steps=1, output_type="np", generator=torch.manual_seed(0)).images
|
||||
|
||||
sd_pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
|
||||
output_freeu = sd_pipe(prompt, num_inference_steps=1, output_type="np", generator=torch.manual_seed(0)).images
|
||||
|
||||
assert not np.allclose(
|
||||
output[0, -3:, -3:, -1], output_freeu[0, -3:, -3:, -1]
|
||||
), "Enabling of FreeU should lead to different results."
|
||||
|
||||
def test_freeu_disabled(self):
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "hey"
|
||||
output = sd_pipe(prompt, num_inference_steps=1, output_type="np", generator=torch.manual_seed(0)).images
|
||||
|
||||
sd_pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
|
||||
sd_pipe.disable_freeu()
|
||||
|
||||
freeu_keys = {"s1", "s2", "b1", "b2"}
|
||||
for upsample_block in sd_pipe.unet.up_blocks:
|
||||
for key in freeu_keys:
|
||||
assert getattr(upsample_block, key) is None, f"Disabling of FreeU should have set {key} to None."
|
||||
|
||||
output_no_freeu = sd_pipe(
|
||||
prompt, num_inference_steps=1, output_type="np", generator=torch.manual_seed(0)
|
||||
).images
|
||||
|
||||
assert np.allclose(
|
||||
output[0, -3:, -3:, -1], output_no_freeu[0, -3:, -3:, -1]
|
||||
), "Disabling of FreeU should lead to results similar to the default pipeline results."
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@@ -600,6 +641,20 @@ class StableDiffusionPipelineSlowTests(unittest.TestCase):
|
||||
expected_slice = np.array([0.43625, 0.43554, 0.36670, 0.40660, 0.39703, 0.38658, 0.43936, 0.43557, 0.40592])
|
||||
assert np.abs(image_slice - expected_slice).max() < 3e-3
|
||||
|
||||
def test_stable_diffusion_v1_4_with_freeu(self):
|
||||
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
inputs["num_inference_steps"] = 25
|
||||
|
||||
sd_pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
|
||||
image = sd_pipe(**inputs).images
|
||||
image = image[0, -3:, -3:, -1].flatten()
|
||||
expected_image = [0.0721, 0.0588, 0.0268, 0.0384, 0.0636, 0.0, 0.0429, 0.0344, 0.0309]
|
||||
max_diff = np.abs(expected_image - image).max()
|
||||
assert max_diff < 1e-3
|
||||
|
||||
def test_stable_diffusion_1_4_pndm(self):
|
||||
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
@@ -1079,7 +1134,7 @@ class StableDiffusionPipelineNightlyTests(unittest.TestCase):
|
||||
"generator": generator,
|
||||
"num_inference_steps": 50,
|
||||
"guidance_scale": 7.5,
|
||||
"output_type": "numpy",
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
@@ -1155,19 +1210,3 @@ class StableDiffusionPipelineNightlyTests(unittest.TestCase):
|
||||
)
|
||||
max_diff = np.abs(expected_image - image).max()
|
||||
assert max_diff < 1e-3
|
||||
|
||||
def test_stable_diffusion_dpm(self):
|
||||
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to(torch_device)
|
||||
sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd_pipe.scheduler.config)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
inputs["num_inference_steps"] = 25
|
||||
image = sd_pipe(**inputs).images[0]
|
||||
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
||||
"/stable_diffusion_text2img/stable_diffusion_1_4_dpm_multi.npy"
|
||||
)
|
||||
max_diff = np.abs(expected_image - image).max()
|
||||
assert max_diff < 1e-3
|
||||
|
||||
@@ -193,3 +193,21 @@ class TextToVideoSDPipelineSlowTests(unittest.TestCase):
|
||||
video = video_frames.cpu().numpy()
|
||||
|
||||
assert np.abs(expected_video - video).mean() < 5e-2
|
||||
|
||||
def test_two_step_model_with_freeu(self):
|
||||
expected_video = []
|
||||
|
||||
pipe = TextToVideoSDPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b")
|
||||
pipe = pipe.to(torch_device)
|
||||
|
||||
prompt = "Spiderman is surfing"
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
|
||||
pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
|
||||
video_frames = pipe(prompt, generator=generator, num_inference_steps=2, output_type="pt").frames
|
||||
video = video_frames.cpu().numpy()
|
||||
video = video[0, 0, -3:, -3:, -1].flatten()
|
||||
|
||||
expected_video = [-0.3102, -0.2477, -0.1772, -0.648, -0.6176, -0.5484, -0.0217, -0.056, -0.0177]
|
||||
|
||||
assert np.abs(expected_video - video).mean() < 5e-2
|
||||
|
||||
Reference in New Issue
Block a user