1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

PAG variant for HunyuanDiT, PAG refactor (#8936)

* copy hunyuandit pipeline

* pag variant of hunyuan dit

* add tests

* update docs

* make style

* make fix-copies

* Update src/diffusers/pipelines/pag/pag_utils.py

* remove incorrect copied from

* remove pag hunyuan attn procs to resolve conflicts

* add pag attn procs again

* new implementation for pag_utils

* revert pag changes

* add pag refactor back; update pixart sigma

* update pixart pag tests

* apply suggestions from review

Co-Authored-By: yixu310@gmail.com

* make style

* update docs, fix tests

* fix tests

* fix test_components_function since list not accepted as valid __init__ param

* apply patch to fix broken tests

Co-Authored-By: Sayak Paul <spsayakpaul@gmail.com>

* make style

* fix hunyuan tests

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
Aryan
2024-08-05 17:56:09 +05:30
committed by GitHub
parent e1d508ae92
commit b7058d142c
16 changed files with 1748 additions and 365 deletions

View File

@@ -20,11 +20,29 @@ The abstract from the paper is:
*Recent studies have demonstrated that diffusion models are capable of generating high-quality samples, but their quality heavily depends on sampling guidance techniques, such as classifier guidance (CG) and classifier-free guidance (CFG). These techniques are often not applicable in unconditional generation or in various downstream tasks such as image restoration. In this paper, we propose a novel sampling guidance, called Perturbed-Attention Guidance (PAG), which improves diffusion sample quality across both unconditional and conditional settings, achieving this without requiring additional training or the integration of external modules. PAG is designed to progressively enhance the structure of samples throughout the denoising process. It involves generating intermediate samples with degraded structure by substituting selected self-attention maps in diffusion U-Net with an identity matrix, by considering the self-attention mechanisms' ability to capture structural information, and guiding the denoising process away from these degraded samples. In both ADM and Stable Diffusion, PAG surprisingly improves sample quality in conditional and even unconditional scenarios. Moreover, PAG significantly improves the baseline performance in various downstream tasks where existing guidances such as CG or CFG cannot be fully utilized, including ControlNet with empty prompts and image restoration such as inpainting and deblurring.*
PAG can be used by specifying the `pag_applied_layers` as a parameter when instantiating a PAG pipeline. It can be a single string or a list of strings. Each string can be a unique layer identifier or a regular expression to identify one or more layers.
- Full identifier as a normal string: `down_blocks.2.attentions.0.transformer_blocks.0.attn1.processor`
- Full identifier as a RegEx: `down_blocks.2.(attentions|motion_modules).0.transformer_blocks.0.attn1.processor`
- Partial identifier as a RegEx: `down_blocks.2`, or `attn1`
- List of identifiers (can be combo of strings and ReGex): `["blocks.1", "blocks.(14|20)", r"down_blocks\.(2,3)"]`
<Tip warning={true}>
Since RegEx is supported as a way for matching layer identifiers, it is crucial to use it correctly otherwise there might be unexpected behaviour. The recommended way to use PAG is by specifying layers as `blocks.{layer_index}` and `blocks.({layer_index_1|layer_index_2|...})`. Using it in any other way, while doable, may bypass our basic validation checks and give you unexpected results.
</Tip>
## AnimateDiffPAGPipeline
[[autodoc]] AnimateDiffPAGPipeline
- all
- __call__
## HunyuanDiTPAGPipeline
[[autodoc]] HunyuanDiTPAGPipeline
- all
- __call__
## StableDiffusionPAGPipeline
[[autodoc]] StableDiffusionPAGPipeline
- all
@@ -59,4 +77,4 @@ The abstract from the paper is:
## PixArtSigmaPAGPipeline
[[autodoc]] PixArtSigmaPAGPipeline
- all
- __call__
- __call__

View File

@@ -252,6 +252,7 @@ else:
"CycleDiffusionPipeline",
"FluxPipeline",
"HunyuanDiTControlNetPipeline",
"HunyuanDiTPAGPipeline",
"HunyuanDiTPipeline",
"I2VGenXLPipeline",
"IFImg2ImgPipeline",
@@ -675,6 +676,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
CycleDiffusionPipeline,
FluxPipeline,
HunyuanDiTControlNetPipeline,
HunyuanDiTPAGPipeline,
HunyuanDiTPipeline,
I2VGenXLPipeline,
IFImg2ImgPipeline,

View File

@@ -2147,6 +2147,253 @@ class FusedHunyuanAttnProcessor2_0:
return hidden_states
class PAGHunyuanAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the HunyuanDiT model. It applies a normalization layer and rotary embedding on query and key vector. This
variant of the processor employs [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377).
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"PAGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
from .embeddings import apply_rotary_emb
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
# chunk
hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)
# 1. Original Path
batch_size, sequence_length, _ = (
hidden_states_org.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states_org)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states_org
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
if not attn.is_cross_attention:
key = apply_rotary_emb(key, image_rotary_emb)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states_org = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states_org = hidden_states_org.to(query.dtype)
# linear proj
hidden_states_org = attn.to_out[0](hidden_states_org)
# dropout
hidden_states_org = attn.to_out[1](hidden_states_org)
if input_ndim == 4:
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
# 2. Perturbed Path
if attn.group_norm is not None:
hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
hidden_states_ptb = attn.to_v(hidden_states_ptb)
hidden_states_ptb = hidden_states_ptb.to(query.dtype)
# linear proj
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
# dropout
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
if input_ndim == 4:
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
# cat
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class PAGCFGHunyuanAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the HunyuanDiT model. It applies a normalization layer and rotary embedding on query and key vector. This
variant of the processor employs [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377).
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"PAGCFGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
from .embeddings import apply_rotary_emb
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
# chunk
hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])
# 1. Original Path
batch_size, sequence_length, _ = (
hidden_states_org.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states_org)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states_org
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
if not attn.is_cross_attention:
key = apply_rotary_emb(key, image_rotary_emb)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states_org = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states_org = hidden_states_org.to(query.dtype)
# linear proj
hidden_states_org = attn.to_out[0](hidden_states_org)
# dropout
hidden_states_org = attn.to_out[1](hidden_states_org)
if input_ndim == 4:
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
# 2. Perturbed Path
if attn.group_norm is not None:
hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
hidden_states_ptb = attn.to_v(hidden_states_ptb)
hidden_states_ptb = hidden_states_ptb.to(query.dtype)
# linear proj
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
# dropout
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
if input_ndim == 4:
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
# cat
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class LuminaAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
@@ -3468,4 +3715,6 @@ AttentionProcessor = Union[
CustomDiffusionAttnProcessor2_0,
PAGCFGIdentitySelfAttnProcessor2_0,
PAGIdentitySelfAttnProcessor2_0,
PAGCFGHunyuanAttnProcessor2_0,
PAGHunyuanAttnProcessor2_0,
]

View File

