diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index cb007b1c1d..91c41bdd43 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -130,15 +130,17 @@ except OptionalDependencyNotAvailable: _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")] else: - _import_structure["guiders"].extend(["ClassifierFreeGuidance"]) + _import_structure["guiders"].extend(["ClassifierFreeGuidance", "SkipLayerGuidance"]) _import_structure["hooks"].extend( [ "FasterCacheConfig", "FirstBlockCacheConfig", "HookRegistry", + "LayerSkipConfig", "PyramidAttentionBroadcastConfig", "apply_faster_cache", "apply_first_block_cache", + "apply_layer_skip", "apply_pyramid_attention_broadcast", ] ) @@ -712,14 +714,16 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: except OptionalDependencyNotAvailable: from .utils.dummy_pt_objects import * # noqa F403 else: - from .guiders import ClassifierFreeGuidance + from .guiders import ClassifierFreeGuidance, SkipLayerGuidance from .hooks import ( FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, + LayerSkipConfig, PyramidAttentionBroadcastConfig, apply_faster_cache, apply_first_block_cache, + apply_layer_skip, apply_pyramid_attention_broadcast, ) from .models import ( diff --git a/src/diffusers/guiders/__init__.py b/src/diffusers/guiders/__init__.py index c56f825512..9724d30756 100644 --- a/src/diffusers/guiders/__init__.py +++ b/src/diffusers/guiders/__init__.py @@ -18,3 +18,4 @@ from ..utils import is_torch_available if is_torch_available(): from .classifier_free_guidance import ClassifierFreeGuidance from .guider_utils import GuidanceMixin, _raise_guidance_deprecation_warning + from .skip_layer_guidance import SkipLayerGuidance diff --git a/src/diffusers/guiders/classifier_free_guidance.py b/src/diffusers/guiders/classifier_free_guidance.py index 2de97291c6..18f2a2d31b 100644 --- a/src/diffusers/guiders/classifier_free_guidance.py +++ b/src/diffusers/guiders/classifier_free_guidance.py @@ -43,11 +43,11 @@ class ClassifierFreeGuidance(GuidanceMixin): paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time. Args: - scale (`float`, defaults to `7.5`): + guidance_scale (`float`, defaults to `7.5`): The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and deterioration of image quality. - rescale (`float`, defaults to `0.0`): + guidance_rescale (`float`, defaults to `0.0`): The rescale factor applied to the noise predictions. This is used to improve image quality and fix overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). @@ -56,31 +56,26 @@ class ClassifierFreeGuidance(GuidanceMixin): we use the diffusers-native implementation that has been in the codebase for a long time. """ - def __init__(self, scale: float = 7.5, rescale: float = 0.0, use_original_formulation: bool = False): - self.scale = scale - self.rescale = rescale + def __init__( + self, guidance_scale: float = 7.5, guidance_rescale: float = 0.0, use_original_formulation: bool = False + ): + self.guidance_scale = guidance_scale + self.guidance_rescale = guidance_rescale self.use_original_formulation = use_original_formulation def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: - if math.isclose(self.scale, 1.0): + if math.isclose(self.guidance_scale, 1.0): return pred_cond shift = pred_cond - pred_uncond pred = pred_cond if self.use_original_formulation else pred_uncond - pred = pred + self.scale * shift - if self.rescale > 0.0: - pred = rescale_noise_cfg(pred, pred_cond, self.rescale) + pred = pred + self.guidance_scale * shift + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) return pred @property def num_conditions(self) -> int: - if math.isclose(self.scale, 1.0): - return 1 - return 2 - - @property - def guidance_scale(self) -> float: - return self.scale - - @property - def guidance_rescale(self) -> float: - return self.rescale + num_conditions = 1 + if not math.isclose(self.guidance_scale, 1.0): + num_conditions += 1 + return num_conditions diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py index 54d3c51955..413a33c41c 100644 --- a/src/diffusers/guiders/guider_utils.py +++ b/src/diffusers/guiders/guider_utils.py @@ -25,6 +25,19 @@ logger = get_logger(__name__) # pylint: disable=invalid-name class GuidanceMixin: r"""Base mixin class providing the skeleton for implementing guidance techniques.""" + def __init__(self): + self._step: int = None + self._num_inference_steps: int = None + self._timestep: torch.LongTensor = None + + def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None: + self._step = step + self._num_inference_steps = num_inference_steps + self._timestep = timestep + + def prepare_models(self, denoiser: torch.nn.Module) -> None: + pass + def prepare_inputs(self, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: num_conditions = self.num_conditions list_of_inputs = [] @@ -32,16 +45,27 @@ class GuidanceMixin: if isinstance(arg, torch.Tensor): list_of_inputs.append([arg] * num_conditions) elif isinstance(arg, (tuple, list)): - inputs = [x for x in arg if x is not None] - if len(inputs) < num_conditions: - raise ValueError(f"Required at least {num_conditions} inputs, but got {len(inputs)}.") - list_of_inputs.append(inputs[:num_conditions]) + if len(arg) != 2: + raise ValueError( + f"Expected a tuple or list of length 2, but got {len(arg)} for argument {arg}. Please provide a tuple/list of length 2 " + f"with the first element being the conditional input and the second element being the unconditional input or None." + ) + if arg[1] is None: + # Only conditioning inputs for all batches + list_of_inputs.append([arg[0]] * num_conditions) + else: + # Alternating conditional and unconditional inputs as batches + inputs = [arg[i % 2] for i in range(num_conditions)] + list_of_inputs.append(inputs) else: raise ValueError( f"Expected a tensor, tuple, or list, but got {type(arg)} for argument {arg}. Please provide a tensor, tuple, or list." ) return tuple(list_of_inputs) + def cleanup_models(self, denoiser: torch.nn.Module) -> None: + pass + def __call__(self, *args) -> Any: if len(args) != self.num_conditions: raise ValueError( diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py new file mode 100644 index 0000000000..677d97a47c --- /dev/null +++ b/src/diffusers/guiders/skip_layer_guidance.py @@ -0,0 +1,195 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple, Union + +import torch + +from ..hooks import HookRegistry, LayerSkipConfig +from ..hooks.layer_skip import _apply_layer_skip_hook +from .guider_utils import GuidanceMixin, rescale_noise_cfg + + +class SkipLayerGuidance(GuidanceMixin): + """ + Skip Layer Guidance (SLG): https://github.com/Stability-AI/sd3.5 + + CFG is a technique used to improve generation quality and condition-following in diffusion models. It works by + jointly training a model on both conditional and unconditional data, and using a weighted sum of the two during + inference. This allows the model to tradeoff between generation quality and sample diversity. + + The original paper proposes scaling and shifting the conditional distribution based on the difference between + conditional and unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)] + + Diffusers implemented the scaling and shifting on the unconditional prediction instead, which is equivalent to what + the original paper proposed in theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)] + + The intution behind the original formulation can be thought of as moving the conditional distribution estimates + further away from the unconditional distribution estimates, while the diffusers-native implementation can be + thought of as moving the unconditional distribution towards the conditional distribution estimates to get rid of + the unconditional predictions (usually negative features like "bad quality, bad anotomy, watermarks", etc.) + + The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the + paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time. + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. + """ + + def __init__( + self, + guidance_scale: float = 7.5, + skip_layer_guidance_scale: float = 2.8, + skip_layer_guidance_start: float = 0.01, + skip_layer_guidance_stop: float = 0.2, + skip_guidance_layers: Optional[Union[int, List[int]]] = None, + skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig]] = None, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + ): + self.guidance_scale = guidance_scale + self.skip_layer_guidance_scale = skip_layer_guidance_scale + self.skip_layer_guidance_start = skip_layer_guidance_start + self.skip_layer_guidance_stop = skip_layer_guidance_stop + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + if skip_guidance_layers is None and skip_layer_config is None: + raise ValueError( + "Either `skip_guidance_layers` or `skip_layer_config` must be provided to enable Skip Layer Guidance." + ) + if skip_guidance_layers is not None and skip_layer_config is not None: + raise ValueError("Only one of `skip_guidance_layers` or `skip_layer_config` can be provided.") + + if skip_guidance_layers is not None: + if isinstance(skip_guidance_layers, int): + skip_guidance_layers = [skip_guidance_layers] + if not isinstance(skip_guidance_layers, list): + raise ValueError( + f"Expected `skip_guidance_layers` to be an int or a list of ints, but got {type(skip_guidance_layers)}." + ) + skip_layer_config = [LayerSkipConfig(layer, fqn="auto") for layer in skip_guidance_layers] + + if isinstance(skip_layer_config, LayerSkipConfig): + skip_layer_config = [skip_layer_config] + + if not isinstance(skip_layer_config, list): + raise ValueError( + f"Expected `skip_layer_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(skip_layer_config)}." + ) + + self.skip_layer_config = skip_layer_config + self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))] + + def prepare_models(self, denoiser: torch.nn.Module): + skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps) + skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps) + + # Register the hooks for layer skipping if the step is within the specified range + if skip_start_step < self._step < skip_stop_step: + for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config): + _apply_layer_skip_hook(denoiser, config, name=name) + + def prepare_inputs(self, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: + num_conditions = self.num_conditions + list_of_inputs = [] + for arg in args: + if isinstance(arg, torch.Tensor): + list_of_inputs.append([arg] * num_conditions) + elif isinstance(arg, (tuple, list)): + if len(arg) != 2: + raise ValueError( + f"Expected a tuple or list of length 2, but got {len(arg)} for argument {arg}. Please provide a tuple/list of length 2 " + f"with the first element being the conditional input and the second element being the unconditional input or None." + ) + if arg[1] is None: + # Only conditioning inputs for all batches + list_of_inputs.append([arg[0]] * num_conditions) + else: + list_of_inputs.append([arg[0], arg[1], arg[0]]) + else: + raise ValueError( + f"Expected a tensor, tuple, or list, but got {type(arg)} for argument {arg}. Please provide a tensor, tuple, or list." + ) + return tuple(list_of_inputs) + + def cleanup_models(self, denoiser: torch.nn.Module): + registry = HookRegistry.check_if_exists_or_initialize(denoiser) + # Remove the hooks after inference + for hook_name in self._skip_layer_hook_names: + registry.remove_hook(hook_name, recurse=True) + + def forward( + self, + pred_cond: torch.Tensor, + pred_uncond: Optional[torch.Tensor] = None, + pred_cond_skip: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + pred = None + skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps) + skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps) + + if math.isclose(self.guidance_scale, 1.0) and math.isclose(self.skip_layer_guidance_scale, 1.0): + pred = pred_cond + + elif math.isclose(self.guidance_scale, 1.0): + if skip_start_step < self._step < skip_stop_step: + shift = pred_cond - pred_cond_skip + pred = pred_cond if self.use_original_formulation else pred_cond_skip + pred = pred + self.skip_layer_guidance_scale * shift + else: + pred = pred_cond + + elif math.isclose(self.skip_layer_guidance_scale, 1.0): + shift = pred_cond - pred_uncond + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + + else: + shift = pred_cond - pred_uncond + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + + if skip_start_step < self._step < skip_stop_step: + shift_skip = pred_cond - pred_cond_skip + pred = pred + self.skip_layer_guidance_scale * shift_skip + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred + + @property + def num_conditions(self) -> int: + num_conditions = 1 + skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps) + skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps) + + if not math.isclose(self.guidance_scale, 1.0): + num_conditions += 1 + if not math.isclose(self.skip_layer_guidance_scale, 1.0) and skip_start_step < self._step < skip_stop_step: + num_conditions += 1 + + return num_conditions diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index 365bed3718..2db36d4366 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -20,5 +20,6 @@ if is_torch_available(): from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache from .group_offloading import apply_group_offloading from .hooks import HookRegistry, ModelHook + from .layer_skip import LayerSkipConfig, apply_layer_skip from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast diff --git a/src/diffusers/hooks/layer_skip.py b/src/diffusers/hooks/layer_skip.py new file mode 100644 index 0000000000..45f9365bcd --- /dev/null +++ b/src/diffusers/hooks/layer_skip.py @@ -0,0 +1,110 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import List, Optional + +import torch + +from ..utils import get_logger +from ..utils.torch_utils import unwrap_module +from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS +from ._helpers import TransformerBlockRegistry +from .hooks import HookRegistry, ModelHook + + +logger = get_logger(__name__) # pylint: disable=invalid-name + +_LAYER_SKIP_HOOK = "layer_skip_hook" + + +@dataclass +class LayerSkipConfig: + r""" + Configuration for skipping internal transformer blocks when executing a transformer model. + + Args: + indices (`List[int]`): + The indices of the layer to skip. This is typically the first layer in the transformer block. + fqn (`str`, defaults to `"auto"`): + The fully qualified name identifying the stack of transformer blocks. Typically, this is + `transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`. + """ + + indices: List[int] + fqn: str = "auto" + + +class LayerSkipHook(ModelHook): + def initialize_hook(self, module): + self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__) + return module + + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + return self._metadata.skip_block_output_fn(module, *args, **kwargs) + + +def apply_layer_skip(module: torch.nn.Module, config: LayerSkipConfig) -> None: + r""" + Apply layer skipping to internal layers of a transformer. + + Args: + module (`torch.nn.Module`): + The transformer model to which the layer skip hook should be applied. + config (`LayerSkipConfig`): + The configuration for the layer skip hook. + + Example: + + ```python + >>> from diffusers import apply_layer_skip_hook, CogVideoXTransformer3DModel, LayerSkipConfig + + >>> transformer = CogVideoXTransformer3DModel.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) + >>> config = LayerSkipConfig(layer_index=[10, 20], fqn="transformer_blocks") + >>> apply_layer_skip_hook(transformer, config) + ``` + """ + _apply_layer_skip_hook(module, config) + + +def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, name: Optional[str] = None) -> None: + name = name or _LAYER_SKIP_HOOK + + if config.fqn == "auto": + for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS: + if hasattr(module, identifier): + config.fqn = identifier + break + else: + raise ValueError( + "Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid " + "`fqn` (fully qualified name) that identifies a stack of transformer blocks." + ) + + transformer_blocks = getattr(module, config.fqn, None) + if transformer_blocks is None or not isinstance(transformer_blocks, torch.nn.ModuleList): + raise ValueError( + f"Could not find {config.fqn} in the provided module, or configured `fqn` (fully qualified name) does not identify " + f"a `torch.nn.ModuleList`. Please provide a valid `fqn` that identifies a stack of transformer blocks." + ) + if len(config.indices) == 0: + raise ValueError("Layer index list is empty. Please provide a non-empty list of layer indices to skip.") + + for i, block in enumerate(transformer_blocks): + if i not in config.indices: + continue + logger.debug(f"Apply LayerSkipHook to '{config.fqn}.{i}'") + registry = HookRegistry.check_if_exists_or_initialize(block) + hook = LayerSkipHook() + registry.register_hook(hook, name) diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py index 568cfd04b8..5d63b588a6 100644 --- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py +++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py @@ -520,7 +520,7 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin): _raise_guidance_deprecation_warning(guidance_scale=guidance_scale) if guidance is None: - guidance = ClassifierFreeGuidance(scale=guidance_scale) + guidance = ClassifierFreeGuidance(guidance_scale=guidance_scale) if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs @@ -612,29 +612,34 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin): ) self._num_timesteps = len(timesteps) - latents, prompt_embeds, original_size, target_size, crops_coords_top_left = guidance.prepare_inputs( - latents, - (prompt_embeds, negative_prompt_embeds), - original_size, - target_size, - crops_coords_top_left, - ) - # Denoising loop transformer_dtype = self.transformer.dtype num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + conds = [prompt_embeds, negative_prompt_embeds, original_size, target_size, crops_coords_top_left] + prompt_embeds, negative_prompt_embeds, original_size, target_size, crops_coords_top_left = [[v] for v in conds] + with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc: for i, t in enumerate(timesteps): self._current_timestep = t if self.interrupt: continue + guidance.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t) + guidance.prepare_models(self.transformer) + latents, prompt_embeds, original_size, target_size, crops_coords_top_left = guidance.prepare_inputs( + latents, + (prompt_embeds[0], negative_prompt_embeds[0]), + original_size[0], + target_size[0], + crops_coords_top_left[0], + ) + noise_preds = [] - for i, (latent, condition, original_size_c, target_size_c, crop_coord_c) in enumerate( + for batch_index, (latent, condition, original_size_c, target_size_c, crop_coord_c) in enumerate( zip(latents, prompt_embeds, original_size, target_size, crops_coords_top_left) ): - cc.mark_state(f"batch_{i}") + cc.mark_state(f"batch_{batch_index}") latent = latent.to(transformer_dtype) timestep = t.expand(latent.shape[0]) noise_pred = self.transformer( @@ -651,6 +656,7 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin): noise_pred = guidance(*noise_preds) latents = self.scheduler.step(noise_pred, t, latents[0], return_dict=False)[0] + guidance.cleanup_models(self.transformer) # call the callback, if provided if callback_on_step_end is not None: @@ -664,10 +670,6 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin): callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds[0]) ] - latents, prompt_embeds = guidance.prepare_inputs( - latents, (prompt_embeds[0], negative_prompt_embeds[0]) - ) - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() @@ -675,7 +677,6 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin): xm.mark_step() self._current_timestep = None - latents = latents[0] if not output_type == "latent": latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 7ae9ca4c67..3c0f45461b 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -17,6 +17,21 @@ class ClassifierFreeGuidance(metaclass=DummyObject): requires_backends(cls, ["torch"]) +class SkipLayerGuidance(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class FasterCacheConfig(metaclass=DummyObject): _backends = ["torch"] @@ -62,6 +77,21 @@ class HookRegistry(metaclass=DummyObject): requires_backends(cls, ["torch"]) +class LayerSkipConfig(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class PyramidAttentionBroadcastConfig(metaclass=DummyObject): _backends = ["torch"] @@ -85,6 +115,10 @@ def apply_first_block_cache(*args, **kwargs): requires_backends(apply_first_block_cache, ["torch"]) +def apply_layer_skip(*args, **kwargs): + requires_backends(apply_layer_skip, ["torch"]) + + def apply_pyramid_attention_broadcast(*args, **kwargs): requires_backends(apply_pyramid_attention_broadcast, ["torch"])