1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

skip layer guidance

This commit is contained in:
Aryan
2025-04-03 03:26:55 +02:00
parent 594e8d663f
commit 5ac7f360b0
9 changed files with 407 additions and 42 deletions

View File

@@ -130,15 +130,17 @@ except OptionalDependencyNotAvailable:
_import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
else:
_import_structure["guiders"].extend(["ClassifierFreeGuidance"])
_import_structure["guiders"].extend(["ClassifierFreeGuidance", "SkipLayerGuidance"])
_import_structure["hooks"].extend(
[
"FasterCacheConfig",
"FirstBlockCacheConfig",
"HookRegistry",
"LayerSkipConfig",
"PyramidAttentionBroadcastConfig",
"apply_faster_cache",
"apply_first_block_cache",
"apply_layer_skip",
"apply_pyramid_attention_broadcast",
]
)
@@ -712,14 +714,16 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from .utils.dummy_pt_objects import * # noqa F403
else:
from .guiders import ClassifierFreeGuidance
from .guiders import ClassifierFreeGuidance, SkipLayerGuidance
from .hooks import (
FasterCacheConfig,
FirstBlockCacheConfig,
HookRegistry,
LayerSkipConfig,
PyramidAttentionBroadcastConfig,
apply_faster_cache,
apply_first_block_cache,
apply_layer_skip,
apply_pyramid_attention_broadcast,
)
from .models import (

View File

@@ -18,3 +18,4 @@ from ..utils import is_torch_available
if is_torch_available():
from .classifier_free_guidance import ClassifierFreeGuidance
from .guider_utils import GuidanceMixin, _raise_guidance_deprecation_warning
from .skip_layer_guidance import SkipLayerGuidance

View File

@@ -43,11 +43,11 @@ class ClassifierFreeGuidance(GuidanceMixin):
paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.
Args:
scale (`float`, defaults to `7.5`):
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
rescale (`float`, defaults to `0.0`):
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
@@ -56,31 +56,26 @@ class ClassifierFreeGuidance(GuidanceMixin):
we use the diffusers-native implementation that has been in the codebase for a long time.
"""
def __init__(self, scale: float = 7.5, rescale: float = 0.0, use_original_formulation: bool = False):
self.scale = scale
self.rescale = rescale
def __init__(
self, guidance_scale: float = 7.5, guidance_rescale: float = 0.0, use_original_formulation: bool = False
):
self.guidance_scale = guidance_scale
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
if math.isclose(self.scale, 1.0):
if math.isclose(self.guidance_scale, 1.0):
return pred_cond
shift = pred_cond - pred_uncond
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.scale * shift
if self.rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.rescale)
pred = pred + self.guidance_scale * shift
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred
@property
def num_conditions(self) -> int:
if math.isclose(self.scale, 1.0):
return 1
return 2
@property
def guidance_scale(self) -> float:
return self.scale
@property
def guidance_rescale(self) -> float:
return self.rescale
num_conditions = 1
if not math.isclose(self.guidance_scale, 1.0):
num_conditions += 1
return num_conditions

View File

@@ -25,6 +25,19 @@ logger = get_logger(__name__) # pylint: disable=invalid-name
class GuidanceMixin:
r"""Base mixin class providing the skeleton for implementing guidance techniques."""
def __init__(self):
self._step: int = None
self._num_inference_steps: int = None
self._timestep: torch.LongTensor = None
def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None:
self._step = step
self._num_inference_steps = num_inference_steps
self._timestep = timestep
def prepare_models(self, denoiser: torch.nn.Module) -> None:
pass
def prepare_inputs(self, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
num_conditions = self.num_conditions
list_of_inputs = []
@@ -32,16 +45,27 @@ class GuidanceMixin:
if isinstance(arg, torch.Tensor):
list_of_inputs.append([arg] * num_conditions)
elif isinstance(arg, (tuple, list)):
inputs = [x for x in arg if x is not None]
if len(inputs) < num_conditions:
raise ValueError(f"Required at least {num_conditions} inputs, but got {len(inputs)}.")
list_of_inputs.append(inputs[:num_conditions])
if len(arg) != 2:
raise ValueError(
f"Expected a tuple or list of length 2, but got {len(arg)} for argument {arg}. Please provide a tuple/list of length 2 "
f"with the first element being the conditional input and the second element being the unconditional input or None."
)
if arg[1] is None:
# Only conditioning inputs for all batches
list_of_inputs.append([arg[0]] * num_conditions)
else:
# Alternating conditional and unconditional inputs as batches
inputs = [arg[i % 2] for i in range(num_conditions)]
list_of_inputs.append(inputs)
else:
raise ValueError(
f"Expected a tensor, tuple, or list, but got {type(arg)} for argument {arg}. Please provide a tensor, tuple, or list."
)
return tuple(list_of_inputs)
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
pass
def __call__(self, *args) -> Any:
if len(args) != self.num_conditions:
raise ValueError(

View File

@@ -0,0 +1,195 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List, Optional, Tuple, Union
import torch
from ..hooks import HookRegistry, LayerSkipConfig
from ..hooks.layer_skip import _apply_layer_skip_hook
from .guider_utils import GuidanceMixin, rescale_noise_cfg
class SkipLayerGuidance(GuidanceMixin):
"""
Skip Layer Guidance (SLG): https://github.com/Stability-AI/sd3.5
CFG is a technique used to improve generation quality and condition-following in diffusion models. It works by
jointly training a model on both conditional and unconditional data, and using a weighted sum of the two during
inference. This allows the model to tradeoff between generation quality and sample diversity.
The original paper proposes scaling and shifting the conditional distribution based on the difference between
conditional and unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)]
Diffusers implemented the scaling and shifting on the unconditional prediction instead, which is equivalent to what
the original paper proposed in theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)]
The intution behind the original formulation can be thought of as moving the conditional distribution estimates
further away from the unconditional distribution estimates, while the diffusers-native implementation can be
thought of as moving the unconditional distribution towards the conditional distribution estimates to get rid of
the unconditional predictions (usually negative features like "bad quality, bad anotomy, watermarks", etc.)
The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the
paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time.
"""
def __init__(
self,
guidance_scale: float = 7.5,
skip_layer_guidance_scale: float = 2.8,
skip_layer_guidance_start: float = 0.01,
skip_layer_guidance_stop: float = 0.2,
skip_guidance_layers: Optional[Union[int, List[int]]] = None,
skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig]] = None,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
):
self.guidance_scale = guidance_scale
self.skip_layer_guidance_scale = skip_layer_guidance_scale
self.skip_layer_guidance_start = skip_layer_guidance_start
self.skip_layer_guidance_stop = skip_layer_guidance_stop
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
if skip_guidance_layers is None and skip_layer_config is None:
raise ValueError(
"Either `skip_guidance_layers` or `skip_layer_config` must be provided to enable Skip Layer Guidance."
)
if skip_guidance_layers is not None and skip_layer_config is not None:
raise ValueError("Only one of `skip_guidance_layers` or `skip_layer_config` can be provided.")
if skip_guidance_layers is not None:
if isinstance(skip_guidance_layers, int):
skip_guidance_layers = [skip_guidance_layers]
if not isinstance(skip_guidance_layers, list):
raise ValueError(
f"Expected `skip_guidance_layers` to be an int or a list of ints, but got {type(skip_guidance_layers)}."
)
skip_layer_config = [LayerSkipConfig(layer, fqn="auto") for layer in skip_guidance_layers]
if isinstance(skip_layer_config, LayerSkipConfig):
skip_layer_config = [skip_layer_config]
if not isinstance(skip_layer_config, list):
raise ValueError(
f"Expected `skip_layer_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(skip_layer_config)}."
)
self.skip_layer_config = skip_layer_config
self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))]
def prepare_models(self, denoiser: torch.nn.Module):
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
# Register the hooks for layer skipping if the step is within the specified range
if skip_start_step < self._step < skip_stop_step:
for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config):
_apply_layer_skip_hook(denoiser, config, name=name)
def prepare_inputs(self, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
num_conditions = self.num_conditions
list_of_inputs = []
for arg in args:
if isinstance(arg, torch.Tensor):
list_of_inputs.append([arg] * num_conditions)
elif isinstance(arg, (tuple, list)):
if len(arg) != 2:
raise ValueError(
f"Expected a tuple or list of length 2, but got {len(arg)} for argument {arg}. Please provide a tuple/list of length 2 "
f"with the first element being the conditional input and the second element being the unconditional input or None."
)
if arg[1] is None:
# Only conditioning inputs for all batches
list_of_inputs.append([arg[0]] * num_conditions)
else:
list_of_inputs.append([arg[0], arg[1], arg[0]])
else:
raise ValueError(
f"Expected a tensor, tuple, or list, but got {type(arg)} for argument {arg}. Please provide a tensor, tuple, or list."
)
return tuple(list_of_inputs)
def cleanup_models(self, denoiser: torch.nn.Module):
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
# Remove the hooks after inference
for hook_name in self._skip_layer_hook_names:
registry.remove_hook(hook_name, recurse=True)
def forward(
self,
pred_cond: torch.Tensor,
pred_uncond: Optional[torch.Tensor] = None,
pred_cond_skip: Optional[torch.Tensor] = None,
) -> torch.Tensor:
pred = None
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
if math.isclose(self.guidance_scale, 1.0) and math.isclose(self.skip_layer_guidance_scale, 1.0):
pred = pred_cond
elif math.isclose(self.guidance_scale, 1.0):
if skip_start_step < self._step < skip_stop_step:
shift = pred_cond - pred_cond_skip
pred = pred_cond if self.use_original_formulation else pred_cond_skip
pred = pred + self.skip_layer_guidance_scale * shift
else:
pred = pred_cond
elif math.isclose(self.skip_layer_guidance_scale, 1.0):
shift = pred_cond - pred_uncond
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift
else:
shift = pred_cond - pred_uncond
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift
if skip_start_step < self._step < skip_stop_step:
shift_skip = pred_cond - pred_cond_skip
pred = pred + self.skip_layer_guidance_scale * shift_skip
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred
@property
def num_conditions(self) -> int:
num_conditions = 1
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
if not math.isclose(self.guidance_scale, 1.0):
num_conditions += 1
if not math.isclose(self.skip_layer_guidance_scale, 1.0) and skip_start_step < self._step < skip_stop_step:
num_conditions += 1
return num_conditions

View File

@@ -20,5 +20,6 @@ if is_torch_available():
from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache
from .group_offloading import apply_group_offloading
from .hooks import HookRegistry, ModelHook
from .layer_skip import LayerSkipConfig, apply_layer_skip
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast

View File

@@ -0,0 +1,110 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import List, Optional
import torch
from ..utils import get_logger
from ..utils.torch_utils import unwrap_module
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
from ._helpers import TransformerBlockRegistry
from .hooks import HookRegistry, ModelHook
logger = get_logger(__name__) # pylint: disable=invalid-name
_LAYER_SKIP_HOOK = "layer_skip_hook"
@dataclass
class LayerSkipConfig:
r"""
Configuration for skipping internal transformer blocks when executing a transformer model.
Args:
indices (`List[int]`):
The indices of the layer to skip. This is typically the first layer in the transformer block.
fqn (`str`, defaults to `"auto"`):
The fully qualified name identifying the stack of transformer blocks. Typically, this is
`transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`.
"""
indices: List[int]
fqn: str = "auto"
class LayerSkipHook(ModelHook):
def initialize_hook(self, module):
self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__)
return module
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
return self._metadata.skip_block_output_fn(module, *args, **kwargs)
def apply_layer_skip(module: torch.nn.Module, config: LayerSkipConfig) -> None:
r"""
Apply layer skipping to internal layers of a transformer.
Args:
module (`torch.nn.Module`):
The transformer model to which the layer skip hook should be applied.
config (`LayerSkipConfig`):
The configuration for the layer skip hook.
Example:
```python
>>> from diffusers import apply_layer_skip_hook, CogVideoXTransformer3DModel, LayerSkipConfig
>>> transformer = CogVideoXTransformer3DModel.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
>>> config = LayerSkipConfig(layer_index=[10, 20], fqn="transformer_blocks")
>>> apply_layer_skip_hook(transformer, config)
```
"""
_apply_layer_skip_hook(module, config)
def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, name: Optional[str] = None) -> None:
name = name or _LAYER_SKIP_HOOK
if config.fqn == "auto":
for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS:
if hasattr(module, identifier):
config.fqn = identifier
break
else:
raise ValueError(
"Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid "
"`fqn` (fully qualified name) that identifies a stack of transformer blocks."
)
transformer_blocks = getattr(module, config.fqn, None)
if transformer_blocks is None or not isinstance(transformer_blocks, torch.nn.ModuleList):
raise ValueError(
f"Could not find {config.fqn} in the provided module, or configured `fqn` (fully qualified name) does not identify "
f"a `torch.nn.ModuleList`. Please provide a valid `fqn` that identifies a stack of transformer blocks."
)
if len(config.indices) == 0:
raise ValueError("Layer index list is empty. Please provide a non-empty list of layer indices to skip.")
for i, block in enumerate(transformer_blocks):
if i not in config.indices:
continue
logger.debug(f"Apply LayerSkipHook to '{config.fqn}.{i}'")
registry = HookRegistry.check_if_exists_or_initialize(block)
hook = LayerSkipHook()
registry.register_hook(hook, name)

View File

@@ -520,7 +520,7 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
_raise_guidance_deprecation_warning(guidance_scale=guidance_scale)
if guidance is None:
guidance = ClassifierFreeGuidance(scale=guidance_scale)
guidance = ClassifierFreeGuidance(guidance_scale=guidance_scale)
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
@@ -612,29 +612,34 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
)
self._num_timesteps = len(timesteps)
latents, prompt_embeds, original_size, target_size, crops_coords_top_left = guidance.prepare_inputs(
latents,
(prompt_embeds, negative_prompt_embeds),
original_size,
target_size,
crops_coords_top_left,
)
# Denoising loop
transformer_dtype = self.transformer.dtype
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
conds = [prompt_embeds, negative_prompt_embeds, original_size, target_size, crops_coords_top_left]
prompt_embeds, negative_prompt_embeds, original_size, target_size, crops_coords_top_left = [[v] for v in conds]
with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
for i, t in enumerate(timesteps):
self._current_timestep = t
if self.interrupt:
continue
guidance.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t)
guidance.prepare_models(self.transformer)
latents, prompt_embeds, original_size, target_size, crops_coords_top_left = guidance.prepare_inputs(
latents,
(prompt_embeds[0], negative_prompt_embeds[0]),
original_size[0],
target_size[0],
crops_coords_top_left[0],
)
noise_preds = []
for i, (latent, condition, original_size_c, target_size_c, crop_coord_c) in enumerate(
for batch_index, (latent, condition, original_size_c, target_size_c, crop_coord_c) in enumerate(
zip(latents, prompt_embeds, original_size, target_size, crops_coords_top_left)
):
cc.mark_state(f"batch_{i}")
cc.mark_state(f"batch_{batch_index}")
latent = latent.to(transformer_dtype)
timestep = t.expand(latent.shape[0])
noise_pred = self.transformer(
@@ -651,6 +656,7 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
noise_pred = guidance(*noise_preds)
latents = self.scheduler.step(noise_pred, t, latents[0], return_dict=False)[0]
guidance.cleanup_models(self.transformer)
# call the callback, if provided
if callback_on_step_end is not None:
@@ -664,10 +670,6 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds[0])
]
latents, prompt_embeds = guidance.prepare_inputs(
latents, (prompt_embeds[0], negative_prompt_embeds[0])
)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
@@ -675,7 +677,6 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
xm.mark_step()
self._current_timestep = None
latents = latents[0]
if not output_type == "latent":
latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor

View File

@@ -17,6 +17,21 @@ class ClassifierFreeGuidance(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class SkipLayerGuidance(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class FasterCacheConfig(metaclass=DummyObject):
_backends = ["torch"]
@@ -62,6 +77,21 @@ class HookRegistry(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class LayerSkipConfig(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class PyramidAttentionBroadcastConfig(metaclass=DummyObject):
_backends = ["torch"]
@@ -85,6 +115,10 @@ def apply_first_block_cache(*args, **kwargs):
requires_backends(apply_first_block_cache, ["torch"])
def apply_layer_skip(*args, **kwargs):
requires_backends(apply_layer_skip, ["torch"])
def apply_pyramid_attention_broadcast(*args, **kwargs):
requires_backends(apply_pyramid_attention_broadcast, ["torch"])