@@ -145,6 +145,7 @@ else:
_import_structure["pag"].extend(
[
"AnimateDiffPAGPipeline",
"HunyuanDiTPAGPipeline",
"StableDiffusionPAGPipeline",
"StableDiffusionControlNetPAGPipeline",
"StableDiffusionXLPAGPipeline",
@@ -532,6 +533,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .musicldm import MusicLDMPipeline
from .pag import (
AnimateDiffPAGPipeline,
HunyuanDiTPAGPipeline,
PixArtSigmaPAGPipeline,
StableDiffusionControlNetPAGPipeline,
StableDiffusionPAGPipeline,

View File

@@ -50,6 +50,7 @@ from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline
from .kolors import KolorsImg2ImgPipeline, KolorsPipeline
from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline
from .pag import (
HunyuanDiTPAGPipeline,
PixArtSigmaPAGPipeline,
StableDiffusionControlNetPAGPipeline,
StableDiffusionPAGPipeline,
@@ -85,6 +86,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
("stable-diffusion-3", StableDiffusion3Pipeline),
("if", IFPipeline),
("hunyuan", HunyuanDiTPipeline),
("hunyuan-pag", HunyuanDiTPAGPipeline),
("kandinsky", KandinskyCombinedPipeline),
("kandinsky22", KandinskyV22CombinedPipeline),
("kandinsky3", Kandinsky3Pipeline),

View File

@@ -24,6 +24,7 @@ except OptionalDependencyNotAvailable:
else:
_import_structure["pipeline_pag_controlnet_sd"] = ["StableDiffusionControlNetPAGPipeline"]
_import_structure["pipeline_pag_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPAGPipeline"]
_import_structure["pipeline_pag_hunyuandit"] = ["HunyuanDiTPAGPipeline"]
_import_structure["pipeline_pag_pixart_sigma"] = ["PixArtSigmaPAGPipeline"]
_import_structure["pipeline_pag_sd"] = ["StableDiffusionPAGPipeline"]
_import_structure["pipeline_pag_sd_animatediff"] = ["AnimateDiffPAGPipeline"]
@@ -41,6 +42,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .pipeline_pag_controlnet_sd import StableDiffusionControlNetPAGPipeline
from .pipeline_pag_controlnet_sd_xl import StableDiffusionXLControlNetPAGPipeline
from .pipeline_pag_hunyuandit import HunyuanDiTPAGPipeline
from .pipeline_pag_pixart_sigma import PixArtSigmaPAGPipeline
from .pipeline_pag_sd import StableDiffusionPAGPipeline
from .pipeline_pag_sd_animatediff import AnimateDiffPAGPipeline

View File

@@ -12,9 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import Dict, List, Tuple, Union
import torch
import torch.nn as nn
from ...models.attention_processor import (
Attention,
AttentionProcessor,
PAGCFGIdentitySelfAttnProcessor2_0,
PAGIdentitySelfAttnProcessor2_0,
)
@@ -25,332 +31,60 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class PAGMixin:
r"""Mixin class for PAG."""
@staticmethod
def _check_input_pag_applied_layer(layer):
r"""
Check if each layer input in `applied_pag_layers` is valid. It should be either one of these 3 formats:
"{block_type}", "{block_type}.{block_index}", or "{block_type}.{block_index}.{attention_index}". `block_type`
can be "down", "mid", "up". `block_index` should be in the format of "block_{i}". `attention_index` should be
in the format of "attentions_{j}". `motion_modules_index` should be in the format of "motion_modules_{j}"
"""
layer_splits = layer.split(".")
if len(layer_splits) > 3:
raise ValueError(f"pag layer should only contains block_type, block_index and attention_index{layer}.")
if len(layer_splits) >= 1:
if layer_splits[0] not in ["down", "mid", "up"]:
raise ValueError(
f"Invalid block_type in pag layer {layer}. Accept 'down', 'mid', 'up', got {layer_splits[0]}"
)
if len(layer_splits) >= 2:
if not layer_splits[1].startswith("block_"):
raise ValueError(f"Invalid block_index in pag layer: {layer}. Should start with 'block_'")
if len(layer_splits) == 3:
layer_2 = layer_splits[2]
if not layer_2.startswith("attentions_") and not layer_2.startswith("motion_modules_"):
raise ValueError(
f"Invalid attention_index in pag layer: {layer}. Should start with 'attentions_' or 'motion_modules_'"
)
r"""Mixin class for [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377v1)."""
def _set_pag_attn_processor(self, pag_applied_layers, do_classifier_free_guidance):
r"""
Set the attention processor for the PAG layers.
"""
if do_classifier_free_guidance:
pag_attn_proc = PAGCFGIdentitySelfAttnProcessor2_0()
else:
pag_attn_proc = PAGIdentitySelfAttnProcessor2_0()
pag_attn_processors = self._pag_attn_processors
if pag_attn_processors is None:
raise ValueError(
"No PAG attention processors have been set. Set the attention processors by calling `set_pag_applied_layers` and passing the relevant parameters."
)
def is_self_attn(module_name):
pag_attn_proc = pag_attn_processors[0] if do_classifier_free_guidance else pag_attn_processors[1]
if hasattr(self, "unet"):
model: nn.Module = self.unet
else:
model: nn.Module = self.transformer
def is_self_attn(module: nn.Module) -> bool:
r"""
Check if the module is self-attention module based on its name.
"""
return "attn1" in module_name and "to" not in name
return isinstance(module, Attention) and not module.is_cross_attention
def get_block_type(module_name):
r"""
Get the block type from the module name. Can be "down", "mid", "up".
"""
# down_blocks.1.attentions.0.transformer_blocks.0.attn1 -> "down"
# down_blocks.1.motion_modules.0.transformer_blocks.0.attn1 -> "down"
return module_name.split(".")[0].split("_")[0]
def is_fake_integral_match(layer_id, name):
layer_id = layer_id.split(".")[-1]
name = name.split(".")[-1]
return layer_id.isnumeric() and name.isnumeric() and layer_id == name
def get_block_index(module_name):
r"""
Get the block index from the module name. Can be "block_0", "block_1", ... If there is only one block (e.g.
mid_block) and index is ommited from the name, it will be "block_0".
"""
# down_blocks.1.attentions.0.transformer_blocks.0.attn1 -> "block_1"
# mid_block.attentions.0.transformer_blocks.0.attn1 -> "block_0"
module_name_splits = module_name.split(".")
block_index = module_name_splits[1]
if "attentions" in block_index or "motion_modules" in block_index:
return "block_0"
else:
return f"block_{block_index}"
def get_attn_index(module_name):
r"""
Get the attention index from the module name. Can be "attentions_0", "attentions_1", "motion_modules_0",
"motion_modules_1", ...
"""
# down_blocks.1.attentions.0.transformer_blocks.0.attn1 -> "attentions_0"
# mid_block.attentions.0.transformer_blocks.0.attn1 -> "attentions_0"
# down_blocks.1.motion_modules.0.transformer_blocks.0.attn1 -> "motion_modules_0"
# mid_block.motion_modules.0.transformer_blocks.0.attn1 -> "motion_modules_0"
module_name_split = module_name.split(".")
mid_name = module_name_split[1]
down_name = module_name_split[2]
if "attentions" in down_name:
return f"attentions_{module_name_split[3]}"
if "attentions" in mid_name:
return f"attentions_{module_name_split[2]}"
if "motion_modules" in down_name:
return f"motion_modules_{module_name_split[3]}"
if "motion_modules" in mid_name:
return f"motion_modules_{module_name_split[2]}"
for pag_layer_input in pag_applied_layers:
for layer_id in pag_applied_layers:
# for each PAG layer input, we find corresponding self-attention layers in the unet model
target_modules = []
pag_layer_input_splits = pag_layer_input.split(".")
if len(pag_layer_input_splits) == 1:
# when the layer input only contains block_type. e.g. "mid", "down", "up"
block_type = pag_layer_input_splits[0]
for name, module in self.unet.named_modules():
if is_self_attn(name) and get_block_type(name) == block_type:
target_modules.append(module)
elif len(pag_layer_input_splits) == 2:
# when the layer input contains both block_type and block_index. e.g. "down.block_1", "mid.block_0"
block_type = pag_layer_input_splits[0]
block_index = pag_layer_input_splits[1]
for name, module in self.unet.named_modules():
if (
is_self_attn(name)
and get_block_type(name) == block_type
and get_block_index(name) == block_index
):
target_modules.append(module)
elif len(pag_layer_input_splits) == 3:
# when the layer input contains block_type, block_index and attention_index.
# e.g. "down.block_1.attentions_1" or "down.block_1.motion_modules_1"
block_type = pag_layer_input_splits[0]
block_index = pag_layer_input_splits[1]
attn_index = pag_layer_input_splits[2]
for name, module in self.unet.named_modules():
if (
is_self_attn(name)
and get_block_type(name) == block_type
and get_block_index(name) == block_index
and get_attn_index(name) == attn_index
):
target_modules.append(module)
if len(target_modules) == 0:
raise ValueError(f"Cannot find pag layer to set attention processor for: {pag_layer_input}")
for module in target_modules:
module.processor = pag_attn_proc
def _get_pag_scale(self, t):
r"""
Get the scale factor for the perturbed attention guidance at timestep `t`.
"""
if self.do_pag_adaptive_scaling:
signal_scale = self.pag_scale - self.pag_adaptive_scale * (1000 - t)
if signal_scale < 0:
signal_scale = 0
return signal_scale
else:
return self.pag_scale
def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_guidance, guidance_scale, t):
r"""
Apply perturbed attention guidance to the noise prediction.
Args:
noise_pred (torch.Tensor): The noise prediction tensor.
do_classifier_free_guidance (bool): Whether to apply classifier-free guidance.
guidance_scale (float): The scale factor for the guidance term.
t (int): The current time step.
Returns:
torch.Tensor: The updated noise prediction tensor after applying perturbed attention guidance.
"""
pag_scale = self._get_pag_scale(t)
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3)
noise_pred = (
noise_pred_uncond
+ guidance_scale * (noise_pred_text - noise_pred_uncond)
+ pag_scale * (noise_pred_text - noise_pred_perturb)
)
else:
noise_pred_text, noise_pred_perturb = noise_pred.chunk(2)
noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb)
return noise_pred
def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free_guidance):
"""
Prepares the perturbed attention guidance for the PAG model.
Args:
cond (torch.Tensor): The conditional input tensor.
uncond (torch.Tensor): The unconditional input tensor.
do_classifier_free_guidance (bool): Flag indicating whether to perform classifier-free guidance.
Returns:
torch.Tensor: The prepared perturbed attention guidance tensor.
"""
cond = torch.cat([cond] * 2, dim=0)
if do_classifier_free_guidance:
cond = torch.cat([uncond, cond], dim=0)
return cond
def set_pag_applied_layers(self, pag_applied_layers):
r"""
set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid.
"""
if not isinstance(pag_applied_layers, list):
pag_applied_layers = [pag_applied_layers]
for pag_layer in pag_applied_layers:
self._check_input_pag_applied_layer(pag_layer)
self.pag_applied_layers = pag_applied_layers
@property
def pag_scale(self):
"""
Get the scale factor for the perturbed attention guidance.
"""
return self._pag_scale
@property
def pag_adaptive_scale(self):
"""
Get the adaptive scale factor for the perturbed attention guidance.
"""
return self._pag_adaptive_scale
@property
def do_pag_adaptive_scaling(self):
"""
Check if the adaptive scaling is enabled for the perturbed attention guidance.
"""
return self._pag_adaptive_scale > 0 and self._pag_scale > 0 and len(self.pag_applied_layers) > 0
@property
def do_perturbed_attention_guidance(self):
"""
Check if the perturbed attention guidance is enabled.
"""
return self._pag_scale > 0 and len(self.pag_applied_layers) > 0
@property
def pag_attn_processors(self):
r"""
Returns:
`dict` of PAG attention processors: A dictionary contains all PAG attention processors used in the model
with the key as the name of the layer.
"""
processors = {}
for name, proc in self.unet.attn_processors.items():
if proc.__class__ in (PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0):
processors[name] = proc
return processors
class PixArtPAGMixin:
@staticmethod
def _check_input_pag_applied_layer(layer):
r"""
Check if each layer input in `applied_pag_layers` is valid. It should be the block index: {block_index}.
"""
# Check if the layer index is valid (should be int or str of int)
if isinstance(layer, int):
return # Valid layer index
if isinstance(layer, str):
if layer.isdigit():
return # Valid layer index
# If it is not a valid layer index, raise a ValueError
raise ValueError(f"Pag layer should only contain block index. Accept number string like '3', got {layer}.")
def _set_pag_attn_processor(self, pag_applied_layers, do_classifier_free_guidance):
r"""
Set the attention processor for the PAG layers.
"""
if do_classifier_free_guidance:
pag_attn_proc = PAGCFGIdentitySelfAttnProcessor2_0()
else:
pag_attn_proc = PAGIdentitySelfAttnProcessor2_0()
def is_self_attn(module_name):
r"""
Check if the module is self-attention module based on its name.
"""
return (
"attn1" in module_name and len(module_name.split(".")) == 3
) # include transformer_blocks.1.attn1, exclude transformer_blocks.18.attn1.to_q, transformer_blocks.1.attn1.add_q_proj, ...
def get_block_index(module_name):
r"""
Get the block index from the module name. can be "block_0", "block_1", ... If there is only one block (e.g.
mid_block) and index is ommited from the name, it will be "block_0".
"""
# transformer_blocks.23.attn -> "23"
return module_name.split(".")[1]
for pag_layer_input in pag_applied_layers:
# for each PAG layer input, we find corresponding self-attention layers in the transformer model
target_modules = []
block_index = str(pag_layer_input)
for name, module in self.transformer.named_modules():
if is_self_attn(name) and get_block_index(name) == block_index:
for name, module in model.named_modules():
# Identify the following simple cases:
# (1) Self Attention layer existing
# (2) Whether the module name matches pag layer id even partially
# (3) Make sure it's not a fake integral match if the layer_id ends with a number
# For example, blocks.1, blocks.10 should be differentiable if layer_id="blocks.1"
if (
is_self_attn(module)
and re.search(layer_id, name) is not None
and not is_fake_integral_match(layer_id, name)
):
logger.debug(f"Applying PAG to layer: {name}")
target_modules.append(module)
if len(target_modules) == 0:
raise ValueError(f"Cannot find pag layer to set attention processor for: {pag_layer_input}")
raise ValueError(f"Cannot find PAG layer to set attention processor for: {layer_id}")
for module in target_modules:
module.processor = pag_attn_proc
# Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.set_pag_applied_layers
def set_pag_applied_layers(self, pag_applied_layers):
r"""
set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid.
"""
if not isinstance(pag_applied_layers, list):
pag_applied_layers = [pag_applied_layers]
for pag_layer in pag_applied_layers:
self._check_input_pag_applied_layer(pag_layer)
self.pag_applied_layers = pag_applied_layers
# Copied from diffusers.pipelines.pag.pag_utils.PAGMixin._get_pag_scale
def _get_pag_scale(self, t):
r"""
Get the scale factor for the perturbed attention guidance at timestep `t`.
@@ -364,7 +98,6 @@ class PixArtPAGMixin:
else:
return self.pag_scale
# Copied from diffusers.pipelines.pag.pag_utils.PAGMixin._apply_perturbed_attention_guidance
def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_guidance, guidance_scale, t):
r"""
Apply perturbed attention guidance to the noise prediction.
@@ -391,7 +124,6 @@ class PixArtPAGMixin:
noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb)
return noise_pred
# Copied from diffusers.pipelines.pag.pag_utils.PAGMixin._prepare_perturbed_attention_guidance
def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free_guidance):
"""
Prepares the perturbed attention guidance for the PAG model.
@@ -411,49 +143,95 @@ class PixArtPAGMixin:
cond = torch.cat([uncond, cond], dim=0)
return cond
def set_pag_applied_layers(
self,
pag_applied_layers: Union[str, List[str]],
pag_attn_processors: Tuple[AttentionProcessor, AttentionProcessor] = (
PAGCFGIdentitySelfAttnProcessor2_0(),
PAGIdentitySelfAttnProcessor2_0(),
),
):
r"""
Set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid.
Args:
pag_applied_layers (`str` or `List[str]`):
One or more strings identifying the layer names, or a simple regex for matching multiple layers, where
PAG is to be applied. A few ways of expected usage are as follows:
- Single layers specified as - "blocks.{layer_index}"
- Multiple layers as a list - ["blocks.{layers_index_1}", "blocks.{layer_index_2}", ...]
- Multiple layers as a block name - "mid"
- Multiple layers as regex - "blocks.({layer_index_1}|{layer_index_2})"
pag_attn_processors:
(`Tuple[AttentionProcessor, AttentionProcessor]`, defaults to `(PAGCFGIdentitySelfAttnProcessor2_0(),
PAGIdentitySelfAttnProcessor2_0())`): A tuple of two attention processors. The first attention
processor is for PAG with Classifier-free guidance enabled (conditional and unconditional). The second
attention processor is for PAG with CFG disabled (unconditional only).
"""
if not hasattr(self, "_pag_attn_processors"):
self._pag_attn_processors = None
if not isinstance(pag_applied_layers, list):
pag_applied_layers = [pag_applied_layers]
if pag_attn_processors is not None:
if not isinstance(pag_attn_processors, tuple) or len(pag_attn_processors) != 2:
raise ValueError("Expected a tuple of two attention processors")
for i in range(len(pag_applied_layers)):
if not isinstance(pag_applied_layers[i], str):
raise ValueError(
f"Expected either a string or a list of string but got type {type(pag_applied_layers[i])}"
)
self.pag_applied_layers = pag_applied_layers
self._pag_attn_processors = pag_attn_processors
@property
# Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.pag_scale
def pag_scale(self):
"""
Get the scale factor for the perturbed attention guidance.
"""
def pag_scale(self) -> float:
r"""Get the scale factor for the perturbed attention guidance."""
return self._pag_scale
@property
# Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.pag_adaptive_scale
def pag_adaptive_scale(self):
"""
Get the adaptive scale factor for the perturbed attention guidance.
"""
def pag_adaptive_scale(self) -> float:
r"""Get the adaptive scale factor for the perturbed attention guidance."""
return self._pag_adaptive_scale
@property
# Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.do_pag_adaptive_scaling
def do_pag_adaptive_scaling(self):
"""
Check if the adaptive scaling is enabled for the perturbed attention guidance.
"""
def do_pag_adaptive_scaling(self) -> bool:
r"""Check if the adaptive scaling is enabled for the perturbed attention guidance."""
return self._pag_adaptive_scale > 0 and self._pag_scale > 0 and len(self.pag_applied_layers) > 0
@property
# Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.do_perturbed_attention_guidance
def do_perturbed_attention_guidance(self):
"""
Check if the perturbed attention guidance is enabled.
"""
def do_perturbed_attention_guidance(self) -> bool:
r"""Check if the perturbed attention guidance is enabled."""
return self._pag_scale > 0 and len(self.pag_applied_layers) > 0
@property
# Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.pag_attn_processors with unet->transformer
def pag_attn_processors(self):
def pag_attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of PAG attention processors: A dictionary contains all PAG attention processors used in the model
with the key as the name of the layer.
"""
if self._pag_attn_processors is None:
return {}
valid_attn_processors = {x.__class__ for x in self._pag_attn_processors}
processors = {}
for name, proc in self.transformer.attn_processors.items():
if proc.__class__ in (PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0):
# We could have iterated through the self.components.items() and checked if a component is
# `ModelMixin` subclassed but that can include a VAE too.
if hasattr(self, "unet"):
denoiser_module = self.unet
elif hasattr(self, "transformer"):
denoiser_module = self.transformer
else:
raise ValueError("No denoiser module found.")
for name, proc in denoiser_module.attn_processors.items():
if proc.__class__ in valid_attn_processors:
processors[name] = proc
return processors

