mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
skip layer guidance
This commit is contained in:
@@ -130,15 +130,17 @@ except OptionalDependencyNotAvailable:
|
||||
_import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
|
||||
|
||||
else:
|
||||
_import_structure["guiders"].extend(["ClassifierFreeGuidance"])
|
||||
_import_structure["guiders"].extend(["ClassifierFreeGuidance", "SkipLayerGuidance"])
|
||||
_import_structure["hooks"].extend(
|
||||
[
|
||||
"FasterCacheConfig",
|
||||
"FirstBlockCacheConfig",
|
||||
"HookRegistry",
|
||||
"LayerSkipConfig",
|
||||
"PyramidAttentionBroadcastConfig",
|
||||
"apply_faster_cache",
|
||||
"apply_first_block_cache",
|
||||
"apply_layer_skip",
|
||||
"apply_pyramid_attention_broadcast",
|
||||
]
|
||||
)
|
||||
@@ -712,14 +714,16 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_pt_objects import * # noqa F403
|
||||
else:
|
||||
from .guiders import ClassifierFreeGuidance
|
||||
from .guiders import ClassifierFreeGuidance, SkipLayerGuidance
|
||||
from .hooks import (
|
||||
FasterCacheConfig,
|
||||
FirstBlockCacheConfig,
|
||||
HookRegistry,
|
||||
LayerSkipConfig,
|
||||
PyramidAttentionBroadcastConfig,
|
||||
apply_faster_cache,
|
||||
apply_first_block_cache,
|
||||
apply_layer_skip,
|
||||
apply_pyramid_attention_broadcast,
|
||||
)
|
||||
from .models import (
|
||||
|
||||
@@ -18,3 +18,4 @@ from ..utils import is_torch_available
|
||||
if is_torch_available():
|
||||
from .classifier_free_guidance import ClassifierFreeGuidance
|
||||
from .guider_utils import GuidanceMixin, _raise_guidance_deprecation_warning
|
||||
from .skip_layer_guidance import SkipLayerGuidance
|
||||
|
||||
@@ -43,11 +43,11 @@ class ClassifierFreeGuidance(GuidanceMixin):
|
||||
paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.
|
||||
|
||||
Args:
|
||||
scale (`float`, defaults to `7.5`):
|
||||
guidance_scale (`float`, defaults to `7.5`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
rescale (`float`, defaults to `0.0`):
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
@@ -56,31 +56,26 @@ class ClassifierFreeGuidance(GuidanceMixin):
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time.
|
||||
"""
|
||||
|
||||
def __init__(self, scale: float = 7.5, rescale: float = 0.0, use_original_formulation: bool = False):
|
||||
self.scale = scale
|
||||
self.rescale = rescale
|
||||
def __init__(
|
||||
self, guidance_scale: float = 7.5, guidance_rescale: float = 0.0, use_original_formulation: bool = False
|
||||
):
|
||||
self.guidance_scale = guidance_scale
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
if math.isclose(self.scale, 1.0):
|
||||
if math.isclose(self.guidance_scale, 1.0):
|
||||
return pred_cond
|
||||
shift = pred_cond - pred_uncond
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.scale * shift
|
||||
if self.rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.rescale)
|
||||
pred = pred + self.guidance_scale * shift
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
return pred
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
if math.isclose(self.scale, 1.0):
|
||||
return 1
|
||||
return 2
|
||||
|
||||
@property
|
||||
def guidance_scale(self) -> float:
|
||||
return self.scale
|
||||
|
||||
@property
|
||||
def guidance_rescale(self) -> float:
|
||||
return self.rescale
|
||||
num_conditions = 1
|
||||
if not math.isclose(self.guidance_scale, 1.0):
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
@@ -25,6 +25,19 @@ logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
class GuidanceMixin:
|
||||
r"""Base mixin class providing the skeleton for implementing guidance techniques."""
|
||||
|
||||
def __init__(self):
|
||||
self._step: int = None
|
||||
self._num_inference_steps: int = None
|
||||
self._timestep: torch.LongTensor = None
|
||||
|
||||
def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None:
|
||||
self._step = step
|
||||
self._num_inference_steps = num_inference_steps
|
||||
self._timestep = timestep
|
||||
|
||||
def prepare_models(self, denoiser: torch.nn.Module) -> None:
|
||||
pass
|
||||
|
||||
def prepare_inputs(self, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
|
||||
num_conditions = self.num_conditions
|
||||
list_of_inputs = []
|
||||
@@ -32,16 +45,27 @@ class GuidanceMixin:
|
||||
if isinstance(arg, torch.Tensor):
|
||||
list_of_inputs.append([arg] * num_conditions)
|
||||
elif isinstance(arg, (tuple, list)):
|
||||
inputs = [x for x in arg if x is not None]
|
||||
if len(inputs) < num_conditions:
|
||||
raise ValueError(f"Required at least {num_conditions} inputs, but got {len(inputs)}.")
|
||||
list_of_inputs.append(inputs[:num_conditions])
|
||||
if len(arg) != 2:
|
||||
raise ValueError(
|
||||
f"Expected a tuple or list of length 2, but got {len(arg)} for argument {arg}. Please provide a tuple/list of length 2 "
|
||||
f"with the first element being the conditional input and the second element being the unconditional input or None."
|
||||
)
|
||||
if arg[1] is None:
|
||||
# Only conditioning inputs for all batches
|
||||
list_of_inputs.append([arg[0]] * num_conditions)
|
||||
else:
|
||||
# Alternating conditional and unconditional inputs as batches
|
||||
inputs = [arg[i % 2] for i in range(num_conditions)]
|
||||
list_of_inputs.append(inputs)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Expected a tensor, tuple, or list, but got {type(arg)} for argument {arg}. Please provide a tensor, tuple, or list."
|
||||
)
|
||||
return tuple(list_of_inputs)
|
||||
|
||||
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
|
||||
pass
|
||||
|
||||
def __call__(self, *args) -> Any:
|
||||
if len(args) != self.num_conditions:
|
||||
raise ValueError(
|
||||
|
||||
195
src/diffusers/guiders/skip_layer_guidance.py
Normal file
195
src/diffusers/guiders/skip_layer_guidance.py
Normal file
@@ -0,0 +1,195 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..hooks import HookRegistry, LayerSkipConfig
|
||||
from ..hooks.layer_skip import _apply_layer_skip_hook
|
||||
from .guider_utils import GuidanceMixin, rescale_noise_cfg
|
||||
|
||||
|
||||
class SkipLayerGuidance(GuidanceMixin):
|
||||
"""
|
||||
Skip Layer Guidance (SLG): https://github.com/Stability-AI/sd3.5
|
||||
|
||||
CFG is a technique used to improve generation quality and condition-following in diffusion models. It works by
|
||||
jointly training a model on both conditional and unconditional data, and using a weighted sum of the two during
|
||||
inference. This allows the model to tradeoff between generation quality and sample diversity.
|
||||
|
||||
The original paper proposes scaling and shifting the conditional distribution based on the difference between
|
||||
conditional and unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)]
|
||||
|
||||
Diffusers implemented the scaling and shifting on the unconditional prediction instead, which is equivalent to what
|
||||
the original paper proposed in theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)]
|
||||
|
||||
The intution behind the original formulation can be thought of as moving the conditional distribution estimates
|
||||
further away from the unconditional distribution estimates, while the diffusers-native implementation can be
|
||||
thought of as moving the unconditional distribution towards the conditional distribution estimates to get rid of
|
||||
the unconditional predictions (usually negative features like "bad quality, bad anotomy, watermarks", etc.)
|
||||
|
||||
The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the
|
||||
paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `7.5`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 7.5,
|
||||
skip_layer_guidance_scale: float = 2.8,
|
||||
skip_layer_guidance_start: float = 0.01,
|
||||
skip_layer_guidance_stop: float = 0.2,
|
||||
skip_guidance_layers: Optional[Union[int, List[int]]] = None,
|
||||
skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig]] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
):
|
||||
self.guidance_scale = guidance_scale
|
||||
self.skip_layer_guidance_scale = skip_layer_guidance_scale
|
||||
self.skip_layer_guidance_start = skip_layer_guidance_start
|
||||
self.skip_layer_guidance_stop = skip_layer_guidance_stop
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
if skip_guidance_layers is None and skip_layer_config is None:
|
||||
raise ValueError(
|
||||
"Either `skip_guidance_layers` or `skip_layer_config` must be provided to enable Skip Layer Guidance."
|
||||
)
|
||||
if skip_guidance_layers is not None and skip_layer_config is not None:
|
||||
raise ValueError("Only one of `skip_guidance_layers` or `skip_layer_config` can be provided.")
|
||||
|
||||
if skip_guidance_layers is not None:
|
||||
if isinstance(skip_guidance_layers, int):
|
||||
skip_guidance_layers = [skip_guidance_layers]
|
||||
if not isinstance(skip_guidance_layers, list):
|
||||
raise ValueError(
|
||||
f"Expected `skip_guidance_layers` to be an int or a list of ints, but got {type(skip_guidance_layers)}."
|
||||
)
|
||||
skip_layer_config = [LayerSkipConfig(layer, fqn="auto") for layer in skip_guidance_layers]
|
||||
|
||||
if isinstance(skip_layer_config, LayerSkipConfig):
|
||||
skip_layer_config = [skip_layer_config]
|
||||
|
||||
if not isinstance(skip_layer_config, list):
|
||||
raise ValueError(
|
||||
f"Expected `skip_layer_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(skip_layer_config)}."
|
||||
)
|
||||
|
||||
self.skip_layer_config = skip_layer_config
|
||||
self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))]
|
||||
|
||||
def prepare_models(self, denoiser: torch.nn.Module):
|
||||
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
|
||||
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
|
||||
|
||||
# Register the hooks for layer skipping if the step is within the specified range
|
||||
if skip_start_step < self._step < skip_stop_step:
|
||||
for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config):
|
||||
_apply_layer_skip_hook(denoiser, config, name=name)
|
||||
|
||||
def prepare_inputs(self, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
|
||||
num_conditions = self.num_conditions
|
||||
list_of_inputs = []
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.Tensor):
|
||||
list_of_inputs.append([arg] * num_conditions)
|
||||
elif isinstance(arg, (tuple, list)):
|
||||
if len(arg) != 2:
|
||||
raise ValueError(
|
||||
f"Expected a tuple or list of length 2, but got {len(arg)} for argument {arg}. Please provide a tuple/list of length 2 "
|
||||
f"with the first element being the conditional input and the second element being the unconditional input or None."
|
||||
)
|
||||
if arg[1] is None:
|
||||
# Only conditioning inputs for all batches
|
||||
list_of_inputs.append([arg[0]] * num_conditions)
|
||||
else:
|
||||
list_of_inputs.append([arg[0], arg[1], arg[0]])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Expected a tensor, tuple, or list, but got {type(arg)} for argument {arg}. Please provide a tensor, tuple, or list."
|
||||
)
|
||||
return tuple(list_of_inputs)
|
||||
|
||||
def cleanup_models(self, denoiser: torch.nn.Module):
|
||||
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
|
||||
# Remove the hooks after inference
|
||||
for hook_name in self._skip_layer_hook_names:
|
||||
registry.remove_hook(hook_name, recurse=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred_cond: torch.Tensor,
|
||||
pred_uncond: Optional[torch.Tensor] = None,
|
||||
pred_cond_skip: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
pred = None
|
||||
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
|
||||
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
|
||||
|
||||
if math.isclose(self.guidance_scale, 1.0) and math.isclose(self.skip_layer_guidance_scale, 1.0):
|
||||
pred = pred_cond
|
||||
|
||||
elif math.isclose(self.guidance_scale, 1.0):
|
||||
if skip_start_step < self._step < skip_stop_step:
|
||||
shift = pred_cond - pred_cond_skip
|
||||
pred = pred_cond if self.use_original_formulation else pred_cond_skip
|
||||
pred = pred + self.skip_layer_guidance_scale * shift
|
||||
else:
|
||||
pred = pred_cond
|
||||
|
||||
elif math.isclose(self.skip_layer_guidance_scale, 1.0):
|
||||
shift = pred_cond - pred_uncond
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift
|
||||
|
||||
else:
|
||||
shift = pred_cond - pred_uncond
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift
|
||||
|
||||
if skip_start_step < self._step < skip_stop_step:
|
||||
shift_skip = pred_cond - pred_cond_skip
|
||||
pred = pred + self.skip_layer_guidance_scale * shift_skip
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return pred
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
|
||||
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
|
||||
|
||||
if not math.isclose(self.guidance_scale, 1.0):
|
||||
num_conditions += 1
|
||||
if not math.isclose(self.skip_layer_guidance_scale, 1.0) and skip_start_step < self._step < skip_stop_step:
|
||||
num_conditions += 1
|
||||
|
||||
return num_conditions
|
||||
@@ -20,5 +20,6 @@ if is_torch_available():
|
||||
from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache
|
||||
from .group_offloading import apply_group_offloading
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
from .layer_skip import LayerSkipConfig, apply_layer_skip
|
||||
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
|
||||
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
|
||||
|
||||
110
src/diffusers/hooks/layer_skip.py
Normal file
110
src/diffusers/hooks/layer_skip.py
Normal file
@@ -0,0 +1,110 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils import get_logger
|
||||
from ..utils.torch_utils import unwrap_module
|
||||
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
|
||||
from ._helpers import TransformerBlockRegistry
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
_LAYER_SKIP_HOOK = "layer_skip_hook"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LayerSkipConfig:
|
||||
r"""
|
||||
Configuration for skipping internal transformer blocks when executing a transformer model.
|
||||
|
||||
Args:
|
||||
indices (`List[int]`):
|
||||
The indices of the layer to skip. This is typically the first layer in the transformer block.
|
||||
fqn (`str`, defaults to `"auto"`):
|
||||
The fully qualified name identifying the stack of transformer blocks. Typically, this is
|
||||
`transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`.
|
||||
"""
|
||||
|
||||
indices: List[int]
|
||||
fqn: str = "auto"
|
||||
|
||||
|
||||
class LayerSkipHook(ModelHook):
|
||||
def initialize_hook(self, module):
|
||||
self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__)
|
||||
return module
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
return self._metadata.skip_block_output_fn(module, *args, **kwargs)
|
||||
|
||||
|
||||
def apply_layer_skip(module: torch.nn.Module, config: LayerSkipConfig) -> None:
|
||||
r"""
|
||||
Apply layer skipping to internal layers of a transformer.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The transformer model to which the layer skip hook should be applied.
|
||||
config (`LayerSkipConfig`):
|
||||
The configuration for the layer skip hook.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from diffusers import apply_layer_skip_hook, CogVideoXTransformer3DModel, LayerSkipConfig
|
||||
|
||||
>>> transformer = CogVideoXTransformer3DModel.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
|
||||
>>> config = LayerSkipConfig(layer_index=[10, 20], fqn="transformer_blocks")
|
||||
>>> apply_layer_skip_hook(transformer, config)
|
||||
```
|
||||
"""
|
||||
_apply_layer_skip_hook(module, config)
|
||||
|
||||
|
||||
def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, name: Optional[str] = None) -> None:
|
||||
name = name or _LAYER_SKIP_HOOK
|
||||
|
||||
if config.fqn == "auto":
|
||||
for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS:
|
||||
if hasattr(module, identifier):
|
||||
config.fqn = identifier
|
||||
break
|
||||
else:
|
||||
raise ValueError(
|
||||
"Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid "
|
||||
"`fqn` (fully qualified name) that identifies a stack of transformer blocks."
|
||||
)
|
||||
|
||||
transformer_blocks = getattr(module, config.fqn, None)
|
||||
if transformer_blocks is None or not isinstance(transformer_blocks, torch.nn.ModuleList):
|
||||
raise ValueError(
|
||||
f"Could not find {config.fqn} in the provided module, or configured `fqn` (fully qualified name) does not identify "
|
||||
f"a `torch.nn.ModuleList`. Please provide a valid `fqn` that identifies a stack of transformer blocks."
|
||||
)
|
||||
if len(config.indices) == 0:
|
||||
raise ValueError("Layer index list is empty. Please provide a non-empty list of layer indices to skip.")
|
||||
|
||||
for i, block in enumerate(transformer_blocks):
|
||||
if i not in config.indices:
|
||||
continue
|
||||
logger.debug(f"Apply LayerSkipHook to '{config.fqn}.{i}'")
|
||||
registry = HookRegistry.check_if_exists_or_initialize(block)
|
||||
hook = LayerSkipHook()
|
||||
registry.register_hook(hook, name)
|
||||
@@ -520,7 +520,7 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
|
||||
_raise_guidance_deprecation_warning(guidance_scale=guidance_scale)
|
||||
if guidance is None:
|
||||
guidance = ClassifierFreeGuidance(scale=guidance_scale)
|
||||
guidance = ClassifierFreeGuidance(guidance_scale=guidance_scale)
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
@@ -612,29 +612,34 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
latents, prompt_embeds, original_size, target_size, crops_coords_top_left = guidance.prepare_inputs(
|
||||
latents,
|
||||
(prompt_embeds, negative_prompt_embeds),
|
||||
original_size,
|
||||
target_size,
|
||||
crops_coords_top_left,
|
||||
)
|
||||
|
||||
# Denoising loop
|
||||
transformer_dtype = self.transformer.dtype
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
conds = [prompt_embeds, negative_prompt_embeds, original_size, target_size, crops_coords_top_left]
|
||||
prompt_embeds, negative_prompt_embeds, original_size, target_size, crops_coords_top_left = [[v] for v in conds]
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
|
||||
for i, t in enumerate(timesteps):
|
||||
self._current_timestep = t
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
guidance.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t)
|
||||
guidance.prepare_models(self.transformer)
|
||||
latents, prompt_embeds, original_size, target_size, crops_coords_top_left = guidance.prepare_inputs(
|
||||
latents,
|
||||
(prompt_embeds[0], negative_prompt_embeds[0]),
|
||||
original_size[0],
|
||||
target_size[0],
|
||||
crops_coords_top_left[0],
|
||||
)
|
||||
|
||||
noise_preds = []
|
||||
for i, (latent, condition, original_size_c, target_size_c, crop_coord_c) in enumerate(
|
||||
for batch_index, (latent, condition, original_size_c, target_size_c, crop_coord_c) in enumerate(
|
||||
zip(latents, prompt_embeds, original_size, target_size, crops_coords_top_left)
|
||||
):
|
||||
cc.mark_state(f"batch_{i}")
|
||||
cc.mark_state(f"batch_{batch_index}")
|
||||
latent = latent.to(transformer_dtype)
|
||||
timestep = t.expand(latent.shape[0])
|
||||
noise_pred = self.transformer(
|
||||
@@ -651,6 +656,7 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
|
||||
noise_pred = guidance(*noise_preds)
|
||||
latents = self.scheduler.step(noise_pred, t, latents[0], return_dict=False)[0]
|
||||
guidance.cleanup_models(self.transformer)
|
||||
|
||||
# call the callback, if provided
|
||||
if callback_on_step_end is not None:
|
||||
@@ -664,10 +670,6 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds[0])
|
||||
]
|
||||
|
||||
latents, prompt_embeds = guidance.prepare_inputs(
|
||||
latents, (prompt_embeds[0], negative_prompt_embeds[0])
|
||||
)
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
@@ -675,7 +677,6 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
latents = latents[0]
|
||||
|
||||
if not output_type == "latent":
|
||||
latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
|
||||
|
||||
@@ -17,6 +17,21 @@ class ClassifierFreeGuidance(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class SkipLayerGuidance(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class FasterCacheConfig(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -62,6 +77,21 @@ class HookRegistry(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class LayerSkipConfig(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class PyramidAttentionBroadcastConfig(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -85,6 +115,10 @@ def apply_first_block_cache(*args, **kwargs):
|
||||
requires_backends(apply_first_block_cache, ["torch"])
|
||||
|
||||
|
||||
def apply_layer_skip(*args, **kwargs):
|
||||
requires_backends(apply_layer_skip, ["torch"])
|
||||
|
||||
|
||||
def apply_pyramid_attention_broadcast(*args, **kwargs):
|
||||
requires_backends(apply_pyramid_attention_broadcast, ["torch"])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user