1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

fix slg bug

This commit is contained in:
Aryan
2025-04-15 08:26:03 +02:00
parent 57c7e15a91
commit 8d31c699a5
5 changed files with 48 additions and 16 deletions

View File

@@ -86,7 +86,7 @@ class AdaptiveProjectedGuidance(BaseGuidance):
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
pred = None
if not self._is_cfg_enabled():
if not self._is_apg_enabled():
pred = pred_cond
else:
pred = normalized_guidance(
@@ -111,11 +111,11 @@ class AdaptiveProjectedGuidance(BaseGuidance):
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_cfg_enabled():
if self._is_apg_enabled():
num_conditions += 1
return num_conditions
def _is_cfg_enabled(self) -> bool:
def _is_apg_enabled(self) -> bool:
if not self._enabled:
return False

View File

@@ -68,6 +68,13 @@ class BaseGuidance:
self._preds = {}
self._num_outputs_prepared = 0
def prepare_models(self, denoiser: torch.nn.Module) -> None:
"""
Prepares the models for the guidance technique on a given batch of data. This method should be overridden in
subclasses to implement specific model preparation logic.
"""
pass
def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.")

View File

@@ -54,6 +54,10 @@ class SkipLayerGuidance(BaseGuidance):
skip_layer_guidance_scale (`float`, defaults to `2.8`):
The scale parameter for skip layer guidance. Anatomy and structure coherence may improve with higher
values, but it may also lead to overexposure and saturation.
skip_layer_guidance_start (`float`, defaults to `0.01`):
The fraction of the total number of denoising steps after which skip layer guidance starts.
skip_layer_guidance_stop (`float`, defaults to `0.2`):
The fraction of the total number of denoising steps after which skip layer guidance stops.
skip_layer_guidance_layers (`int` or `List[int]`, *optional*):
The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not
provided, `skip_layer_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion
@@ -81,20 +85,33 @@ class SkipLayerGuidance(BaseGuidance):
self,
guidance_scale: float = 7.5,
skip_layer_guidance_scale: float = 2.8,
skip_layer_guidance_start: float = 0.01,
skip_layer_guidance_stop: float = 0.2,
skip_layer_guidance_layers: Optional[Union[int, List[int]]] = None,
skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig]] = None,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.01,
stop: float = 0.2,
start: float = 0.0,
stop: float = 1.0,
):
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.skip_layer_guidance_scale = skip_layer_guidance_scale
self.skip_layer_guidance_start = skip_layer_guidance_start
self.skip_layer_guidance_stop = skip_layer_guidance_stop
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
if not (0.0 <= skip_layer_guidance_start < 1.0):
raise ValueError(
f"Expected `skip_layer_guidance_start` to be between 0.0 and 1.0, but got {skip_layer_guidance_start}."
)
if not (skip_layer_guidance_start <= skip_layer_guidance_stop <= 1.0):
raise ValueError(
f"Expected `skip_layer_guidance_stop` to be between 0.0 and 1.0, but got {skip_layer_guidance_stop}."
)
if skip_layer_guidance_layers is None and skip_layer_config is None:
raise ValueError(
"Either `skip_layer_guidance_layers` or `skip_layer_config` must be provided to enable Skip Layer Guidance."
@@ -122,11 +139,12 @@ class SkipLayerGuidance(BaseGuidance):
self.skip_layer_config = skip_layer_config
self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))]
def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
if self._num_outputs_prepared == 0 and self._is_slg_enabled():
def prepare_models(self, denoiser: torch.nn.Module) -> None:
if self._is_slg_enabled() and self.is_conditional and self._num_outputs_prepared > 0:
for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config):
_apply_layer_skip_hook(denoiser, config, name=name)
def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
num_conditions = self.num_conditions
list_of_inputs = []
for arg in args:
@@ -161,7 +179,8 @@ class SkipLayerGuidance(BaseGuidance):
key = "pred_cond_skip"
self._preds[key] = pred
if self._num_outputs_prepared == self.num_conditions:
if key == "pred_cond_skip":
# If we are in SLG mode, we need to remove the hooks after inference
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
# Remove the hooks after inference
for hook_name in self._skip_layer_hook_names:
@@ -233,8 +252,8 @@ class SkipLayerGuidance(BaseGuidance):
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
is_within_range = skip_start_step < self._step < skip_stop_step
is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0)

View File

@@ -80,14 +80,17 @@ class AttentionProcessorSkipHook(ModelHook):
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
if self.skip_attention_scores:
print("Skipping attention scores")
with AttentionScoreSkipFunctionMode():
return self.fn_ref.original_forward(*args, **kwargs)
else:
print("Skipping attention processor output")
return self.skip_processor_output_fn(module, *args, **kwargs)
class FeedForwardSkipHook(ModelHook):
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
print("Skipping feed-forward block")
output = kwargs.get("hidden_states", None)
if output is None:
output = kwargs.get("x", None)
@@ -102,18 +105,22 @@ class TransformerBlockSkipHook(ModelHook):
return module
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
print("Skipping transformer block")
return self._metadata.skip_block_output_fn(module, *args, **kwargs)
def apply_layer_skip(module: torch.nn.Module, config: LayerSkipConfig) -> None:
r"""
Apply layer skipping to internal layers of a transformer.
Args:
module (`torch.nn.Module`):
The transformer model to which the layer skip hook should be applied.
config (`LayerSkipConfig`):
The configuration for the layer skip hook.
Example:
```python
>>> from diffusers import apply_layer_skip_hook, CogVideoXTransformer3DModel, LayerSkipConfig
>>> transformer = CogVideoXTransformer3DModel.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
@@ -168,17 +175,13 @@ def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, nam
registry = HookRegistry.check_if_exists_or_initialize(submodule)
hook = AttentionProcessorSkipHook(output_fn, config.skip_attention_scores)
registry.register_hook(hook, name)
elif config.skip_ff:
if config.skip_ff:
for submodule_name, submodule in block.named_modules():
if isinstance(submodule, _FEEDFORWARD_CLASSES):
logger.debug(f"Applying FeedForwardSkipHook to '{config.fqn}.{i}.{submodule_name}'")
registry = HookRegistry.check_if_exists_or_initialize(submodule)
hook = FeedForwardSkipHook()
registry.register_hook(hook, name)
else:
raise ValueError(
"At least one of `skip_attention`, `skip_attention_scores`, or `skip_ff` must be set to True."
)
if not blocks_found:
raise ValueError(

View File

@@ -2267,6 +2267,7 @@ class StableDiffusionXLDenoiseStep(PipelineBlock):
) in enumerate(zip(
latents, prompt_embeds, add_time_ids, pooled_prompt_embeds, mask, masked_image_latents, ip_adapter_embeds
)):
pipeline.guider.prepare_models(pipeline.unet)
latents_i = pipeline.scheduler.scale_model_input(latents_i, t)
# Prepare for inpainting
@@ -2670,6 +2671,7 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
) in enumerate(zip(
latents, prompt_embeds, add_time_ids, pooled_prompt_embeds, mask, masked_image_latents, ip_adapter_embeds
)):
pipeline.guider.prepare_models(pipeline.unet)
latents_i = pipeline.scheduler.scale_model_input(latents_i, t)
# Prepare for inpainting
@@ -3085,6 +3087,7 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock):
) in enumerate(zip(
latents, prompt_embeds, add_time_ids, pooled_prompt_embeds, mask, masked_image_latents, ip_adapter_embeds
)):
pipeline.guider.prepare_models(pipeline.unet)
latents_i = pipeline.scheduler.scale_model_input(latents_i, t)
# Prepare for inpainting