View File

@@ -0,0 +1,953 @@
# Copyright 2024 HunyuanDiT Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from transformers import BertModel, BertTokenizer, CLIPImageProcessor, MT5Tokenizer, T5EncoderModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, HunyuanDiT2DModel
from ...models.attention_processor import PAGCFGHunyuanAttnProcessor2_0, PAGHunyuanAttnProcessor2_0
from ...models.embeddings import get_2d_rotary_pos_embed
from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from ...schedulers import DDPMScheduler
from ...utils import (
is_torch_xla_available,
logging,
replace_example_docstring,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pag_utils import PAGMixin
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```python
>>> import torch
>>> from diffusers import AutoPipelineForText2Image
>>> pipe = AutoPipelineForText2Image.from_pretrained(
... "Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers",
... torch_dtype=torch.float16,
... enable_pag=True,
... pag_applied_layers=[14],
... ).to("cuda")
>>> # prompt = "an astronaut riding a horse"
>>> prompt = "一个宇航员在骑马"
>>> image = pipe(prompt, guidance_scale=4, pag_scale=3).images[0]
```
"""
STANDARD_RATIO = np.array(
[
1.0, # 1:1
4.0 / 3.0, # 4:3
3.0 / 4.0, # 3:4
16.0 / 9.0, # 16:9
9.0 / 16.0, # 9:16
]
)
STANDARD_SHAPE = [
[(1024, 1024), (1280, 1280)], # 1:1
[(1024, 768), (1152, 864), (1280, 960)], # 4:3
[(768, 1024), (864, 1152), (960, 1280)], # 3:4
[(1280, 768)], # 16:9
[(768, 1280)], # 9:16
]
STANDARD_AREA = [np.array([w * h for w, h in shapes]) for shapes in STANDARD_SHAPE]
SUPPORTED_SHAPE = [
(1024, 1024),
(1280, 1280), # 1:1
(1024, 768),
(1152, 864),
(1280, 960), # 4:3
(768, 1024),
(864, 1152),
(960, 1280), # 3:4
(1280, 768), # 16:9
(768, 1280), # 9:16
]
def map_to_standard_shapes(target_width, target_height):
target_ratio = target_width / target_height
closest_ratio_idx = np.argmin(np.abs(STANDARD_RATIO - target_ratio))
closest_area_idx = np.argmin(np.abs(STANDARD_AREA[closest_ratio_idx] - target_width * target_height))
width, height = STANDARD_SHAPE[closest_ratio_idx][closest_area_idx]
return width, height
def get_resize_crop_region_for_grid(src, tgt_size):
th = tw = tgt_size
h, w = src
r = h / w
# resize
if r > 1:
resize_height = th
resize_width = int(round(th / h * w))
else:
resize_width = tw
resize_height = int(round(tw / w * h))
crop_top = int(round((th - resize_height) / 2.0))
crop_left = int(round((tw - resize_width) / 2.0))
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
return noise_cfg
class HunyuanDiTPAGPipeline(DiffusionPipeline, PAGMixin):
r"""
Pipeline for English/Chinese-to-image generation using HunyuanDiT and [Perturbed Attention
Guidance](https://huggingface.co/docs/diffusers/en/using-diffusers/pag).
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
HunyuanDiT uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by
ourselves)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. We use
`sdxl-vae-fp16-fix`.
text_encoder (Optional[`~transformers.BertModel`, `~transformers.CLIPTextModel`]):
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
HunyuanDiT uses a fine-tuned [bilingual CLIP].
tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]):
A `BertTokenizer` or `CLIPTokenizer` to tokenize text.
transformer ([`HunyuanDiT2DModel`]):
The HunyuanDiT model designed by Tencent Hunyuan.
text_encoder_2 (`T5EncoderModel`):
The mT5 embedder. Specifically, it is 't5-v1_1-xxl'.
tokenizer_2 (`MT5Tokenizer`):
The tokenizer for the mT5 embedder.
scheduler ([`DDPMScheduler`]):
A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents.
"""
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
_optional_components = [
"safety_checker",
"feature_extractor",
"text_encoder_2",
"tokenizer_2",
"text_encoder",
"tokenizer",
]
_exclude_from_cpu_offload = ["safety_checker"]
_callback_tensor_inputs = [
"latents",
"prompt_embeds",
"negative_prompt_embeds",
"prompt_embeds_2",
"negative_prompt_embeds_2",
]
def __init__(
self,
vae: AutoencoderKL,
text_encoder: BertModel,
tokenizer: BertTokenizer,
transformer: HunyuanDiT2DModel,
scheduler: DDPMScheduler,
safety_checker: Optional[StableDiffusionSafetyChecker] = None,
feature_extractor: Optional[CLIPImageProcessor] = None,
requires_safety_checker: bool = True,
text_encoder_2: Optional[T5EncoderModel] = None,
tokenizer_2: Optional[MT5Tokenizer] = None,
pag_applied_layers: Union[str, List[str]] = "blocks.1", # "blocks.16.attn1", "blocks.16", "16", 16
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
transformer=transformer,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
text_encoder_2=text_encoder_2,
)
if safety_checker is None and requires_safety_checker:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
self.default_sample_size = (
self.transformer.config.sample_size
if hasattr(self, "transformer") and self.transformer is not None
else 128
)
self.set_pag_applied_layers(
pag_applied_layers, pag_attn_processors=(PAGCFGHunyuanAttnProcessor2_0(), PAGHunyuanAttnProcessor2_0())
)
# Copied from diffusers.pipelines.hunyuandit.pipeline_hunyuandit.HunyuanDiTPipeline.encode_prompt
def encode_prompt(
self,
prompt: str,
device: torch.device = None,
dtype: torch.dtype = None,
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
negative_prompt: Optional[str] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
max_sequence_length: Optional[int] = None,
text_encoder_index: int = 0,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
dtype (`torch.dtype`):
torch dtype
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
prompt_attention_mask (`torch.Tensor`, *optional*):
Attention mask for the prompt. Required when `prompt_embeds` is passed directly.
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly.
max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt.
text_encoder_index (`int`, *optional*):
Index of the text encoder to use. `0` for clip and `1` for T5.
"""
if dtype is None:
if self.text_encoder_2 is not None:
dtype = self.text_encoder_2.dtype
elif self.transformer is not None:
dtype = self.transformer.dtype
else:
dtype = None
if device is None:
device = self._execution_device
tokenizers = [self.tokenizer, self.tokenizer_2]
text_encoders = [self.text_encoder, self.text_encoder_2]
tokenizer = tokenizers[text_encoder_index]
text_encoder = text_encoders[text_encoder_index]
if max_sequence_length is None:
if text_encoder_index == 0:
max_length = 77
if text_encoder_index == 1:
max_length = 256
else:
max_length = max_sequence_length
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_attention_mask=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {tokenizer.model_max_length} tokens: {removed_text}"
)
prompt_attention_mask = text_inputs.attention_mask.to(device)
prompt_embeds = text_encoder(
text_input_ids.to(device),
attention_mask=prompt_attention_mask,
)
prompt_embeds = prompt_embeds[0]
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
max_length = prompt_embeds.shape[1]
uncond_input = tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
negative_prompt_attention_mask = uncond_input.attention_mask.to(device)
negative_prompt_embeds = text_encoder(
uncond_input.input_ids.to(device),
attention_mask=negative_prompt_attention_mask,
)
negative_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
return image, has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
# Copied from diffusers.pipelines.hunyuandit.pipeline_hunyuandit.HunyuanDiTPipeline.check_inputs
def check_inputs(
self,
prompt,
height,
width,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
prompt_attention_mask=None,
negative_prompt_attention_mask=None,
prompt_embeds_2=None,
negative_prompt_embeds_2=None,
prompt_attention_mask_2=None,
negative_prompt_attention_mask_2=None,
callback_on_step_end_tensor_inputs=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is None and prompt_embeds_2 is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt_embeds is not None and prompt_attention_mask is None:
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
if prompt_embeds_2 is not None and prompt_attention_mask_2 is None:
raise ValueError("Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None:
raise ValueError(
"Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None:
if prompt_embeds_2.shape != negative_prompt_embeds_2.shape:
raise ValueError(
"`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but"
f" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`"
f" {negative_prompt_embeds_2.shape}."
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
@property
def guidance_scale(self):
return self._guidance_scale
@property
def guidance_rescale(self):
return self._guidance_rescale
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_embeds_2: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds_2: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
prompt_attention_mask_2: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask_2: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
guidance_rescale: float = 0.0,
original_size: Optional[Tuple[int, int]] = (1024, 1024),
target_size: Optional[Tuple[int, int]] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
use_resolution_binning: bool = True,
pag_scale: float = 3.0,
pag_adaptive_scale: float = 0.0,
):
r"""
The call function to the pipeline for generation with HunyuanDiT.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
height (`int`):
The height in pixels of the generated image.
width (`int`):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. This parameter is modulated by `strength`.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
prompt_embeds_2 (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
negative_prompt_embeds_2 (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
prompt_attention_mask (`torch.Tensor`, *optional*):
Attention mask for the prompt. Required when `prompt_embeds` is passed directly.
prompt_attention_mask_2 (`torch.Tensor`, *optional*):
Attention mask for the prompt. Required when `prompt_embeds_2` is passed directly.
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly.
negative_prompt_attention_mask_2 (`torch.Tensor`, *optional*):
Attention mask for the negative prompt. Required when `negative_prompt_embeds_2` is passed directly.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback_on_step_end (`Callable[[int, int, Dict], None]`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
A callback function or a list of callback functions to be called at the end of each denoising step.
callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
A list of tensor inputs that should be passed to the callback function. If not defined, all tensor
inputs will be passed.
guidance_rescale (`float`, *optional*, defaults to 0.0):
Rescale the noise_cfg according to `guidance_rescale`. Based on findings of [Common Diffusion Noise
Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`):
The original size of the image. Used to calculate the time ids.
target_size (`Tuple[int, int]`, *optional*):
The target size of the image. Used to calculate the time ids.
crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to `(0, 0)`):
The top left coordinates of the crop. Used to calculate the time ids.
use_resolution_binning (`bool`, *optional*, defaults to `True`):
Whether to use resolution binning or not. If `True`, the input resolution will be mapped to the closest
standard resolution. Supported resolutions are 1024x1024, 1280x1280, 1024x768, 1152x864, 1280x960,
768x1024, 864x1152, 960x1280, 1280x768, and 768x1280. It is recommended to set this to `True`.
pag_scale (`float`, *optional*, defaults to 3.0):
The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention
guidance will not be used.
pag_adaptive_scale (`float`, *optional*, defaults to 0.0):
The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is
used.
Examples:
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
otherwise a `tuple` is returned where the first element is a list with the generated images and the
second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content.
"""
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# 0. Default height and width
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
height = int((height // 16) * 16)
width = int((width // 16) * 16)
if use_resolution_binning and (height, width) not in SUPPORTED_SHAPE:
width, height = map_to_standard_shapes(width, height)
height = int(height)
width = int(width)
logger.warning(f"Reshaped to (height, width)=({height}, {width}), Supported shapes are {SUPPORTED_SHAPE}")
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
width,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
prompt_embeds_2,
negative_prompt_embeds_2,
prompt_attention_mask_2,
negative_prompt_attention_mask_2,
callback_on_step_end_tensor_inputs,
)
self._guidance_scale = guidance_scale
self._guidance_rescale = guidance_rescale
self._interrupt = False
self._pag_scale = pag_scale
self._pag_adaptive_scale = pag_adaptive_scale
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# 3. Encode input prompt
(
prompt_embeds,
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
) = self.encode_prompt(
prompt=prompt,
device=device,
dtype=self.transformer.dtype,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
max_sequence_length=77,
text_encoder_index=0,
)
(
prompt_embeds_2,
negative_prompt_embeds_2,
prompt_attention_mask_2,
negative_prompt_attention_mask_2,
) = self.encode_prompt(
prompt=prompt,
device=device,
dtype=self.transformer.dtype,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds_2,
negative_prompt_embeds=negative_prompt_embeds_2,
prompt_attention_mask=prompt_attention_mask_2,
negative_prompt_attention_mask=negative_prompt_attention_mask_2,
max_sequence_length=256,
text_encoder_index=1,
)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Create image_rotary_emb, style embedding & time ids
grid_height = height // 8 // self.transformer.config.patch_size
grid_width = width // 8 // self.transformer.config.patch_size
base_size = 512 // 8 // self.transformer.config.patch_size
grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size)
image_rotary_emb = get_2d_rotary_pos_embed(
self.transformer.inner_dim // self.transformer.num_heads, grid_crops_coords, (grid_height, grid_width)
)
style = torch.tensor([0], device=device)
target_size = target_size or (height, width)
add_time_ids = list(original_size + target_size + crops_coords_top_left)
add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if self.do_perturbed_attention_guidance:
prompt_embeds = self._prepare_perturbed_attention_guidance(
prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance
)
prompt_attention_mask = self._prepare_perturbed_attention_guidance(
prompt_attention_mask, negative_prompt_attention_mask, self.do_classifier_free_guidance
)
prompt_embeds_2 = self._prepare_perturbed_attention_guidance(
prompt_embeds_2, negative_prompt_embeds_2, self.do_classifier_free_guidance
)
prompt_attention_mask_2 = self._prepare_perturbed_attention_guidance(
prompt_attention_mask_2, negative_prompt_attention_mask_2, self.do_classifier_free_guidance
)
add_time_ids = torch.cat([add_time_ids] * 3, dim=0)
style = torch.cat([style] * 3, dim=0)
elif self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])
prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2])
add_time_ids = torch.cat([add_time_ids] * 2, dim=0)
style = torch.cat([style] * 2, dim=0)
prompt_embeds = prompt_embeds.to(device=device)
prompt_attention_mask = prompt_attention_mask.to(device=device)
prompt_embeds_2 = prompt_embeds_2.to(device=device)
prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device)
add_time_ids = add_time_ids.to(dtype=prompt_embeds.dtype, device=device).repeat(
batch_size * num_images_per_prompt, 1
)
style = style.to(device=device).repeat(batch_size * num_images_per_prompt)
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
if self.do_perturbed_attention_guidance:
original_attn_proc = self.transformer.attn_processors
self._set_pag_attn_processor(
pag_applied_layers=self.pag_applied_layers,
do_classifier_free_guidance=self.do_classifier_free_guidance,
)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0]))
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to(
dtype=latent_model_input.dtype
)
# predict the noise residual
noise_pred = self.transformer(
latent_model_input,
t_expand,
encoder_hidden_states=prompt_embeds,
text_embedding_mask=prompt_attention_mask,
encoder_hidden_states_t5=prompt_embeds_2,
text_embedding_mask_t5=prompt_attention_mask_2,
image_meta_size=add_time_ids,
style=style,
image_rotary_emb=image_rotary_emb,
return_dict=False,
)[0]
noise_pred, _ = noise_pred.chunk(2, dim=1)
# perform guidance
if self.do_perturbed_attention_guidance:
noise_pred = self._apply_perturbed_attention_guidance(
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t
)
elif self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if self.do_classifier_free_guidance and guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2)
negative_prompt_embeds_2 = callback_outputs.pop(
"negative_prompt_embeds_2", negative_prompt_embeds_2
)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
has_nsfw_concept = None
if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
# 9. Offload all models
self.maybe_free_model_hooks()
if self.do_perturbed_attention_guidance:
self.transformer.set_attn_processor(original_attn_proc)
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

