mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
@@ -12,18 +12,26 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import register_to_config
|
||||
from ..hooks import LayerSkipConfig
|
||||
from ..hooks import HookRegistry, LayerSkipConfig
|
||||
from ..hooks.layer_skip import _apply_layer_skip_hook
|
||||
from ..utils import get_logger
|
||||
from .skip_layer_guidance import SkipLayerGuidance
|
||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modular_pipelines.modular_pipeline import BlockState
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class PerturbedAttentionGuidance(SkipLayerGuidance):
|
||||
class PerturbedAttentionGuidance(BaseGuidance):
|
||||
"""
|
||||
Perturbed Attention Guidance (PAG): https://huggingface.co/papers/2403.17377
|
||||
|
||||
@@ -36,7 +44,7 @@ class PerturbedAttentionGuidance(SkipLayerGuidance):
|
||||
Additional reading:
|
||||
- [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507)
|
||||
|
||||
PAG is implemented as a specialization of the SkipLayerGuidance due to similarities in the configuration parameters
|
||||
PAG is implemented with similar implementation to SkipLayerGuidance due to overlap in the configuration parameters
|
||||
and implementation details.
|
||||
|
||||
Args:
|
||||
@@ -75,6 +83,8 @@ class PerturbedAttentionGuidance(SkipLayerGuidance):
|
||||
# complex to support joint latent conditioning in a model-agnostic manner without specializing the implementation
|
||||
# for each model architecture.
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
@@ -89,6 +99,15 @@ class PerturbedAttentionGuidance(SkipLayerGuidance):
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.skip_layer_guidance_scale = perturbed_guidance_scale
|
||||
self.skip_layer_guidance_start = perturbed_guidance_start
|
||||
self.skip_layer_guidance_stop = perturbed_guidance_stop
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
if perturbed_guidance_config is None:
|
||||
if perturbed_guidance_layers is None:
|
||||
raise ValueError(
|
||||
@@ -130,15 +149,123 @@ class PerturbedAttentionGuidance(SkipLayerGuidance):
|
||||
config.skip_attention_scores = True
|
||||
config.skip_ff = False
|
||||
|
||||
super().__init__(
|
||||
guidance_scale=guidance_scale,
|
||||
skip_layer_guidance_scale=perturbed_guidance_scale,
|
||||
skip_layer_guidance_start=perturbed_guidance_start,
|
||||
skip_layer_guidance_stop=perturbed_guidance_stop,
|
||||
skip_layer_guidance_layers=perturbed_guidance_layers,
|
||||
skip_layer_config=perturbed_guidance_config,
|
||||
guidance_rescale=guidance_rescale,
|
||||
use_original_formulation=use_original_formulation,
|
||||
start=start,
|
||||
stop=stop,
|
||||
)
|
||||
self.skip_layer_config = perturbed_guidance_config
|
||||
self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))]
|
||||
|
||||
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_models
|
||||
def prepare_models(self, denoiser: torch.nn.Module) -> None:
|
||||
self._count_prepared += 1
|
||||
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
|
||||
for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config):
|
||||
_apply_layer_skip_hook(denoiser, config, name=name)
|
||||
|
||||
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.cleanup_models
|
||||
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
|
||||
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
|
||||
# Remove the hooks after inference
|
||||
for hook_name in self._skip_layer_hook_names:
|
||||
registry.remove_hook(hook_name, recurse=True)
|
||||
|
||||
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_inputs
|
||||
def prepare_inputs(
|
||||
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
||||
) -> List["BlockState"]:
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
if self.num_conditions == 1:
|
||||
tuple_indices = [0]
|
||||
input_predictions = ["pred_cond"]
|
||||
elif self.num_conditions == 2:
|
||||
tuple_indices = [0, 1]
|
||||
input_predictions = (
|
||||
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
|
||||
)
|
||||
else:
|
||||
tuple_indices = [0, 1, 0]
|
||||
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.forward
|
||||
def forward(
|
||||
self,
|
||||
pred_cond: torch.Tensor,
|
||||
pred_uncond: Optional[torch.Tensor] = None,
|
||||
pred_cond_skip: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
pred = None
|
||||
|
||||
if not self._is_cfg_enabled() and not self._is_slg_enabled():
|
||||
pred = pred_cond
|
||||
elif not self._is_cfg_enabled():
|
||||
shift = pred_cond - pred_cond_skip
|
||||
pred = pred_cond if self.use_original_formulation else pred_cond_skip
|
||||
pred = pred + self.skip_layer_guidance_scale * shift
|
||||
elif not self._is_slg_enabled():
|
||||
shift = pred_cond - pred_uncond
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift
|
||||
else:
|
||||
shift = pred_cond - pred_uncond
|
||||
shift_skip = pred_cond - pred_cond_skip
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift + self.skip_layer_guidance_scale * shift_skip
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return pred, {}
|
||||
|
||||
@property
|
||||
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.is_conditional
|
||||
def is_conditional(self) -> bool:
|
||||
return self._count_prepared == 1 or self._count_prepared == 3
|
||||
|
||||
@property
|
||||
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.num_conditions
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_cfg_enabled():
|
||||
num_conditions += 1
|
||||
if self._is_slg_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance._is_cfg_enabled
|
||||
def _is_cfg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
return is_within_range and not is_close
|
||||
|
||||
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance._is_slg_enabled
|
||||
def _is_slg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
|
||||
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step < self._step < skip_stop_step
|
||||
|
||||
is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0)
|
||||
|
||||
return is_within_range and not is_zero
|
||||
|
||||
@@ -335,7 +335,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
|
||||
components_manager: Optional[ComponentsManager] = None,
|
||||
collection: Optional[str] = None,
|
||||
):
|
||||
) -> "ModularPipeline":
|
||||
"""
|
||||
create a ModularPipeline, optionally accept modular_repo to load from hub.
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user