1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Core] Add FreeU mechanism (#5164)

*  Added Fourier filter function to upsample blocks

* 🔧 Update Fourier_filter for float16 support

*  Added UNetFreeUConfig to UNet model for FreeU adaptation 🛠️

* move unet to its original form and add fourier_filter to torch_utils.

* implement freeU enable mechanism

* implement disable mechanism

* resolution index.

* correct resolution idx condition.

* fix copies.

* no need to use resolution_idx in vae.

* spell out the kwargs

* proper config property

* fix attribution setting

* place unet hasattr properly.

* fix: attribute access.

* proper disable

* remove validation method.

* debug

* debug

* debug

* debug

* debug

* debug

* potential fix.

* add: doc.

* fix copies

* add: tests.

* add: support freeU in SDXL.

* set default value of resolution idx.

* set default values for resolution_idx.

* fix copies

* fix rest.

* fix copies

* address PR comments.

* run fix-copies

* move apply_free_u to utils and other minors.

* introduce support for video (unet3D)

* minor ups

* consistent fix-copies.

* consistent stuff

* fix-copies

* add: rest

* add: docs.

* fix: tests

* fix: doc path

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* style up

* move to techniques.

* add: slow test for sd freeu.

* add: slow test for sd freeu.

* add: slow test for sd freeu.

* add: slow test for sd freeu.

* add: slow test for sd freeu.

* add: slow test for sd freeu.

* add: slow test for video with freeu

* add: slow test for video with freeu

* add: slow test for video with freeu

* style

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
This commit is contained in:
Kadir Nar
2023-10-05 11:37:04 +03:00
committed by GitHub
parent e46ec5f88f
commit 84b82a6cb7
14 changed files with 637 additions and 17 deletions

View File

@@ -58,6 +58,8 @@
title: Control image brightness
- local: using-diffusers/weighted_prompts
title: Prompt weighting
- local: using-diffusers/freeu
title: Improve generation quality with FreeU
title: Techniques
- sections:
- local: using-diffusers/pipeline_overview

View File

@@ -0,0 +1,123 @@
# Improve generation quality with FreeU
[[open-in-colab]]
The UNet is responsible for denoising during the reverse diffusion process, and there are two distinct features in its architecture:
1. Backbone features primarily contribute to the denoising process
2. Skip features mainly introduce high-frequency features into the decoder module and can make the network overlook the semantics in the backbone features
However, the skip connection can sometimes introduce unnatural image details. [FreeU](https://hf.co/papers/2309.11497) is a technique for improving image quality by rebalancing the contributions from the UNets skip connections and backbone feature maps.
FreeU is applied during inference and it does not require any additional training. The technique works for different tasks such as text-to-image, image-to-image, and text-to-video.
In this guide, you will apply FreeU to the [`StableDiffusionPipeline`], [`StableDiffusionXLPipeline`], and [`TextToVideoSDPipeline`].
## StableDiffusionPipeline
Load the pipeline:
```py
from diffusers import DiffusionPipeline
import torch
pipeline = DiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, safety_checker=None
).to("cuda")
```
Then enable the FreeU mechanism with the FreeU-specific hyperparameters. These values are scaling factors for the backbone and skip features.
```py
pipeline.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
```
The values above are from the official FreeU [code repository](https://github.com/ChenyangSi/FreeU) where you can also find [reference hyperparameters](https://github.com/ChenyangSi/FreeU#range-for-more-parameters) for different models.
<Tip>
Disable the FreeU mechanism by calling `disable_freeu()` on a pipeline.
</Tip>
And then run inference:
```py
prompt = "A squirrel eating a burger"
seed = 2023
image = pipeline(prompt, generator=torch.manual_seed(seed)).images[0]
```
The figure below compares non-FreeU and FreeU results respectively for the same hyperparameters used above (`prompt` and `seed`):
![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/freeu/sdv1_5_freeu.jpg)
Let's see how Stable Diffusion 2 results are impacted:
```py
from diffusers import DiffusionPipeline
import torch
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16, safety_checker=None
).to("cuda")
prompt = "A squirrel eating a burger"
seed = 2023
pipeline.enable_freeu(s1=0.9, s2=0.2, b1=1.1, b2=1.2)
image = pipeline(prompt, generator=torch.manual_seed(seed)).images[0]
```
![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/freeu/sdv2_1_freeu.jpg)
## Stable Diffusion XL
Finally, let's take a look at how FreeU affects Stable Diffusion XL results:
```py
from diffusers import DiffusionPipeline
import torch
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16,
).to("cuda")
prompt = "A squirrel eating a burger"
seed = 2023
# Comes from
# https://wandb.ai/nasirk24/UNET-FreeU-SDXL/reports/FreeU-SDXL-Optimal-Parameters--Vmlldzo1NDg4NTUw
pipeline.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2)
image = pipeline(prompt, generator=torch.manual_seed(seed)).images[0]
```
![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/freeu/sdxl_freeu.jpg)
## Text-to-video generation
FreeU can also be used to improve video quality:
```python
from diffusers import DiffusionPipeline
from diffusers.utils import export_to_video
import torch
model_id = "cerspense/zeroscope_v2_576w"
pipe = DiffusionPipeline.from_pretrained("cerspense/zeroscope_v2_576w", torch_dtype=torch.float16).to("cuda")
pipe = pipe.to("cuda")
prompt = "an astronaut riding a horse on mars"
seed = 2023
# The values come from
# https://github.com/lyn-rgb/FreeU_Diffusers#video-pipelines
pipe.enable_freeu(b1=1.2, b2=1.4, s1=0.9, s2=0.2)
video_frames = pipe(prompt, height=320, width=576, num_frames=30, generator=torch.manual_seed(seed)).frames
export_to_video(video_frames, "astronaut_rides_horse.mp4")
```
Thanks to [kadirnar](https://github.com/kadirnar/) for helping to integrate the feature, and to [justindujardin](https://github.com/justindujardin) for the helpful discussions.

View File

@@ -19,6 +19,7 @@ import torch.nn.functional as F
from torch import nn
from ..utils import is_torch_version, logging
from ..utils.torch_utils import apply_freeu
from .activations import get_activation
from .attention import AdaGroupNorm
from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
@@ -249,6 +250,7 @@ def get_up_block(
add_upsample,
resnet_eps,
resnet_act_fn,
resolution_idx=None,
transformer_layers_per_block=1,
num_attention_heads=None,
resnet_groups=None,
@@ -281,6 +283,7 @@ def get_up_block(
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
resolution_idx=resolution_idx,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
@@ -295,6 +298,7 @@ def get_up_block(
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
resolution_idx=resolution_idx,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
@@ -314,6 +318,7 @@ def get_up_block(
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
resolution_idx=resolution_idx,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
@@ -337,6 +342,7 @@ def get_up_block(
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
resolution_idx=resolution_idx,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
@@ -362,6 +368,7 @@ def get_up_block(
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
resolution_idx=resolution_idx,
dropout=dropout,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
@@ -377,6 +384,7 @@ def get_up_block(
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
resolution_idx=resolution_idx,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
@@ -390,6 +398,7 @@ def get_up_block(
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
resolution_idx=resolution_idx,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
@@ -402,6 +411,7 @@ def get_up_block(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
resolution_idx=resolution_idx,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
@@ -415,6 +425,7 @@ def get_up_block(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
resolution_idx=resolution_idx,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
@@ -430,6 +441,7 @@ def get_up_block(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
resolution_idx=resolution_idx,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
@@ -441,6 +453,7 @@ def get_up_block(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
resolution_idx=resolution_idx,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
@@ -1993,6 +2006,7 @@ class AttnUpBlock2D(nn.Module):
prev_output_channel: int,
out_channels: int,
temb_channels: int,
resolution_idx: int = None,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
@@ -2075,6 +2089,8 @@ class AttnUpBlock2D(nn.Module):
else:
self.upsamplers = None
self.resolution_idx = resolution_idx
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0):
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
@@ -2103,6 +2119,7 @@ class CrossAttnUpBlock2D(nn.Module):
out_channels: int,
prev_output_channel: int,
temb_channels: int,
resolution_idx: int = None,
dropout: float = 0.0,
num_layers: int = 1,
transformer_layers_per_block: int = 1,
@@ -2181,6 +2198,7 @@ class CrossAttnUpBlock2D(nn.Module):
self.upsamplers = None
self.gradient_checkpointing = False
self.resolution_idx = resolution_idx
def forward(
self,
@@ -2194,11 +2212,30 @@ class CrossAttnUpBlock2D(nn.Module):
encoder_attention_mask: Optional[torch.FloatTensor] = None,
):
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
is_freeu_enabled = (
getattr(self, "s1", None)
and getattr(self, "s2", None)
and getattr(self, "b1", None)
and getattr(self, "b2", None)
)
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# FreeU: Only operate on the first two stages
if is_freeu_enabled:
hidden_states, res_hidden_states = apply_freeu(
self.resolution_idx,
hidden_states,
res_hidden_states,
s1=self.s1,
s2=self.s2,
b1=self.b1,
b2=self.b2,
)
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
@@ -2252,6 +2289,7 @@ class UpBlock2D(nn.Module):
prev_output_channel: int,
out_channels: int,
temb_channels: int,
resolution_idx: int = None,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
@@ -2292,12 +2330,33 @@ class UpBlock2D(nn.Module):
self.upsamplers = None
self.gradient_checkpointing = False
self.resolution_idx = resolution_idx
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0):
is_freeu_enabled = (
getattr(self, "s1", None)
and getattr(self, "s2", None)
and getattr(self, "b1", None)
and getattr(self, "b2", None)
)
for resnet in self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# FreeU: Only operate on the first two stages
if is_freeu_enabled:
hidden_states, res_hidden_states = apply_freeu(
self.resolution_idx,
hidden_states,
res_hidden_states,
s1=self.s1,
s2=self.s2,
b1=self.b1,
b2=self.b2,
)
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
@@ -2331,6 +2390,7 @@ class UpDecoderBlock2D(nn.Module):
self,
in_channels: int,
out_channels: int,
resolution_idx: int = None,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
@@ -2370,6 +2430,8 @@ class UpDecoderBlock2D(nn.Module):
else:
self.upsamplers = None
self.resolution_idx = resolution_idx
def forward(self, hidden_states, temb=None, scale: float = 1.0):
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=temb, scale=scale)
@@ -2386,6 +2448,7 @@ class AttnUpDecoderBlock2D(nn.Module):
self,
in_channels: int,
out_channels: int,
resolution_idx: int = None,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
@@ -2449,6 +2512,8 @@ class AttnUpDecoderBlock2D(nn.Module):
else:
self.upsamplers = None
self.resolution_idx = resolution_idx
def forward(self, hidden_states, temb=None, scale: float = 1.0):
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb=temb, scale=scale)
@@ -2469,6 +2534,7 @@ class AttnSkipUpBlock2D(nn.Module):
prev_output_channel: int,
out_channels: int,
temb_channels: int,
resolution_idx: int = None,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
@@ -2553,6 +2619,8 @@ class AttnSkipUpBlock2D(nn.Module):
self.skip_norm = None
self.act = None
self.resolution_idx = resolution_idx
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None, scale: float = 1.0):
for resnet in self.resnets:
# pop res hidden states
@@ -2589,6 +2657,7 @@ class SkipUpBlock2D(nn.Module):
prev_output_channel: int,
out_channels: int,
temb_channels: int,
resolution_idx: int = None,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
@@ -2651,6 +2720,8 @@ class SkipUpBlock2D(nn.Module):
self.skip_norm = None
self.act = None
self.resolution_idx = resolution_idx
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None, scale: float = 1.0):
for resnet in self.resnets:
# pop res hidden states
@@ -2684,6 +2755,7 @@ class ResnetUpsampleBlock2D(nn.Module):
prev_output_channel: int,
out_channels: int,
temb_channels: int,
resolution_idx: int = None,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
@@ -2743,6 +2815,7 @@ class ResnetUpsampleBlock2D(nn.Module):
self.upsamplers = None
self.gradient_checkpointing = False
self.resolution_idx = resolution_idx
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0):
for resnet in self.resnets:
@@ -2784,6 +2857,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
out_channels: int,
prev_output_channel: int,
temb_channels: int,
resolution_idx: int = None,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
@@ -2873,6 +2947,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
self.upsamplers = None
self.gradient_checkpointing = False
self.resolution_idx = resolution_idx
def forward(
self,
@@ -2947,6 +3022,7 @@ class KUpBlock2D(nn.Module):
in_channels: int,
out_channels: int,
temb_channels: int,
resolution_idx: int,
dropout: float = 0.0,
num_layers: int = 5,
resnet_eps: float = 1e-5,
@@ -2988,6 +3064,7 @@ class KUpBlock2D(nn.Module):
self.upsamplers = None
self.gradient_checkpointing = False
self.resolution_idx = resolution_idx
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0):
res_hidden_states_tuple = res_hidden_states_tuple[-1]
@@ -3027,6 +3104,7 @@ class KCrossAttnUpBlock2D(nn.Module):
in_channels: int,
out_channels: int,
temb_channels: int,
resolution_idx: int,
dropout: float = 0.0,
num_layers: int = 4,
resnet_eps: float = 1e-5,
@@ -3104,6 +3182,7 @@ class KCrossAttnUpBlock2D(nn.Module):
self.upsamplers = None
self.gradient_checkpointing = False
self.resolution_idx = resolution_idx
def forward(
self,

View File

@@ -542,6 +542,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
add_upsample=add_upsample,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resolution_idx=i,
resnet_groups=norm_num_groups,
cross_attention_dim=reversed_cross_attention_dim[i],
num_attention_heads=reversed_num_attention_heads[i],
@@ -733,6 +734,38 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def enable_freeu(self, s1, s2, b1, b2):
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
The suffixes after the scaling factors represent the stage blocks where they are being applied.
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
Args:
s1 (`float`):
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
mitigate the "oversmoothing effect" in the enhanced denoising process.
s2 (`float`):
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
mitigate the "oversmoothing effect" in the enhanced denoising process.
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
"""
for i, upsample_block in enumerate(self.up_blocks):
setattr(upsample_block, "s1", s1)
setattr(upsample_block, "s2", s2)
setattr(upsample_block, "b1", b1)
setattr(upsample_block, "b2", b2)
def disable_freeu(self):
"""Disables the FreeU mechanism."""
freeu_keys = {"s1", "s2", "b1", "b2"}
for i, upsample_block in enumerate(self.up_blocks):
for k in freeu_keys:
if hasattr(upsample_block, k) or getattr(upsample_block, k) is not None:
setattr(upsample_block, k, None)
def forward(
self,
sample: torch.FloatTensor,

View File

@@ -15,6 +15,7 @@
import torch
from torch import nn
from ..utils.torch_utils import apply_freeu
from .resnet import Downsample2D, ResnetBlock2D, TemporalConvLayer, Upsample2D
from .transformer_2d import Transformer2DModel
from .transformer_temporal import TransformerTemporalModel
@@ -87,6 +88,7 @@ def get_up_block(
resnet_eps,
resnet_act_fn,
num_attention_heads,
resolution_idx=None,
resnet_groups=None,
cross_attention_dim=None,
dual_cross_attention=False,
@@ -107,6 +109,7 @@ def get_up_block(
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
resolution_idx=resolution_idx,
)
elif up_block_type == "CrossAttnUpBlock3D":
if cross_attention_dim is None:
@@ -128,6 +131,7 @@ def get_up_block(
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
resolution_idx=resolution_idx,
)
raise ValueError(f"{up_block_type} does not exist.")
@@ -496,6 +500,7 @@ class CrossAttnUpBlock3D(nn.Module):
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
resolution_idx=None,
):
super().__init__()
resnets = []
@@ -565,6 +570,7 @@ class CrossAttnUpBlock3D(nn.Module):
self.upsamplers = None
self.gradient_checkpointing = False
self.resolution_idx = resolution_idx
def forward(
self,
@@ -577,6 +583,13 @@ class CrossAttnUpBlock3D(nn.Module):
num_frames=1,
cross_attention_kwargs=None,
):
is_freeu_enabled = (
getattr(self, "s1", None)
and getattr(self, "s2", None)
and getattr(self, "b1", None)
and getattr(self, "b2", None)
)
# TODO(Patrick, William) - attention mask is not used
for resnet, temp_conv, attn, temp_attn in zip(
self.resnets, self.temp_convs, self.attentions, self.temp_attentions
@@ -584,6 +597,19 @@ class CrossAttnUpBlock3D(nn.Module):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# FreeU: Only operate on the first two stages
if is_freeu_enabled:
hidden_states, res_hidden_states = apply_freeu(
self.resolution_idx,
hidden_states,
res_hidden_states,
s1=self.s1,
s2=self.s2,
b1=self.b1,
b2=self.b2,
)
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
@@ -621,6 +647,7 @@ class UpBlock3D(nn.Module):
resnet_pre_norm: bool = True,
output_scale_factor=1.0,
add_upsample=True,
resolution_idx=None,
):
super().__init__()
resnets = []
@@ -661,12 +688,32 @@ class UpBlock3D(nn.Module):
self.upsamplers = None
self.gradient_checkpointing = False
self.resolution_idx = resolution_idx
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, num_frames=1):
is_freeu_enabled = (
getattr(self, "s1", None)
and getattr(self, "s2", None)
and getattr(self, "b1", None)
and getattr(self, "b2", None)
)
for resnet, temp_conv in zip(self.resnets, self.temp_convs):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# FreeU: Only operate on the first two stages
if is_freeu_enabled:
hidden_states, res_hidden_states = apply_freeu(
self.resolution_idx,
hidden_states,
res_hidden_states,
s1=self.s1,
s2=self.s2,
b1=self.b1,
b2=self.b2,
)
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)

View File

@@ -255,6 +255,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
cross_attention_dim=cross_attention_dim,
num_attention_heads=reversed_num_attention_heads[i],
dual_cross_attention=False,
resolution_idx=i,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
@@ -462,6 +463,40 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
module.gradient_checkpointing = value
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.enable_freeu
def enable_freeu(self, s1, s2, b1, b2):
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
The suffixes after the scaling factors represent the stage blocks where they are being applied.
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
Args:
s1 (`float`):
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
mitigate the "oversmoothing effect" in the enhanced denoising process.
s2 (`float`):
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
mitigate the "oversmoothing effect" in the enhanced denoising process.
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
"""
for i, upsample_block in enumerate(self.up_blocks):
setattr(upsample_block, "s1", s1)
setattr(upsample_block, "s2", s2)
setattr(upsample_block, "b1", b1)
setattr(upsample_block, "b2", b2)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.disable_freeu
def disable_freeu(self):
"""Disables the FreeU mechanism."""
freeu_keys = {"s1", "s2", "b1", "b2"}
for i, upsample_block in enumerate(self.up_blocks):
for k in freeu_keys:
if hasattr(upsample_block, k) or getattr(upsample_block, k) is not None:
setattr(upsample_block, k, None)
def forward(
self,
sample: torch.FloatTensor,

View File

@@ -547,6 +547,32 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
latents = latents * self.scheduler.init_noise_sigma
return latents
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
The suffixes after the scaling factors represent the stages where they are being applied.
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
that are known to work well for different pipelines such as Alt Diffusion v1, v2, and Alt Diffusion XL.
Args:
s1 (`float`):
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
mitigate "oversmoothing effect" in the enhanced denoising process.
s2 (`float`):
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
mitigate "oversmoothing effect" in the enhanced denoising process.
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
"""
if not hasattr(self, "unet"):
raise ValueError("The pipeline must have `unet` for using FreeU.")
self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
def disable_freeu(self):
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(

View File

@@ -537,6 +537,32 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
latents = latents * self.scheduler.init_noise_sigma
return latents
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
The suffixes after the scaling factors represent the stages where they are being applied.
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
Args:
s1 (`float`):
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
mitigate "oversmoothing effect" in the enhanced denoising process.
s2 (`float`):
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
mitigate "oversmoothing effect" in the enhanced denoising process.
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
"""
if not hasattr(self, "unet"):
raise ValueError("The pipeline must have `unet` for using FreeU.")
self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
def disable_freeu(self):
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(

View File

@@ -560,6 +560,34 @@ class StableDiffusionXLPipeline(
self.vae.decoder.conv_in.to(dtype)
self.vae.decoder.mid_block.to(dtype)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
def enable_freeu(self, s1=0.9, s2=0.2, b1=1.2, b2=1.4):
r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
The suffixes after the scaling factors represent the stages where they are being applied.
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
Args:
s1 (`float`):
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
mitigate "oversmoothing effect" in the enhanced denoising process.
s2 (`float`):
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
mitigate "oversmoothing effect" in the enhanced denoising process.
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
"""
if not hasattr(self, "unet"):
raise ValueError("The pipeline must have `unet` for using FreeU.")
self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu
def disable_freeu(self):
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(

View File

@@ -472,6 +472,34 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora
latents = latents * self.scheduler.init_noise_sigma
return latents
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
The suffixes after the scaling factors represent the stages where they are being applied.
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
Args:
s1 (`float`):
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
mitigate "oversmoothing effect" in the enhanced denoising process.
s2 (`float`):
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
mitigate "oversmoothing effect" in the enhanced denoising process.
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
"""
if not hasattr(self, "unet"):
raise ValueError("The pipeline must have `unet` for using FreeU.")
self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu
def disable_freeu(self):
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(

View File

@@ -32,6 +32,7 @@ from ...models.embeddings import (
from ...models.transformer_2d import Transformer2DModel
from ...models.unet_2d_condition import UNet2DConditionOutput
from ...utils import is_torch_version, logging
from ...utils.torch_utils import apply_freeu
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -749,6 +750,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
add_upsample=add_upsample,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resolution_idx=i,
resnet_groups=norm_num_groups,
cross_attention_dim=reversed_cross_attention_dim[i],
num_attention_heads=reversed_num_attention_heads[i],
@@ -941,6 +943,38 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def enable_freeu(self, s1, s2, b1, b2):
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
The suffixes after the scaling factors represent the stage blocks where they are being applied.
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
Args:
s1 (`float`):
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
mitigate the "oversmoothing effect" in the enhanced denoising process.
s2 (`float`):
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
mitigate the "oversmoothing effect" in the enhanced denoising process.
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
"""
for i, upsample_block in enumerate(self.up_blocks):
setattr(upsample_block, "s1", s1)
setattr(upsample_block, "s2", s2)
setattr(upsample_block, "b1", b1)
setattr(upsample_block, "b2", b2)
def disable_freeu(self):
"""Disables the FreeU mechanism."""
freeu_keys = {"s1", "s2", "b1", "b2"}
for i, upsample_block in enumerate(self.up_blocks):
for k in freeu_keys:
if hasattr(upsample_block, k) or getattr(upsample_block, k) is not None:
setattr(upsample_block, k, None)
def forward(
self,
sample: torch.FloatTensor,
@@ -1630,6 +1664,7 @@ class UpBlockFlat(nn.Module):
prev_output_channel: int,
out_channels: int,
temb_channels: int,
resolution_idx: int = None,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
@@ -1670,12 +1705,33 @@ class UpBlockFlat(nn.Module):
self.upsamplers = None
self.gradient_checkpointing = False
self.resolution_idx = resolution_idx
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0):
is_freeu_enabled = (
getattr(self, "s1", None)
and getattr(self, "s2", None)
and getattr(self, "b1", None)
and getattr(self, "b2", None)
)
for resnet in self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# FreeU: Only operate on the first two stages
if is_freeu_enabled:
hidden_states, res_hidden_states = apply_freeu(
self.resolution_idx,
hidden_states,
res_hidden_states,
s1=self.s1,
s2=self.s2,
b1=self.b1,
b2=self.b2,
)
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
@@ -1712,6 +1768,7 @@ class CrossAttnUpBlockFlat(nn.Module):
out_channels: int,
prev_output_channel: int,
temb_channels: int,
resolution_idx: int = None,
dropout: float = 0.0,
num_layers: int = 1,
transformer_layers_per_block: int = 1,
@@ -1790,6 +1847,7 @@ class CrossAttnUpBlockFlat(nn.Module):
self.upsamplers = None
self.gradient_checkpointing = False
self.resolution_idx = resolution_idx
def forward(
self,
@@ -1803,11 +1861,30 @@ class CrossAttnUpBlockFlat(nn.Module):
encoder_attention_mask: Optional[torch.FloatTensor] = None,
):
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
is_freeu_enabled = (
getattr(self, "s1", None)
and getattr(self, "s2", None)
and getattr(self, "b1", None)
and getattr(self, "b2", None)
)
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# FreeU: Only operate on the first two stages
if is_freeu_enabled:
hidden_states, res_hidden_states = apply_freeu(
self.resolution_idx,
hidden_states,
res_hidden_states,
s1=self.s1,
s2=self.s2,
b1=self.b1,
b2=self.b2,
)
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:

View File

@@ -22,6 +22,7 @@ from .import_utils import is_torch_available, is_torch_version
if is_torch_available():
import torch
from torch.fft import fftn, fftshift, ifftn, ifftshift
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -86,3 +87,61 @@ def is_compiled_module(module):
if is_torch_version("<", "2.0.0") or not hasattr(torch, "_dynamo"):
return False
return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)
def fourier_filter(x_in, threshold, scale):
"""Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497).
This version of the method comes from here:
https://github.com/huggingface/diffusers/pull/5164#issuecomment-1732638706
"""
x = x_in
B, C, H, W = x.shape
# Non-power of 2 images must be float32
if (W & (W - 1)) != 0 or (H & (H - 1)) != 0:
x = x.to(dtype=torch.float32)
# FFT
x_freq = fftn(x, dim=(-2, -1))
x_freq = fftshift(x_freq, dim=(-2, -1))
B, C, H, W = x_freq.shape
mask = torch.ones((B, C, H, W), device=x.device)
crow, ccol = H // 2, W // 2
mask[..., crow - threshold : crow + threshold, ccol - threshold : ccol + threshold] = scale
x_freq = x_freq * mask
# IFFT
x_freq = ifftshift(x_freq, dim=(-2, -1))
x_filtered = ifftn(x_freq, dim=(-2, -1)).real
return x_filtered.to(dtype=x_in.dtype)
def apply_freeu(
resolution_idx: int, hidden_states: torch.Tensor, res_hidden_states: torch.Tensor, **freeu_kwargs
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Applies the FreeU mechanism as introduced in https:
//arxiv.org/abs/2309.11497. Adapted from the official code repository: https://github.com/ChenyangSi/FreeU.
Args:
resolution_idx (`int`): Integer denoting the UNet block where FreeU is being applied.
hidden_states (`torch.Tensor`): Inputs to the underlying block.
res_hidden_states (`torch.Tensor`): Features from the skip block corresponding to the underlying block.
s1 (`float`): Scaling factor for stage 1 to attenuate the contributions of the skip features.
s2 (`float`): Scaling factor for stage 2 to attenuate the contributions of the skip features.
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
"""
if resolution_idx == 0:
num_half_channels = hidden_states.shape[1] // 2
hidden_states[:, :num_half_channels] = hidden_states[:, :num_half_channels] * freeu_kwargs["b1"]
res_hidden_states = fourier_filter(res_hidden_states, threshold=1, scale=freeu_kwargs["s1"])
if resolution_idx == 1:
num_half_channels = hidden_states.shape[1] // 2
hidden_states[:, :num_half_channels] = hidden_states[:, :num_half_channels] * freeu_kwargs["b2"]
res_hidden_states = fourier_filter(res_hidden_states, threshold=1, scale=freeu_kwargs["s2"])
return hidden_states, res_hidden_states

View File

@@ -565,6 +565,47 @@ class StableDiffusionPipelineFastTests(
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
def test_freeu_enabled(self):
components = self.get_dummy_components()
sd_pipe = StableDiffusionPipeline(**components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
prompt = "hey"
output = sd_pipe(prompt, num_inference_steps=1, output_type="np", generator=torch.manual_seed(0)).images
sd_pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
output_freeu = sd_pipe(prompt, num_inference_steps=1, output_type="np", generator=torch.manual_seed(0)).images
assert not np.allclose(
output[0, -3:, -3:, -1], output_freeu[0, -3:, -3:, -1]
), "Enabling of FreeU should lead to different results."
def test_freeu_disabled(self):
components = self.get_dummy_components()
sd_pipe = StableDiffusionPipeline(**components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
prompt = "hey"
output = sd_pipe(prompt, num_inference_steps=1, output_type="np", generator=torch.manual_seed(0)).images
sd_pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
sd_pipe.disable_freeu()
freeu_keys = {"s1", "s2", "b1", "b2"}
for upsample_block in sd_pipe.unet.up_blocks:
for key in freeu_keys:
assert getattr(upsample_block, key) is None, f"Disabling of FreeU should have set {key} to None."
output_no_freeu = sd_pipe(
prompt, num_inference_steps=1, output_type="np", generator=torch.manual_seed(0)
).images
assert np.allclose(
output[0, -3:, -3:, -1], output_no_freeu[0, -3:, -3:, -1]
), "Disabling of FreeU should lead to results similar to the default pipeline results."
@slow
@require_torch_gpu
@@ -600,6 +641,20 @@ class StableDiffusionPipelineSlowTests(unittest.TestCase):
expected_slice = np.array([0.43625, 0.43554, 0.36670, 0.40660, 0.39703, 0.38658, 0.43936, 0.43557, 0.40592])
assert np.abs(image_slice - expected_slice).max() < 3e-3
def test_stable_diffusion_v1_4_with_freeu(self):
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
inputs["num_inference_steps"] = 25
sd_pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
image = sd_pipe(**inputs).images
image = image[0, -3:, -3:, -1].flatten()
expected_image = [0.0721, 0.0588, 0.0268, 0.0384, 0.0636, 0.0, 0.0429, 0.0344, 0.0309]
max_diff = np.abs(expected_image - image).max()
assert max_diff < 1e-3
def test_stable_diffusion_1_4_pndm(self):
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
sd_pipe = sd_pipe.to(torch_device)
@@ -1079,7 +1134,7 @@ class StableDiffusionPipelineNightlyTests(unittest.TestCase):
"generator": generator,
"num_inference_steps": 50,
"guidance_scale": 7.5,
"output_type": "numpy",
"output_type": "np",
}
return inputs
@@ -1155,19 +1210,3 @@ class StableDiffusionPipelineNightlyTests(unittest.TestCase):
)
max_diff = np.abs(expected_image - image).max()
assert max_diff < 1e-3
def test_stable_diffusion_dpm(self):
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to(torch_device)
sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
inputs["num_inference_steps"] = 25
image = sd_pipe(**inputs).images[0]
expected_image = load_numpy(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_text2img/stable_diffusion_1_4_dpm_multi.npy"
)
max_diff = np.abs(expected_image - image).max()
assert max_diff < 1e-3

View File

@@ -193,3 +193,21 @@ class TextToVideoSDPipelineSlowTests(unittest.TestCase):
video = video_frames.cpu().numpy()
assert np.abs(expected_video - video).mean() < 5e-2
def test_two_step_model_with_freeu(self):
expected_video = []
pipe = TextToVideoSDPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b")
pipe = pipe.to(torch_device)
prompt = "Spiderman is surfing"
generator = torch.Generator(device="cpu").manual_seed(0)
pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
video_frames = pipe(prompt, generator=generator, num_inference_steps=2, output_type="pt").frames
video = video_frames.cpu().numpy()
video = video[0, 0, -3:, -3:, -1].flatten()
expected_video = [-0.3102, -0.2477, -0.1772, -0.648, -0.6176, -0.5484, -0.0217, -0.056, -0.0177]
assert np.abs(expected_video - video).mean() < 5e-2