View File

@@ -40,7 +40,7 @@ from ..pixart_alpha.pipeline_pixart_alpha import (
ASPECT_RATIO_1024_BIN,
)
from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN
from .pag_utils import PixArtPAGMixin
from .pag_utils import PAGMixin
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -61,7 +61,7 @@ EXAMPLE_DOC_STRING = """
>>> pipe = AutoPipelineForText2Image.from_pretrained(
... "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",
... torch_dtype=torch.float16,
... pag_applied_layers=[14],
... pag_applied_layers=["blocks.14"],
... enable_pag=True,
... )
>>> pipe = pipe.to("cuda")
@@ -132,7 +132,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
class PixArtSigmaPAGPipeline(DiffusionPipeline, PixArtPAGMixin):
class PixArtSigmaPAGPipeline(DiffusionPipeline, PAGMixin):
r"""
[PAG pipeline](https://huggingface.co/docs/diffusers/main/en/using-diffusers/pag) for text-to-image generation
using PixArt-Sigma.
@@ -164,7 +164,7 @@ class PixArtSigmaPAGPipeline(DiffusionPipeline, PixArtPAGMixin):
vae: AutoencoderKL,
transformer: PixArtTransformer2DModel,
scheduler: KarrasDiffusionSchedulers,
pag_applied_layers: Union[str, List[str]] = "1", # 1st transformer block
pag_applied_layers: Union[str, List[str]] = "blocks.1", # 1st transformer block
):
super().__init__()

