mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
fix slg bug
This commit is contained in:
@@ -86,7 +86,7 @@ class AdaptiveProjectedGuidance(BaseGuidance):
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
pred = None
|
||||
|
||||
if not self._is_cfg_enabled():
|
||||
if not self._is_apg_enabled():
|
||||
pred = pred_cond
|
||||
else:
|
||||
pred = normalized_guidance(
|
||||
@@ -111,11 +111,11 @@ class AdaptiveProjectedGuidance(BaseGuidance):
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_cfg_enabled():
|
||||
if self._is_apg_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_cfg_enabled(self) -> bool:
|
||||
def _is_apg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
|
||||
@@ -68,6 +68,13 @@ class BaseGuidance:
|
||||
self._preds = {}
|
||||
self._num_outputs_prepared = 0
|
||||
|
||||
def prepare_models(self, denoiser: torch.nn.Module) -> None:
|
||||
"""
|
||||
Prepares the models for the guidance technique on a given batch of data. This method should be overridden in
|
||||
subclasses to implement specific model preparation logic.
|
||||
"""
|
||||
pass
|
||||
|
||||
def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
|
||||
raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.")
|
||||
|
||||
|
||||
@@ -54,6 +54,10 @@ class SkipLayerGuidance(BaseGuidance):
|
||||
skip_layer_guidance_scale (`float`, defaults to `2.8`):
|
||||
The scale parameter for skip layer guidance. Anatomy and structure coherence may improve with higher
|
||||
values, but it may also lead to overexposure and saturation.
|
||||
skip_layer_guidance_start (`float`, defaults to `0.01`):
|
||||
The fraction of the total number of denoising steps after which skip layer guidance starts.
|
||||
skip_layer_guidance_stop (`float`, defaults to `0.2`):
|
||||
The fraction of the total number of denoising steps after which skip layer guidance stops.
|
||||
skip_layer_guidance_layers (`int` or `List[int]`, *optional*):
|
||||
The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not
|
||||
provided, `skip_layer_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion
|
||||
@@ -81,20 +85,33 @@ class SkipLayerGuidance(BaseGuidance):
|
||||
self,
|
||||
guidance_scale: float = 7.5,
|
||||
skip_layer_guidance_scale: float = 2.8,
|
||||
skip_layer_guidance_start: float = 0.01,
|
||||
skip_layer_guidance_stop: float = 0.2,
|
||||
skip_layer_guidance_layers: Optional[Union[int, List[int]]] = None,
|
||||
skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig]] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.01,
|
||||
stop: float = 0.2,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.skip_layer_guidance_scale = skip_layer_guidance_scale
|
||||
self.skip_layer_guidance_start = skip_layer_guidance_start
|
||||
self.skip_layer_guidance_stop = skip_layer_guidance_stop
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
if not (0.0 <= skip_layer_guidance_start < 1.0):
|
||||
raise ValueError(
|
||||
f"Expected `skip_layer_guidance_start` to be between 0.0 and 1.0, but got {skip_layer_guidance_start}."
|
||||
)
|
||||
if not (skip_layer_guidance_start <= skip_layer_guidance_stop <= 1.0):
|
||||
raise ValueError(
|
||||
f"Expected `skip_layer_guidance_stop` to be between 0.0 and 1.0, but got {skip_layer_guidance_stop}."
|
||||
)
|
||||
|
||||
if skip_layer_guidance_layers is None and skip_layer_config is None:
|
||||
raise ValueError(
|
||||
"Either `skip_layer_guidance_layers` or `skip_layer_config` must be provided to enable Skip Layer Guidance."
|
||||
@@ -122,11 +139,12 @@ class SkipLayerGuidance(BaseGuidance):
|
||||
self.skip_layer_config = skip_layer_config
|
||||
self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))]
|
||||
|
||||
def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
|
||||
if self._num_outputs_prepared == 0 and self._is_slg_enabled():
|
||||
def prepare_models(self, denoiser: torch.nn.Module) -> None:
|
||||
if self._is_slg_enabled() and self.is_conditional and self._num_outputs_prepared > 0:
|
||||
for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config):
|
||||
_apply_layer_skip_hook(denoiser, config, name=name)
|
||||
|
||||
|
||||
def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
|
||||
num_conditions = self.num_conditions
|
||||
list_of_inputs = []
|
||||
for arg in args:
|
||||
@@ -161,7 +179,8 @@ class SkipLayerGuidance(BaseGuidance):
|
||||
key = "pred_cond_skip"
|
||||
self._preds[key] = pred
|
||||
|
||||
if self._num_outputs_prepared == self.num_conditions:
|
||||
if key == "pred_cond_skip":
|
||||
# If we are in SLG mode, we need to remove the hooks after inference
|
||||
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
|
||||
# Remove the hooks after inference
|
||||
for hook_name in self._skip_layer_hook_names:
|
||||
@@ -233,8 +252,8 @@ class SkipLayerGuidance(BaseGuidance):
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
|
||||
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step < self._step < skip_stop_step
|
||||
|
||||
is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0)
|
||||
|
||||
@@ -80,14 +80,17 @@ class AttentionProcessorSkipHook(ModelHook):
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
if self.skip_attention_scores:
|
||||
print("Skipping attention scores")
|
||||
with AttentionScoreSkipFunctionMode():
|
||||
return self.fn_ref.original_forward(*args, **kwargs)
|
||||
else:
|
||||
print("Skipping attention processor output")
|
||||
return self.skip_processor_output_fn(module, *args, **kwargs)
|
||||
|
||||
|
||||
class FeedForwardSkipHook(ModelHook):
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
print("Skipping feed-forward block")
|
||||
output = kwargs.get("hidden_states", None)
|
||||
if output is None:
|
||||
output = kwargs.get("x", None)
|
||||
@@ -102,18 +105,22 @@ class TransformerBlockSkipHook(ModelHook):
|
||||
return module
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
print("Skipping transformer block")
|
||||
return self._metadata.skip_block_output_fn(module, *args, **kwargs)
|
||||
|
||||
|
||||
def apply_layer_skip(module: torch.nn.Module, config: LayerSkipConfig) -> None:
|
||||
r"""
|
||||
Apply layer skipping to internal layers of a transformer.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The transformer model to which the layer skip hook should be applied.
|
||||
config (`LayerSkipConfig`):
|
||||
The configuration for the layer skip hook.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from diffusers import apply_layer_skip_hook, CogVideoXTransformer3DModel, LayerSkipConfig
|
||||
>>> transformer = CogVideoXTransformer3DModel.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
|
||||
@@ -168,17 +175,13 @@ def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, nam
|
||||
registry = HookRegistry.check_if_exists_or_initialize(submodule)
|
||||
hook = AttentionProcessorSkipHook(output_fn, config.skip_attention_scores)
|
||||
registry.register_hook(hook, name)
|
||||
elif config.skip_ff:
|
||||
if config.skip_ff:
|
||||
for submodule_name, submodule in block.named_modules():
|
||||
if isinstance(submodule, _FEEDFORWARD_CLASSES):
|
||||
logger.debug(f"Applying FeedForwardSkipHook to '{config.fqn}.{i}.{submodule_name}'")
|
||||
registry = HookRegistry.check_if_exists_or_initialize(submodule)
|
||||
hook = FeedForwardSkipHook()
|
||||
registry.register_hook(hook, name)
|
||||
else:
|
||||
raise ValueError(
|
||||
"At least one of `skip_attention`, `skip_attention_scores`, or `skip_ff` must be set to True."
|
||||
)
|
||||
|
||||
if not blocks_found:
|
||||
raise ValueError(
|
||||
|
||||
@@ -2267,6 +2267,7 @@ class StableDiffusionXLDenoiseStep(PipelineBlock):
|
||||
) in enumerate(zip(
|
||||
latents, prompt_embeds, add_time_ids, pooled_prompt_embeds, mask, masked_image_latents, ip_adapter_embeds
|
||||
)):
|
||||
pipeline.guider.prepare_models(pipeline.unet)
|
||||
latents_i = pipeline.scheduler.scale_model_input(latents_i, t)
|
||||
|
||||
# Prepare for inpainting
|
||||
@@ -2670,6 +2671,7 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
|
||||
) in enumerate(zip(
|
||||
latents, prompt_embeds, add_time_ids, pooled_prompt_embeds, mask, masked_image_latents, ip_adapter_embeds
|
||||
)):
|
||||
pipeline.guider.prepare_models(pipeline.unet)
|
||||
latents_i = pipeline.scheduler.scale_model_input(latents_i, t)
|
||||
|
||||
# Prepare for inpainting
|
||||
@@ -3085,6 +3087,7 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock):
|
||||
) in enumerate(zip(
|
||||
latents, prompt_embeds, add_time_ids, pooled_prompt_embeds, mask, masked_image_latents, ip_adapter_embeds
|
||||
)):
|
||||
pipeline.guider.prepare_models(pipeline.unet)
|
||||
latents_i = pipeline.scheduler.scale_model_input(latents_i, t)
|
||||
|
||||
# Prepare for inpainting
|
||||
|
||||
Reference in New Issue
Block a user