mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Modular Diffusers Guiders (#11311)
* cfg; slg; pag; sdxl without controlnet * support sdxl controlnet * support controlnet union * update * update * cfg zero* * use unwrap_module for torch compiled modules * remove guider kwargs * remove commented code * remove old guider * fix slg bug * remove debug print * autoguidance * smoothed energy guidance * add note about seg * tangential cfg * cfg plus plus * support cfgpp in ddim * apply review suggestions * refactor * rename enable/disable * remove cfg++ for now * rename do_classifier_free_guidance->prepare_unconditional_embeds * remove unused
This commit is contained in:
@@ -33,6 +33,7 @@ from .utils import (
|
||||
|
||||
_import_structure = {
|
||||
"configuration_utils": ["ConfigMixin"],
|
||||
"guiders": [],
|
||||
"hooks": [],
|
||||
"loaders": ["FromOriginalModelMixin"],
|
||||
"models": [],
|
||||
@@ -129,12 +130,26 @@ except OptionalDependencyNotAvailable:
|
||||
_import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
|
||||
|
||||
else:
|
||||
_import_structure["guiders"].extend(
|
||||
[
|
||||
"AdaptiveProjectedGuidance",
|
||||
"AutoGuidance",
|
||||
"ClassifierFreeGuidance",
|
||||
"ClassifierFreeZeroStarGuidance",
|
||||
"SkipLayerGuidance",
|
||||
"SmoothedEnergyGuidance",
|
||||
"TangentialClassifierFreeGuidance",
|
||||
]
|
||||
)
|
||||
_import_structure["hooks"].extend(
|
||||
[
|
||||
"FasterCacheConfig",
|
||||
"HookRegistry",
|
||||
"PyramidAttentionBroadcastConfig",
|
||||
"LayerSkipConfig",
|
||||
"SmoothedEnergyGuidanceConfig",
|
||||
"apply_faster_cache",
|
||||
"apply_layer_skip",
|
||||
"apply_pyramid_attention_broadcast",
|
||||
]
|
||||
)
|
||||
@@ -711,10 +726,22 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_pt_objects import * # noqa F403
|
||||
else:
|
||||
from .guiders import (
|
||||
AdaptiveProjectedGuidance,
|
||||
AutoGuidance,
|
||||
ClassifierFreeGuidance,
|
||||
ClassifierFreeZeroStarGuidance,
|
||||
SkipLayerGuidance,
|
||||
SmoothedEnergyGuidance,
|
||||
TangentialClassifierFreeGuidance,
|
||||
)
|
||||
from .hooks import (
|
||||
FasterCacheConfig,
|
||||
HookRegistry,
|
||||
LayerSkipConfig,
|
||||
PyramidAttentionBroadcastConfig,
|
||||
SmoothedEnergyGuidanceConfig,
|
||||
apply_layer_skip,
|
||||
apply_faster_cache,
|
||||
apply_pyramid_attention_broadcast,
|
||||
)
|
||||
|
||||
@@ -1,748 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .models.attention_processor import (
|
||||
Attention,
|
||||
AttentionProcessor,
|
||||
PAGCFGIdentitySelfAttnProcessor2_0,
|
||||
PAGIdentitySelfAttnProcessor2_0,
|
||||
)
|
||||
from .utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
||||
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
r"""
|
||||
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
|
||||
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
||||
|
||||
Args:
|
||||
noise_cfg (`torch.Tensor`):
|
||||
The predicted noise tensor for the guided diffusion process.
|
||||
noise_pred_text (`torch.Tensor`):
|
||||
The predicted noise tensor for the text-guided diffusion process.
|
||||
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
||||
A rescale factor applied to the noise predictions.
|
||||
|
||||
Returns:
|
||||
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
|
||||
"""
|
||||
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
||||
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
||||
# rescale the results from guidance (fixes overexposure)
|
||||
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
||||
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
||||
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
||||
return noise_cfg
|
||||
|
||||
|
||||
class CFGGuider:
|
||||
"""
|
||||
This class is used to guide the pipeline with CFG (Classifier-Free Guidance).
|
||||
"""
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1.0 and not self._disable_guidance
|
||||
|
||||
@property
|
||||
def guidance_rescale(self):
|
||||
return self._guidance_rescale
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
return self._batch_size
|
||||
|
||||
def set_guider(self, pipeline, guider_kwargs: Dict[str, Any]):
|
||||
# a flag to disable CFG, e.g. we disable it for LCM and use a guidance scale embedding instead
|
||||
disable_guidance = guider_kwargs.get("disable_guidance", False)
|
||||
guidance_scale = guider_kwargs.get("guidance_scale", None)
|
||||
if guidance_scale is None:
|
||||
raise ValueError("guidance_scale is not provided in guider_kwargs")
|
||||
guidance_rescale = guider_kwargs.get("guidance_rescale", 0.0)
|
||||
batch_size = guider_kwargs.get("batch_size", None)
|
||||
if batch_size is None:
|
||||
raise ValueError("batch_size is not provided in guider_kwargs")
|
||||
self._guidance_scale = guidance_scale
|
||||
self._guidance_rescale = guidance_rescale
|
||||
self._batch_size = batch_size
|
||||
self._disable_guidance = disable_guidance
|
||||
|
||||
def reset_guider(self, pipeline):
|
||||
pass
|
||||
|
||||
def maybe_update_guider(self, pipeline, timestep):
|
||||
pass
|
||||
|
||||
def maybe_update_input(self, pipeline, cond_input):
|
||||
pass
|
||||
|
||||
def _maybe_split_prepared_input(self, cond):
|
||||
"""
|
||||
Process and potentially split the conditional input for Classifier-Free Guidance (CFG).
|
||||
|
||||
This method handles inputs that may already have CFG applied (i.e. when `cond` is output of `prepare_input`).
|
||||
It determines whether to split the input based on its batch size relative to the expected batch size.
|
||||
|
||||
Args:
|
||||
cond (torch.Tensor): The conditional input tensor to process.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
||||
- The negative conditional input (uncond_input)
|
||||
- The positive conditional input (cond_input)
|
||||
"""
|
||||
if cond.shape[0] == self.batch_size * 2:
|
||||
neg_cond = cond[0 : self.batch_size]
|
||||
cond = cond[self.batch_size :]
|
||||
return neg_cond, cond
|
||||
elif cond.shape[0] == self.batch_size:
|
||||
return cond, cond
|
||||
else:
|
||||
raise ValueError(f"Unsupported input shape: {cond.shape}")
|
||||
|
||||
def _is_prepared_input(self, cond):
|
||||
"""
|
||||
Check if the input is already prepared for Classifier-Free Guidance (CFG).
|
||||
|
||||
Args:
|
||||
cond (torch.Tensor): The conditional input tensor to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the input is already prepared, False otherwise.
|
||||
"""
|
||||
cond_tensor = cond[0] if isinstance(cond, (list, tuple)) else cond
|
||||
|
||||
return cond_tensor.shape[0] == self.batch_size * 2
|
||||
|
||||
def prepare_input(
|
||||
self,
|
||||
cond_input: Union[torch.Tensor, List[torch.Tensor]],
|
||||
negative_cond_input: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
"""
|
||||
Prepare the input for CFG.
|
||||
|
||||
Args:
|
||||
cond_input (Union[torch.Tensor, List[torch.Tensor]]):
|
||||
The conditional input. It can be a single tensor or a
|
||||
list of tensors. It must have the same length as `negative_cond_input`.
|
||||
negative_cond_input (Union[torch.Tensor, List[torch.Tensor]]): The negative conditional input. It can be a
|
||||
single tensor or a list of tensors. It must have the same length as `cond_input`.
|
||||
|
||||
Returns:
|
||||
Union[torch.Tensor, List[torch.Tensor]]: The prepared input.
|
||||
"""
|
||||
|
||||
# we check if cond_input already has CFG applied, and split if it is the case.
|
||||
if self._is_prepared_input(cond_input) and self.do_classifier_free_guidance:
|
||||
return cond_input
|
||||
|
||||
if self._is_prepared_input(cond_input) and not self.do_classifier_free_guidance:
|
||||
if isinstance(cond_input, list):
|
||||
negative_cond_input, cond_input = zip(*[self._maybe_split_prepared_input(cond) for cond in cond_input])
|
||||
else:
|
||||
negative_cond_input, cond_input = self._maybe_split_prepared_input(cond_input)
|
||||
|
||||
if not self._is_prepared_input(cond_input) and self.do_classifier_free_guidance and negative_cond_input is None:
|
||||
raise ValueError(
|
||||
"`negative_cond_input` is required when cond_input does not already contains negative conditional input"
|
||||
)
|
||||
|
||||
if isinstance(cond_input, (list, tuple)):
|
||||
if not self.do_classifier_free_guidance:
|
||||
return cond_input
|
||||
|
||||
if len(negative_cond_input) != len(cond_input):
|
||||
raise ValueError("The length of negative_cond_input and cond_input must be the same.")
|
||||
prepared_input = []
|
||||
for neg_cond, cond in zip(negative_cond_input, cond_input):
|
||||
if neg_cond.shape[0] != cond.shape[0]:
|
||||
raise ValueError("The batch size of negative_cond_input and cond_input must be the same.")
|
||||
prepared_input.append(torch.cat([neg_cond, cond], dim=0))
|
||||
return prepared_input
|
||||
|
||||
elif isinstance(cond_input, torch.Tensor):
|
||||
if not self.do_classifier_free_guidance:
|
||||
return cond_input
|
||||
else:
|
||||
return torch.cat([negative_cond_input, cond_input], dim=0)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported input type: {type(cond_input)}")
|
||||
|
||||
def apply_guidance(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timestep: int = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if not self.do_classifier_free_guidance:
|
||||
return model_output
|
||||
|
||||
noise_pred_uncond, noise_pred_text = model_output.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
|
||||
return noise_pred
|
||||
|
||||
|
||||
class PAGGuider:
|
||||
"""
|
||||
This class is used to guide the pipeline with CFG (Classifier-Free Guidance).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pag_applied_layers: Union[str, List[str]],
|
||||
pag_attn_processors: Tuple[AttentionProcessor, AttentionProcessor] = (
|
||||
PAGCFGIdentitySelfAttnProcessor2_0(),
|
||||
PAGIdentitySelfAttnProcessor2_0(),
|
||||
),
|
||||
):
|
||||
r"""
|
||||
Set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid.
|
||||
|
||||
Args:
|
||||
pag_applied_layers (`str` or `List[str]`):
|
||||
One or more strings identifying the layer names, or a simple regex for matching multiple layers, where
|
||||
PAG is to be applied. A few ways of expected usage are as follows:
|
||||
- Single layers specified as - "blocks.{layer_index}"
|
||||
- Multiple layers as a list - ["blocks.{layers_index_1}", "blocks.{layer_index_2}", ...]
|
||||
- Multiple layers as a block name - "mid"
|
||||
- Multiple layers as regex - "blocks.({layer_index_1}|{layer_index_2})"
|
||||
pag_attn_processors:
|
||||
(`Tuple[AttentionProcessor, AttentionProcessor]`, defaults to `(PAGCFGIdentitySelfAttnProcessor2_0(),
|
||||
PAGIdentitySelfAttnProcessor2_0())`): A tuple of two attention processors. The first attention
|
||||
processor is for PAG with Classifier-free guidance enabled (conditional and unconditional). The second
|
||||
attention processor is for PAG with CFG disabled (unconditional only).
|
||||
"""
|
||||
|
||||
if not isinstance(pag_applied_layers, list):
|
||||
pag_applied_layers = [pag_applied_layers]
|
||||
if pag_attn_processors is not None:
|
||||
if not isinstance(pag_attn_processors, tuple) or len(pag_attn_processors) != 2:
|
||||
raise ValueError("Expected a tuple of two attention processors")
|
||||
|
||||
for i in range(len(pag_applied_layers)):
|
||||
if not isinstance(pag_applied_layers[i], str):
|
||||
raise ValueError(
|
||||
f"Expected either a string or a list of string but got type {type(pag_applied_layers[i])}"
|
||||
)
|
||||
|
||||
self.pag_applied_layers = pag_applied_layers
|
||||
self._pag_attn_processors = pag_attn_processors
|
||||
|
||||
def _set_pag_attn_processor(self, model, pag_applied_layers, do_classifier_free_guidance):
|
||||
r"""
|
||||
Set the attention processor for the PAG layers.
|
||||
"""
|
||||
pag_attn_processors = self._pag_attn_processors
|
||||
pag_attn_proc = pag_attn_processors[0] if do_classifier_free_guidance else pag_attn_processors[1]
|
||||
|
||||
def is_self_attn(module: nn.Module) -> bool:
|
||||
r"""
|
||||
Check if the module is self-attention module based on its name.
|
||||
"""
|
||||
return isinstance(module, Attention) and not module.is_cross_attention
|
||||
|
||||
def is_fake_integral_match(layer_id, name):
|
||||
layer_id = layer_id.split(".")[-1]
|
||||
name = name.split(".")[-1]
|
||||
return layer_id.isnumeric() and name.isnumeric() and layer_id == name
|
||||
|
||||
for layer_id in pag_applied_layers:
|
||||
# for each PAG layer input, we find corresponding self-attention layers in the unet model
|
||||
target_modules = []
|
||||
|
||||
for name, module in model.named_modules():
|
||||
# Identify the following simple cases:
|
||||
# (1) Self Attention layer existing
|
||||
# (2) Whether the module name matches pag layer id even partially
|
||||
# (3) Make sure it's not a fake integral match if the layer_id ends with a number
|
||||
# For example, blocks.1, blocks.10 should be differentiable if layer_id="blocks.1"
|
||||
if (
|
||||
is_self_attn(module)
|
||||
and re.search(layer_id, name) is not None
|
||||
and not is_fake_integral_match(layer_id, name)
|
||||
):
|
||||
logger.debug(f"Applying PAG to layer: {name}")
|
||||
target_modules.append(module)
|
||||
|
||||
if len(target_modules) == 0:
|
||||
raise ValueError(f"Cannot find PAG layer to set attention processor for: {layer_id}")
|
||||
|
||||
for module in target_modules:
|
||||
module.processor = pag_attn_proc
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1 and not self._disable_guidance
|
||||
|
||||
@property
|
||||
def do_perturbed_attention_guidance(self):
|
||||
return self._pag_scale > 0 and not self._disable_guidance
|
||||
|
||||
@property
|
||||
def do_pag_adaptive_scaling(self):
|
||||
return self._pag_adaptive_scale > 0 and self._pag_scale > 0 and not self._disable_guidance
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def guidance_rescale(self):
|
||||
return self._guidance_rescale
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
return self._batch_size
|
||||
|
||||
@property
|
||||
def pag_scale(self):
|
||||
return self._pag_scale
|
||||
|
||||
@property
|
||||
def pag_adaptive_scale(self):
|
||||
return self._pag_adaptive_scale
|
||||
|
||||
def set_guider(self, pipeline, guider_kwargs: Dict[str, Any]):
|
||||
pag_scale = guider_kwargs.get("pag_scale", 3.0)
|
||||
pag_adaptive_scale = guider_kwargs.get("pag_adaptive_scale", 0.0)
|
||||
|
||||
batch_size = guider_kwargs.get("batch_size", None)
|
||||
if batch_size is None:
|
||||
raise ValueError("batch_size is a required argument for PAGGuider")
|
||||
|
||||
guidance_scale = guider_kwargs.get("guidance_scale", None)
|
||||
guidance_rescale = guider_kwargs.get("guidance_rescale", 0.0)
|
||||
disable_guidance = guider_kwargs.get("disable_guidance", False)
|
||||
|
||||
if guidance_scale is None:
|
||||
raise ValueError("guidance_scale is a required argument for PAGGuider")
|
||||
|
||||
self._pag_scale = pag_scale
|
||||
self._pag_adaptive_scale = pag_adaptive_scale
|
||||
self._guidance_scale = guidance_scale
|
||||
self._disable_guidance = disable_guidance
|
||||
self._guidance_rescale = guidance_rescale
|
||||
self._batch_size = batch_size
|
||||
if not hasattr(pipeline, "original_attn_proc") or pipeline.original_attn_proc is None:
|
||||
pipeline.original_attn_proc = pipeline.unet.attn_processors
|
||||
self._set_pag_attn_processor(
|
||||
model=pipeline.unet if hasattr(pipeline, "unet") else pipeline.transformer,
|
||||
pag_applied_layers=self.pag_applied_layers,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
)
|
||||
|
||||
def reset_guider(self, pipeline):
|
||||
if (
|
||||
self.do_perturbed_attention_guidance
|
||||
and hasattr(pipeline, "original_attn_proc")
|
||||
and pipeline.original_attn_proc is not None
|
||||
):
|
||||
pipeline.unet.set_attn_processor(pipeline.original_attn_proc)
|
||||
pipeline.original_attn_proc = None
|
||||
|
||||
def maybe_update_guider(self, pipeline, timestep):
|
||||
pass
|
||||
|
||||
def maybe_update_input(self, pipeline, cond_input):
|
||||
pass
|
||||
|
||||
def _is_prepared_input(self, cond):
|
||||
"""
|
||||
Check if the input is already prepared for Perturbed Attention Guidance (PAG).
|
||||
|
||||
Args:
|
||||
cond (torch.Tensor): The conditional input tensor to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the input is already prepared, False otherwise.
|
||||
"""
|
||||
cond_tensor = cond[0] if isinstance(cond, (list, tuple)) else cond
|
||||
|
||||
return cond_tensor.shape[0] == self.batch_size * 3
|
||||
|
||||
def _maybe_split_prepared_input(self, cond):
|
||||
"""
|
||||
Process and potentially split the conditional input for Classifier-Free Guidance (CFG).
|
||||
|
||||
This method handles inputs that may already have CFG applied (i.e. when `cond` is output of `prepare_input`).
|
||||
It determines whether to split the input based on its batch size relative to the expected batch size.
|
||||
|
||||
Args:
|
||||
cond (torch.Tensor): The conditional input tensor to process.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
||||
- The negative conditional input (uncond_input)
|
||||
- The positive conditional input (cond_input)
|
||||
"""
|
||||
if cond.shape[0] == self.batch_size * 3:
|
||||
neg_cond = cond[0 : self.batch_size]
|
||||
cond = cond[self.batch_size : self.batch_size * 2]
|
||||
return neg_cond, cond
|
||||
elif cond.shape[0] == self.batch_size:
|
||||
return cond, cond
|
||||
else:
|
||||
raise ValueError(f"Unsupported input shape: {cond.shape}")
|
||||
|
||||
def prepare_input(
|
||||
self,
|
||||
cond_input: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]],
|
||||
negative_cond_input: Optional[Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]] = None,
|
||||
) -> Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]:
|
||||
"""
|
||||
Prepare the input for CFG.
|
||||
|
||||
Args:
|
||||
cond_input (Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]):
|
||||
The conditional input. It can be a single tensor or a
|
||||
list of tensors. It must have the same length as `negative_cond_input`.
|
||||
negative_cond_input (Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]):
|
||||
The negative conditional input. It can be a single tensor or a list of tensors. It must have the same
|
||||
length as `cond_input`.
|
||||
|
||||
Returns:
|
||||
Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]: The prepared input.
|
||||
"""
|
||||
|
||||
# we check if cond_input already has CFG applied, and split if it is the case.
|
||||
|
||||
if self._is_prepared_input(cond_input) and self.do_perturbed_attention_guidance:
|
||||
return cond_input
|
||||
|
||||
if self._is_prepared_input(cond_input) and not self.do_perturbed_attention_guidance:
|
||||
if isinstance(cond_input, list):
|
||||
negative_cond_input, cond_input = zip(*[self._maybe_split_prepared_input(cond) for cond in cond_input])
|
||||
else:
|
||||
negative_cond_input, cond_input = self._maybe_split_prepared_input(cond_input)
|
||||
|
||||
if not self._is_prepared_input(cond_input) and self.do_perturbed_attention_guidance and negative_cond_input is None:
|
||||
raise ValueError(
|
||||
"`negative_cond_input` is required when cond_input does not already contains negative conditional input"
|
||||
)
|
||||
|
||||
if isinstance(cond_input, (list, tuple)):
|
||||
if not self.do_perturbed_attention_guidance:
|
||||
return cond_input
|
||||
|
||||
if len(negative_cond_input) != len(cond_input):
|
||||
raise ValueError("The length of negative_cond_input and cond_input must be the same.")
|
||||
|
||||
prepared_input = []
|
||||
for neg_cond, cond in zip(negative_cond_input, cond_input):
|
||||
if neg_cond.shape[0] != cond.shape[0]:
|
||||
raise ValueError("The batch size of negative_cond_input and cond_input must be the same.")
|
||||
|
||||
cond = torch.cat([cond] * 2, dim=0)
|
||||
if self.do_classifier_free_guidance:
|
||||
prepared_input.append(torch.cat([neg_cond, cond], dim=0))
|
||||
else:
|
||||
prepared_input.append(cond)
|
||||
|
||||
return prepared_input
|
||||
|
||||
elif isinstance(cond_input, torch.Tensor):
|
||||
if not self.do_perturbed_attention_guidance:
|
||||
return cond_input
|
||||
|
||||
cond_input = torch.cat([cond_input] * 2, dim=0)
|
||||
if self.do_classifier_free_guidance:
|
||||
return torch.cat([negative_cond_input, cond_input], dim=0)
|
||||
else:
|
||||
return cond_input
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported input type: {type(negative_cond_input)} and {type(cond_input)}")
|
||||
|
||||
def apply_guidance(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timestep: int,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if not self.do_perturbed_attention_guidance:
|
||||
return model_output
|
||||
|
||||
if self.do_pag_adaptive_scaling:
|
||||
pag_scale = max(self._pag_scale - self._pag_adaptive_scale * (1000 - timestep), 0)
|
||||
else:
|
||||
pag_scale = self._pag_scale
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text, noise_pred_perturb = model_output.chunk(3)
|
||||
noise_pred = (
|
||||
noise_pred_uncond
|
||||
+ self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
+ pag_scale * (noise_pred_text - noise_pred_perturb)
|
||||
)
|
||||
else:
|
||||
noise_pred_text, noise_pred_perturb = model_output.chunk(2)
|
||||
noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb)
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
|
||||
|
||||
return noise_pred
|
||||
|
||||
|
||||
class MomentumBuffer:
|
||||
def __init__(self, momentum: float):
|
||||
self.momentum = momentum
|
||||
self.running_average = 0
|
||||
|
||||
def update(self, update_value: torch.Tensor):
|
||||
new_average = self.momentum * self.running_average
|
||||
self.running_average = update_value + new_average
|
||||
|
||||
|
||||
class APGGuider:
|
||||
"""
|
||||
This class is used to guide the pipeline with APG (Adaptive Projected Guidance).
|
||||
"""
|
||||
|
||||
def normalized_guidance(
|
||||
self,
|
||||
pred_cond: torch.Tensor,
|
||||
pred_uncond: torch.Tensor,
|
||||
guidance_scale: float,
|
||||
momentum_buffer: MomentumBuffer = None,
|
||||
norm_threshold: float = 0.0,
|
||||
eta: float = 1.0,
|
||||
):
|
||||
"""
|
||||
Based on the findings of [Eliminating Oversaturation and Artifacts of High Guidance Scales in Diffusion
|
||||
Models](https://arxiv.org/pdf/2410.02416)
|
||||
"""
|
||||
diff = pred_cond - pred_uncond
|
||||
if momentum_buffer is not None:
|
||||
momentum_buffer.update(diff)
|
||||
diff = momentum_buffer.running_average
|
||||
if norm_threshold > 0:
|
||||
ones = torch.ones_like(diff)
|
||||
diff_norm = diff.norm(p=2, dim=[-1, -2, -3], keepdim=True)
|
||||
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
|
||||
diff = diff * scale_factor
|
||||
v0, v1 = diff.double(), pred_cond.double()
|
||||
v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3])
|
||||
v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1
|
||||
v0_orthogonal = v0 - v0_parallel
|
||||
diff_parallel, diff_orthogonal = v0_parallel.to(diff.dtype), v0_orthogonal.to(diff.dtype)
|
||||
normalized_update = diff_orthogonal + eta * diff_parallel
|
||||
pred_guided = pred_cond + (guidance_scale - 1) * normalized_update
|
||||
return pred_guided
|
||||
|
||||
@property
|
||||
def adaptive_projected_guidance_momentum(self):
|
||||
return self._adaptive_projected_guidance_momentum
|
||||
|
||||
@property
|
||||
def adaptive_projected_guidance_rescale_factor(self):
|
||||
return self._adaptive_projected_guidance_rescale_factor
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1.0 and not self._disable_guidance
|
||||
|
||||
@property
|
||||
def guidance_rescale(self):
|
||||
return self._guidance_rescale
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
return self._batch_size
|
||||
|
||||
def set_guider(self, pipeline, guider_kwargs: Dict[str, Any]):
|
||||
disable_guidance = guider_kwargs.get("disable_guidance", False)
|
||||
guidance_scale = guider_kwargs.get("guidance_scale", None)
|
||||
if guidance_scale is None:
|
||||
raise ValueError("guidance_scale is not provided in guider_kwargs")
|
||||
adaptive_projected_guidance_momentum = guider_kwargs.get("adaptive_projected_guidance_momentum", None)
|
||||
adaptive_projected_guidance_rescale_factor = guider_kwargs.get(
|
||||
"adaptive_projected_guidance_rescale_factor", 15.0
|
||||
)
|
||||
guidance_rescale = guider_kwargs.get("guidance_rescale", 0.0)
|
||||
batch_size = guider_kwargs.get("batch_size", None)
|
||||
if batch_size is None:
|
||||
raise ValueError("batch_size is not provided in guider_kwargs")
|
||||
self._adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
|
||||
self._adaptive_projected_guidance_rescale_factor = adaptive_projected_guidance_rescale_factor
|
||||
self._guidance_scale = guidance_scale
|
||||
self._guidance_rescale = guidance_rescale
|
||||
self._batch_size = batch_size
|
||||
self._disable_guidance = disable_guidance
|
||||
if adaptive_projected_guidance_momentum is not None:
|
||||
self.momentum_buffer = MomentumBuffer(adaptive_projected_guidance_momentum)
|
||||
else:
|
||||
self.momentum_buffer = None
|
||||
self.scheduler = pipeline.scheduler
|
||||
|
||||
def reset_guider(self, pipeline):
|
||||
pass
|
||||
|
||||
def maybe_update_guider(self, pipeline, timestep):
|
||||
pass
|
||||
|
||||
def maybe_update_input(self, pipeline, cond_input):
|
||||
pass
|
||||
|
||||
def _maybe_split_prepared_input(self, cond):
|
||||
"""
|
||||
Process and potentially split the conditional input for Classifier-Free Guidance (CFG).
|
||||
|
||||
This method handles inputs that may already have CFG applied (i.e. when `cond` is output of `prepare_input`).
|
||||
It determines whether to split the input based on its batch size relative to the expected batch size.
|
||||
|
||||
Args:
|
||||
cond (torch.Tensor): The conditional input tensor to process.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
||||
- The negative conditional input (uncond_input)
|
||||
- The positive conditional input (cond_input)
|
||||
"""
|
||||
if cond.shape[0] == self.batch_size * 2:
|
||||
neg_cond = cond[0 : self.batch_size]
|
||||
cond = cond[self.batch_size :]
|
||||
return neg_cond, cond
|
||||
elif cond.shape[0] == self.batch_size:
|
||||
return cond, cond
|
||||
else:
|
||||
raise ValueError(f"Unsupported input shape: {cond.shape}")
|
||||
|
||||
def _is_prepared_input(self, cond):
|
||||
"""
|
||||
Check if the input is already prepared for Classifier-Free Guidance (CFG).
|
||||
|
||||
Args:
|
||||
cond (torch.Tensor): The conditional input tensor to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the input is already prepared, False otherwise.
|
||||
"""
|
||||
cond_tensor = cond[0] if isinstance(cond, (list, tuple)) else cond
|
||||
|
||||
return cond_tensor.shape[0] == self.batch_size * 2
|
||||
|
||||
def prepare_input(
|
||||
self,
|
||||
cond_input: Union[torch.Tensor, List[torch.Tensor]],
|
||||
negative_cond_input: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
"""
|
||||
Prepare the input for CFG.
|
||||
|
||||
Args:
|
||||
cond_input (Union[torch.Tensor, List[torch.Tensor]]):
|
||||
The conditional input. It can be a single tensor or a
|
||||
list of tensors. It must have the same length as `negative_cond_input`.
|
||||
negative_cond_input (Union[torch.Tensor, List[torch.Tensor]]): The negative conditional input. It can be a
|
||||
single tensor or a list of tensors. It must have the same length as `cond_input`.
|
||||
|
||||
Returns:
|
||||
Union[torch.Tensor, List[torch.Tensor]]: The prepared input.
|
||||
"""
|
||||
|
||||
# we check if cond_input already has CFG applied, and split if it is the case.
|
||||
if self._is_prepared_input(cond_input) and self.do_classifier_free_guidance:
|
||||
return cond_input
|
||||
|
||||
if self._is_prepared_input(cond_input) and not self.do_classifier_free_guidance:
|
||||
if isinstance(cond_input, list):
|
||||
negative_cond_input, cond_input = zip(*[self._maybe_split_prepared_input(cond) for cond in cond_input])
|
||||
else:
|
||||
negative_cond_input, cond_input = self._maybe_split_prepared_input(cond_input)
|
||||
|
||||
if not self._is_prepared_input(cond_input) and self.do_classifier_free_guidance and negative_cond_input is None:
|
||||
raise ValueError(
|
||||
"`negative_cond_input` is required when cond_input does not already contains negative conditional input"
|
||||
)
|
||||
|
||||
if isinstance(cond_input, (list, tuple)):
|
||||
if not self.do_classifier_free_guidance:
|
||||
return cond_input
|
||||
|
||||
if len(negative_cond_input) != len(cond_input):
|
||||
raise ValueError("The length of negative_cond_input and cond_input must be the same.")
|
||||
prepared_input = []
|
||||
for neg_cond, cond in zip(negative_cond_input, cond_input):
|
||||
if neg_cond.shape[0] != cond.shape[0]:
|
||||
raise ValueError("The batch size of negative_cond_input and cond_input must be the same.")
|
||||
prepared_input.append(torch.cat([neg_cond, cond], dim=0))
|
||||
return prepared_input
|
||||
|
||||
elif isinstance(cond_input, torch.Tensor):
|
||||
if not self.do_classifier_free_guidance:
|
||||
return cond_input
|
||||
else:
|
||||
return torch.cat([negative_cond_input, cond_input], dim=0)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported input type: {type(cond_input)}")
|
||||
|
||||
def apply_guidance(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timestep: int = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if not self.do_classifier_free_guidance:
|
||||
return model_output
|
||||
|
||||
if latents is None:
|
||||
raise ValueError("APG requires `latents` to convert model output to denoised prediction (x0).")
|
||||
|
||||
sigma = self.scheduler.sigmas[self.scheduler.step_index]
|
||||
noise_pred = latents - sigma * model_output
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = self.normalized_guidance(
|
||||
noise_pred_text,
|
||||
noise_pred_uncond,
|
||||
self.guidance_scale,
|
||||
self.momentum_buffer,
|
||||
self.adaptive_projected_guidance_rescale_factor,
|
||||
)
|
||||
noise_pred = (latents - noise_pred) / sigma
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
|
||||
return noise_pred
|
||||
|
||||
|
||||
Guiders = Union[CFGGuider, PAGGuider, APGGuider]
|
||||
29
src/diffusers/guiders/__init__.py
Normal file
29
src/diffusers/guiders/__init__.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Union
|
||||
|
||||
from ..utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .adaptive_projected_guidance import AdaptiveProjectedGuidance
|
||||
from .auto_guidance import AutoGuidance
|
||||
from .classifier_free_guidance import ClassifierFreeGuidance
|
||||
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
|
||||
from .skip_layer_guidance import SkipLayerGuidance
|
||||
from .smoothed_energy_guidance import SmoothedEnergyGuidance
|
||||
from .tangential_classifier_free_guidance import TangentialClassifierFreeGuidance
|
||||
|
||||
GuiderType = Union[AdaptiveProjectedGuidance, AutoGuidance, ClassifierFreeGuidance, ClassifierFreeZeroStarGuidance, SkipLayerGuidance, SmoothedEnergyGuidance, TangentialClassifierFreeGuidance]
|
||||
180
src/diffusers/guiders/adaptive_projected_guidance.py
Normal file
180
src/diffusers/guiders/adaptive_projected_guidance.py
Normal file
@@ -0,0 +1,180 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Optional, List, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..pipelines.modular_pipeline import BlockState
|
||||
|
||||
|
||||
class AdaptiveProjectedGuidance(BaseGuidance):
|
||||
"""
|
||||
Adaptive Projected Guidance (APG): https://huggingface.co/papers/2410.02416
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `7.5`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
adaptive_projected_guidance_momentum (`float`, defaults to `None`):
|
||||
The momentum parameter for the adaptive projected guidance. Disabled if set to `None`.
|
||||
adaptive_projected_guidance_rescale (`float`, defaults to `15.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
start (`float`, defaults to `0.0`):
|
||||
The fraction of the total number of denoising steps after which guidance starts.
|
||||
stop (`float`, defaults to `1.0`):
|
||||
The fraction of the total number of denoising steps after which guidance stops.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 7.5,
|
||||
adaptive_projected_guidance_momentum: Optional[float] = None,
|
||||
adaptive_projected_guidance_rescale: float = 15.0,
|
||||
eta: float = 1.0,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
|
||||
self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
|
||||
self.eta = eta
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
self.momentum_buffer = None
|
||||
|
||||
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
|
||||
if self._step == 0:
|
||||
if self.adaptive_projected_guidance_momentum is not None:
|
||||
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i])
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
pred = None
|
||||
|
||||
if not self._is_apg_enabled():
|
||||
pred = pred_cond
|
||||
else:
|
||||
pred = normalized_guidance(
|
||||
pred_cond,
|
||||
pred_uncond,
|
||||
self.guidance_scale,
|
||||
self.momentum_buffer,
|
||||
self.eta,
|
||||
self.adaptive_projected_guidance_rescale,
|
||||
self.use_original_formulation,
|
||||
)
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return pred, {}
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
return self._count_prepared == 1
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_apg_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_apg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
return is_within_range and not is_close
|
||||
|
||||
|
||||
class MomentumBuffer:
|
||||
def __init__(self, momentum: float):
|
||||
self.momentum = momentum
|
||||
self.running_average = 0
|
||||
|
||||
def update(self, update_value: torch.Tensor):
|
||||
new_average = self.momentum * self.running_average
|
||||
self.running_average = update_value + new_average
|
||||
|
||||
|
||||
def normalized_guidance(
|
||||
pred_cond: torch.Tensor,
|
||||
pred_uncond: torch.Tensor,
|
||||
guidance_scale: float,
|
||||
momentum_buffer: Optional[MomentumBuffer] = None,
|
||||
eta: float = 1.0,
|
||||
norm_threshold: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
):
|
||||
diff = pred_cond - pred_uncond
|
||||
dim = [-i for i in range(1, len(diff.shape))]
|
||||
|
||||
if momentum_buffer is not None:
|
||||
momentum_buffer.update(diff)
|
||||
diff = momentum_buffer.running_average
|
||||
|
||||
if norm_threshold > 0:
|
||||
ones = torch.ones_like(diff)
|
||||
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
|
||||
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
|
||||
diff = diff * scale_factor
|
||||
|
||||
v0, v1 = diff.double(), pred_cond.double()
|
||||
v1 = torch.nn.functional.normalize(v1, dim=dim)
|
||||
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
|
||||
v0_orthogonal = v0 - v0_parallel
|
||||
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
|
||||
normalized_update = diff_orthogonal + eta * diff_parallel
|
||||
|
||||
pred = pred_cond if use_original_formulation else pred_uncond
|
||||
pred = pred + guidance_scale * normalized_update
|
||||
|
||||
return pred
|
||||
173
src/diffusers/guiders/auto_guidance.py
Normal file
173
src/diffusers/guiders/auto_guidance.py
Normal file
@@ -0,0 +1,173 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Union, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from ..hooks import HookRegistry, LayerSkipConfig
|
||||
from ..hooks.layer_skip import _apply_layer_skip_hook
|
||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..pipelines.modular_pipeline import BlockState
|
||||
|
||||
|
||||
class AutoGuidance(BaseGuidance):
|
||||
"""
|
||||
AutoGuidance: https://huggingface.co/papers/2406.02507
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `7.5`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
auto_guidance_layers (`int` or `List[int]`, *optional*):
|
||||
The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not
|
||||
provided, `skip_layer_config` must be provided.
|
||||
auto_guidance_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
|
||||
The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of
|
||||
`LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided.
|
||||
dropout (`float`, *optional*):
|
||||
The dropout probability for autoguidance on the enabled skip layers (either with `auto_guidance_layers` or
|
||||
`auto_guidance_config`). If not provided, the dropout probability will be set to 1.0.
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
start (`float`, defaults to `0.0`):
|
||||
The fraction of the total number of denoising steps after which guidance starts.
|
||||
stop (`float`, defaults to `1.0`):
|
||||
The fraction of the total number of denoising steps after which guidance stops.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 7.5,
|
||||
auto_guidance_layers: Optional[Union[int, List[int]]] = None,
|
||||
auto_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig]] = None,
|
||||
dropout: Optional[float] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.auto_guidance_layers = auto_guidance_layers
|
||||
self.auto_guidance_config = auto_guidance_config
|
||||
self.dropout = dropout
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
if auto_guidance_layers is None and auto_guidance_config is None:
|
||||
raise ValueError(
|
||||
"Either `auto_guidance_layers` or `auto_guidance_config` must be provided to enable Skip Layer Guidance."
|
||||
)
|
||||
if auto_guidance_layers is not None and auto_guidance_config is not None:
|
||||
raise ValueError("Only one of `auto_guidance_layers` or `auto_guidance_config` can be provided.")
|
||||
if (dropout is None and auto_guidance_layers is not None) or (dropout is not None and auto_guidance_layers is None):
|
||||
raise ValueError("`dropout` must be provided if `auto_guidance_layers` is provided.")
|
||||
|
||||
if auto_guidance_layers is not None:
|
||||
if isinstance(auto_guidance_layers, int):
|
||||
auto_guidance_layers = [auto_guidance_layers]
|
||||
if not isinstance(auto_guidance_layers, list):
|
||||
raise ValueError(
|
||||
f"Expected `auto_guidance_layers` to be an int or a list of ints, but got {type(auto_guidance_layers)}."
|
||||
)
|
||||
auto_guidance_config = [LayerSkipConfig(layer, fqn="auto", dropout=dropout) for layer in auto_guidance_layers]
|
||||
|
||||
if isinstance(auto_guidance_config, LayerSkipConfig):
|
||||
auto_guidance_config = [auto_guidance_config]
|
||||
|
||||
if not isinstance(auto_guidance_config, list):
|
||||
raise ValueError(
|
||||
f"Expected `auto_guidance_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(auto_guidance_config)}."
|
||||
)
|
||||
|
||||
self.auto_guidance_config = auto_guidance_config
|
||||
self._auto_guidance_hook_names = [f"AutoGuidance_{i}" for i in range(len(self.auto_guidance_config))]
|
||||
|
||||
def prepare_models(self, denoiser: torch.nn.Module) -> None:
|
||||
self._count_prepared += 1
|
||||
if self._is_ag_enabled() and self.is_unconditional:
|
||||
for name, config in zip(self._auto_guidance_hook_names, self.auto_guidance_config):
|
||||
_apply_layer_skip_hook(denoiser, config, name=name)
|
||||
|
||||
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
|
||||
if self._is_ag_enabled() and self.is_unconditional:
|
||||
for name in self._auto_guidance_hook_names:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
|
||||
registry.remove_hook(name, recurse=True)
|
||||
|
||||
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i])
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
pred = None
|
||||
|
||||
if not self._is_ag_enabled():
|
||||
pred = pred_cond
|
||||
else:
|
||||
shift = pred_cond - pred_uncond
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return pred, {}
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
return self._count_prepared == 1
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_ag_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_ag_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
return is_within_range and not is_close
|
||||
128
src/diffusers/guiders/classifier_free_guidance.py
Normal file
128
src/diffusers/guiders/classifier_free_guidance.py
Normal file
@@ -0,0 +1,128 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Optional, List, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..pipelines.modular_pipeline import BlockState
|
||||
|
||||
|
||||
class ClassifierFreeGuidance(BaseGuidance):
|
||||
"""
|
||||
Classifier-free guidance (CFG): https://huggingface.co/papers/2207.12598
|
||||
|
||||
CFG is a technique used to improve generation quality and condition-following in diffusion models. It works by
|
||||
jointly training a model on both conditional and unconditional data, and using a weighted sum of the two during
|
||||
inference. This allows the model to tradeoff between generation quality and sample diversity.
|
||||
The original paper proposes scaling and shifting the conditional distribution based on the difference between
|
||||
conditional and unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)]
|
||||
|
||||
Diffusers implemented the scaling and shifting on the unconditional prediction instead based on the [Imagen
|
||||
paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original paper proposed in
|
||||
theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)]
|
||||
|
||||
The intution behind the original formulation can be thought of as moving the conditional distribution estimates
|
||||
further away from the unconditional distribution estimates, while the diffusers-native implementation can be
|
||||
thought of as moving the unconditional distribution towards the conditional distribution estimates to get rid of
|
||||
the unconditional predictions (usually negative features like "bad quality, bad anotomy, watermarks", etc.)
|
||||
|
||||
The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the
|
||||
paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `7.5`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
start (`float`, defaults to `0.0`):
|
||||
The fraction of the total number of denoising steps after which guidance starts.
|
||||
stop (`float`, defaults to `1.0`):
|
||||
The fraction of the total number of denoising steps after which guidance stops.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond"]
|
||||
|
||||
def __init__(
|
||||
self, guidance_scale: float = 7.5, guidance_rescale: float = 0.0, use_original_formulation: bool = False, start: float = 0.0, stop: float = 1.0
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i])
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
pred = None
|
||||
|
||||
if not self._is_cfg_enabled():
|
||||
pred = pred_cond
|
||||
else:
|
||||
shift = pred_cond - pred_uncond
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return pred, {}
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
return self._count_prepared == 1
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_cfg_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_cfg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
return is_within_range and not is_close
|
||||
144
src/diffusers/guiders/classifier_free_zero_star_guidance.py
Normal file
144
src/diffusers/guiders/classifier_free_zero_star_guidance.py
Normal file
@@ -0,0 +1,144 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Optional, List, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..pipelines.modular_pipeline import BlockState
|
||||
|
||||
|
||||
class ClassifierFreeZeroStarGuidance(BaseGuidance):
|
||||
"""
|
||||
Classifier-free Zero* (CFG-Zero*): https://huggingface.co/papers/2503.18886
|
||||
|
||||
This is an implementation of the Classifier-Free Zero* guidance technique, which is a variant of classifier-free
|
||||
guidance. It proposes zero initialization of the noise predictions for the first few steps of the diffusion
|
||||
process, and also introduces an optimal rescaling factor for the noise predictions, which can help in improving the
|
||||
quality of generated images.
|
||||
|
||||
The authors of the paper suggest setting zero initialization in the first 4% of the inference steps.
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `7.5`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
zero_init_steps (`int`, defaults to `1`):
|
||||
The number of inference steps for which the noise predictions are zeroed out (see Section 4.2).
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
start (`float`, defaults to `0.01`):
|
||||
The fraction of the total number of denoising steps after which guidance starts.
|
||||
stop (`float`, defaults to `0.2`):
|
||||
The fraction of the total number of denoising steps after which guidance stops.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 7.5,
|
||||
zero_init_steps: int = 1,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.zero_init_steps = zero_init_steps
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i])
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
pred = None
|
||||
|
||||
if self._step < self.zero_init_steps:
|
||||
pred = torch.zeros_like(pred_cond)
|
||||
elif not self._is_cfg_enabled():
|
||||
pred = pred_cond
|
||||
else:
|
||||
pred_cond_flat = pred_cond.flatten(1)
|
||||
pred_uncond_flat = pred_uncond.flatten(1)
|
||||
alpha = cfg_zero_star_scale(pred_cond_flat, pred_uncond_flat)
|
||||
alpha = alpha.view(-1, *(1,) * (len(pred_cond.shape) - 1))
|
||||
pred_uncond = pred_uncond * alpha
|
||||
shift = pred_cond - pred_uncond
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return pred, {}
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
return self._count_prepared == 1
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_cfg_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_cfg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
return is_within_range and not is_close
|
||||
|
||||
|
||||
def cfg_zero_star_scale(cond: torch.Tensor, uncond: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
|
||||
cond_dtype = cond.dtype
|
||||
cond = cond.float()
|
||||
uncond = uncond.float()
|
||||
dot_product = torch.sum(cond * uncond, dim=1, keepdim=True)
|
||||
squared_norm = torch.sum(uncond**2, dim=1, keepdim=True) + eps
|
||||
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
|
||||
scale = dot_product / squared_norm
|
||||
return scale.to(dtype=cond_dtype)
|
||||
215
src/diffusers/guiders/guider_utils.py
Normal file
215
src/diffusers/guiders/guider_utils.py
Normal file
@@ -0,0 +1,215 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils import get_logger
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..pipelines.modular_pipeline import BlockState
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class BaseGuidance:
|
||||
r"""Base class providing the skeleton for implementing guidance techniques."""
|
||||
|
||||
_input_predictions = None
|
||||
_identifier_key = "__guidance_identifier__"
|
||||
|
||||
def __init__(self, start: float = 0.0, stop: float = 1.0):
|
||||
self._start = start
|
||||
self._stop = stop
|
||||
self._step: int = None
|
||||
self._num_inference_steps: int = None
|
||||
self._timestep: torch.LongTensor = None
|
||||
self._count_prepared = 0
|
||||
self._input_fields: Dict[str, Union[str, Tuple[str, str]]] = None
|
||||
self._enabled = True
|
||||
|
||||
if not (0.0 <= start < 1.0):
|
||||
raise ValueError(
|
||||
f"Expected `start` to be between 0.0 and 1.0, but got {start}."
|
||||
)
|
||||
if not (start <= stop <= 1.0):
|
||||
raise ValueError(
|
||||
f"Expected `stop` to be between {start} and 1.0, but got {stop}."
|
||||
)
|
||||
|
||||
if self._input_predictions is None or not isinstance(self._input_predictions, list):
|
||||
raise ValueError(
|
||||
"`_input_predictions` must be a list of required prediction names for the guidance technique."
|
||||
)
|
||||
|
||||
def disable(self):
|
||||
self._enabled = False
|
||||
|
||||
def enable(self):
|
||||
self._enabled = True
|
||||
|
||||
def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None:
|
||||
self._step = step
|
||||
self._num_inference_steps = num_inference_steps
|
||||
self._timestep = timestep
|
||||
self._count_prepared = 0
|
||||
|
||||
def set_input_fields(self, **kwargs: Dict[str, Union[str, Tuple[str, str]]]) -> None:
|
||||
"""
|
||||
Set the input fields for the guidance technique. The input fields are used to specify the names of the
|
||||
returned attributes containing the prepared data after `prepare_inputs` is called. The prepared data is
|
||||
obtained from the values of the provided keyword arguments to this method.
|
||||
|
||||
Args:
|
||||
**kwargs (`Dict[str, Union[str, Tuple[str, str]]]`):
|
||||
A dictionary where the keys are the names of the fields that will be used to store the data once
|
||||
it is prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2,
|
||||
which is used to look up the required data provided for preparation.
|
||||
|
||||
If a string is provided, it will be used as the conditional data (or unconditional if used with
|
||||
a guidance method that requires it). If a tuple of length 2 is provided, the first element must
|
||||
be the conditional data identifier and the second element must be the unconditional data identifier
|
||||
or None.
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
data = {"prompt_embeds": <some tensor>, "negative_prompt_embeds": <some tensor>, "latents": <some tensor>}
|
||||
|
||||
BaseGuidance.set_input_fields(
|
||||
latents="latents",
|
||||
prompt_embeds=("prompt_embeds", "negative_prompt_embeds"),
|
||||
)
|
||||
```
|
||||
"""
|
||||
for key, value in kwargs.items():
|
||||
is_string = isinstance(value, str)
|
||||
is_tuple_of_str_with_len_2 = isinstance(value, tuple) and len(value) == 2 and all(isinstance(v, str) for v in value)
|
||||
if not (is_string or is_tuple_of_str_with_len_2):
|
||||
raise ValueError(
|
||||
f"Expected `set_input_fields` to be called with a string or a tuple of string with length 2, but got {type(value)} for key {key}."
|
||||
)
|
||||
self._input_fields = kwargs
|
||||
|
||||
def prepare_models(self, denoiser: torch.nn.Module) -> None:
|
||||
"""
|
||||
Prepares the models for the guidance technique on a given batch of data. This method should be overridden in
|
||||
subclasses to implement specific model preparation logic.
|
||||
"""
|
||||
self._count_prepared += 1
|
||||
|
||||
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
|
||||
"""
|
||||
Cleans up the models for the guidance technique after a given batch of data. This method should be overridden in
|
||||
subclasses to implement specific model cleanup logic. It is useful for removing any hooks or other stateful
|
||||
modifications made during `prepare_models`.
|
||||
"""
|
||||
pass
|
||||
|
||||
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
|
||||
raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.")
|
||||
|
||||
def __call__(self, data: List["BlockState"]) -> Any:
|
||||
if not all(hasattr(d, "noise_pred") for d in data):
|
||||
raise ValueError("Expected all data to have `noise_pred` attribute.")
|
||||
if len(data) != self.num_conditions:
|
||||
raise ValueError(
|
||||
f"Expected {self.num_conditions} data items, but got {len(data)}. Please check the input data."
|
||||
)
|
||||
forward_inputs = {getattr(d, self._identifier_key): d.noise_pred for d in data}
|
||||
return self.forward(**forward_inputs)
|
||||
|
||||
def forward(self, *args, **kwargs) -> Any:
|
||||
raise NotImplementedError("BaseGuidance::forward must be implemented in subclasses.")
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
raise NotImplementedError("BaseGuidance::is_conditional must be implemented in subclasses.")
|
||||
|
||||
@property
|
||||
def is_unconditional(self) -> bool:
|
||||
return not self.is_conditional
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
raise NotImplementedError("BaseGuidance::num_conditions must be implemented in subclasses.")
|
||||
|
||||
@classmethod
|
||||
def _prepare_batch(cls, input_fields: Dict[str, Union[str, Tuple[str, str]]], data: "BlockState", tuple_index: int, identifier: str) -> "BlockState":
|
||||
"""
|
||||
Prepares a batch of data for the guidance technique. This method is used in the `prepare_inputs` method of
|
||||
the `BaseGuidance` class. It prepares the batch based on the provided tuple index.
|
||||
|
||||
Args:
|
||||
input_fields (`Dict[str, Union[str, Tuple[str, str]]]`):
|
||||
A dictionary where the keys are the names of the fields that will be used to store the data once
|
||||
it is prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2,
|
||||
which is used to look up the required data provided for preparation.
|
||||
If a string is provided, it will be used as the conditional data (or unconditional if used with
|
||||
a guidance method that requires it). If a tuple of length 2 is provided, the first element must
|
||||
be the conditional data identifier and the second element must be the unconditional data identifier
|
||||
or None.
|
||||
data (`BlockState`):
|
||||
The input data to be prepared.
|
||||
tuple_index (`int`):
|
||||
The index to use when accessing input fields that are tuples.
|
||||
|
||||
Returns:
|
||||
`BlockState`: The prepared batch of data.
|
||||
"""
|
||||
from ..pipelines.modular_pipeline import BlockState
|
||||
|
||||
if input_fields is None:
|
||||
raise ValueError("Input fields have not been set. Please call `set_input_fields` before preparing inputs.")
|
||||
data_batch = {}
|
||||
for key, value in input_fields.items():
|
||||
try:
|
||||
if isinstance(value, str):
|
||||
data_batch[key] = getattr(data, value)
|
||||
elif isinstance(value, tuple):
|
||||
data_batch[key] = getattr(data, value[tuple_index])
|
||||
else:
|
||||
# We've already checked that value is a string or a tuple of strings with length 2
|
||||
pass
|
||||
except AttributeError:
|
||||
raise ValueError(f"Expected `data` to have attribute(s) {value}, but it does not. Please check the input data.")
|
||||
data_batch[cls._identifier_key] = identifier
|
||||
return BlockState(**data_batch)
|
||||
|
||||
|
||||
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
r"""
|
||||
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
|
||||
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
||||
Args:
|
||||
noise_cfg (`torch.Tensor`):
|
||||
The predicted noise tensor for the guided diffusion process.
|
||||
noise_pred_text (`torch.Tensor`):
|
||||
The predicted noise tensor for the text-guided diffusion process.
|
||||
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
||||
A rescale factor applied to the noise predictions.
|
||||
Returns:
|
||||
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
|
||||
"""
|
||||
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
||||
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
||||
# rescale the results from guidance (fixes overexposure)
|
||||
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
||||
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
||||
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
||||
return noise_cfg
|
||||
247
src/diffusers/guiders/skip_layer_guidance.py
Normal file
247
src/diffusers/guiders/skip_layer_guidance.py
Normal file
@@ -0,0 +1,247 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Union, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from ..hooks import HookRegistry, LayerSkipConfig
|
||||
from ..hooks.layer_skip import _apply_layer_skip_hook
|
||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..pipelines.modular_pipeline import BlockState
|
||||
|
||||
|
||||
class SkipLayerGuidance(BaseGuidance):
|
||||
"""
|
||||
Skip Layer Guidance (SLG): https://github.com/Stability-AI/sd3.5
|
||||
|
||||
Spatio-Temporal Guidance (STG): https://huggingface.co/papers/2411.18664
|
||||
|
||||
SLG was introduced by StabilityAI for improving structure and anotomy coherence in generated images. It works by
|
||||
skipping the forward pass of specified transformer blocks during the denoising process on an additional conditional
|
||||
batch of data, apart from the conditional and unconditional batches already used in CFG
|
||||
([~guiders.classifier_free_guidance.ClassifierFreeGuidance]), and then scaling and shifting the CFG predictions
|
||||
based on the difference between conditional without skipping and conditional with skipping predictions.
|
||||
|
||||
The intution behind SLG can be thought of as moving the CFG predicted distribution estimates further away from
|
||||
worse versions of the conditional distribution estimates (because skipping layers is equivalent to using a worse
|
||||
version of the model for the conditional prediction).
|
||||
|
||||
STG is an improvement and follow-up work combining ideas from SLG, PAG and similar techniques for improving
|
||||
generation quality in video diffusion models.
|
||||
|
||||
Additional reading:
|
||||
- [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507)
|
||||
|
||||
The values for `skip_layer_guidance_scale`, `skip_layer_guidance_start`, and `skip_layer_guidance_stop` are
|
||||
defaulted to the recommendations by StabilityAI for Stable Diffusion 3.5 Medium.
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `7.5`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
skip_layer_guidance_scale (`float`, defaults to `2.8`):
|
||||
The scale parameter for skip layer guidance. Anatomy and structure coherence may improve with higher
|
||||
values, but it may also lead to overexposure and saturation.
|
||||
skip_layer_guidance_start (`float`, defaults to `0.01`):
|
||||
The fraction of the total number of denoising steps after which skip layer guidance starts.
|
||||
skip_layer_guidance_stop (`float`, defaults to `0.2`):
|
||||
The fraction of the total number of denoising steps after which skip layer guidance stops.
|
||||
skip_layer_guidance_layers (`int` or `List[int]`, *optional*):
|
||||
The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not
|
||||
provided, `skip_layer_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion
|
||||
3.5 Medium.
|
||||
skip_layer_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
|
||||
The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of
|
||||
`LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided.
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
start (`float`, defaults to `0.01`):
|
||||
The fraction of the total number of denoising steps after which guidance starts.
|
||||
stop (`float`, defaults to `0.2`):
|
||||
The fraction of the total number of denoising steps after which guidance stops.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 7.5,
|
||||
skip_layer_guidance_scale: float = 2.8,
|
||||
skip_layer_guidance_start: float = 0.01,
|
||||
skip_layer_guidance_stop: float = 0.2,
|
||||
skip_layer_guidance_layers: Optional[Union[int, List[int]]] = None,
|
||||
skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig]] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.skip_layer_guidance_scale = skip_layer_guidance_scale
|
||||
self.skip_layer_guidance_start = skip_layer_guidance_start
|
||||
self.skip_layer_guidance_stop = skip_layer_guidance_stop
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
if not (0.0 <= skip_layer_guidance_start < 1.0):
|
||||
raise ValueError(
|
||||
f"Expected `skip_layer_guidance_start` to be between 0.0 and 1.0, but got {skip_layer_guidance_start}."
|
||||
)
|
||||
if not (skip_layer_guidance_start <= skip_layer_guidance_stop <= 1.0):
|
||||
raise ValueError(
|
||||
f"Expected `skip_layer_guidance_stop` to be between 0.0 and 1.0, but got {skip_layer_guidance_stop}."
|
||||
)
|
||||
|
||||
if skip_layer_guidance_layers is None and skip_layer_config is None:
|
||||
raise ValueError(
|
||||
"Either `skip_layer_guidance_layers` or `skip_layer_config` must be provided to enable Skip Layer Guidance."
|
||||
)
|
||||
if skip_layer_guidance_layers is not None and skip_layer_config is not None:
|
||||
raise ValueError("Only one of `skip_layer_guidance_layers` or `skip_layer_config` can be provided.")
|
||||
|
||||
if skip_layer_guidance_layers is not None:
|
||||
if isinstance(skip_layer_guidance_layers, int):
|
||||
skip_layer_guidance_layers = [skip_layer_guidance_layers]
|
||||
if not isinstance(skip_layer_guidance_layers, list):
|
||||
raise ValueError(
|
||||
f"Expected `skip_layer_guidance_layers` to be an int or a list of ints, but got {type(skip_layer_guidance_layers)}."
|
||||
)
|
||||
skip_layer_config = [LayerSkipConfig(layer, fqn="auto") for layer in skip_layer_guidance_layers]
|
||||
|
||||
if isinstance(skip_layer_config, LayerSkipConfig):
|
||||
skip_layer_config = [skip_layer_config]
|
||||
|
||||
if not isinstance(skip_layer_config, list):
|
||||
raise ValueError(
|
||||
f"Expected `skip_layer_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(skip_layer_config)}."
|
||||
)
|
||||
|
||||
self.skip_layer_config = skip_layer_config
|
||||
self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))]
|
||||
|
||||
def prepare_models(self, denoiser: torch.nn.Module) -> None:
|
||||
self._count_prepared += 1
|
||||
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
|
||||
for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config):
|
||||
_apply_layer_skip_hook(denoiser, config, name=name)
|
||||
|
||||
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
|
||||
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
|
||||
# Remove the hooks after inference
|
||||
for hook_name in self._skip_layer_hook_names:
|
||||
registry.remove_hook(hook_name, recurse=True)
|
||||
|
||||
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
|
||||
if self.num_conditions == 1:
|
||||
tuple_indices = [0]
|
||||
input_predictions = ["pred_cond"]
|
||||
elif self.num_conditions == 2:
|
||||
tuple_indices = [0, 1]
|
||||
input_predictions = ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
|
||||
else:
|
||||
tuple_indices = [0, 1, 0]
|
||||
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], input_predictions[i])
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred_cond: torch.Tensor,
|
||||
pred_uncond: Optional[torch.Tensor] = None,
|
||||
pred_cond_skip: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
pred = None
|
||||
|
||||
if not self._is_cfg_enabled() and not self._is_slg_enabled():
|
||||
pred = pred_cond
|
||||
elif not self._is_cfg_enabled():
|
||||
shift = pred_cond - pred_cond_skip
|
||||
pred = pred_cond if self.use_original_formulation else pred_cond_skip
|
||||
pred = pred + self.skip_layer_guidance_scale * shift
|
||||
elif not self._is_slg_enabled():
|
||||
shift = pred_cond - pred_uncond
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift
|
||||
else:
|
||||
shift = pred_cond - pred_uncond
|
||||
shift_skip = pred_cond - pred_cond_skip
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift + self.skip_layer_guidance_scale * shift_skip
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return pred, {}
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
return self._count_prepared == 1 or self._count_prepared == 3
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_cfg_enabled():
|
||||
num_conditions += 1
|
||||
if self._is_slg_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_cfg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
return is_within_range and not is_close
|
||||
|
||||
def _is_slg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
|
||||
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step < self._step < skip_stop_step
|
||||
|
||||
is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0)
|
||||
|
||||
return is_within_range and not is_zero
|
||||
240
src/diffusers/guiders/smoothed_energy_guidance.py
Normal file
240
src/diffusers/guiders/smoothed_energy_guidance.py
Normal file
@@ -0,0 +1,240 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Union, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from ..hooks import HookRegistry
|
||||
from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig, _apply_smoothed_energy_guidance_hook
|
||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..pipelines.modular_pipeline import BlockState
|
||||
|
||||
|
||||
class SmoothedEnergyGuidance(BaseGuidance):
|
||||
"""
|
||||
Smoothed Energy Guidance (SEG): https://huggingface.co/papers/2408.00760
|
||||
|
||||
SEG is only supported as an experimental prototype feature for now, so the implementation may be modified
|
||||
in the future without warning or guarantee of reproducibility. This implementation assumes:
|
||||
- Generated images are square (height == width)
|
||||
- The model does not combine different modalities together (e.g., text and image latent streams are
|
||||
not combined together such as Flux)
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `7.5`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
seg_guidance_scale (`float`, defaults to `3.0`):
|
||||
The scale parameter for smoothed energy guidance. Anatomy and structure coherence may improve with higher
|
||||
values, but it may also lead to overexposure and saturation.
|
||||
seg_blur_sigma (`float`, defaults to `9999999.0`):
|
||||
The amount by which we blur the attention weights. Setting this value greater than 9999.0 results in
|
||||
infinite blur, which means uniform queries. Controlling it exponentially is empirically effective.
|
||||
seg_blur_threshold_inf (`float`, defaults to `9999.0`):
|
||||
The threshold above which the blur is considered infinite.
|
||||
seg_guidance_start (`float`, defaults to `0.0`):
|
||||
The fraction of the total number of denoising steps after which smoothed energy guidance starts.
|
||||
seg_guidance_stop (`float`, defaults to `1.0`):
|
||||
The fraction of the total number of denoising steps after which smoothed energy guidance stops.
|
||||
seg_guidance_layers (`int` or `List[int]`, *optional*):
|
||||
The layer indices to apply smoothed energy guidance to. Can be a single integer or a list of integers. If not
|
||||
provided, `seg_guidance_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion
|
||||
3.5 Medium.
|
||||
seg_guidance_config (`SmoothedEnergyGuidanceConfig` or `List[SmoothedEnergyGuidanceConfig]`, *optional*):
|
||||
The configuration for the smoothed energy layer guidance. Can be a single `SmoothedEnergyGuidanceConfig` or a list of
|
||||
`SmoothedEnergyGuidanceConfig`. If not provided, `seg_guidance_layers` must be provided.
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
start (`float`, defaults to `0.01`):
|
||||
The fraction of the total number of denoising steps after which guidance starts.
|
||||
stop (`float`, defaults to `0.2`):
|
||||
The fraction of the total number of denoising steps after which guidance stops.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 7.5,
|
||||
seg_guidance_scale: float = 2.8,
|
||||
seg_blur_sigma: float = 9999999.0,
|
||||
seg_blur_threshold_inf: float = 9999.0,
|
||||
seg_guidance_start: float = 0.0,
|
||||
seg_guidance_stop: float = 1.0,
|
||||
seg_guidance_layers: Optional[Union[int, List[int]]] = None,
|
||||
seg_guidance_config: Union[SmoothedEnergyGuidanceConfig, List[SmoothedEnergyGuidanceConfig]] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.seg_guidance_scale = seg_guidance_scale
|
||||
self.seg_blur_sigma = seg_blur_sigma
|
||||
self.seg_blur_threshold_inf = seg_blur_threshold_inf
|
||||
self.seg_guidance_start = seg_guidance_start
|
||||
self.seg_guidance_stop = seg_guidance_stop
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
if not (0.0 <= seg_guidance_start < 1.0):
|
||||
raise ValueError(
|
||||
f"Expected `seg_guidance_start` to be between 0.0 and 1.0, but got {seg_guidance_start}."
|
||||
)
|
||||
if not (seg_guidance_start <= seg_guidance_stop <= 1.0):
|
||||
raise ValueError(
|
||||
f"Expected `seg_guidance_stop` to be between 0.0 and 1.0, but got {seg_guidance_stop}."
|
||||
)
|
||||
|
||||
if seg_guidance_layers is None and seg_guidance_config is None:
|
||||
raise ValueError(
|
||||
"Either `seg_guidance_layers` or `seg_guidance_config` must be provided to enable Smoothed Energy Guidance."
|
||||
)
|
||||
if seg_guidance_layers is not None and seg_guidance_config is not None:
|
||||
raise ValueError("Only one of `seg_guidance_layers` or `seg_guidance_config` can be provided.")
|
||||
|
||||
if seg_guidance_layers is not None:
|
||||
if isinstance(seg_guidance_layers, int):
|
||||
seg_guidance_layers = [seg_guidance_layers]
|
||||
if not isinstance(seg_guidance_layers, list):
|
||||
raise ValueError(
|
||||
f"Expected `seg_guidance_layers` to be an int or a list of ints, but got {type(seg_guidance_layers)}."
|
||||
)
|
||||
seg_guidance_config = [SmoothedEnergyGuidanceConfig(layer, fqn="auto") for layer in seg_guidance_layers]
|
||||
|
||||
if isinstance(seg_guidance_config, SmoothedEnergyGuidanceConfig):
|
||||
seg_guidance_config = [seg_guidance_config]
|
||||
|
||||
if not isinstance(seg_guidance_config, list):
|
||||
raise ValueError(
|
||||
f"Expected `seg_guidance_config` to be a SmoothedEnergyGuidanceConfig or a list of SmoothedEnergyGuidanceConfig, but got {type(seg_guidance_config)}."
|
||||
)
|
||||
|
||||
self.seg_guidance_config = seg_guidance_config
|
||||
self._seg_layer_hook_names = [f"SmoothedEnergyGuidance_{i}" for i in range(len(self.seg_guidance_config))]
|
||||
|
||||
def prepare_models(self, denoiser: torch.nn.Module) -> None:
|
||||
if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1:
|
||||
for name, config in zip(self._seg_layer_hook_names, self.seg_guidance_config):
|
||||
_apply_smoothed_energy_guidance_hook(denoiser, config, self.seg_blur_sigma, name=name)
|
||||
|
||||
def cleanup_models(self, denoiser: torch.nn.Module):
|
||||
if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
|
||||
# Remove the hooks after inference
|
||||
for hook_name in self._seg_layer_hook_names:
|
||||
registry.remove_hook(hook_name, recurse=True)
|
||||
|
||||
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
|
||||
if self.num_conditions == 1:
|
||||
tuple_indices = [0]
|
||||
input_predictions = ["pred_cond"]
|
||||
elif self.num_conditions == 2:
|
||||
tuple_indices = [0, 1]
|
||||
input_predictions = ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_seg"]
|
||||
else:
|
||||
tuple_indices = [0, 1, 0]
|
||||
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], input_predictions[i])
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred_cond: torch.Tensor,
|
||||
pred_uncond: Optional[torch.Tensor] = None,
|
||||
pred_cond_seg: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
pred = None
|
||||
|
||||
if not self._is_cfg_enabled() and not self._is_seg_enabled():
|
||||
pred = pred_cond
|
||||
elif not self._is_cfg_enabled():
|
||||
shift = pred_cond - pred_cond_seg
|
||||
pred = pred_cond if self.use_original_formulation else pred_cond_seg
|
||||
pred = pred + self.seg_guidance_scale * shift
|
||||
elif not self._is_seg_enabled():
|
||||
shift = pred_cond - pred_uncond
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift
|
||||
else:
|
||||
shift = pred_cond - pred_uncond
|
||||
shift_seg = pred_cond - pred_cond_seg
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift + self.seg_guidance_scale * shift_seg
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return pred, {}
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
return self._count_prepared == 1 or self._count_prepared == 3
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_cfg_enabled():
|
||||
num_conditions += 1
|
||||
if self._is_seg_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_cfg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
return is_within_range and not is_close
|
||||
|
||||
def _is_seg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self.seg_guidance_start * self._num_inference_steps)
|
||||
skip_stop_step = int(self.seg_guidance_stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step < self._step < skip_stop_step
|
||||
|
||||
is_zero = math.isclose(self.seg_guidance_scale, 0.0)
|
||||
|
||||
return is_within_range and not is_zero
|
||||
133
src/diffusers/guiders/tangential_classifier_free_guidance.py
Normal file
133
src/diffusers/guiders/tangential_classifier_free_guidance.py
Normal file
@@ -0,0 +1,133 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Optional, List, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..pipelines.modular_pipeline import BlockState
|
||||
|
||||
|
||||
class TangentialClassifierFreeGuidance(BaseGuidance):
|
||||
"""
|
||||
Tangential Classifier Free Guidance (TCFG): https://huggingface.co/papers/2503.18137
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `7.5`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
start (`float`, defaults to `0.0`):
|
||||
The fraction of the total number of denoising steps after which guidance starts.
|
||||
stop (`float`, defaults to `1.0`):
|
||||
The fraction of the total number of denoising steps after which guidance stops.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 7.5,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i])
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
pred = None
|
||||
|
||||
if not self._is_tcfg_enabled():
|
||||
pred = pred_cond
|
||||
else:
|
||||
pred = normalized_guidance(pred_cond, pred_uncond, self.guidance_scale, self.use_original_formulation)
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return pred, {}
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
return self._num_outputs_prepared == 1
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_tcfg_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_tcfg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
return is_within_range and not is_close
|
||||
|
||||
|
||||
def normalized_guidance(pred_cond: torch.Tensor, pred_uncond: torch.Tensor, guidance_scale: float, use_original_formulation: bool = False) -> torch.Tensor:
|
||||
cond_dtype = pred_cond.dtype
|
||||
preds = torch.stack([pred_cond, pred_uncond], dim=1).float()
|
||||
preds = preds.flatten(2)
|
||||
U, S, Vh = torch.linalg.svd(preds, full_matrices=False)
|
||||
Vh_modified = Vh.clone()
|
||||
Vh_modified[:, 1] = 0
|
||||
|
||||
uncond_flat = pred_uncond.reshape(pred_uncond.size(0), 1, -1).float()
|
||||
x_Vh = torch.matmul(uncond_flat, Vh.transpose(-2, -1))
|
||||
x_Vh_V = torch.matmul(x_Vh, Vh_modified)
|
||||
pred_uncond = x_Vh_V.reshape(pred_uncond.shape).to(cond_dtype)
|
||||
|
||||
pred = pred_cond if use_original_formulation else pred_uncond
|
||||
shift = pred_cond - pred_uncond
|
||||
pred = pred + guidance_scale * shift
|
||||
|
||||
return pred
|
||||
@@ -5,5 +5,7 @@ if is_torch_available():
|
||||
from .faster_cache import FasterCacheConfig, apply_faster_cache
|
||||
from .group_offloading import apply_group_offloading
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
from .layer_skip import LayerSkipConfig, apply_layer_skip
|
||||
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
|
||||
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
|
||||
from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig
|
||||
|
||||
43
src/diffusers/hooks/_common.py
Normal file
43
src/diffusers/hooks/_common.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from ..models.attention import FeedForward, LuminaFeedForward
|
||||
from ..models.attention_processor import Attention, MochiAttention
|
||||
|
||||
|
||||
_ATTENTION_CLASSES = (Attention, MochiAttention)
|
||||
_FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward)
|
||||
|
||||
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
|
||||
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
|
||||
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers")
|
||||
|
||||
_ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple(
|
||||
{
|
||||
*_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||
*_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||
*_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _get_submodule_from_fqn(module: torch.nn.Module, fqn: str) -> Optional[torch.nn.Module]:
|
||||
for submodule_name, submodule in module.named_modules():
|
||||
if submodule_name == fqn:
|
||||
return submodule
|
||||
return None
|
||||
271
src/diffusers/hooks/_helpers.py
Normal file
271
src/diffusers/hooks/_helpers.py
Normal file
@@ -0,0 +1,271 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Type
|
||||
|
||||
from ..models.attention import BasicTransformerBlock
|
||||
from ..models.attention_processor import AttnProcessor2_0
|
||||
from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock
|
||||
from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor, CogView4TransformerBlock
|
||||
from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
|
||||
from ..models.transformers.transformer_hunyuan_video import (
|
||||
HunyuanVideoSingleTransformerBlock,
|
||||
HunyuanVideoTokenReplaceSingleTransformerBlock,
|
||||
HunyuanVideoTokenReplaceTransformerBlock,
|
||||
HunyuanVideoTransformerBlock,
|
||||
)
|
||||
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
|
||||
from ..models.transformers.transformer_mochi import MochiTransformerBlock
|
||||
from ..models.transformers.transformer_wan import WanTransformerBlock
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttentionProcessorMetadata:
|
||||
skip_processor_output_fn: Callable[[Any], Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransformerBlockMetadata:
|
||||
skip_block_output_fn: Callable[[Any], Any]
|
||||
return_hidden_states_index: int = None
|
||||
return_encoder_hidden_states_index: int = None
|
||||
|
||||
|
||||
class AttentionProcessorRegistry:
|
||||
_registry = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, model_class: Type, metadata: AttentionProcessorMetadata):
|
||||
cls._registry[model_class] = metadata
|
||||
|
||||
@classmethod
|
||||
def get(cls, model_class: Type) -> AttentionProcessorMetadata:
|
||||
if model_class not in cls._registry:
|
||||
raise ValueError(f"Model class {model_class} not registered.")
|
||||
return cls._registry[model_class]
|
||||
|
||||
|
||||
class TransformerBlockRegistry:
|
||||
_registry = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, model_class: Type, metadata: TransformerBlockMetadata):
|
||||
cls._registry[model_class] = metadata
|
||||
|
||||
@classmethod
|
||||
def get(cls, model_class: Type) -> TransformerBlockMetadata:
|
||||
if model_class not in cls._registry:
|
||||
raise ValueError(f"Model class {model_class} not registered.")
|
||||
return cls._registry[model_class]
|
||||
|
||||
|
||||
def _register_attention_processors_metadata():
|
||||
# AttnProcessor2_0
|
||||
AttentionProcessorRegistry.register(
|
||||
model_class=AttnProcessor2_0,
|
||||
metadata=AttentionProcessorMetadata(
|
||||
skip_processor_output_fn=_skip_proc_output_fn_Attention_AttnProcessor2_0,
|
||||
),
|
||||
)
|
||||
|
||||
# CogView4AttnProcessor
|
||||
AttentionProcessorRegistry.register(
|
||||
model_class=CogView4AttnProcessor,
|
||||
metadata=AttentionProcessorMetadata(
|
||||
skip_processor_output_fn=_skip_proc_output_fn_Attention_CogView4AttnProcessor,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _register_transformer_blocks_metadata():
|
||||
# BasicTransformerBlock
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=BasicTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
skip_block_output_fn=_skip_block_output_fn_BasicTransformerBlock,
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=None,
|
||||
),
|
||||
)
|
||||
|
||||
# CogVideoX
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=CogVideoXBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
skip_block_output_fn=_skip_block_output_fn_CogVideoXBlock,
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
),
|
||||
)
|
||||
|
||||
# CogView4
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=CogView4TransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
skip_block_output_fn=_skip_block_output_fn_CogView4TransformerBlock,
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
),
|
||||
)
|
||||
|
||||
# Flux
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=FluxTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
skip_block_output_fn=_skip_block_output_fn_FluxTransformerBlock,
|
||||
return_hidden_states_index=1,
|
||||
return_encoder_hidden_states_index=0,
|
||||
),
|
||||
)
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=FluxSingleTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
skip_block_output_fn=_skip_block_output_fn_FluxSingleTransformerBlock,
|
||||
return_hidden_states_index=1,
|
||||
return_encoder_hidden_states_index=0,
|
||||
),
|
||||
)
|
||||
|
||||
# HunyuanVideo
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=HunyuanVideoTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTransformerBlock,
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
),
|
||||
)
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=HunyuanVideoSingleTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
skip_block_output_fn=_skip_block_output_fn_HunyuanVideoSingleTransformerBlock,
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
),
|
||||
)
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=HunyuanVideoTokenReplaceTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock,
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
),
|
||||
)
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=HunyuanVideoTokenReplaceSingleTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTokenReplaceSingleTransformerBlock,
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
),
|
||||
)
|
||||
|
||||
# LTXVideo
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=LTXVideoTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
skip_block_output_fn=_skip_block_output_fn_LTXVideoTransformerBlock,
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=None,
|
||||
),
|
||||
)
|
||||
|
||||
# Mochi
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=MochiTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
skip_block_output_fn=_skip_block_output_fn_MochiTransformerBlock,
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
),
|
||||
)
|
||||
|
||||
# Wan
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=WanTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
skip_block_output_fn=_skip_block_output_fn_WanTransformerBlock,
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=None,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# fmt: off
|
||||
def _skip_attention___ret___hidden_states(self, *args, **kwargs):
|
||||
hidden_states = kwargs.get("hidden_states", None)
|
||||
if hidden_states is None and len(args) > 0:
|
||||
hidden_states = args[0]
|
||||
return hidden_states
|
||||
|
||||
|
||||
def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs):
|
||||
hidden_states = kwargs.get("hidden_states", None)
|
||||
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
|
||||
if hidden_states is None and len(args) > 0:
|
||||
hidden_states = args[0]
|
||||
if encoder_hidden_states is None and len(args) > 1:
|
||||
encoder_hidden_states = args[1]
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
_skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states
|
||||
_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states
|
||||
|
||||
|
||||
def _skip_block_output_fn___hidden_states_0___ret___hidden_states(self, *args, **kwargs):
|
||||
hidden_states = kwargs.get("hidden_states", None)
|
||||
if hidden_states is None and len(args) > 0:
|
||||
hidden_states = args[0]
|
||||
return hidden_states
|
||||
|
||||
|
||||
def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs):
|
||||
hidden_states = kwargs.get("hidden_states", None)
|
||||
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
|
||||
if hidden_states is None and len(args) > 0:
|
||||
hidden_states = args[0]
|
||||
if encoder_hidden_states is None and len(args) > 1:
|
||||
encoder_hidden_states = args[1]
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states(self, *args, **kwargs):
|
||||
hidden_states = kwargs.get("hidden_states", None)
|
||||
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
|
||||
if hidden_states is None and len(args) > 0:
|
||||
hidden_states = args[0]
|
||||
if encoder_hidden_states is None and len(args) > 1:
|
||||
encoder_hidden_states = args[1]
|
||||
return encoder_hidden_states, hidden_states
|
||||
|
||||
|
||||
_skip_block_output_fn_BasicTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states
|
||||
_skip_block_output_fn_CogVideoXBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
|
||||
_skip_block_output_fn_CogView4TransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
|
||||
_skip_block_output_fn_FluxTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states
|
||||
_skip_block_output_fn_FluxSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states
|
||||
_skip_block_output_fn_HunyuanVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
|
||||
_skip_block_output_fn_HunyuanVideoSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
|
||||
_skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
|
||||
_skip_block_output_fn_HunyuanVideoTokenReplaceSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
|
||||
_skip_block_output_fn_LTXVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states
|
||||
_skip_block_output_fn_MochiTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
|
||||
_skip_block_output_fn_WanTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states
|
||||
# fmt: on
|
||||
|
||||
|
||||
_register_attention_processors_metadata()
|
||||
_register_transformer_blocks_metadata()
|
||||
229
src/diffusers/hooks/layer_skip.py
Normal file
229
src/diffusers/hooks/layer_skip.py
Normal file
@@ -0,0 +1,229 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils import get_logger
|
||||
from ..utils.torch_utils import unwrap_module
|
||||
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _FEEDFORWARD_CLASSES, _get_submodule_from_fqn
|
||||
from ._helpers import AttentionProcessorRegistry, TransformerBlockRegistry
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
_LAYER_SKIP_HOOK = "layer_skip_hook"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LayerSkipConfig:
|
||||
r"""
|
||||
Configuration for skipping internal transformer blocks when executing a transformer model.
|
||||
|
||||
Args:
|
||||
indices (`List[int]`):
|
||||
The indices of the layer to skip. This is typically the first layer in the transformer block.
|
||||
fqn (`str`, defaults to `"auto"`):
|
||||
The fully qualified name identifying the stack of transformer blocks. Typically, this is
|
||||
`transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`.
|
||||
For automatic detection, set this to `"auto"`.
|
||||
"auto" only works on DiT models. For UNet models, you must provide the correct fqn.
|
||||
skip_attention (`bool`, defaults to `True`):
|
||||
Whether to skip attention blocks.
|
||||
skip_ff (`bool`, defaults to `True`):
|
||||
Whether to skip feed-forward blocks.
|
||||
skip_attention_scores (`bool`, defaults to `False`):
|
||||
Whether to skip attention score computation in the attention blocks. This is equivalent to using `value`
|
||||
projections as the output of scaled dot product attention.
|
||||
dropout (`float`, defaults to `1.0`):
|
||||
The dropout probability for dropping the outputs of the skipped layers. By default, this is set to `1.0`,
|
||||
meaning that the outputs of the skipped layers are completely ignored. If set to `0.0`, the outputs of the
|
||||
skipped layers are fully retained, which is equivalent to not skipping any layers.
|
||||
"""
|
||||
|
||||
indices: List[int]
|
||||
fqn: str = "auto"
|
||||
skip_attention: bool = True
|
||||
skip_attention_scores: bool = False
|
||||
skip_ff: bool = True
|
||||
dropout: float = 1.0
|
||||
|
||||
def __post_init__(self):
|
||||
if not (0 <= self.dropout <= 1):
|
||||
raise ValueError(f"Expected `dropout` to be between 0.0 and 1.0, but got {self.dropout}.")
|
||||
if not math.isclose(self.dropout, 1.0) and self.skip_attention_scores:
|
||||
raise ValueError(
|
||||
"Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
|
||||
)
|
||||
|
||||
|
||||
class AttentionScoreSkipFunctionMode(torch.overrides.TorchFunctionMode):
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
if func is torch.nn.functional.scaled_dot_product_attention:
|
||||
value = kwargs.get("value", None)
|
||||
if value is None:
|
||||
value = args[2]
|
||||
return value
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
class AttentionProcessorSkipHook(ModelHook):
|
||||
def __init__(self, skip_processor_output_fn: Callable, skip_attention_scores: bool = False, dropout: float = 1.0):
|
||||
self.skip_processor_output_fn = skip_processor_output_fn
|
||||
self.skip_attention_scores = skip_attention_scores
|
||||
self.dropout = dropout
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
if self.skip_attention_scores:
|
||||
if not math.isclose(self.dropout, 1.0):
|
||||
raise ValueError(
|
||||
"Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
|
||||
)
|
||||
with AttentionScoreSkipFunctionMode():
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
else:
|
||||
if math.isclose(self.dropout, 1.0):
|
||||
output = self.skip_processor_output_fn(module, *args, **kwargs)
|
||||
else:
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
output = torch.nn.functional.dropout(output, p=self.dropout)
|
||||
return output
|
||||
|
||||
|
||||
class FeedForwardSkipHook(ModelHook):
|
||||
def __init__(self, dropout: float):
|
||||
super().__init__()
|
||||
self.dropout = dropout
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
if math.isclose(self.dropout, 1.0):
|
||||
output = kwargs.get("hidden_states", None)
|
||||
if output is None:
|
||||
output = kwargs.get("x", None)
|
||||
if output is None and len(args) > 0:
|
||||
output = args[0]
|
||||
else:
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
output = torch.nn.functional.dropout(output, p=self.dropout)
|
||||
return output
|
||||
|
||||
|
||||
class TransformerBlockSkipHook(ModelHook):
|
||||
def __init__(self, dropout: float):
|
||||
super().__init__()
|
||||
self.dropout = dropout
|
||||
|
||||
def initialize_hook(self, module):
|
||||
self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__)
|
||||
return module
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
if math.isclose(self.dropout, 1.0):
|
||||
output = self._metadata.skip_block_output_fn(module, *args, **kwargs)
|
||||
else:
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
output = torch.nn.functional.dropout(output, p=self.dropout)
|
||||
return output
|
||||
|
||||
def apply_layer_skip(module: torch.nn.Module, config: LayerSkipConfig) -> None:
|
||||
r"""
|
||||
Apply layer skipping to internal layers of a transformer.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The transformer model to which the layer skip hook should be applied.
|
||||
config (`LayerSkipConfig`):
|
||||
The configuration for the layer skip hook.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from diffusers import apply_layer_skip_hook, CogVideoXTransformer3DModel, LayerSkipConfig
|
||||
>>> transformer = CogVideoXTransformer3DModel.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
|
||||
>>> config = LayerSkipConfig(layer_index=[10, 20], fqn="transformer_blocks")
|
||||
>>> apply_layer_skip_hook(transformer, config)
|
||||
```
|
||||
"""
|
||||
_apply_layer_skip_hook(module, config)
|
||||
|
||||
|
||||
def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, name: Optional[str] = None) -> None:
|
||||
name = name or _LAYER_SKIP_HOOK
|
||||
|
||||
if config.skip_attention and config.skip_attention_scores:
|
||||
raise ValueError("Cannot set both `skip_attention` and `skip_attention_scores` to True. Please choose one.")
|
||||
if not math.isclose(config.dropout, 1.0) and config.skip_attention_scores:
|
||||
raise ValueError("Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0.")
|
||||
|
||||
if config.fqn == "auto":
|
||||
for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS:
|
||||
if hasattr(module, identifier):
|
||||
config.fqn = identifier
|
||||
break
|
||||
else:
|
||||
raise ValueError(
|
||||
"Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid "
|
||||
"`fqn` (fully qualified name) that identifies a stack of transformer blocks."
|
||||
)
|
||||
|
||||
transformer_blocks = _get_submodule_from_fqn(module, config.fqn)
|
||||
if transformer_blocks is None or not isinstance(transformer_blocks, torch.nn.ModuleList):
|
||||
raise ValueError(
|
||||
f"Could not find {config.fqn} in the provided module, or configured `fqn` (fully qualified name) does not identify "
|
||||
f"a `torch.nn.ModuleList`. Please provide a valid `fqn` that identifies a stack of transformer blocks."
|
||||
)
|
||||
if len(config.indices) == 0:
|
||||
raise ValueError("Layer index list is empty. Please provide a non-empty list of layer indices to skip.")
|
||||
|
||||
blocks_found = False
|
||||
for i, block in enumerate(transformer_blocks):
|
||||
if i not in config.indices:
|
||||
continue
|
||||
|
||||
blocks_found = True
|
||||
|
||||
if config.skip_attention and config.skip_ff:
|
||||
logger.debug(f"Applying TransformerBlockSkipHook to '{config.fqn}.{i}'")
|
||||
registry = HookRegistry.check_if_exists_or_initialize(block)
|
||||
hook = TransformerBlockSkipHook(config.dropout)
|
||||
registry.register_hook(hook, name)
|
||||
|
||||
elif config.skip_attention or config.skip_attention_scores:
|
||||
for submodule_name, submodule in block.named_modules():
|
||||
if isinstance(submodule, _ATTENTION_CLASSES) and not submodule.is_cross_attention:
|
||||
logger.debug(f"Applying AttentionProcessorSkipHook to '{config.fqn}.{i}.{submodule_name}'")
|
||||
output_fn = AttentionProcessorRegistry.get(submodule.processor.__class__).skip_processor_output_fn
|
||||
registry = HookRegistry.check_if_exists_or_initialize(submodule)
|
||||
hook = AttentionProcessorSkipHook(output_fn, config.skip_attention_scores, config.dropout)
|
||||
registry.register_hook(hook, name)
|
||||
|
||||
if config.skip_ff:
|
||||
for submodule_name, submodule in block.named_modules():
|
||||
if isinstance(submodule, _FEEDFORWARD_CLASSES):
|
||||
logger.debug(f"Applying FeedForwardSkipHook to '{config.fqn}.{i}.{submodule_name}'")
|
||||
registry = HookRegistry.check_if_exists_or_initialize(submodule)
|
||||
hook = FeedForwardSkipHook(config.dropout)
|
||||
registry.register_hook(hook, name)
|
||||
|
||||
if not blocks_found:
|
||||
raise ValueError(
|
||||
f"Could not find any transformer blocks matching the provided indices {config.indices} and "
|
||||
f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness."
|
||||
)
|
||||
158
src/diffusers/hooks/smoothed_energy_guidance_utils.py
Normal file
158
src/diffusers/hooks/smoothed_energy_guidance_utils.py
Normal file
@@ -0,0 +1,158 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..utils import get_logger
|
||||
from ._common import _ATTENTION_CLASSES, _get_submodule_from_fqn
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
_SMOOTHED_ENERGY_GUIDANCE_HOOK = "smoothed_energy_guidance_hook"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SmoothedEnergyGuidanceConfig:
|
||||
r"""
|
||||
Configuration for skipping internal transformer blocks when executing a transformer model.
|
||||
|
||||
Args:
|
||||
indices (`List[int]`):
|
||||
The indices of the layer to skip. This is typically the first layer in the transformer block.
|
||||
fqn (`str`, defaults to `"auto"`):
|
||||
The fully qualified name identifying the stack of transformer blocks. Typically, this is
|
||||
`transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`.
|
||||
For automatic detection, set this to `"auto"`.
|
||||
"auto" only works on DiT models. For UNet models, you must provide the correct fqn.
|
||||
_query_proj_identifiers (`List[str]`, defaults to `None`):
|
||||
The identifiers for the query projection layers. Typically, these are `to_q`, `query`, or `q_proj`.
|
||||
If `None`, `to_q` is used by default.
|
||||
"""
|
||||
|
||||
indices: List[int]
|
||||
fqn: str = "auto"
|
||||
_query_proj_identifiers: List[str] = None
|
||||
|
||||
|
||||
class SmoothedEnergyGuidanceHook(ModelHook):
|
||||
def __init__(self, blur_sigma: float = 1.0, blur_threshold_inf: float = 9999.9) -> None:
|
||||
super().__init__()
|
||||
self.blur_sigma = blur_sigma
|
||||
self.blur_threshold_inf = blur_threshold_inf
|
||||
|
||||
def post_forward(self, module: torch.nn.Module, output: torch.Tensor) -> torch.Tensor:
|
||||
# Copied from https://github.com/SusungHong/SEG-SDXL/blob/cf8256d640d5373541cfea3b3b6caf93272cf986/pipeline_seg.py#L172C31-L172C102
|
||||
kernel_size = math.ceil(6 * self.blur_sigma) + 1 - math.ceil(6 * self.blur_sigma) % 2
|
||||
smoothed_output = _gaussian_blur_2d(output, kernel_size, self.blur_sigma, self.blur_threshold_inf)
|
||||
return smoothed_output
|
||||
|
||||
|
||||
def _apply_smoothed_energy_guidance_hook(module: torch.nn.Module, config: SmoothedEnergyGuidanceConfig, blur_sigma: float, name: Optional[str] = None) -> None:
|
||||
name = name or _SMOOTHED_ENERGY_GUIDANCE_HOOK
|
||||
|
||||
if config.fqn == "auto":
|
||||
for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS:
|
||||
if hasattr(module, identifier):
|
||||
config.fqn = identifier
|
||||
break
|
||||
else:
|
||||
raise ValueError(
|
||||
"Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid "
|
||||
"`fqn` (fully qualified name) that identifies a stack of transformer blocks."
|
||||
)
|
||||
|
||||
if config._query_proj_identifiers is None:
|
||||
config._query_proj_identifiers = ["to_q"]
|
||||
|
||||
transformer_blocks = _get_submodule_from_fqn(module, config.fqn)
|
||||
blocks_found = False
|
||||
for i, block in enumerate(transformer_blocks):
|
||||
if i not in config.indices:
|
||||
continue
|
||||
|
||||
blocks_found = True
|
||||
|
||||
for submodule_name, submodule in block.named_modules():
|
||||
if not isinstance(submodule, _ATTENTION_CLASSES) or submodule.is_cross_attention:
|
||||
continue
|
||||
for identifier in config._query_proj_identifiers:
|
||||
query_proj = getattr(submodule, identifier, None)
|
||||
if query_proj is None or not isinstance(query_proj, torch.nn.Linear):
|
||||
continue
|
||||
logger.debug(
|
||||
f"Registering smoothed energy guidance hook on {config.fqn}.{i}.{submodule_name}.{identifier}"
|
||||
)
|
||||
registry = HookRegistry.check_if_exists_or_initialize(query_proj)
|
||||
hook = SmoothedEnergyGuidanceHook(blur_sigma)
|
||||
registry.register_hook(hook, name)
|
||||
|
||||
if not blocks_found:
|
||||
raise ValueError(
|
||||
f"Could not find any transformer blocks matching the provided indices {config.indices} and "
|
||||
f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness."
|
||||
)
|
||||
|
||||
|
||||
# Modified from https://github.com/SusungHong/SEG-SDXL/blob/cf8256d640d5373541cfea3b3b6caf93272cf986/pipeline_seg.py#L71
|
||||
def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma_threshold_inf: float) -> torch.Tensor:
|
||||
"""
|
||||
This implementation assumes that the input query is for visual (image/videos) tokens to apply the 2D gaussian
|
||||
blur. However, some models use joint text-visual token attention for which this may not be suitable. Additionally,
|
||||
this implementation also assumes that the visual tokens come from a square image/video. In practice, despite
|
||||
these assumptions, applying the 2D square gaussian blur on the query projections generates reasonable results
|
||||
for Smoothed Energy Guidance.
|
||||
|
||||
SEG is only supported as an experimental prototype feature for now, so the implementation may be modified
|
||||
in the future without warning or guarantee of reproducibility.
|
||||
"""
|
||||
assert query.ndim == 3
|
||||
|
||||
is_inf = sigma > sigma_threshold_inf
|
||||
batch_size, seq_len, embed_dim = query.shape
|
||||
|
||||
seq_len_sqrt = int(math.sqrt(seq_len))
|
||||
num_square_tokens = seq_len_sqrt * seq_len_sqrt
|
||||
query_slice = query[:, :num_square_tokens, :]
|
||||
query_slice = query_slice.permute(0, 2, 1)
|
||||
query_slice = query_slice.reshape(batch_size, embed_dim, seq_len_sqrt, seq_len_sqrt)
|
||||
|
||||
if is_inf:
|
||||
kernel_size = min(kernel_size, seq_len_sqrt - (seq_len_sqrt % 2 - 1))
|
||||
kernel_size_half = (kernel_size - 1) / 2
|
||||
|
||||
x = torch.linspace(-kernel_size_half, kernel_size_half, steps=kernel_size)
|
||||
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
|
||||
kernel1d = pdf / pdf.sum()
|
||||
kernel1d = kernel1d.to(query)
|
||||
kernel2d = torch.matmul(kernel1d[:, None], kernel1d[None, :])
|
||||
kernel2d = kernel2d.expand(embed_dim, 1, kernel2d.shape[0], kernel2d.shape[1])
|
||||
|
||||
padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
|
||||
query_slice = F.pad(query_slice, padding, mode="reflect")
|
||||
query_slice = F.conv2d(query_slice, kernel2d, groups=embed_dim)
|
||||
else:
|
||||
query_slice[:] = query_slice.mean(dim=(-2, -1), keepdim=True)
|
||||
|
||||
query_slice = query_slice.reshape(batch_size, embed_dim, num_square_tokens)
|
||||
query_slice = query_slice.permute(0, 2, 1)
|
||||
query[:, :num_square_tokens, :] = query_slice.clone()
|
||||
|
||||
return query
|
||||
@@ -19,7 +19,6 @@ import PIL
|
||||
import torch
|
||||
from collections import OrderedDict
|
||||
|
||||
from ...guider import CFGGuider
|
||||
from ...image_processor import VaeImageProcessor, PipelineImageInput
|
||||
from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin
|
||||
from ...models import ControlNetModel, ImageProjection, UNet2DConditionModel, AutoencoderKL, ControlNetUnionModel
|
||||
@@ -31,7 +30,7 @@ from ...utils import (
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import is_compiled_module, randn_tensor
|
||||
from ...utils.torch_utils import randn_tensor, unwrap_module
|
||||
from ..controlnet.multicontrolnet import MultiControlNetModel
|
||||
from ..modular_pipeline import (
|
||||
AutoPipelineBlocks,
|
||||
@@ -58,7 +57,7 @@ from transformers import (
|
||||
)
|
||||
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...guider import Guiders, CFGGuider
|
||||
from ...guiders import GuiderType, ClassifierFreeGuidance
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -185,6 +184,7 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock, ModularIPAdapterMixin):
|
||||
ComponentSpec("image_encoder", CLIPVisionModelWithProjection),
|
||||
ComponentSpec("feature_extractor", CLIPImageProcessor),
|
||||
ComponentSpec("unet", UNet2DConditionModel),
|
||||
ComponentSpec("guider", GuiderType),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -195,11 +195,7 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock, ModularIPAdapterMixin):
|
||||
PipelineImageInput,
|
||||
required=True,
|
||||
description="The image(s) to be used as ip adapter"
|
||||
),
|
||||
InputParam(
|
||||
"guidance_scale",
|
||||
default=5.0,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@@ -237,10 +233,10 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock, ModularIPAdapterMixin):
|
||||
|
||||
# modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
|
||||
def prepare_ip_adapter_image_embeds(
|
||||
self, components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
|
||||
self, components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, prepare_unconditional_embeds
|
||||
):
|
||||
image_embeds = []
|
||||
if do_classifier_free_guidance:
|
||||
if prepare_unconditional_embeds:
|
||||
negative_image_embeds = []
|
||||
if ip_adapter_image_embeds is None:
|
||||
if not isinstance(ip_adapter_image, list):
|
||||
@@ -260,11 +256,11 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock, ModularIPAdapterMixin):
|
||||
)
|
||||
|
||||
image_embeds.append(single_image_embeds[None, :])
|
||||
if do_classifier_free_guidance:
|
||||
if prepare_unconditional_embeds:
|
||||
negative_image_embeds.append(single_negative_image_embeds[None, :])
|
||||
else:
|
||||
for single_image_embeds in ip_adapter_image_embeds:
|
||||
if do_classifier_free_guidance:
|
||||
if prepare_unconditional_embeds:
|
||||
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
|
||||
negative_image_embeds.append(single_negative_image_embeds)
|
||||
image_embeds.append(single_image_embeds)
|
||||
@@ -272,7 +268,7 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock, ModularIPAdapterMixin):
|
||||
ip_adapter_image_embeds = []
|
||||
for i, single_image_embeds in enumerate(image_embeds):
|
||||
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
|
||||
if do_classifier_free_guidance:
|
||||
if prepare_unconditional_embeds:
|
||||
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
|
||||
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
|
||||
|
||||
@@ -285,7 +281,7 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock, ModularIPAdapterMixin):
|
||||
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
|
||||
data = self.get_block_state(state)
|
||||
|
||||
data.do_classifier_free_guidance = data.guidance_scale > 1.0
|
||||
data.prepare_unconditional_embeds = pipeline.guider.num_conditions > 1
|
||||
data.device = pipeline._execution_device
|
||||
|
||||
data.ip_adapter_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
@@ -294,9 +290,9 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock, ModularIPAdapterMixin):
|
||||
ip_adapter_image_embeds=None,
|
||||
device=data.device,
|
||||
num_images_per_prompt=1,
|
||||
do_classifier_free_guidance=data.do_classifier_free_guidance,
|
||||
prepare_unconditional_embeds=data.prepare_unconditional_embeds,
|
||||
)
|
||||
if data.do_classifier_free_guidance:
|
||||
if data.prepare_unconditional_embeds:
|
||||
data.negative_ip_adapter_embeds = []
|
||||
for i, image_embeds in enumerate(data.ip_adapter_embeds):
|
||||
negative_image_embeds, image_embeds = image_embeds.chunk(2)
|
||||
@@ -324,6 +320,7 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
ComponentSpec("text_encoder_2", CLIPTextModelWithProjection),
|
||||
ComponentSpec("tokenizer", CLIPTokenizer),
|
||||
ComponentSpec("tokenizer_2", CLIPTokenizer),
|
||||
ComponentSpec("guider", GuiderType),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -338,7 +335,6 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
InputParam("negative_prompt"),
|
||||
InputParam("negative_prompt_2"),
|
||||
InputParam("cross_attention_kwargs"),
|
||||
InputParam("guidance_scale",default=5.0),
|
||||
InputParam("clip_skip"),
|
||||
]
|
||||
|
||||
@@ -359,7 +355,6 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
elif data.prompt_2 is not None and (not isinstance(data.prompt_2, str) and not isinstance(data.prompt_2, list)):
|
||||
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(data.prompt_2)}")
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt with self -> components
|
||||
def encode_prompt(
|
||||
self,
|
||||
components,
|
||||
@@ -367,7 +362,7 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
prompt_2: Optional[str] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
prepare_unconditional_embeds: bool = True,
|
||||
negative_prompt: Optional[str] = None,
|
||||
negative_prompt_2: Optional[str] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
@@ -390,8 +385,8 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
prepare_unconditional_embeds (`bool`):
|
||||
whether to use prepare unconditional embeddings or not
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
@@ -499,10 +494,10 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
zero_out_negative_prompt = negative_prompt is None and components.config.force_zeros_for_empty_prompt
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
|
||||
if prepare_unconditional_embeds and negative_prompt_embeds is None and zero_out_negative_prompt:
|
||||
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
|
||||
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
|
||||
elif do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
elif prepare_unconditional_embeds and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt_2 = negative_prompt_2 or negative_prompt
|
||||
|
||||
@@ -563,7 +558,7 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
if prepare_unconditional_embeds:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
|
||||
@@ -578,7 +573,7 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
||||
bs_embed * num_images_per_prompt, -1
|
||||
)
|
||||
if do_classifier_free_guidance:
|
||||
if prepare_unconditional_embeds:
|
||||
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
||||
bs_embed * num_images_per_prompt, -1
|
||||
)
|
||||
@@ -602,10 +597,9 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
data = self.get_block_state(state)
|
||||
self.check_inputs(pipeline, data)
|
||||
|
||||
data.do_classifier_free_guidance = data.guidance_scale > 1.0
|
||||
data.prepare_unconditional_embeds = pipeline.guider.num_conditions > 1
|
||||
data.device = pipeline._execution_device
|
||||
|
||||
|
||||
# Encode input prompt
|
||||
data.text_encoder_lora_scale = (
|
||||
data.cross_attention_kwargs.get("scale", None) if data.cross_attention_kwargs is not None else None
|
||||
@@ -621,7 +615,7 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
data.prompt_2,
|
||||
data.device,
|
||||
1,
|
||||
data.do_classifier_free_guidance,
|
||||
data.prepare_unconditional_embeds,
|
||||
data.negative_prompt,
|
||||
data.negative_prompt_2,
|
||||
prompt_embeds=None,
|
||||
@@ -1751,7 +1745,6 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
InputParam("crops_coords_top_left", default=(0, 0)),
|
||||
InputParam("negative_crops_coords_top_left", default=(0, 0)),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam("guidance_scale", required=True),
|
||||
InputParam("aesthetic_score", default=6.0),
|
||||
InputParam("negative_aesthetic_score", default=2.0),
|
||||
]
|
||||
@@ -1898,7 +1891,8 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
and pipeline.unet is not None
|
||||
and pipeline.unet.config.time_cond_proj_dim is not None
|
||||
):
|
||||
data.guidance_scale_tensor = torch.tensor(data.guidance_scale - 1).repeat(data.batch_size * data.num_images_per_prompt)
|
||||
# TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this!
|
||||
data.guidance_scale_tensor = torch.tensor(pipeline.guider.guidance_scale - 1).repeat(data.batch_size * data.num_images_per_prompt)
|
||||
data.timestep_cond = self.get_guidance_scale_embedding(
|
||||
data.guidance_scale_tensor, embedding_dim=pipeline.unet.config.time_cond_proj_dim
|
||||
).to(device=data.device, dtype=data.latents.dtype)
|
||||
@@ -1926,7 +1920,6 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
InputParam("crops_coords_top_left", default=(0, 0)),
|
||||
InputParam("negative_crops_coords_top_left", default=(0, 0)),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam("guidance_scale", default=5.0),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -2052,7 +2045,8 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
and pipeline.unet is not None
|
||||
and pipeline.unet.config.time_cond_proj_dim is not None
|
||||
):
|
||||
data.guidance_scale_tensor = torch.tensor(data.guidance_scale - 1).repeat(data.batch_size * data.num_images_per_prompt)
|
||||
# TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this!
|
||||
data.guidance_scale_tensor = torch.tensor(pipeline.guider.guidance_scale - 1).repeat(data.batch_size * data.num_images_per_prompt)
|
||||
data.timestep_cond = self.get_guidance_scale_embedding(
|
||||
data.guidance_scale_tensor, embedding_dim=pipeline.unet.config.time_cond_proj_dim
|
||||
).to(device=data.device, dtype=data.latents.dtype)
|
||||
@@ -2068,7 +2062,7 @@ class StableDiffusionXLDenoiseStep(PipelineBlock):
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("guider", CFGGuider, obj=CFGGuider()),
|
||||
ComponentSpec("guider", GuiderType, obj=ClassifierFreeGuidance()),
|
||||
ComponentSpec("scheduler", KarrasDiffusionSchedulers),
|
||||
ComponentSpec("unet", UNet2DConditionModel),
|
||||
]
|
||||
@@ -2082,12 +2076,9 @@ class StableDiffusionXLDenoiseStep(PipelineBlock):
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("guidance_scale", default=5.0),
|
||||
InputParam("guidance_rescale", default=0.0),
|
||||
InputParam("cross_attention_kwargs"),
|
||||
InputParam("generator"),
|
||||
InputParam("eta", default=0.0),
|
||||
InputParam("guider_kwargs"),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
]
|
||||
|
||||
@@ -2238,78 +2229,63 @@ class StableDiffusionXLDenoiseStep(PipelineBlock):
|
||||
|
||||
data.num_channels_unet = pipeline.unet.config.in_channels
|
||||
data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False
|
||||
|
||||
# adding default guider arguments: do_classifier_free_guidance, guidance_scale, guidance_rescale
|
||||
data.guider_kwargs = data.guider_kwargs or {}
|
||||
data.guider_kwargs = {
|
||||
**data.guider_kwargs,
|
||||
"disable_guidance": data.disable_guidance,
|
||||
"guidance_scale": data.guidance_scale,
|
||||
"guidance_rescale": data.guidance_rescale,
|
||||
"batch_size": data.batch_size * data.num_images_per_prompt,
|
||||
}
|
||||
|
||||
pipeline.guider.set_guider(pipeline, data.guider_kwargs)
|
||||
# Prepare conditional inputs using the guider
|
||||
data.prompt_embeds = pipeline.guider.prepare_input(
|
||||
data.prompt_embeds,
|
||||
data.negative_prompt_embeds,
|
||||
)
|
||||
data.add_time_ids = pipeline.guider.prepare_input(
|
||||
data.add_time_ids,
|
||||
data.negative_add_time_ids,
|
||||
)
|
||||
data.pooled_prompt_embeds = pipeline.guider.prepare_input(
|
||||
data.pooled_prompt_embeds,
|
||||
data.negative_pooled_prompt_embeds,
|
||||
)
|
||||
|
||||
if data.num_channels_unet == 9:
|
||||
data.mask = pipeline.guider.prepare_input(data.mask, data.mask)
|
||||
data.masked_image_latents = pipeline.guider.prepare_input(data.masked_image_latents, data.masked_image_latents)
|
||||
|
||||
data.added_cond_kwargs = {
|
||||
"text_embeds": data.pooled_prompt_embeds,
|
||||
"time_ids": data.add_time_ids,
|
||||
}
|
||||
|
||||
if data.ip_adapter_embeds is not None:
|
||||
data.ip_adapter_embeds = pipeline.guider.prepare_input(data.ip_adapter_embeds, data.negative_ip_adapter_embeds)
|
||||
data.added_cond_kwargs["image_embeds"] = data.ip_adapter_embeds
|
||||
if data.disable_guidance:
|
||||
pipeline.guider.disable()
|
||||
else:
|
||||
pipeline.guider.enable()
|
||||
|
||||
# Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta)
|
||||
data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0)
|
||||
|
||||
pipeline.guider.set_input_fields(
|
||||
prompt_embeds=("prompt_embeds", "negative_prompt_embeds"),
|
||||
add_time_ids=("add_time_ids", "negative_add_time_ids"),
|
||||
pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"),
|
||||
ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"),
|
||||
)
|
||||
|
||||
with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(data.timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
data.latent_model_input = pipeline.guider.prepare_input(data.latents, data.latents)
|
||||
data.latent_model_input = pipeline.scheduler.scale_model_input(data.latent_model_input, t)
|
||||
pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t)
|
||||
guider_data = pipeline.guider.prepare_inputs(data)
|
||||
|
||||
# inpainting
|
||||
data.scaled_latents = pipeline.scheduler.scale_model_input(data.latents, t)
|
||||
|
||||
# Prepare for inpainting
|
||||
if data.num_channels_unet == 9:
|
||||
data.latent_model_input = torch.cat([data.latent_model_input, data.mask, data.masked_image_latents], dim=1)
|
||||
data.scaled_latents = torch.cat([data.scaled_latents, data.mask, data.masked_image_latents], dim=1)
|
||||
|
||||
for batch in guider_data:
|
||||
pipeline.guider.prepare_models(pipeline.unet)
|
||||
|
||||
# Prepare additional conditionings
|
||||
batch.added_cond_kwargs = {
|
||||
"text_embeds": batch.pooled_prompt_embeds,
|
||||
"time_ids": batch.add_time_ids,
|
||||
}
|
||||
if batch.ip_adapter_embeds is not None:
|
||||
batch.added_cond_kwargs["image_embeds"] = batch.ip_adapter_embeds
|
||||
|
||||
# Predict the noise residual
|
||||
batch.noise_pred = pipeline.unet(
|
||||
data.scaled_latents,
|
||||
t,
|
||||
encoder_hidden_states=batch.prompt_embeds,
|
||||
timestep_cond=data.timestep_cond,
|
||||
cross_attention_kwargs=data.cross_attention_kwargs,
|
||||
added_cond_kwargs=batch.added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
pipeline.guider.cleanup_models(pipeline.unet)
|
||||
|
||||
# predict the noise residual
|
||||
data.noise_pred = pipeline.unet(
|
||||
data.latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=data.prompt_embeds,
|
||||
timestep_cond=data.timestep_cond,
|
||||
cross_attention_kwargs=data.cross_attention_kwargs,
|
||||
added_cond_kwargs=data.added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
# perform guidance
|
||||
data.noise_pred = pipeline.guider.apply_guidance(
|
||||
data.noise_pred,
|
||||
timestep=t,
|
||||
latents=data.latents,
|
||||
)
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
# Perform guidance
|
||||
data.noise_pred, scheduler_step_kwargs = pipeline.guider(guider_data)
|
||||
|
||||
# Perform scheduler step using the predicted output
|
||||
data.latents_dtype = data.latents.dtype
|
||||
data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0]
|
||||
data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0]
|
||||
|
||||
if data.latents.dtype != data.latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
@@ -2328,7 +2304,6 @@ class StableDiffusionXLDenoiseStep(PipelineBlock):
|
||||
if i == len(data.timesteps) - 1 or ((i + 1) > data.num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
pipeline.guider.reset_guider(pipeline)
|
||||
self.add_block_state(state, data)
|
||||
|
||||
return pipeline, state
|
||||
@@ -2341,12 +2316,11 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("guider", CFGGuider, obj=CFGGuider()),
|
||||
ComponentSpec("guider", GuiderType, obj=ClassifierFreeGuidance()),
|
||||
ComponentSpec("scheduler", KarrasDiffusionSchedulers),
|
||||
ComponentSpec("unet", UNet2DConditionModel),
|
||||
ComponentSpec("controlnet", ControlNetModel),
|
||||
ComponentSpec("control_image_processor", VaeImageProcessor, obj=VaeImageProcessor(do_convert_rgb=True, do_normalize=False)),
|
||||
ComponentSpec("controlnet_guider", CFGGuider, obj=CFGGuider()),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -2362,12 +2336,9 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
|
||||
InputParam("controlnet_conditioning_scale", default=1.0),
|
||||
InputParam("guess_mode", default=False),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam("guidance_scale", default=5.0),
|
||||
InputParam("guidance_rescale", default=0.0),
|
||||
InputParam("cross_attention_kwargs"),
|
||||
InputParam("generator"),
|
||||
InputParam("eta", default=0.0),
|
||||
InputParam("guider_kwargs"),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -2514,8 +2485,8 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
|
||||
image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32)
|
||||
else:
|
||||
image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
|
||||
|
||||
image_batch_size = image.shape[0]
|
||||
|
||||
if image_batch_size == 1:
|
||||
repeat_by = batch_size
|
||||
else:
|
||||
@@ -2523,9 +2494,7 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
|
||||
repeat_by = num_images_per_prompt
|
||||
|
||||
image = image.repeat_interleave(repeat_by, dim=0)
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
|
||||
return image
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components
|
||||
@@ -2556,14 +2525,12 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
|
||||
data.num_channels_unet = pipeline.unet.config.in_channels
|
||||
|
||||
# (1) prepare controlnet inputs
|
||||
|
||||
data.device = pipeline._execution_device
|
||||
|
||||
data.height, data.width = data.latents.shape[-2:]
|
||||
data.height = data.height * pipeline.vae_scale_factor
|
||||
data.width = data.width * pipeline.vae_scale_factor
|
||||
|
||||
controlnet = pipeline.controlnet._orig_mod if is_compiled_module(pipeline.controlnet) else pipeline.controlnet
|
||||
controlnet = unwrap_module(pipeline.controlnet)
|
||||
|
||||
# (1.1)
|
||||
# control_guidance_start/control_guidance_end (align format)
|
||||
@@ -2641,72 +2608,30 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
|
||||
data.controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
|
||||
|
||||
# (2) Prepare conditional inputs for unet using the guider
|
||||
# adding default guider arguments: disable_guidance, guidance_scale, guidance_rescale
|
||||
data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False
|
||||
data.guider_kwargs = data.guider_kwargs or {}
|
||||
data.guider_kwargs = {
|
||||
**data.guider_kwargs,
|
||||
"disable_guidance": data.disable_guidance,
|
||||
"guidance_scale": data.guidance_scale,
|
||||
"guidance_rescale": data.guidance_rescale,
|
||||
"batch_size": data.batch_size * data.num_images_per_prompt,
|
||||
}
|
||||
pipeline.guider.set_guider(pipeline, data.guider_kwargs)
|
||||
data.prompt_embeds = pipeline.guider.prepare_input(
|
||||
data.prompt_embeds,
|
||||
data.negative_prompt_embeds,
|
||||
)
|
||||
data.add_time_ids = pipeline.guider.prepare_input(
|
||||
data.add_time_ids,
|
||||
data.negative_add_time_ids,
|
||||
)
|
||||
data.pooled_prompt_embeds = pipeline.guider.prepare_input(
|
||||
data.pooled_prompt_embeds,
|
||||
data.negative_pooled_prompt_embeds,
|
||||
)
|
||||
if data.num_channels_unet == 9:
|
||||
data.mask = pipeline.guider.prepare_input(data.mask, data.mask)
|
||||
data.masked_image_latents = pipeline.guider.prepare_input(data.masked_image_latents, data.masked_image_latents)
|
||||
|
||||
data.added_cond_kwargs = {
|
||||
"text_embeds": data.pooled_prompt_embeds,
|
||||
"time_ids": data.add_time_ids,
|
||||
}
|
||||
|
||||
if data.ip_adapter_embeds is not None:
|
||||
data.ip_adapter_embeds = pipeline.guider.prepare_input(data.ip_adapter_embeds, data.negative_ip_adapter_embeds)
|
||||
data.added_cond_kwargs["image_embeds"] = data.ip_adapter_embeds
|
||||
|
||||
# (3) Prepare conditional inputs for controlnet using the guider
|
||||
data.controlnet_disable_guidance = True if data.disable_guidance or data.guess_mode else False
|
||||
data.controlnet_guider_kwargs = data.guider_kwargs or {}
|
||||
data.controlnet_guider_kwargs = {
|
||||
**data.controlnet_guider_kwargs,
|
||||
"disable_guidance": data.controlnet_disable_guidance,
|
||||
"guidance_scale": data.guidance_scale,
|
||||
"guidance_rescale": data.guidance_rescale,
|
||||
"batch_size": data.batch_size * data.num_images_per_prompt,
|
||||
}
|
||||
pipeline.controlnet_guider.set_guider(pipeline, data.controlnet_guider_kwargs)
|
||||
data.controlnet_prompt_embeds = pipeline.controlnet_guider.prepare_input(data.prompt_embeds)
|
||||
data.controlnet_added_cond_kwargs = {
|
||||
"text_embeds": pipeline.controlnet_guider.prepare_input(data.pooled_prompt_embeds),
|
||||
"time_ids": pipeline.controlnet_guider.prepare_input(data.add_time_ids),
|
||||
}
|
||||
data.control_image = pipeline.controlnet_guider.prepare_input(data.control_image, data.control_image)
|
||||
if data.disable_guidance:
|
||||
pipeline.guider.disable()
|
||||
else:
|
||||
pipeline.guider.enable()
|
||||
|
||||
# (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta)
|
||||
data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0)
|
||||
|
||||
pipeline.guider.set_input_fields(
|
||||
prompt_embeds=("prompt_embeds", "negative_prompt_embeds"),
|
||||
add_time_ids=("add_time_ids", "negative_add_time_ids"),
|
||||
pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"),
|
||||
ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"),
|
||||
)
|
||||
|
||||
# (5) Denoise loop
|
||||
with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(data.timesteps):
|
||||
# prepare latents for unet using the guider
|
||||
data.latent_model_input = pipeline.guider.prepare_input(data.latents, data.latents)
|
||||
pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t)
|
||||
guider_data = pipeline.guider.prepare_inputs(data)
|
||||
|
||||
# prepare latents for controlnet using the guider
|
||||
data.control_model_input = pipeline.controlnet_guider.prepare_input(data.latents, data.latents)
|
||||
data.scaled_latents = pipeline.scheduler.scale_model_input(data.latents, t)
|
||||
|
||||
if isinstance(data.controlnet_keep[i], list):
|
||||
data.cond_scale = [c * s for c, s in zip(data.controlnet_conditioning_scale, data.controlnet_keep[i])]
|
||||
@@ -2715,52 +2640,72 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
|
||||
if isinstance(data.controlnet_cond_scale, list):
|
||||
data.controlnet_cond_scale = data.controlnet_cond_scale[0]
|
||||
data.cond_scale = data.controlnet_cond_scale * data.controlnet_keep[i]
|
||||
|
||||
for batch in guider_data:
|
||||
pipeline.guider.prepare_models(pipeline.unet)
|
||||
|
||||
# Prepare additional conditionings
|
||||
batch.added_cond_kwargs = {
|
||||
"text_embeds": batch.pooled_prompt_embeds,
|
||||
"time_ids": batch.add_time_ids,
|
||||
}
|
||||
if batch.ip_adapter_embeds is not None:
|
||||
batch.added_cond_kwargs["image_embeds"] = batch.ip_adapter_embeds
|
||||
|
||||
# Prepare controlnet additional conditionings
|
||||
batch.controlnet_added_cond_kwargs = {
|
||||
"text_embeds": batch.pooled_prompt_embeds,
|
||||
"time_ids": batch.add_time_ids,
|
||||
}
|
||||
|
||||
data.down_block_res_samples, data.mid_block_res_sample = pipeline.controlnet(
|
||||
pipeline.scheduler.scale_model_input(data.control_model_input, t),
|
||||
t,
|
||||
encoder_hidden_states=data.controlnet_prompt_embeds,
|
||||
controlnet_cond=data.control_image,
|
||||
conditioning_scale=data.cond_scale,
|
||||
guess_mode=data.guess_mode,
|
||||
added_cond_kwargs=data.controlnet_added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
# Will always be run atleast once with every guider
|
||||
if pipeline.guider.is_conditional or not data.guess_mode:
|
||||
data.down_block_res_samples, data.mid_block_res_sample = pipeline.controlnet(
|
||||
data.scaled_latents,
|
||||
t,
|
||||
encoder_hidden_states=batch.prompt_embeds,
|
||||
controlnet_cond=data.control_image,
|
||||
conditioning_scale=data.cond_scale,
|
||||
guess_mode=data.guess_mode,
|
||||
added_cond_kwargs=batch.controlnet_added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
batch.down_block_res_samples = data.down_block_res_samples
|
||||
batch.mid_block_res_sample = data.mid_block_res_sample
|
||||
|
||||
if pipeline.guider.is_unconditional and data.guess_mode:
|
||||
batch.down_block_res_samples = [torch.zeros_like(d) for d in data.down_block_res_samples]
|
||||
batch.mid_block_res_sample = torch.zeros_like(data.mid_block_res_sample)
|
||||
|
||||
# Prepare for inpainting
|
||||
if data.num_channels_unet == 9:
|
||||
data.scaled_latents = torch.cat([data.scaled_latents, data.mask, data.masked_image_latents], dim=1)
|
||||
|
||||
# when we apply guidance for unet, but not for controlnet:
|
||||
# add 0 to the unconditional batch
|
||||
data.down_block_res_samples = pipeline.guider.prepare_input(
|
||||
data.down_block_res_samples, [torch.zeros_like(d) for d in data.down_block_res_samples]
|
||||
)
|
||||
data.mid_block_res_sample = pipeline.guider.prepare_input(
|
||||
data.mid_block_res_sample, torch.zeros_like(data.mid_block_res_sample)
|
||||
)
|
||||
batch.noise_pred = pipeline.unet(
|
||||
data.scaled_latents,
|
||||
t,
|
||||
encoder_hidden_states=batch.prompt_embeds,
|
||||
timestep_cond=data.timestep_cond,
|
||||
cross_attention_kwargs=data.cross_attention_kwargs,
|
||||
added_cond_kwargs=batch.added_cond_kwargs,
|
||||
down_block_additional_residuals=batch.down_block_res_samples,
|
||||
mid_block_additional_residual=batch.mid_block_res_sample,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
pipeline.guider.cleanup_models(pipeline.unet)
|
||||
|
||||
# Perform guidance
|
||||
data.noise_pred, scheduler_step_kwargs = pipeline.guider(guider_data)
|
||||
|
||||
data.latent_model_input = pipeline.scheduler.scale_model_input(data.latent_model_input, t)
|
||||
if data.num_channels_unet == 9:
|
||||
data.latent_model_input = torch.cat([data.latent_model_input, data.mask, data.masked_image_latents], dim=1)
|
||||
|
||||
data.noise_pred = pipeline.unet(
|
||||
data.latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=data.prompt_embeds,
|
||||
timestep_cond=data.timestep_cond,
|
||||
cross_attention_kwargs=data.cross_attention_kwargs,
|
||||
added_cond_kwargs=data.added_cond_kwargs,
|
||||
down_block_additional_residuals=data.down_block_res_samples,
|
||||
mid_block_additional_residual=data.mid_block_res_sample,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
# perform guidance
|
||||
data.noise_pred = pipeline.guider.apply_guidance(data.noise_pred, timestep=t, latents=data.latents)
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
# Perform scheduler step using the predicted output
|
||||
data.latents_dtype = data.latents.dtype
|
||||
data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0]
|
||||
data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0]
|
||||
|
||||
if data.latents.dtype != data.latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
data.latents = data.latents.to(data.latents_dtype)
|
||||
|
||||
|
||||
if data.num_channels_unet == 4 and data.mask is not None and data.image_latents is not None:
|
||||
data.init_latents_proper = data.image_latents
|
||||
@@ -2774,9 +2719,6 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
|
||||
|
||||
if i == len(data.timesteps) - 1 or ((i + 1) > data.num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
pipeline.guider.reset_guider(pipeline)
|
||||
pipeline.controlnet_guider.reset_guider(pipeline)
|
||||
|
||||
self.add_block_state(state, data)
|
||||
|
||||
@@ -2792,8 +2734,7 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock):
|
||||
ComponentSpec("unet", UNet2DConditionModel),
|
||||
ComponentSpec("controlnet", ControlNetUnionModel),
|
||||
ComponentSpec("scheduler", KarrasDiffusionSchedulers),
|
||||
ComponentSpec("guider", CFGGuider, obj=CFGGuider()),
|
||||
ComponentSpec("controlnet_guider", CFGGuider, obj=CFGGuider()),
|
||||
ComponentSpec("guider", GuiderType, obj=ClassifierFreeGuidance()),
|
||||
ComponentSpec("control_image_processor", VaeImageProcessor, obj=VaeImageProcessor(do_convert_rgb=True, do_normalize=False)),
|
||||
]
|
||||
|
||||
@@ -2810,12 +2751,9 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock):
|
||||
InputParam("controlnet_conditioning_scale", default=1.0),
|
||||
InputParam("guess_mode", default=False),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam("guidance_scale", default=5.0),
|
||||
InputParam("guidance_rescale", default=0.0),
|
||||
InputParam("cross_attention_kwargs"),
|
||||
InputParam("generator"),
|
||||
InputParam("eta", default=0.0),
|
||||
InputParam("guider_kwargs")
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -3008,7 +2946,7 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock):
|
||||
data.height = data.height * pipeline.vae_scale_factor
|
||||
data.width = data.width * pipeline.vae_scale_factor
|
||||
|
||||
controlnet = pipeline.controlnet._orig_mod if is_compiled_module(pipeline.controlnet) else pipeline.controlnet
|
||||
controlnet = unwrap_module(pipeline.controlnet)
|
||||
|
||||
# (1.1)
|
||||
# control guidance
|
||||
@@ -3058,7 +2996,6 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock):
|
||||
crops_coords=data.crops_coords,
|
||||
)
|
||||
data.height, data.width = data.control_image[idx].shape[-2:]
|
||||
|
||||
|
||||
# (1.6)
|
||||
# controlnet_keep
|
||||
@@ -3072,80 +3009,32 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock):
|
||||
# (2) Prepare conditional inputs for unet using the guider
|
||||
# adding default guider arguments: disable_guidance, guidance_scale, guidance_rescale
|
||||
data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False
|
||||
data.guider_kwargs = data.guider_kwargs or {}
|
||||
data.guider_kwargs = {
|
||||
**data.guider_kwargs,
|
||||
"disable_guidance": data.disable_guidance,
|
||||
"guidance_scale": data.guidance_scale,
|
||||
"guidance_rescale": data.guidance_rescale,
|
||||
"batch_size": data.batch_size * data.num_images_per_prompt,
|
||||
}
|
||||
pipeline.guider.set_guider(pipeline, data.guider_kwargs)
|
||||
data.prompt_embeds = pipeline.guider.prepare_input(
|
||||
data.prompt_embeds,
|
||||
data.negative_prompt_embeds,
|
||||
)
|
||||
data.add_time_ids = pipeline.guider.prepare_input(
|
||||
data.add_time_ids,
|
||||
data.negative_add_time_ids,
|
||||
)
|
||||
data.pooled_prompt_embeds = pipeline.guider.prepare_input(
|
||||
data.pooled_prompt_embeds,
|
||||
data.negative_pooled_prompt_embeds,
|
||||
)
|
||||
if data.disable_guidance:
|
||||
pipeline.guider.disable()
|
||||
else:
|
||||
pipeline.guider.enable()
|
||||
|
||||
if data.num_channels_unet == 9:
|
||||
data.mask = pipeline.guider.prepare_input(data.mask, data.mask)
|
||||
data.masked_image_latents = pipeline.guider.prepare_input(data.masked_image_latents, data.masked_image_latents)
|
||||
|
||||
data.added_cond_kwargs = {
|
||||
"text_embeds": data.pooled_prompt_embeds,
|
||||
"time_ids": data.add_time_ids,
|
||||
}
|
||||
|
||||
if data.ip_adapter_embeds is not None:
|
||||
data.ip_adapter_embeds = pipeline.guider.prepare_input(data.ip_adapter_embeds, data.negative_ip_adapter_embeds)
|
||||
data.added_cond_kwargs["image_embeds"] = data.ip_adapter_embeds
|
||||
|
||||
# (3) Prepare conditional inputs for controlnet using the guider
|
||||
data.controlnet_disable_guidance = True if data.disable_guidance or data.guess_mode else False
|
||||
data.controlnet_guider_kwargs = data.guider_kwargs or {}
|
||||
data.controlnet_guider_kwargs = {
|
||||
**data.controlnet_guider_kwargs,
|
||||
"disable_guidance": data.controlnet_disable_guidance,
|
||||
"guidance_scale": data.guidance_scale,
|
||||
"guidance_rescale": data.guidance_rescale,
|
||||
"batch_size": data.batch_size * data.num_images_per_prompt,
|
||||
}
|
||||
pipeline.controlnet_guider.set_guider(pipeline, data.controlnet_guider_kwargs)
|
||||
data.controlnet_prompt_embeds = pipeline.controlnet_guider.prepare_input(data.prompt_embeds)
|
||||
data.controlnet_added_cond_kwargs = {
|
||||
"text_embeds": pipeline.controlnet_guider.prepare_input(data.pooled_prompt_embeds),
|
||||
"time_ids": pipeline.controlnet_guider.prepare_input(data.add_time_ids),
|
||||
}
|
||||
for idx, _ in enumerate(data.control_image):
|
||||
data.control_image[idx] = pipeline.controlnet_guider.prepare_input(data.control_image[idx], data.control_image[idx])
|
||||
|
||||
data.control_type = (
|
||||
data.control_type.reshape(1, -1)
|
||||
.to(data.device, dtype=data.prompt_embeds.dtype)
|
||||
)
|
||||
data.control_type = data.control_type.reshape(1, -1).to(data.device, dtype=data.prompt_embeds.dtype)
|
||||
repeat_by = data.batch_size * data.num_images_per_prompt // data.control_type.shape[0]
|
||||
data.control_type = data.control_type.repeat_interleave(repeat_by, dim=0)
|
||||
data.control_type = pipeline.controlnet_guider.prepare_input(data.control_type, data.control_type)
|
||||
|
||||
# (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta)
|
||||
data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0)
|
||||
|
||||
pipeline.guider.set_input_fields(
|
||||
prompt_embeds=("prompt_embeds", "negative_prompt_embeds"),
|
||||
add_time_ids=("add_time_ids", "negative_add_time_ids"),
|
||||
pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"),
|
||||
ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"),
|
||||
)
|
||||
|
||||
with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(data.timesteps):
|
||||
# prepare latents for unet using the guider
|
||||
data.latent_model_input = pipeline.guider.prepare_input(data.latents, data.latents)
|
||||
pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t)
|
||||
guider_data = pipeline.guider.prepare_inputs(data)
|
||||
|
||||
# prepare latents for controlnet using the guider
|
||||
data.control_model_input = pipeline.controlnet_guider.prepare_input(data.latents, data.latents)
|
||||
data.scaled_latents = pipeline.scheduler.scale_model_input(data.latents, t)
|
||||
|
||||
if isinstance(data.controlnet_keep[i], list):
|
||||
data.cond_scale = [c * s for c, s in zip(data.controlnet_conditioning_scale, data.controlnet_keep[i])]
|
||||
@@ -3154,49 +3043,69 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock):
|
||||
if isinstance(data.controlnet_cond_scale, list):
|
||||
data.controlnet_cond_scale = data.controlnet_cond_scale[0]
|
||||
data.cond_scale = data.controlnet_cond_scale * data.controlnet_keep[i]
|
||||
|
||||
for batch in guider_data:
|
||||
pipeline.guider.prepare_models(pipeline.unet)
|
||||
|
||||
# Prepare additional conditionings
|
||||
batch.added_cond_kwargs = {
|
||||
"text_embeds": batch.pooled_prompt_embeds,
|
||||
"time_ids": batch.add_time_ids,
|
||||
}
|
||||
if batch.ip_adapter_embeds is not None:
|
||||
batch.added_cond_kwargs["image_embeds"] = batch.ip_adapter_embeds
|
||||
|
||||
# Prepare controlnet additional conditionings
|
||||
batch.controlnet_added_cond_kwargs = {
|
||||
"text_embeds": batch.pooled_prompt_embeds,
|
||||
"time_ids": batch.add_time_ids,
|
||||
}
|
||||
|
||||
data.down_block_res_samples, data.mid_block_res_sample = pipeline.controlnet(
|
||||
pipeline.scheduler.scale_model_input(data.control_model_input, t),
|
||||
t,
|
||||
encoder_hidden_states=data.controlnet_prompt_embeds,
|
||||
controlnet_cond=data.control_image,
|
||||
control_type=data.control_type,
|
||||
control_type_idx=data.control_mode,
|
||||
conditioning_scale=data.cond_scale,
|
||||
guess_mode=data.guess_mode,
|
||||
added_cond_kwargs=data.controlnet_added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
# Will always be run atleast once with every guider
|
||||
if pipeline.guider.is_conditional or not data.guess_mode:
|
||||
data.down_block_res_samples, data.mid_block_res_sample = pipeline.controlnet(
|
||||
data.scaled_latents,
|
||||
t,
|
||||
encoder_hidden_states=batch.prompt_embeds,
|
||||
controlnet_cond=data.control_image,
|
||||
control_type=data.control_type,
|
||||
control_type_idx=data.control_mode,
|
||||
conditioning_scale=data.cond_scale,
|
||||
guess_mode=data.guess_mode,
|
||||
added_cond_kwargs=batch.controlnet_added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
batch.down_block_res_samples = data.down_block_res_samples
|
||||
batch.mid_block_res_sample = data.mid_block_res_sample
|
||||
|
||||
if pipeline.guider.is_unconditional and data.guess_mode:
|
||||
batch.down_block_res_samples = [torch.zeros_like(d) for d in data.down_block_res_samples]
|
||||
batch.mid_block_res_sample = torch.zeros_like(data.mid_block_res_sample)
|
||||
|
||||
# when we apply guidance for unet, but not for controlnet:
|
||||
# add 0 to the unconditional batch
|
||||
data.down_block_res_samples = pipeline.guider.prepare_input(
|
||||
data.down_block_res_samples, [torch.zeros_like(d) for d in data.down_block_res_samples]
|
||||
)
|
||||
data.mid_block_res_sample = pipeline.guider.prepare_input(
|
||||
data.mid_block_res_sample, torch.zeros_like(data.mid_block_res_sample)
|
||||
)
|
||||
if data.num_channels_unet == 9:
|
||||
data.scaled_latents = torch.cat([data.scaled_latents, data.mask, data.masked_image_latents], dim=1)
|
||||
|
||||
data.latent_model_input = pipeline.scheduler.scale_model_input(data.latent_model_input, t)
|
||||
if data.num_channels_unet == 9:
|
||||
data.latent_model_input = torch.cat([data.latent_model_input, data.mask, data.masked_image_latents], dim=1)
|
||||
batch.noise_pred = pipeline.unet(
|
||||
data.scaled_latents,
|
||||
t,
|
||||
encoder_hidden_states=batch.prompt_embeds,
|
||||
timestep_cond=data.timestep_cond,
|
||||
cross_attention_kwargs=data.cross_attention_kwargs,
|
||||
added_cond_kwargs=batch.added_cond_kwargs,
|
||||
down_block_additional_residuals=batch.down_block_res_samples,
|
||||
mid_block_additional_residual=batch.mid_block_res_sample,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
pipeline.guider.cleanup_models(pipeline.unet)
|
||||
|
||||
# Perform guidance
|
||||
data.noise_pred, scheduler_step_kwargs = pipeline.guider(guider_data)
|
||||
|
||||
data.noise_pred = pipeline.unet(
|
||||
data.latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=data.prompt_embeds,
|
||||
timestep_cond=data.timestep_cond,
|
||||
cross_attention_kwargs=data.cross_attention_kwargs,
|
||||
added_cond_kwargs=data.added_cond_kwargs,
|
||||
down_block_additional_residuals=data.down_block_res_samples,
|
||||
mid_block_additional_residual=data.mid_block_res_sample,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
# perform guidance
|
||||
data.noise_pred = pipeline.guider.apply_guidance(data.noise_pred, timestep=t, latents=data.latents)
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
# Perform scheduler step using the predicted output
|
||||
data.latents_dtype = data.latents.dtype
|
||||
data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0]
|
||||
data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0]
|
||||
|
||||
if data.latents.dtype != data.latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
@@ -3209,14 +3118,10 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock):
|
||||
data.init_latents_proper = pipeline.scheduler.add_noise(
|
||||
data.init_latents_proper, data.noise, torch.tensor([data.noise_timestep])
|
||||
)
|
||||
|
||||
data.latents = (1 - data.mask) * data.init_latents_proper + data.mask * data.latents
|
||||
|
||||
if i == len(data.timesteps) - 1 or ((i + 1) > data.num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
pipeline.guider.reset_guider(pipeline)
|
||||
pipeline.controlnet_guider.reset_guider(pipeline)
|
||||
|
||||
self.add_block_state(state, data)
|
||||
|
||||
@@ -3543,6 +3448,11 @@ class StableDiffusionXLAutoPipeline(SequentialPipelineBlocks):
|
||||
"- to run the ip_adapter workflow, you need to provide `ip_adapter_image`\n" + \
|
||||
"- for text-to-image generation, all you need to provide is `prompt`"
|
||||
|
||||
# TODO(yiyi, aryan): We need another step before text encoder to set the `num_inference_steps` attribute for guider so that
|
||||
# things like when to do guidance and how many conditions to be prepared can be determined. Currently, this is done by
|
||||
# always assuming you want to do guidance in the Guiders. So, negative embeddings are prepared regardless of what the
|
||||
# configuration of guider is.
|
||||
|
||||
# block mapping
|
||||
TEXT2IMAGE_BLOCKS = OrderedDict([
|
||||
("text_encoder", StableDiffusionXLTextEncoderStep),
|
||||
@@ -3664,7 +3574,6 @@ SDXL_INPUTS_SCHEMA = {
|
||||
"negative_prompt": InputParam("negative_prompt", type_hint=Union[str, List[str]], description="The prompt or prompts not to guide the image generation"),
|
||||
"negative_prompt_2": InputParam("negative_prompt_2", type_hint=Union[str, List[str]], description="The negative prompt or prompts for text_encoder_2"),
|
||||
"cross_attention_kwargs": InputParam("cross_attention_kwargs", type_hint=Optional[dict], description="Kwargs dictionary passed to the AttentionProcessor"),
|
||||
"guidance_scale": InputParam("guidance_scale", type_hint=float, default=5.0, description="Classifier-Free Diffusion Guidance scale"),
|
||||
"clip_skip": InputParam("clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder"),
|
||||
"image": InputParam("image", type_hint=PipelineImageInput, required=True, description="The image(s) to modify for img2img or inpainting"),
|
||||
"mask_image": InputParam("mask_image", type_hint=PipelineImageInput, required=True, description="Mask image for inpainting, white pixels will be repainted"),
|
||||
@@ -3689,9 +3598,7 @@ SDXL_INPUTS_SCHEMA = {
|
||||
"negative_crops_coords_top_left": InputParam("negative_crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Negative conditioning crop coordinates"),
|
||||
"aesthetic_score": InputParam("aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image"),
|
||||
"negative_aesthetic_score": InputParam("negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"),
|
||||
"guidance_rescale": InputParam("guidance_rescale", type_hint=float, default=0.0, description="Guidance rescale factor to fix overexposure"),
|
||||
"eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"),
|
||||
"guider_kwargs": InputParam("guider_kwargs", type_hint=Optional[Dict[str, Any]], description="Kwargs dictionary passed to the Guider"),
|
||||
"output_type": InputParam("output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"),
|
||||
"return_dict": InputParam("return_dict", type_hint=bool, default=True, description="Whether to return a StableDiffusionXLPipelineOutput"),
|
||||
"ip_adapter_image": InputParam("ip_adapter_image", type_hint=PipelineImageInput, required=True, description="Image(s) to be used as IP adapter"),
|
||||
@@ -3757,4 +3664,4 @@ SDXL_INTERMEDIATE_OUTPUTS_SCHEMA = {
|
||||
|
||||
SDXL_OUTPUTS_SCHEMA = {
|
||||
"images": OutputParam("images", type_hint=Union[Tuple[Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]]], StableDiffusionXLPipelineOutput], description="The final generated images")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -90,6 +90,11 @@ def is_compiled_module(module) -> bool:
|
||||
return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)
|
||||
|
||||
|
||||
def unwrap_module(module):
|
||||
"""Unwraps a module if it was compiled with torch.compile()"""
|
||||
return module._orig_mod if is_compiled_module(module) else module
|
||||
|
||||
|
||||
def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor":
|
||||
"""Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497).
|
||||
|
||||
|
||||
Reference in New Issue
Block a user