View File

@@ -129,7 +129,7 @@ class AnimateDiffPAGPipeline(
scheduler: KarrasDiffusionSchedulers,
feature_extractor: CLIPImageProcessor = None,
image_encoder: CLIPVisionModelWithProjection = None,
pag_applied_layers: Union[str, List[str]] = "mid", # ["mid"], ["down.block_1"], ["up.block_0.attentions_0"]
pag_applied_layers: Union[str, List[str]] = "mid_block.*attn1", # ["mid"], ["down_blocks.1"]
):
super().__init__()
if isinstance(unet, UNet2DConditionModel):

View File

@@ -332,6 +332,21 @@ class HunyuanDiTControlNetPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class HunyuanDiTPAGPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class HunyuanDiTPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -429,7 +429,10 @@ class AnimateDiffPAGPipelineFastTests(
pipe.set_progress_bar_config(disable=None)
# pag_applied_layers = ["mid","up","down"] should apply to all self-attention layers
all_self_attn_layers = [k for k in pipe.unet.attn_processors.keys() if "attn1" in k]
# Note that for motion modules in AnimateDiff, both attn1 and attn2 are self-attention
all_self_attn_layers = [
k for k in pipe.unet.attn_processors.keys() if "attn1" in k or ("motion_modules" in k and "attn2" in k)
]
original_attn_procs = pipe.unet.attn_processors
pag_layers = [
"down",
@@ -439,12 +442,13 @@ class AnimateDiffPAGPipelineFastTests(
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert set(pipe.pag_attn_processors) == set(all_self_attn_layers)
# pag_applied_layers = ["mid"], or ["mid.block_0"] or ["mid.block_0.motion_modules_0"] should apply to all self-attention layers in mid_block, i.e.
# pag_applied_layers = ["mid"], or ["mid_block.0"] should apply to all self-attention layers in mid_block, i.e.
# mid_block.motion_modules.0.transformer_blocks.0.attn1.processor
# mid_block.attentions.0.transformer_blocks.0.attn1.processor
all_self_attn_mid_layers = [
"mid_block.motion_modules.0.transformer_blocks.0.attn1.processor",
"mid_block.attentions.0.transformer_blocks.0.attn1.processor",
"mid_block.motion_modules.0.transformer_blocks.0.attn1.processor",
"mid_block.motion_modules.0.transformer_blocks.0.attn2.processor",
]
pipe.unet.set_attn_processor(original_attn_procs.copy())
pag_layers = ["mid"]
@@ -452,17 +456,17 @@ class AnimateDiffPAGPipelineFastTests(
assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers)
pipe.unet.set_attn_processor(original_attn_procs.copy())
pag_layers = ["mid.block_0"]
pag_layers = ["mid_block"]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers)
pipe.unet.set_attn_processor(original_attn_procs.copy())
pag_layers = ["mid.block_0.attentions_0", "mid.block_0.motion_modules_0"]
pag_layers = ["mid_block.(attentions|motion_modules)"]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers)
pipe.unet.set_attn_processor(original_attn_procs.copy())
pag_layers = ["mid.block_0.attentions_1"]
pag_layers = ["mid_block.attentions.1"]
with self.assertRaises(ValueError):
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
@@ -474,19 +478,19 @@ class AnimateDiffPAGPipelineFastTests(
pipe.unet.set_attn_processor(original_attn_procs.copy())
pag_layers = ["down"]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert len(pipe.pag_attn_processors) == 6
assert len(pipe.pag_attn_processors) == 10
pipe.unet.set_attn_processor(original_attn_procs.copy())
pag_layers = ["down.block_0"]
pag_layers = ["down_blocks.0"]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert (len(pipe.pag_attn_processors)) == 4
assert (len(pipe.pag_attn_processors)) == 6
pipe.unet.set_attn_processor(original_attn_procs.copy())
pag_layers = ["down.block_1"]
pag_layers = ["blocks.1"]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert len(pipe.pag_attn_processors) == 2
assert len(pipe.pag_attn_processors) == 10
pipe.unet.set_attn_processor(original_attn_procs.copy())
pag_layers = ["down.block_1.motion_modules_2"]
pag_layers = ["motion_modules.42"]
with self.assertRaises(ValueError):
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)

View File

@@ -0,0 +1,358 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import tempfile
import unittest
import numpy as np
import torch
from transformers import AutoTokenizer, BertModel, T5EncoderModel
from diffusers import (
AutoencoderKL,
DDPMScheduler,
HunyuanDiT2DModel,
HunyuanDiTPAGPipeline,
HunyuanDiTPipeline,
)
from diffusers.utils.testing_utils import (
enable_full_determinism,
torch_device,
)
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
enable_full_determinism()
class HunyuanDiTPAGPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = HunyuanDiTPAGPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = PipelineTesterMixin.required_optional_params
def get_dummy_components(self):
torch.manual_seed(0)
transformer = HunyuanDiT2DModel(
sample_size=16,
num_layers=2,
patch_size=2,
attention_head_dim=8,
num_attention_heads=3,
in_channels=4,
cross_attention_dim=32,
cross_attention_dim_t5=32,
pooled_projection_dim=16,
hidden_size=24,
activation_fn="gelu-approximate",
)
torch.manual_seed(0)
vae = AutoencoderKL()
scheduler = DDPMScheduler()
text_encoder = BertModel.from_pretrained("hf-internal-testing/tiny-random-BertModel")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BertModel")
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
components = {
"transformer": transformer.eval(),
"vae": vae.eval(),
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2,
"safety_checker": None,
"feature_extractor": None,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 5.0,
"output_type": "np",
"use_resolution_binning": False,
"pag_scale": 0.0,
}
return inputs
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]
self.assertEqual(image.shape, (1, 16, 16, 3))
expected_slice = np.array(
[0.56939435, 0.34541583, 0.35915792, 0.46489206, 0.38775963, 0.45004836, 0.5957267, 0.59481275, 0.33287364]
)
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3)
def test_sequential_cpu_offload_forward_pass(self):
# TODO(YiYi) need to fix later
pass
def test_sequential_offload_forward_pass_twice(self):
# TODO(YiYi) need to fix later
pass
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(
expected_max_diff=1e-3,
)
def test_save_load_optional_components(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
prompt = inputs["prompt"]
generator = inputs["generator"]
num_inference_steps = inputs["num_inference_steps"]
output_type = inputs["output_type"]
(
prompt_embeds,
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
) = pipe.encode_prompt(prompt, device=torch_device, dtype=torch.float32, text_encoder_index=0)
(
prompt_embeds_2,
negative_prompt_embeds_2,
prompt_attention_mask_2,
negative_prompt_attention_mask_2,
) = pipe.encode_prompt(
prompt,
device=torch_device,
dtype=torch.float32,
text_encoder_index=1,
)
# inputs with prompt converted to embeddings
inputs = {
"prompt_embeds": prompt_embeds,
"prompt_attention_mask": prompt_attention_mask,
"negative_prompt_embeds": negative_prompt_embeds,
"negative_prompt_attention_mask": negative_prompt_attention_mask,
"prompt_embeds_2": prompt_embeds_2,
"prompt_attention_mask_2": prompt_attention_mask_2,
"negative_prompt_embeds_2": negative_prompt_embeds_2,
"negative_prompt_attention_mask_2": negative_prompt_attention_mask_2,
"generator": generator,
"num_inference_steps": num_inference_steps,
"output_type": output_type,
"use_resolution_binning": False,
}
# set all optional components to None
for optional_component in pipe._optional_components:
setattr(pipe, optional_component, None)
output = pipe(**inputs)[0]
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir)
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
pipe_loaded.to(torch_device)
pipe_loaded.set_progress_bar_config(disable=None)
for optional_component in pipe._optional_components:
self.assertTrue(
getattr(pipe_loaded, optional_component) is None,
f"`{optional_component}` did not stay set to None after loading.",
)
inputs = self.get_dummy_inputs(torch_device)
generator = inputs["generator"]
num_inference_steps = inputs["num_inference_steps"]
output_type = inputs["output_type"]
# inputs with prompt converted to embeddings
inputs = {
"prompt_embeds": prompt_embeds,
"prompt_attention_mask": prompt_attention_mask,
"negative_prompt_embeds": negative_prompt_embeds,
"negative_prompt_attention_mask": negative_prompt_attention_mask,
"prompt_embeds_2": prompt_embeds_2,
"prompt_attention_mask_2": prompt_attention_mask_2,
"negative_prompt_embeds_2": negative_prompt_embeds_2,
"negative_prompt_attention_mask_2": negative_prompt_attention_mask_2,
"generator": generator,
"num_inference_steps": num_inference_steps,
"output_type": output_type,
"use_resolution_binning": False,
}
output_loaded = pipe_loaded(**inputs)[0]
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
self.assertLess(max_diff, 1e-4)
def test_feed_forward_chunking(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
image_slice_no_chunking = image[0, -3:, -3:, -1]
pipe.transformer.enable_forward_chunking(chunk_size=1, dim=0)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
image_slice_chunking = image[0, -3:, -3:, -1]
max_diff = np.abs(to_np(image_slice_no_chunking) - to_np(image_slice_chunking)).max()
self.assertLess(max_diff, 1e-4)
def test_fused_qkv_projections(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["return_dict"] = False
image = pipe(**inputs)[0]
original_image_slice = image[0, -3:, -3:, -1]
pipe.transformer.fuse_qkv_projections()
inputs = self.get_dummy_inputs(device)
inputs["return_dict"] = False
image_fused = pipe(**inputs)[0]
image_slice_fused = image_fused[0, -3:, -3:, -1]
pipe.transformer.unfuse_qkv_projections()
inputs = self.get_dummy_inputs(device)
inputs["return_dict"] = False
image_disabled = pipe(**inputs)[0]
image_slice_disabled = image_disabled[0, -3:, -3:, -1]
assert np.allclose(
original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2
), "Fusion of QKV projections shouldn't affect the outputs."
assert np.allclose(
image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
assert np.allclose(
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Original outputs should match when fused QKV projections are disabled."
def test_pag_disable_enable(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
# base pipeline (expect same output when pag is disabled)
pipe_sd = HunyuanDiTPipeline(**components)
pipe_sd = pipe_sd.to(device)
pipe_sd.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
assert (
"pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
components = self.get_dummy_components()
# pag disabled with pag_scale=0.0
pipe_pag = self.pipeline_class(**components)
pipe_pag = pipe_pag.to(device)
pipe_pag.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["pag_scale"] = 0.0
out_pag_disabled = pipe_pag(**inputs).images[0, -3:, -3:, -1]
# pag enabled
pipe_pag = self.pipeline_class(**components)
pipe_pag = pipe_pag.to(device)
pipe_pag.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["pag_scale"] = 3.0
out_pag_enabled = pipe_pag(**inputs).images[0, -3:, -3:, -1]
assert np.abs(out.flatten() - out_pag_disabled.flatten()).max() < 1e-3
assert np.abs(out.flatten() - out_pag_enabled.flatten()).max() > 1e-3
def test_pag_applied_layers(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
# base pipeline
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
all_self_attn_layers = [k for k in pipe.transformer.attn_processors.keys() if "attn1" in k]
original_attn_procs = pipe.transformer.attn_processors
pag_layers = ["blocks.0", "blocks.1"]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert set(pipe.pag_attn_processors) == set(all_self_attn_layers)
# blocks.0
block_0_self_attn = ["blocks.0.attn1.processor"]
pipe.transformer.set_attn_processor(original_attn_procs.copy())
pag_layers = ["blocks.0"]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert set(pipe.pag_attn_processors) == set(block_0_self_attn)
pipe.transformer.set_attn_processor(original_attn_procs.copy())
pag_layers = ["blocks.0.attn1"]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert set(pipe.pag_attn_processors) == set(block_0_self_attn)
pipe.transformer.set_attn_processor(original_attn_procs.copy())
pag_layers = ["blocks.(0|1)"]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert (len(pipe.pag_attn_processors)) == 2
pipe.transformer.set_attn_processor(original_attn_procs.copy())
pag_layers = ["blocks.0", r"blocks\.1"]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert len(pipe.pag_attn_processors) == 2

View File

@@ -127,7 +127,7 @@ class PixArtSigmaPAGPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
out = pipe(**inputs).images[0, -3:, -3:, -1]
# pag disabled with pag_scale=0.0
components["pag_applied_layers"] = [1]
components["pag_applied_layers"] = ["blocks.1"]
pipe_pag = self.pipeline_class(**components)
pipe_pag = pipe_pag.to(device)
pipe_pag.set_progress_bar_config(disable=None)
@@ -158,7 +158,7 @@ class PixArtSigmaPAGPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
# "attn1" should apply to all self-attention layers.
all_self_attn_layers = [k for k in pipe.transformer.attn_processors.keys() if "attn1" in k]
pag_layers = [0, 1]
pag_layers = ["blocks.0", "blocks.1"]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert set(pipe.pag_attn_processors) == set(all_self_attn_layers)
@@ -228,7 +228,7 @@ class PixArtSigmaPAGPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir)
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, pag_applied_layers=[1])
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, pag_applied_layers=["blocks.1"])
pipe_loaded.to(torch_device)
pipe_loaded.set_progress_bar_config(disable=None)
@@ -282,7 +282,7 @@ class PixArtSigmaPAGPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipe.save_pretrained(tmpdir, safe_serialization=False)
with CaptureLogger(logger) as cap_logger:
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, pag_applied_layers=[1])
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, pag_applied_layers=["blocks.1"])
for name in pipe_loaded.components.keys():
if name not in pipe_loaded._optional_components:

View File

@@ -213,18 +213,18 @@ class StableDiffusionPAGPipelineFastTests(
assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers)
pipe.unet.set_attn_processor(original_attn_procs.copy())
pag_layers = ["mid.block_0"]
pag_layers = ["mid_block"]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers)
pipe.unet.set_attn_processor(original_attn_procs.copy())
pag_layers = ["mid.block_0.attentions_0"]
pag_layers = ["mid_block.attentions.0"]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers)
# pag_applied_layers = ["mid.block_0.attentions_1"] does not exist in the model
pipe.unet.set_attn_processor(original_attn_procs.copy())
pag_layers = ["mid.block_0.attentions_1"]
pag_layers = ["mid_block.attentions.1"]
with self.assertRaises(ValueError):
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
@@ -239,17 +239,17 @@ class StableDiffusionPAGPipelineFastTests(
assert len(pipe.pag_attn_processors) == 2
pipe.unet.set_attn_processor(original_attn_procs.copy())
pag_layers = ["down.block_0"]
pag_layers = ["down_blocks.0"]
with self.assertRaises(ValueError):
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
pipe.unet.set_attn_processor(original_attn_procs.copy())
pag_layers = ["down.block_1"]
pag_layers = ["down_blocks.1"]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert len(pipe.pag_attn_processors) == 2
pipe.unet.set_attn_processor(original_attn_procs.copy())
pag_layers = ["down.block_1.attentions_1"]
pag_layers = ["down_blocks.1.attentions.1"]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert len(pipe.pag_attn_processors) == 1

View File

@@ -225,18 +225,18 @@ class StableDiffusionXLPAGPipelineFastTests(
assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers)
pipe.unet.set_attn_processor(original_attn_procs.copy())
pag_layers = ["mid.block_0"]
pag_layers = ["mid_block"]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers)
pipe.unet.set_attn_processor(original_attn_procs.copy())
pag_layers = ["mid.block_0.attentions_0"]
pag_layers = ["mid_block.attentions.0"]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers)
# pag_applied_layers = ["mid.block_0.attentions_1"] does not exist in the model
pipe.unet.set_attn_processor(original_attn_procs.copy())
pag_layers = ["mid.block_0.attentions_1"]
pag_layers = ["mid_block.attentions.1"]
with self.assertRaises(ValueError):
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
@@ -251,17 +251,17 @@ class StableDiffusionXLPAGPipelineFastTests(
assert len(pipe.pag_attn_processors) == 4
pipe.unet.set_attn_processor(original_attn_procs.copy())
pag_layers = ["down.block_0"]
pag_layers = ["down_blocks.0"]
with self.assertRaises(ValueError):
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
pipe.unet.set_attn_processor(original_attn_procs.copy())
pag_layers = ["down.block_1"]
pag_layers = ["down_blocks.1"]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert len(pipe.pag_attn_processors) == 4
pipe.unet.set_attn_processor(original_attn_procs.copy())
pag_layers = ["down.block_1.attentions_1"]
pag_layers = ["down_blocks.1.attentions.1"]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert len(pipe.pag_attn_processors) == 2