From 84b82a6cb7045de687684c23caa7683d43a7cd9d Mon Sep 17 00:00:00 2001 From: Kadir Nar Date: Thu, 5 Oct 2023 11:37:04 +0300 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20[Core]=20Add=20FreeU=20mechanism=20?= =?UTF-8?q?(#5164)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ✨ Added Fourier filter function to upsample blocks * 🔧 Update Fourier_filter for float16 support * ✨ Added UNetFreeUConfig to UNet model for FreeU adaptation 🛠️ * move unet to its original form and add fourier_filter to torch_utils. * implement freeU enable mechanism * implement disable mechanism * resolution index. * correct resolution idx condition. * fix copies. * no need to use resolution_idx in vae. * spell out the kwargs * proper config property * fix attribution setting * place unet hasattr properly. * fix: attribute access. * proper disable * remove validation method. * debug * debug * debug * debug * debug * debug * potential fix. * add: doc. * fix copies * add: tests. * add: support freeU in SDXL. * set default value of resolution idx. * set default values for resolution_idx. * fix copies * fix rest. * fix copies * address PR comments. * run fix-copies * move apply_free_u to utils and other minors. * introduce support for video (unet3D) * minor ups * consistent fix-copies. * consistent stuff * fix-copies * add: rest * add: docs. * fix: tests * fix: doc path * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * style up * move to techniques. * add: slow test for sd freeu. * add: slow test for sd freeu. * add: slow test for sd freeu. * add: slow test for sd freeu. * add: slow test for sd freeu. * add: slow test for sd freeu. * add: slow test for video with freeu * add: slow test for video with freeu * add: slow test for video with freeu * style --------- Co-authored-by: Sayak Paul Co-authored-by: Patrick von Platen Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/_toctree.yml | 2 + docs/source/en/using-diffusers/freeu.md | 123 ++++++++++++++++++ src/diffusers/models/unet_2d_blocks.py | 79 +++++++++++ src/diffusers/models/unet_2d_condition.py | 33 +++++ src/diffusers/models/unet_3d_blocks.py | 47 +++++++ src/diffusers/models/unet_3d_condition.py | 35 +++++ .../alt_diffusion/pipeline_alt_diffusion.py | 26 ++++ .../pipeline_stable_diffusion.py | 26 ++++ .../pipeline_stable_diffusion_xl.py | 28 ++++ .../pipeline_text_to_video_synth.py | 28 ++++ .../versatile_diffusion/modeling_text_unet.py | 77 +++++++++++ src/diffusers/utils/torch_utils.py | 59 +++++++++ .../stable_diffusion/test_stable_diffusion.py | 73 ++++++++--- .../test_text_to_video.py | 18 +++ 14 files changed, 637 insertions(+), 17 deletions(-) create mode 100644 docs/source/en/using-diffusers/freeu.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index cc50a95643..d95e553bd3 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -58,6 +58,8 @@ title: Control image brightness - local: using-diffusers/weighted_prompts title: Prompt weighting + - local: using-diffusers/freeu + title: Improve generation quality with FreeU title: Techniques - sections: - local: using-diffusers/pipeline_overview diff --git a/docs/source/en/using-diffusers/freeu.md b/docs/source/en/using-diffusers/freeu.md new file mode 100644 index 0000000000..6c23ec7543 --- /dev/null +++ b/docs/source/en/using-diffusers/freeu.md @@ -0,0 +1,123 @@ +# Improve generation quality with FreeU + +[[open-in-colab]] + +The UNet is responsible for denoising during the reverse diffusion process, and there are two distinct features in its architecture: + +1. Backbone features primarily contribute to the denoising process +2. Skip features mainly introduce high-frequency features into the decoder module and can make the network overlook the semantics in the backbone features + +However, the skip connection can sometimes introduce unnatural image details. [FreeU](https://hf.co/papers/2309.11497) is a technique for improving image quality by rebalancing the contributions from the UNet’s skip connections and backbone feature maps. + +FreeU is applied during inference and it does not require any additional training. The technique works for different tasks such as text-to-image, image-to-image, and text-to-video. + +In this guide, you will apply FreeU to the [`StableDiffusionPipeline`], [`StableDiffusionXLPipeline`], and [`TextToVideoSDPipeline`]. + +## StableDiffusionPipeline + +Load the pipeline: + +```py +from diffusers import DiffusionPipeline +import torch + +pipeline = DiffusionPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, safety_checker=None +).to("cuda") +``` + +Then enable the FreeU mechanism with the FreeU-specific hyperparameters. These values are scaling factors for the backbone and skip features. + +```py +pipeline.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4) +``` + +The values above are from the official FreeU [code repository](https://github.com/ChenyangSi/FreeU) where you can also find [reference hyperparameters](https://github.com/ChenyangSi/FreeU#range-for-more-parameters) for different models. + + + +Disable the FreeU mechanism by calling `disable_freeu()` on a pipeline. + + + +And then run inference: + +```py +prompt = "A squirrel eating a burger" +seed = 2023 +image = pipeline(prompt, generator=torch.manual_seed(seed)).images[0] +``` + +The figure below compares non-FreeU and FreeU results respectively for the same hyperparameters used above (`prompt` and `seed`): + +![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/freeu/sdv1_5_freeu.jpg) + + +Let's see how Stable Diffusion 2 results are impacted: + +```py +from diffusers import DiffusionPipeline +import torch + +pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16, safety_checker=None +).to("cuda") + +prompt = "A squirrel eating a burger" +seed = 2023 + +pipeline.enable_freeu(s1=0.9, s2=0.2, b1=1.1, b2=1.2) +image = pipeline(prompt, generator=torch.manual_seed(seed)).images[0] +``` + + +![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/freeu/sdv2_1_freeu.jpg) + +## Stable Diffusion XL + +Finally, let's take a look at how FreeU affects Stable Diffusion XL results: + +```py +from diffusers import DiffusionPipeline +import torch + +pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, +).to("cuda") + +prompt = "A squirrel eating a burger" +seed = 2023 + +# Comes from +# https://wandb.ai/nasirk24/UNET-FreeU-SDXL/reports/FreeU-SDXL-Optimal-Parameters--Vmlldzo1NDg4NTUw +pipeline.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2) +image = pipeline(prompt, generator=torch.manual_seed(seed)).images[0] +``` + + +![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/freeu/sdxl_freeu.jpg) + +## Text-to-video generation + +FreeU can also be used to improve video quality: + +```python +from diffusers import DiffusionPipeline +from diffusers.utils import export_to_video +import torch + +model_id = "cerspense/zeroscope_v2_576w" +pipe = DiffusionPipeline.from_pretrained("cerspense/zeroscope_v2_576w", torch_dtype=torch.float16).to("cuda") +pipe = pipe.to("cuda") + +prompt = "an astronaut riding a horse on mars" +seed = 2023 + +# The values come from +# https://github.com/lyn-rgb/FreeU_Diffusers#video-pipelines +pipe.enable_freeu(b1=1.2, b2=1.4, s1=0.9, s2=0.2) +video_frames = pipe(prompt, height=320, width=576, num_frames=30, generator=torch.manual_seed(seed)).frames +export_to_video(video_frames, "astronaut_rides_horse.mp4") +``` + +Thanks to [kadirnar](https://github.com/kadirnar/) for helping to integrate the feature, and to [justindujardin](https://github.com/justindujardin) for the helpful discussions. \ No newline at end of file diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 8aebb3aad6..1290eff63e 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -19,6 +19,7 @@ import torch.nn.functional as F from torch import nn from ..utils import is_torch_version, logging +from ..utils.torch_utils import apply_freeu from .activations import get_activation from .attention import AdaGroupNorm from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 @@ -249,6 +250,7 @@ def get_up_block( add_upsample, resnet_eps, resnet_act_fn, + resolution_idx=None, transformer_layers_per_block=1, num_attention_heads=None, resnet_groups=None, @@ -281,6 +283,7 @@ def get_up_block( out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, + resolution_idx=resolution_idx, dropout=dropout, add_upsample=add_upsample, resnet_eps=resnet_eps, @@ -295,6 +298,7 @@ def get_up_block( out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, + resolution_idx=resolution_idx, dropout=dropout, add_upsample=add_upsample, resnet_eps=resnet_eps, @@ -314,6 +318,7 @@ def get_up_block( out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, + resolution_idx=resolution_idx, dropout=dropout, add_upsample=add_upsample, resnet_eps=resnet_eps, @@ -337,6 +342,7 @@ def get_up_block( out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, + resolution_idx=resolution_idx, dropout=dropout, add_upsample=add_upsample, resnet_eps=resnet_eps, @@ -362,6 +368,7 @@ def get_up_block( out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, + resolution_idx=resolution_idx, dropout=dropout, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, @@ -377,6 +384,7 @@ def get_up_block( out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, + resolution_idx=resolution_idx, dropout=dropout, add_upsample=add_upsample, resnet_eps=resnet_eps, @@ -390,6 +398,7 @@ def get_up_block( out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, + resolution_idx=resolution_idx, dropout=dropout, add_upsample=add_upsample, resnet_eps=resnet_eps, @@ -402,6 +411,7 @@ def get_up_block( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, + resolution_idx=resolution_idx, dropout=dropout, add_upsample=add_upsample, resnet_eps=resnet_eps, @@ -415,6 +425,7 @@ def get_up_block( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, + resolution_idx=resolution_idx, dropout=dropout, add_upsample=add_upsample, resnet_eps=resnet_eps, @@ -430,6 +441,7 @@ def get_up_block( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, + resolution_idx=resolution_idx, dropout=dropout, add_upsample=add_upsample, resnet_eps=resnet_eps, @@ -441,6 +453,7 @@ def get_up_block( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, + resolution_idx=resolution_idx, dropout=dropout, add_upsample=add_upsample, resnet_eps=resnet_eps, @@ -1993,6 +2006,7 @@ class AttnUpBlock2D(nn.Module): prev_output_channel: int, out_channels: int, temb_channels: int, + resolution_idx: int = None, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, @@ -2075,6 +2089,8 @@ class AttnUpBlock2D(nn.Module): else: self.upsamplers = None + self.resolution_idx = resolution_idx + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0): for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states @@ -2103,6 +2119,7 @@ class CrossAttnUpBlock2D(nn.Module): out_channels: int, prev_output_channel: int, temb_channels: int, + resolution_idx: int = None, dropout: float = 0.0, num_layers: int = 1, transformer_layers_per_block: int = 1, @@ -2181,6 +2198,7 @@ class CrossAttnUpBlock2D(nn.Module): self.upsamplers = None self.gradient_checkpointing = False + self.resolution_idx = resolution_idx def forward( self, @@ -2194,11 +2212,30 @@ class CrossAttnUpBlock2D(nn.Module): encoder_attention_mask: Optional[torch.FloatTensor] = None, ): lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: @@ -2252,6 +2289,7 @@ class UpBlock2D(nn.Module): prev_output_channel: int, out_channels: int, temb_channels: int, + resolution_idx: int = None, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, @@ -2292,12 +2330,33 @@ class UpBlock2D(nn.Module): self.upsamplers = None self.gradient_checkpointing = False + self.resolution_idx = resolution_idx def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0): + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: @@ -2331,6 +2390,7 @@ class UpDecoderBlock2D(nn.Module): self, in_channels: int, out_channels: int, + resolution_idx: int = None, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, @@ -2370,6 +2430,8 @@ class UpDecoderBlock2D(nn.Module): else: self.upsamplers = None + self.resolution_idx = resolution_idx + def forward(self, hidden_states, temb=None, scale: float = 1.0): for resnet in self.resnets: hidden_states = resnet(hidden_states, temb=temb, scale=scale) @@ -2386,6 +2448,7 @@ class AttnUpDecoderBlock2D(nn.Module): self, in_channels: int, out_channels: int, + resolution_idx: int = None, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, @@ -2449,6 +2512,8 @@ class AttnUpDecoderBlock2D(nn.Module): else: self.upsamplers = None + self.resolution_idx = resolution_idx + def forward(self, hidden_states, temb=None, scale: float = 1.0): for resnet, attn in zip(self.resnets, self.attentions): hidden_states = resnet(hidden_states, temb=temb, scale=scale) @@ -2469,6 +2534,7 @@ class AttnSkipUpBlock2D(nn.Module): prev_output_channel: int, out_channels: int, temb_channels: int, + resolution_idx: int = None, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, @@ -2553,6 +2619,8 @@ class AttnSkipUpBlock2D(nn.Module): self.skip_norm = None self.act = None + self.resolution_idx = resolution_idx + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None, scale: float = 1.0): for resnet in self.resnets: # pop res hidden states @@ -2589,6 +2657,7 @@ class SkipUpBlock2D(nn.Module): prev_output_channel: int, out_channels: int, temb_channels: int, + resolution_idx: int = None, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, @@ -2651,6 +2720,8 @@ class SkipUpBlock2D(nn.Module): self.skip_norm = None self.act = None + self.resolution_idx = resolution_idx + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None, scale: float = 1.0): for resnet in self.resnets: # pop res hidden states @@ -2684,6 +2755,7 @@ class ResnetUpsampleBlock2D(nn.Module): prev_output_channel: int, out_channels: int, temb_channels: int, + resolution_idx: int = None, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, @@ -2743,6 +2815,7 @@ class ResnetUpsampleBlock2D(nn.Module): self.upsamplers = None self.gradient_checkpointing = False + self.resolution_idx = resolution_idx def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0): for resnet in self.resnets: @@ -2784,6 +2857,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module): out_channels: int, prev_output_channel: int, temb_channels: int, + resolution_idx: int = None, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, @@ -2873,6 +2947,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module): self.upsamplers = None self.gradient_checkpointing = False + self.resolution_idx = resolution_idx def forward( self, @@ -2947,6 +3022,7 @@ class KUpBlock2D(nn.Module): in_channels: int, out_channels: int, temb_channels: int, + resolution_idx: int, dropout: float = 0.0, num_layers: int = 5, resnet_eps: float = 1e-5, @@ -2988,6 +3064,7 @@ class KUpBlock2D(nn.Module): self.upsamplers = None self.gradient_checkpointing = False + self.resolution_idx = resolution_idx def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0): res_hidden_states_tuple = res_hidden_states_tuple[-1] @@ -3027,6 +3104,7 @@ class KCrossAttnUpBlock2D(nn.Module): in_channels: int, out_channels: int, temb_channels: int, + resolution_idx: int, dropout: float = 0.0, num_layers: int = 4, resnet_eps: float = 1e-5, @@ -3104,6 +3182,7 @@ class KCrossAttnUpBlock2D(nn.Module): self.upsamplers = None self.gradient_checkpointing = False + self.resolution_idx = resolution_idx def forward( self, diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 866254a895..52c3fc141e 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -542,6 +542,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) add_upsample=add_upsample, resnet_eps=norm_eps, resnet_act_fn=act_fn, + resolution_idx=i, resnet_groups=norm_num_groups, cross_attention_dim=reversed_cross_attention_dim[i], num_attention_heads=reversed_num_attention_heads[i], @@ -733,6 +734,38 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value + def enable_freeu(self, s1, s2, b1, b2): + r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. + + The suffixes after the scaling factors represent the stage blocks where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that + are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + for i, upsample_block in enumerate(self.up_blocks): + setattr(upsample_block, "s1", s1) + setattr(upsample_block, "s2", s2) + setattr(upsample_block, "b1", b1) + setattr(upsample_block, "b2", b2) + + def disable_freeu(self): + """Disables the FreeU mechanism.""" + freeu_keys = {"s1", "s2", "b1", "b2"} + for i, upsample_block in enumerate(self.up_blocks): + for k in freeu_keys: + if hasattr(upsample_block, k) or getattr(upsample_block, k) is not None: + setattr(upsample_block, k, None) + def forward( self, sample: torch.FloatTensor, diff --git a/src/diffusers/models/unet_3d_blocks.py b/src/diffusers/models/unet_3d_blocks.py index ab5c393518..180ae0dc1a 100644 --- a/src/diffusers/models/unet_3d_blocks.py +++ b/src/diffusers/models/unet_3d_blocks.py @@ -15,6 +15,7 @@ import torch from torch import nn +from ..utils.torch_utils import apply_freeu from .resnet import Downsample2D, ResnetBlock2D, TemporalConvLayer, Upsample2D from .transformer_2d import Transformer2DModel from .transformer_temporal import TransformerTemporalModel @@ -87,6 +88,7 @@ def get_up_block( resnet_eps, resnet_act_fn, num_attention_heads, + resolution_idx=None, resnet_groups=None, cross_attention_dim=None, dual_cross_attention=False, @@ -107,6 +109,7 @@ def get_up_block( resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, + resolution_idx=resolution_idx, ) elif up_block_type == "CrossAttnUpBlock3D": if cross_attention_dim is None: @@ -128,6 +131,7 @@ def get_up_block( only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, + resolution_idx=resolution_idx, ) raise ValueError(f"{up_block_type} does not exist.") @@ -496,6 +500,7 @@ class CrossAttnUpBlock3D(nn.Module): use_linear_projection=False, only_cross_attention=False, upcast_attention=False, + resolution_idx=None, ): super().__init__() resnets = [] @@ -565,6 +570,7 @@ class CrossAttnUpBlock3D(nn.Module): self.upsamplers = None self.gradient_checkpointing = False + self.resolution_idx = resolution_idx def forward( self, @@ -577,6 +583,13 @@ class CrossAttnUpBlock3D(nn.Module): num_frames=1, cross_attention_kwargs=None, ): + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + # TODO(Patrick, William) - attention mask is not used for resnet, temp_conv, attn, temp_attn in zip( self.resnets, self.temp_convs, self.attentions, self.temp_attentions @@ -584,6 +597,19 @@ class CrossAttnUpBlock3D(nn.Module): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = resnet(hidden_states, temb) @@ -621,6 +647,7 @@ class UpBlock3D(nn.Module): resnet_pre_norm: bool = True, output_scale_factor=1.0, add_upsample=True, + resolution_idx=None, ): super().__init__() resnets = [] @@ -661,12 +688,32 @@ class UpBlock3D(nn.Module): self.upsamplers = None self.gradient_checkpointing = False + self.resolution_idx = resolution_idx def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, num_frames=1): + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) for resnet, temp_conv in zip(self.resnets, self.temp_convs): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = resnet(hidden_states, temb) diff --git a/src/diffusers/models/unet_3d_condition.py b/src/diffusers/models/unet_3d_condition.py index 01af31061d..4e6de97390 100644 --- a/src/diffusers/models/unet_3d_condition.py +++ b/src/diffusers/models/unet_3d_condition.py @@ -255,6 +255,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) cross_attention_dim=cross_attention_dim, num_attention_heads=reversed_num_attention_heads[i], dual_cross_attention=False, + resolution_idx=i, ) self.up_blocks.append(up_block) prev_output_channel = output_channel @@ -462,6 +463,40 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): module.gradient_checkpointing = value + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.enable_freeu + def enable_freeu(self, s1, s2, b1, b2): + r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. + + The suffixes after the scaling factors represent the stage blocks where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that + are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + for i, upsample_block in enumerate(self.up_blocks): + setattr(upsample_block, "s1", s1) + setattr(upsample_block, "s2", s2) + setattr(upsample_block, "b1", b1) + setattr(upsample_block, "b2", b2) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.disable_freeu + def disable_freeu(self): + """Disables the FreeU mechanism.""" + freeu_keys = {"s1", "s2", "b1", "b2"} + for i, upsample_block in enumerate(self.up_blocks): + for k in freeu_keys: + if hasattr(upsample_block, k) or getattr(upsample_block, k) is not None: + setattr(upsample_block, k, None) + def forward( self, sample: torch.FloatTensor, diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index 66c0fc4891..ba3930f5da 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -547,6 +547,32 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL latents = latents * self.scheduler.init_noise_sigma return latents + def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): + r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497. + + The suffixes after the scaling factors represent the stages where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values + that are known to work well for different pipelines such as Alt Diffusion v1, v2, and Alt Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + if not hasattr(self, "unet"): + raise ValueError("The pipeline must have `unet` for using FreeU.") + self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) + + def disable_freeu(self): + """Disables the FreeU mechanism if enabled.""" + self.unet.disable_freeu() + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 0050840dd1..68cdbbe78b 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -537,6 +537,32 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo latents = latents * self.scheduler.init_noise_sigma return latents + def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): + r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497. + + The suffixes after the scaling factors represent the stages where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values + that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + if not hasattr(self, "unet"): + raise ValueError("The pipeline must have `unet` for using FreeU.") + self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) + + def disable_freeu(self): + """Disables the FreeU mechanism if enabled.""" + self.unet.disable_freeu() + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 98eeb8e344..4c1bd857d7 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -560,6 +560,34 @@ class StableDiffusionXLPipeline( self.vae.decoder.conv_in.to(dtype) self.vae.decoder.mid_block.to(dtype) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu + def enable_freeu(self, s1=0.9, s2=0.2, b1=1.2, b2=1.4): + r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497. + + The suffixes after the scaling factors represent the stages where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values + that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + if not hasattr(self, "unet"): + raise ValueError("The pipeline must have `unet` for using FreeU.") + self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu + def disable_freeu(self): + """Disables the FreeU mechanism if enabled.""" + self.unet.disable_freeu() + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index fa2c4a28d1..42c00597be 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -472,6 +472,34 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora latents = latents * self.scheduler.init_noise_sigma return latents + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu + def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): + r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497. + + The suffixes after the scaling factors represent the stages where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values + that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + if not hasattr(self, "unet"): + raise ValueError("The pipeline must have `unet` for using FreeU.") + self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu + def disable_freeu(self): + """Disables the FreeU mechanism if enabled.""" + self.unet.disable_freeu() + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index f2b191496a..7d1ca1d934 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -32,6 +32,7 @@ from ...models.embeddings import ( from ...models.transformer_2d import Transformer2DModel from ...models.unet_2d_condition import UNet2DConditionOutput from ...utils import is_torch_version, logging +from ...utils.torch_utils import apply_freeu logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -749,6 +750,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): add_upsample=add_upsample, resnet_eps=norm_eps, resnet_act_fn=act_fn, + resolution_idx=i, resnet_groups=norm_num_groups, cross_attention_dim=reversed_cross_attention_dim[i], num_attention_heads=reversed_num_attention_heads[i], @@ -941,6 +943,38 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value + def enable_freeu(self, s1, s2, b1, b2): + r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. + + The suffixes after the scaling factors represent the stage blocks where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that + are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + for i, upsample_block in enumerate(self.up_blocks): + setattr(upsample_block, "s1", s1) + setattr(upsample_block, "s2", s2) + setattr(upsample_block, "b1", b1) + setattr(upsample_block, "b2", b2) + + def disable_freeu(self): + """Disables the FreeU mechanism.""" + freeu_keys = {"s1", "s2", "b1", "b2"} + for i, upsample_block in enumerate(self.up_blocks): + for k in freeu_keys: + if hasattr(upsample_block, k) or getattr(upsample_block, k) is not None: + setattr(upsample_block, k, None) + def forward( self, sample: torch.FloatTensor, @@ -1630,6 +1664,7 @@ class UpBlockFlat(nn.Module): prev_output_channel: int, out_channels: int, temb_channels: int, + resolution_idx: int = None, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, @@ -1670,12 +1705,33 @@ class UpBlockFlat(nn.Module): self.upsamplers = None self.gradient_checkpointing = False + self.resolution_idx = resolution_idx def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0): + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: @@ -1712,6 +1768,7 @@ class CrossAttnUpBlockFlat(nn.Module): out_channels: int, prev_output_channel: int, temb_channels: int, + resolution_idx: int = None, dropout: float = 0.0, num_layers: int = 1, transformer_layers_per_block: int = 1, @@ -1790,6 +1847,7 @@ class CrossAttnUpBlockFlat(nn.Module): self.upsamplers = None self.gradient_checkpointing = False + self.resolution_idx = resolution_idx def forward( self, @@ -1803,11 +1861,30 @@ class CrossAttnUpBlockFlat(nn.Module): encoder_attention_mask: Optional[torch.FloatTensor] = None, ): lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 99ea4d8cf1..7955ccb01d 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -22,6 +22,7 @@ from .import_utils import is_torch_available, is_torch_version if is_torch_available(): import torch + from torch.fft import fftn, fftshift, ifftn, ifftshift logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -86,3 +87,61 @@ def is_compiled_module(module): if is_torch_version("<", "2.0.0") or not hasattr(torch, "_dynamo"): return False return isinstance(module, torch._dynamo.eval_frame.OptimizedModule) + + +def fourier_filter(x_in, threshold, scale): + """Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497). + + This version of the method comes from here: + https://github.com/huggingface/diffusers/pull/5164#issuecomment-1732638706 + """ + x = x_in + B, C, H, W = x.shape + + # Non-power of 2 images must be float32 + if (W & (W - 1)) != 0 or (H & (H - 1)) != 0: + x = x.to(dtype=torch.float32) + + # FFT + x_freq = fftn(x, dim=(-2, -1)) + x_freq = fftshift(x_freq, dim=(-2, -1)) + + B, C, H, W = x_freq.shape + mask = torch.ones((B, C, H, W), device=x.device) + + crow, ccol = H // 2, W // 2 + mask[..., crow - threshold : crow + threshold, ccol - threshold : ccol + threshold] = scale + x_freq = x_freq * mask + + # IFFT + x_freq = ifftshift(x_freq, dim=(-2, -1)) + x_filtered = ifftn(x_freq, dim=(-2, -1)).real + + return x_filtered.to(dtype=x_in.dtype) + + +def apply_freeu( + resolution_idx: int, hidden_states: torch.Tensor, res_hidden_states: torch.Tensor, **freeu_kwargs +) -> Tuple[torch.Tensor, torch.Tensor]: + """Applies the FreeU mechanism as introduced in https: + //arxiv.org/abs/2309.11497. Adapted from the official code repository: https://github.com/ChenyangSi/FreeU. + + Args: + resolution_idx (`int`): Integer denoting the UNet block where FreeU is being applied. + hidden_states (`torch.Tensor`): Inputs to the underlying block. + res_hidden_states (`torch.Tensor`): Features from the skip block corresponding to the underlying block. + s1 (`float`): Scaling factor for stage 1 to attenuate the contributions of the skip features. + s2 (`float`): Scaling factor for stage 2 to attenuate the contributions of the skip features. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + if resolution_idx == 0: + num_half_channels = hidden_states.shape[1] // 2 + hidden_states[:, :num_half_channels] = hidden_states[:, :num_half_channels] * freeu_kwargs["b1"] + res_hidden_states = fourier_filter(res_hidden_states, threshold=1, scale=freeu_kwargs["s1"]) + if resolution_idx == 1: + num_half_channels = hidden_states.shape[1] // 2 + hidden_states[:, :num_half_channels] = hidden_states[:, :num_half_channels] * freeu_kwargs["b2"] + res_hidden_states = fourier_filter(res_hidden_states, threshold=1, scale=freeu_kwargs["s2"]) + + return hidden_states, res_hidden_states diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index df9e8d47f1..d6a63b9891 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -565,6 +565,47 @@ class StableDiffusionPipelineFastTests( def test_inference_batch_single_identical(self): super().test_inference_batch_single_identical(expected_max_diff=3e-3) + def test_freeu_enabled(self): + components = self.get_dummy_components() + sd_pipe = StableDiffusionPipeline(**components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "hey" + output = sd_pipe(prompt, num_inference_steps=1, output_type="np", generator=torch.manual_seed(0)).images + + sd_pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4) + output_freeu = sd_pipe(prompt, num_inference_steps=1, output_type="np", generator=torch.manual_seed(0)).images + + assert not np.allclose( + output[0, -3:, -3:, -1], output_freeu[0, -3:, -3:, -1] + ), "Enabling of FreeU should lead to different results." + + def test_freeu_disabled(self): + components = self.get_dummy_components() + sd_pipe = StableDiffusionPipeline(**components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "hey" + output = sd_pipe(prompt, num_inference_steps=1, output_type="np", generator=torch.manual_seed(0)).images + + sd_pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4) + sd_pipe.disable_freeu() + + freeu_keys = {"s1", "s2", "b1", "b2"} + for upsample_block in sd_pipe.unet.up_blocks: + for key in freeu_keys: + assert getattr(upsample_block, key) is None, f"Disabling of FreeU should have set {key} to None." + + output_no_freeu = sd_pipe( + prompt, num_inference_steps=1, output_type="np", generator=torch.manual_seed(0) + ).images + + assert np.allclose( + output[0, -3:, -3:, -1], output_no_freeu[0, -3:, -3:, -1] + ), "Disabling of FreeU should lead to results similar to the default pipeline results." + @slow @require_torch_gpu @@ -600,6 +641,20 @@ class StableDiffusionPipelineSlowTests(unittest.TestCase): expected_slice = np.array([0.43625, 0.43554, 0.36670, 0.40660, 0.39703, 0.38658, 0.43936, 0.43557, 0.40592]) assert np.abs(image_slice - expected_slice).max() < 3e-3 + def test_stable_diffusion_v1_4_with_freeu(self): + sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_inputs(torch_device) + inputs["num_inference_steps"] = 25 + + sd_pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4) + image = sd_pipe(**inputs).images + image = image[0, -3:, -3:, -1].flatten() + expected_image = [0.0721, 0.0588, 0.0268, 0.0384, 0.0636, 0.0, 0.0429, 0.0344, 0.0309] + max_diff = np.abs(expected_image - image).max() + assert max_diff < 1e-3 + def test_stable_diffusion_1_4_pndm(self): sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4") sd_pipe = sd_pipe.to(torch_device) @@ -1079,7 +1134,7 @@ class StableDiffusionPipelineNightlyTests(unittest.TestCase): "generator": generator, "num_inference_steps": 50, "guidance_scale": 7.5, - "output_type": "numpy", + "output_type": "np", } return inputs @@ -1155,19 +1210,3 @@ class StableDiffusionPipelineNightlyTests(unittest.TestCase): ) max_diff = np.abs(expected_image - image).max() assert max_diff < 1e-3 - - def test_stable_diffusion_dpm(self): - sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to(torch_device) - sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd_pipe.scheduler.config) - sd_pipe.set_progress_bar_config(disable=None) - - inputs = self.get_inputs(torch_device) - inputs["num_inference_steps"] = 25 - image = sd_pipe(**inputs).images[0] - - expected_image = load_numpy( - "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main" - "/stable_diffusion_text2img/stable_diffusion_1_4_dpm_multi.npy" - ) - max_diff = np.abs(expected_image - image).max() - assert max_diff < 1e-3 diff --git a/tests/pipelines/text_to_video_synthesis/test_text_to_video.py b/tests/pipelines/text_to_video_synthesis/test_text_to_video.py index 2c47dc492d..933583ce4b 100644 --- a/tests/pipelines/text_to_video_synthesis/test_text_to_video.py +++ b/tests/pipelines/text_to_video_synthesis/test_text_to_video.py @@ -193,3 +193,21 @@ class TextToVideoSDPipelineSlowTests(unittest.TestCase): video = video_frames.cpu().numpy() assert np.abs(expected_video - video).mean() < 5e-2 + + def test_two_step_model_with_freeu(self): + expected_video = [] + + pipe = TextToVideoSDPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b") + pipe = pipe.to(torch_device) + + prompt = "Spiderman is surfing" + generator = torch.Generator(device="cpu").manual_seed(0) + + pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4) + video_frames = pipe(prompt, generator=generator, num_inference_steps=2, output_type="pt").frames + video = video_frames.cpu().numpy() + video = video[0, 0, -3:, -3:, -1].flatten() + + expected_video = [-0.3102, -0.2477, -0.1772, -0.648, -0.6176, -0.5484, -0.0217, -0.056, -0.0177] + + assert np.abs(expected_video - video).mean() < 5e-2