diff --git a/src/diffusers/guiders/adaptive_projected_guidance.py b/src/diffusers/guiders/adaptive_projected_guidance.py index ab5175745b..45bd196860 100644 --- a/src/diffusers/guiders/adaptive_projected_guidance.py +++ b/src/diffusers/guiders/adaptive_projected_guidance.py @@ -86,7 +86,7 @@ class AdaptiveProjectedGuidance(BaseGuidance): def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: pred = None - if not self._is_cfg_enabled(): + if not self._is_apg_enabled(): pred = pred_cond else: pred = normalized_guidance( @@ -111,11 +111,11 @@ class AdaptiveProjectedGuidance(BaseGuidance): @property def num_conditions(self) -> int: num_conditions = 1 - if self._is_cfg_enabled(): + if self._is_apg_enabled(): num_conditions += 1 return num_conditions - def _is_cfg_enabled(self) -> bool: + def _is_apg_enabled(self) -> bool: if not self._enabled: return False diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py index 690afae891..60859bf390 100644 --- a/src/diffusers/guiders/guider_utils.py +++ b/src/diffusers/guiders/guider_utils.py @@ -68,6 +68,13 @@ class BaseGuidance: self._preds = {} self._num_outputs_prepared = 0 + def prepare_models(self, denoiser: torch.nn.Module) -> None: + """ + Prepares the models for the guidance technique on a given batch of data. This method should be overridden in + subclasses to implement specific model preparation logic. + """ + pass + def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.") diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py index bac851c0dc..3fbfd771ef 100644 --- a/src/diffusers/guiders/skip_layer_guidance.py +++ b/src/diffusers/guiders/skip_layer_guidance.py @@ -54,6 +54,10 @@ class SkipLayerGuidance(BaseGuidance): skip_layer_guidance_scale (`float`, defaults to `2.8`): The scale parameter for skip layer guidance. Anatomy and structure coherence may improve with higher values, but it may also lead to overexposure and saturation. + skip_layer_guidance_start (`float`, defaults to `0.01`): + The fraction of the total number of denoising steps after which skip layer guidance starts. + skip_layer_guidance_stop (`float`, defaults to `0.2`): + The fraction of the total number of denoising steps after which skip layer guidance stops. skip_layer_guidance_layers (`int` or `List[int]`, *optional*): The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not provided, `skip_layer_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion @@ -81,20 +85,33 @@ class SkipLayerGuidance(BaseGuidance): self, guidance_scale: float = 7.5, skip_layer_guidance_scale: float = 2.8, + skip_layer_guidance_start: float = 0.01, + skip_layer_guidance_stop: float = 0.2, skip_layer_guidance_layers: Optional[Union[int, List[int]]] = None, skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig]] = None, guidance_rescale: float = 0.0, use_original_formulation: bool = False, - start: float = 0.01, - stop: float = 0.2, + start: float = 0.0, + stop: float = 1.0, ): super().__init__(start, stop) self.guidance_scale = guidance_scale self.skip_layer_guidance_scale = skip_layer_guidance_scale + self.skip_layer_guidance_start = skip_layer_guidance_start + self.skip_layer_guidance_stop = skip_layer_guidance_stop self.guidance_rescale = guidance_rescale self.use_original_formulation = use_original_formulation + if not (0.0 <= skip_layer_guidance_start < 1.0): + raise ValueError( + f"Expected `skip_layer_guidance_start` to be between 0.0 and 1.0, but got {skip_layer_guidance_start}." + ) + if not (skip_layer_guidance_start <= skip_layer_guidance_stop <= 1.0): + raise ValueError( + f"Expected `skip_layer_guidance_stop` to be between 0.0 and 1.0, but got {skip_layer_guidance_stop}." + ) + if skip_layer_guidance_layers is None and skip_layer_config is None: raise ValueError( "Either `skip_layer_guidance_layers` or `skip_layer_config` must be provided to enable Skip Layer Guidance." @@ -122,11 +139,12 @@ class SkipLayerGuidance(BaseGuidance): self.skip_layer_config = skip_layer_config self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))] - def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: - if self._num_outputs_prepared == 0 and self._is_slg_enabled(): + def prepare_models(self, denoiser: torch.nn.Module) -> None: + if self._is_slg_enabled() and self.is_conditional and self._num_outputs_prepared > 0: for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config): _apply_layer_skip_hook(denoiser, config, name=name) - + + def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: num_conditions = self.num_conditions list_of_inputs = [] for arg in args: @@ -161,7 +179,8 @@ class SkipLayerGuidance(BaseGuidance): key = "pred_cond_skip" self._preds[key] = pred - if self._num_outputs_prepared == self.num_conditions: + if key == "pred_cond_skip": + # If we are in SLG mode, we need to remove the hooks after inference registry = HookRegistry.check_if_exists_or_initialize(denoiser) # Remove the hooks after inference for hook_name in self._skip_layer_hook_names: @@ -233,8 +252,8 @@ class SkipLayerGuidance(BaseGuidance): is_within_range = True if self._num_inference_steps is not None: - skip_start_step = int(self._start * self._num_inference_steps) - skip_stop_step = int(self._stop * self._num_inference_steps) + skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps) + skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps) is_within_range = skip_start_step < self._step < skip_stop_step is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0) diff --git a/src/diffusers/hooks/layer_skip.py b/src/diffusers/hooks/layer_skip.py index b847152390..e28322ac48 100644 --- a/src/diffusers/hooks/layer_skip.py +++ b/src/diffusers/hooks/layer_skip.py @@ -80,14 +80,17 @@ class AttentionProcessorSkipHook(ModelHook): def new_forward(self, module: torch.nn.Module, *args, **kwargs): if self.skip_attention_scores: + print("Skipping attention scores") with AttentionScoreSkipFunctionMode(): return self.fn_ref.original_forward(*args, **kwargs) else: + print("Skipping attention processor output") return self.skip_processor_output_fn(module, *args, **kwargs) class FeedForwardSkipHook(ModelHook): def new_forward(self, module: torch.nn.Module, *args, **kwargs): + print("Skipping feed-forward block") output = kwargs.get("hidden_states", None) if output is None: output = kwargs.get("x", None) @@ -102,18 +105,22 @@ class TransformerBlockSkipHook(ModelHook): return module def new_forward(self, module: torch.nn.Module, *args, **kwargs): + print("Skipping transformer block") return self._metadata.skip_block_output_fn(module, *args, **kwargs) def apply_layer_skip(module: torch.nn.Module, config: LayerSkipConfig) -> None: r""" Apply layer skipping to internal layers of a transformer. + Args: module (`torch.nn.Module`): The transformer model to which the layer skip hook should be applied. config (`LayerSkipConfig`): The configuration for the layer skip hook. + Example: + ```python >>> from diffusers import apply_layer_skip_hook, CogVideoXTransformer3DModel, LayerSkipConfig >>> transformer = CogVideoXTransformer3DModel.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) @@ -168,17 +175,13 @@ def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, nam registry = HookRegistry.check_if_exists_or_initialize(submodule) hook = AttentionProcessorSkipHook(output_fn, config.skip_attention_scores) registry.register_hook(hook, name) - elif config.skip_ff: + if config.skip_ff: for submodule_name, submodule in block.named_modules(): if isinstance(submodule, _FEEDFORWARD_CLASSES): logger.debug(f"Applying FeedForwardSkipHook to '{config.fqn}.{i}.{submodule_name}'") registry = HookRegistry.check_if_exists_or_initialize(submodule) hook = FeedForwardSkipHook() registry.register_hook(hook, name) - else: - raise ValueError( - "At least one of `skip_attention`, `skip_attention_scores`, or `skip_ff` must be set to True." - ) if not blocks_found: raise ValueError( diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 68d9d913bd..aed212c3f8 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -2267,6 +2267,7 @@ class StableDiffusionXLDenoiseStep(PipelineBlock): ) in enumerate(zip( latents, prompt_embeds, add_time_ids, pooled_prompt_embeds, mask, masked_image_latents, ip_adapter_embeds )): + pipeline.guider.prepare_models(pipeline.unet) latents_i = pipeline.scheduler.scale_model_input(latents_i, t) # Prepare for inpainting @@ -2670,6 +2671,7 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): ) in enumerate(zip( latents, prompt_embeds, add_time_ids, pooled_prompt_embeds, mask, masked_image_latents, ip_adapter_embeds )): + pipeline.guider.prepare_models(pipeline.unet) latents_i = pipeline.scheduler.scale_model_input(latents_i, t) # Prepare for inpainting @@ -3085,6 +3087,7 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock): ) in enumerate(zip( latents, prompt_embeds, add_time_ids, pooled_prompt_embeds, mask, masked_image_latents, ip_adapter_embeds )): + pipeline.guider.prepare_models(pipeline.unet) latents_i = pipeline.scheduler.scale_model_input(latents_i, t) # Prepare for inpainting