mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
* cfg; slg; pag; sdxl without controlnet * support sdxl controlnet * support controlnet union * update * update * cfg zero* * use unwrap_module for torch compiled modules * remove guider kwargs * remove commented code * remove old guider * fix slg bug * remove debug print * autoguidance * smoothed energy guidance * add note about seg * tangential cfg * cfg plus plus * support cfgpp in ddim * apply review suggestions * refactor * rename enable/disable * remove cfg++ for now * rename do_classifier_free_guidance->prepare_unconditional_embeds * remove unused
230 lines
10 KiB
Python
230 lines
10 KiB
Python
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import math
|
|
from dataclasses import dataclass
|
|
from typing import Callable, List, Optional
|
|
|
|
import torch
|
|
|
|
from ..utils import get_logger
|
|
from ..utils.torch_utils import unwrap_module
|
|
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _FEEDFORWARD_CLASSES, _get_submodule_from_fqn
|
|
from ._helpers import AttentionProcessorRegistry, TransformerBlockRegistry
|
|
from .hooks import HookRegistry, ModelHook
|
|
|
|
|
|
logger = get_logger(__name__) # pylint: disable=invalid-name
|
|
|
|
_LAYER_SKIP_HOOK = "layer_skip_hook"
|
|
|
|
|
|
@dataclass
|
|
class LayerSkipConfig:
|
|
r"""
|
|
Configuration for skipping internal transformer blocks when executing a transformer model.
|
|
|
|
Args:
|
|
indices (`List[int]`):
|
|
The indices of the layer to skip. This is typically the first layer in the transformer block.
|
|
fqn (`str`, defaults to `"auto"`):
|
|
The fully qualified name identifying the stack of transformer blocks. Typically, this is
|
|
`transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`.
|
|
For automatic detection, set this to `"auto"`.
|
|
"auto" only works on DiT models. For UNet models, you must provide the correct fqn.
|
|
skip_attention (`bool`, defaults to `True`):
|
|
Whether to skip attention blocks.
|
|
skip_ff (`bool`, defaults to `True`):
|
|
Whether to skip feed-forward blocks.
|
|
skip_attention_scores (`bool`, defaults to `False`):
|
|
Whether to skip attention score computation in the attention blocks. This is equivalent to using `value`
|
|
projections as the output of scaled dot product attention.
|
|
dropout (`float`, defaults to `1.0`):
|
|
The dropout probability for dropping the outputs of the skipped layers. By default, this is set to `1.0`,
|
|
meaning that the outputs of the skipped layers are completely ignored. If set to `0.0`, the outputs of the
|
|
skipped layers are fully retained, which is equivalent to not skipping any layers.
|
|
"""
|
|
|
|
indices: List[int]
|
|
fqn: str = "auto"
|
|
skip_attention: bool = True
|
|
skip_attention_scores: bool = False
|
|
skip_ff: bool = True
|
|
dropout: float = 1.0
|
|
|
|
def __post_init__(self):
|
|
if not (0 <= self.dropout <= 1):
|
|
raise ValueError(f"Expected `dropout` to be between 0.0 and 1.0, but got {self.dropout}.")
|
|
if not math.isclose(self.dropout, 1.0) and self.skip_attention_scores:
|
|
raise ValueError(
|
|
"Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
|
|
)
|
|
|
|
|
|
class AttentionScoreSkipFunctionMode(torch.overrides.TorchFunctionMode):
|
|
def __torch_function__(self, func, types, args=(), kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
if func is torch.nn.functional.scaled_dot_product_attention:
|
|
value = kwargs.get("value", None)
|
|
if value is None:
|
|
value = args[2]
|
|
return value
|
|
return func(*args, **kwargs)
|
|
|
|
|
|
class AttentionProcessorSkipHook(ModelHook):
|
|
def __init__(self, skip_processor_output_fn: Callable, skip_attention_scores: bool = False, dropout: float = 1.0):
|
|
self.skip_processor_output_fn = skip_processor_output_fn
|
|
self.skip_attention_scores = skip_attention_scores
|
|
self.dropout = dropout
|
|
|
|
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
|
if self.skip_attention_scores:
|
|
if not math.isclose(self.dropout, 1.0):
|
|
raise ValueError(
|
|
"Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
|
|
)
|
|
with AttentionScoreSkipFunctionMode():
|
|
output = self.fn_ref.original_forward(*args, **kwargs)
|
|
else:
|
|
if math.isclose(self.dropout, 1.0):
|
|
output = self.skip_processor_output_fn(module, *args, **kwargs)
|
|
else:
|
|
output = self.fn_ref.original_forward(*args, **kwargs)
|
|
output = torch.nn.functional.dropout(output, p=self.dropout)
|
|
return output
|
|
|
|
|
|
class FeedForwardSkipHook(ModelHook):
|
|
def __init__(self, dropout: float):
|
|
super().__init__()
|
|
self.dropout = dropout
|
|
|
|
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
|
if math.isclose(self.dropout, 1.0):
|
|
output = kwargs.get("hidden_states", None)
|
|
if output is None:
|
|
output = kwargs.get("x", None)
|
|
if output is None and len(args) > 0:
|
|
output = args[0]
|
|
else:
|
|
output = self.fn_ref.original_forward(*args, **kwargs)
|
|
output = torch.nn.functional.dropout(output, p=self.dropout)
|
|
return output
|
|
|
|
|
|
class TransformerBlockSkipHook(ModelHook):
|
|
def __init__(self, dropout: float):
|
|
super().__init__()
|
|
self.dropout = dropout
|
|
|
|
def initialize_hook(self, module):
|
|
self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__)
|
|
return module
|
|
|
|
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
|
if math.isclose(self.dropout, 1.0):
|
|
output = self._metadata.skip_block_output_fn(module, *args, **kwargs)
|
|
else:
|
|
output = self.fn_ref.original_forward(*args, **kwargs)
|
|
output = torch.nn.functional.dropout(output, p=self.dropout)
|
|
return output
|
|
|
|
def apply_layer_skip(module: torch.nn.Module, config: LayerSkipConfig) -> None:
|
|
r"""
|
|
Apply layer skipping to internal layers of a transformer.
|
|
|
|
Args:
|
|
module (`torch.nn.Module`):
|
|
The transformer model to which the layer skip hook should be applied.
|
|
config (`LayerSkipConfig`):
|
|
The configuration for the layer skip hook.
|
|
|
|
Example:
|
|
|
|
```python
|
|
>>> from diffusers import apply_layer_skip_hook, CogVideoXTransformer3DModel, LayerSkipConfig
|
|
>>> transformer = CogVideoXTransformer3DModel.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
|
|
>>> config = LayerSkipConfig(layer_index=[10, 20], fqn="transformer_blocks")
|
|
>>> apply_layer_skip_hook(transformer, config)
|
|
```
|
|
"""
|
|
_apply_layer_skip_hook(module, config)
|
|
|
|
|
|
def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, name: Optional[str] = None) -> None:
|
|
name = name or _LAYER_SKIP_HOOK
|
|
|
|
if config.skip_attention and config.skip_attention_scores:
|
|
raise ValueError("Cannot set both `skip_attention` and `skip_attention_scores` to True. Please choose one.")
|
|
if not math.isclose(config.dropout, 1.0) and config.skip_attention_scores:
|
|
raise ValueError("Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0.")
|
|
|
|
if config.fqn == "auto":
|
|
for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS:
|
|
if hasattr(module, identifier):
|
|
config.fqn = identifier
|
|
break
|
|
else:
|
|
raise ValueError(
|
|
"Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid "
|
|
"`fqn` (fully qualified name) that identifies a stack of transformer blocks."
|
|
)
|
|
|
|
transformer_blocks = _get_submodule_from_fqn(module, config.fqn)
|
|
if transformer_blocks is None or not isinstance(transformer_blocks, torch.nn.ModuleList):
|
|
raise ValueError(
|
|
f"Could not find {config.fqn} in the provided module, or configured `fqn` (fully qualified name) does not identify "
|
|
f"a `torch.nn.ModuleList`. Please provide a valid `fqn` that identifies a stack of transformer blocks."
|
|
)
|
|
if len(config.indices) == 0:
|
|
raise ValueError("Layer index list is empty. Please provide a non-empty list of layer indices to skip.")
|
|
|
|
blocks_found = False
|
|
for i, block in enumerate(transformer_blocks):
|
|
if i not in config.indices:
|
|
continue
|
|
|
|
blocks_found = True
|
|
|
|
if config.skip_attention and config.skip_ff:
|
|
logger.debug(f"Applying TransformerBlockSkipHook to '{config.fqn}.{i}'")
|
|
registry = HookRegistry.check_if_exists_or_initialize(block)
|
|
hook = TransformerBlockSkipHook(config.dropout)
|
|
registry.register_hook(hook, name)
|
|
|
|
elif config.skip_attention or config.skip_attention_scores:
|
|
for submodule_name, submodule in block.named_modules():
|
|
if isinstance(submodule, _ATTENTION_CLASSES) and not submodule.is_cross_attention:
|
|
logger.debug(f"Applying AttentionProcessorSkipHook to '{config.fqn}.{i}.{submodule_name}'")
|
|
output_fn = AttentionProcessorRegistry.get(submodule.processor.__class__).skip_processor_output_fn
|
|
registry = HookRegistry.check_if_exists_or_initialize(submodule)
|
|
hook = AttentionProcessorSkipHook(output_fn, config.skip_attention_scores, config.dropout)
|
|
registry.register_hook(hook, name)
|
|
|
|
if config.skip_ff:
|
|
for submodule_name, submodule in block.named_modules():
|
|
if isinstance(submodule, _FEEDFORWARD_CLASSES):
|
|
logger.debug(f"Applying FeedForwardSkipHook to '{config.fqn}.{i}.{submodule_name}'")
|
|
registry = HookRegistry.check_if_exists_or_initialize(submodule)
|
|
hook = FeedForwardSkipHook(config.dropout)
|
|
registry.register_hook(hook, name)
|
|
|
|
if not blocks_found:
|
|
raise ValueError(
|
|
f"Could not find any transformer blocks matching the provided indices {config.indices} and "
|
|
f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness."
|
|
)
|