1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Copied-from implementation of PAG-guider (#11882)

* update

* fix
This commit is contained in:
Aryan
2025-07-08 09:46:52 +05:30
committed by GitHub
parent e0083b29d5
commit be5e10ae61
2 changed files with 145 additions and 18 deletions

View File

@@ -12,18 +12,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Union
import math
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import torch
from ..configuration_utils import register_to_config
from ..hooks import LayerSkipConfig
from ..hooks import HookRegistry, LayerSkipConfig
from ..hooks.layer_skip import _apply_layer_skip_hook
from ..utils import get_logger
from .skip_layer_guidance import SkipLayerGuidance
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
logger = get_logger(__name__) # pylint: disable=invalid-name
class PerturbedAttentionGuidance(SkipLayerGuidance):
class PerturbedAttentionGuidance(BaseGuidance):
"""
Perturbed Attention Guidance (PAG): https://huggingface.co/papers/2403.17377
@@ -36,7 +44,7 @@ class PerturbedAttentionGuidance(SkipLayerGuidance):
Additional reading:
- [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507)
PAG is implemented as a specialization of the SkipLayerGuidance due to similarities in the configuration parameters
PAG is implemented with similar implementation to SkipLayerGuidance due to overlap in the configuration parameters
and implementation details.
Args:
@@ -75,6 +83,8 @@ class PerturbedAttentionGuidance(SkipLayerGuidance):
# complex to support joint latent conditioning in a model-agnostic manner without specializing the implementation
# for each model architecture.
_input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
@register_to_config
def __init__(
self,
@@ -89,6 +99,15 @@ class PerturbedAttentionGuidance(SkipLayerGuidance):
start: float = 0.0,
stop: float = 1.0,
):
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.skip_layer_guidance_scale = perturbed_guidance_scale
self.skip_layer_guidance_start = perturbed_guidance_start
self.skip_layer_guidance_stop = perturbed_guidance_stop
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
if perturbed_guidance_config is None:
if perturbed_guidance_layers is None:
raise ValueError(
@@ -130,15 +149,123 @@ class PerturbedAttentionGuidance(SkipLayerGuidance):
config.skip_attention_scores = True
config.skip_ff = False
super().__init__(
guidance_scale=guidance_scale,
skip_layer_guidance_scale=perturbed_guidance_scale,
skip_layer_guidance_start=perturbed_guidance_start,
skip_layer_guidance_stop=perturbed_guidance_stop,
skip_layer_guidance_layers=perturbed_guidance_layers,
skip_layer_config=perturbed_guidance_config,
guidance_rescale=guidance_rescale,
use_original_formulation=use_original_formulation,
start=start,
stop=stop,
)
self.skip_layer_config = perturbed_guidance_config
self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))]
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_models
def prepare_models(self, denoiser: torch.nn.Module) -> None:
self._count_prepared += 1
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config):
_apply_layer_skip_hook(denoiser, config, name=name)
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.cleanup_models
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
# Remove the hooks after inference
for hook_name in self._skip_layer_hook_names:
registry.remove_hook(hook_name, recurse=True)
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_inputs
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
elif self.num_conditions == 2:
tuple_indices = [0, 1]
input_predictions = (
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
)
else:
tuple_indices = [0, 1, 0]
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
data_batches.append(data_batch)
return data_batches
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.forward
def forward(
self,
pred_cond: torch.Tensor,
pred_uncond: Optional[torch.Tensor] = None,
pred_cond_skip: Optional[torch.Tensor] = None,
) -> torch.Tensor:
pred = None
if not self._is_cfg_enabled() and not self._is_slg_enabled():
pred = pred_cond
elif not self._is_cfg_enabled():
shift = pred_cond - pred_cond_skip
pred = pred_cond if self.use_original_formulation else pred_cond_skip
pred = pred + self.skip_layer_guidance_scale * shift
elif not self._is_slg_enabled():
shift = pred_cond - pred_uncond
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift
else:
shift = pred_cond - pred_uncond
shift_skip = pred_cond - pred_cond_skip
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift + self.skip_layer_guidance_scale * shift_skip
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {}
@property
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.is_conditional
def is_conditional(self) -> bool:
return self._count_prepared == 1 or self._count_prepared == 3
@property
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.num_conditions
def num_conditions(self) -> int:
num_conditions = 1
if self._is_cfg_enabled():
num_conditions += 1
if self._is_slg_enabled():
num_conditions += 1
return num_conditions
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance._is_cfg_enabled
def _is_cfg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance._is_slg_enabled
def _is_slg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
is_within_range = skip_start_step < self._step < skip_stop_step
is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0)
return is_within_range and not is_zero

View File

@@ -335,7 +335,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
components_manager: Optional[ComponentsManager] = None,
collection: Optional[str] = None,
):
) -> "ModularPipeline":
"""
create a ModularPipeline, optionally accept modular_repo to load from hub.
"""