mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Implement CustomDiffusionAttnProcessor2_0. (#4604)
* Implement `CustomDiffusionAttnProcessor2_0` * Doc-strings and type annotations for `CustomDiffusionAttnProcessor2_0`. (#1) * Update attnprocessor.md * Update attention_processor.py * Interops for `CustomDiffusionAttnProcessor2_0`. * Formatted `attention_processor.py`. * Formatted doc-string in `attention_processor.py` * Conditional CustomDiffusion2_0 for training example. * Remove unnecessary reference impl in comments. * Fix `save_attn_procs`.
This commit is contained in:
@@ -17,6 +17,9 @@ An attention processor is a class for applying different types of attention mech
|
||||
## CustomDiffusionAttnProcessor
|
||||
[[autodoc]] models.attention_processor.CustomDiffusionAttnProcessor
|
||||
|
||||
## CustomDiffusionAttnProcessor2_0
|
||||
[[autodoc]] models.attention_processor.CustomDiffusionAttnProcessor2_0
|
||||
|
||||
## AttnAddedKVProcessor
|
||||
[[autodoc]] models.attention_processor.AttnAddedKVProcessor
|
||||
|
||||
@@ -39,4 +42,4 @@ An attention processor is a class for applying different types of attention mech
|
||||
[[autodoc]] models.attention_processor.SlicedAttnProcessor
|
||||
|
||||
## SlicedAttnAddedKVProcessor
|
||||
[[autodoc]] models.attention_processor.SlicedAttnAddedKVProcessor
|
||||
[[autodoc]] models.attention_processor.SlicedAttnAddedKVProcessor
|
||||
|
||||
@@ -51,7 +51,11 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.loaders import AttnProcsLayers
|
||||
from diffusers.models.attention_processor import CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor
|
||||
from diffusers.models.attention_processor import (
|
||||
CustomDiffusionAttnProcessor,
|
||||
CustomDiffusionAttnProcessor2_0,
|
||||
CustomDiffusionXFormersAttnProcessor,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
@@ -870,7 +874,9 @@ def main(args):
|
||||
unet.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
attention_class = CustomDiffusionAttnProcessor
|
||||
attention_class = (
|
||||
CustomDiffusionAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else CustomDiffusionAttnProcessor
|
||||
)
|
||||
if args.enable_xformers_memory_efficient_attention:
|
||||
if is_xformers_available():
|
||||
import xformers
|
||||
|
||||
@@ -559,6 +559,7 @@ class UNet2DConditionLoadersMixin:
|
||||
"""
|
||||
from .models.attention_processor import (
|
||||
CustomDiffusionAttnProcessor,
|
||||
CustomDiffusionAttnProcessor2_0,
|
||||
CustomDiffusionXFormersAttnProcessor,
|
||||
)
|
||||
|
||||
@@ -578,7 +579,10 @@ class UNet2DConditionLoadersMixin:
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
is_custom_diffusion = any(
|
||||
isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor))
|
||||
isinstance(
|
||||
x,
|
||||
(CustomDiffusionAttnProcessor, CustomDiffusionAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor),
|
||||
)
|
||||
for (_, x) in self.attn_processors.items()
|
||||
)
|
||||
if is_custom_diffusion:
|
||||
@@ -586,7 +590,14 @@ class UNet2DConditionLoadersMixin:
|
||||
{
|
||||
y: x
|
||||
for (y, x) in self.attn_processors.items()
|
||||
if isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor))
|
||||
if isinstance(
|
||||
x,
|
||||
(
|
||||
CustomDiffusionAttnProcessor,
|
||||
CustomDiffusionAttnProcessor2_0,
|
||||
CustomDiffusionXFormersAttnProcessor,
|
||||
),
|
||||
)
|
||||
}
|
||||
)
|
||||
state_dict = model_to_save.state_dict()
|
||||
|
||||
@@ -173,7 +173,8 @@ class Attention(nn.Module):
|
||||
LORA_ATTENTION_PROCESSORS,
|
||||
)
|
||||
is_custom_diffusion = hasattr(self, "processor") and isinstance(
|
||||
self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)
|
||||
self.processor,
|
||||
(CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0),
|
||||
)
|
||||
is_added_kv_processor = hasattr(self, "processor") and isinstance(
|
||||
self.processor,
|
||||
@@ -261,7 +262,12 @@ class Attention(nn.Module):
|
||||
processor.load_state_dict(self.processor.state_dict())
|
||||
processor.to(self.processor.to_q_lora.up.weight.device)
|
||||
elif is_custom_diffusion:
|
||||
processor = CustomDiffusionAttnProcessor(
|
||||
attn_processor_class = (
|
||||
CustomDiffusionAttnProcessor2_0
|
||||
if hasattr(F, "scaled_dot_product_attention")
|
||||
else CustomDiffusionAttnProcessor
|
||||
)
|
||||
processor = attn_processor_class(
|
||||
train_kv=self.processor.train_kv,
|
||||
train_q_out=self.processor.train_q_out,
|
||||
hidden_size=self.processor.hidden_size,
|
||||
@@ -1156,6 +1162,111 @@ class CustomDiffusionXFormersAttnProcessor(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CustomDiffusionAttnProcessor2_0(nn.Module):
|
||||
r"""
|
||||
Processor for implementing attention for the Custom Diffusion method using PyTorch 2.0’s memory-efficient scaled
|
||||
dot-product attention.
|
||||
|
||||
Args:
|
||||
train_kv (`bool`, defaults to `True`):
|
||||
Whether to newly train the key and value matrices corresponding to the text features.
|
||||
train_q_out (`bool`, defaults to `True`):
|
||||
Whether to newly train query matrices corresponding to the latent image features.
|
||||
hidden_size (`int`, *optional*, defaults to `None`):
|
||||
The hidden size of the attention layer.
|
||||
cross_attention_dim (`int`, *optional*, defaults to `None`):
|
||||
The number of channels in the `encoder_hidden_states`.
|
||||
out_bias (`bool`, defaults to `True`):
|
||||
Whether to include the bias parameter in `train_q_out`.
|
||||
dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probability to use.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
train_kv=True,
|
||||
train_q_out=True,
|
||||
hidden_size=None,
|
||||
cross_attention_dim=None,
|
||||
out_bias=True,
|
||||
dropout=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.train_kv = train_kv
|
||||
self.train_q_out = train_q_out
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
|
||||
# `_custom_diffusion` id for easy serialization and loading.
|
||||
if self.train_kv:
|
||||
self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||
self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||
if self.train_q_out:
|
||||
self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
|
||||
self.to_out_custom_diffusion = nn.ModuleList([])
|
||||
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
|
||||
self.to_out_custom_diffusion.append(nn.Dropout(dropout))
|
||||
|
||||
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
if self.train_q_out:
|
||||
query = self.to_q_custom_diffusion(hidden_states)
|
||||
else:
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
crossattn = False
|
||||
encoder_hidden_states = hidden_states
|
||||
else:
|
||||
crossattn = True
|
||||
if attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
if self.train_kv:
|
||||
key = self.to_k_custom_diffusion(encoder_hidden_states)
|
||||
value = self.to_v_custom_diffusion(encoder_hidden_states)
|
||||
else:
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
if crossattn:
|
||||
detach = torch.ones_like(key)
|
||||
detach[:, :1, :] = detach[:, :1, :] * 0.0
|
||||
key = detach * key + (1 - detach) * key.detach()
|
||||
value = detach * value + (1 - detach) * value.detach()
|
||||
|
||||
inner_dim = hidden_states.shape[-1]
|
||||
|
||||
head_dim = inner_dim // attn.heads
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if self.train_q_out:
|
||||
# linear proj
|
||||
hidden_states = self.to_out_custom_diffusion[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = self.to_out_custom_diffusion[1](hidden_states)
|
||||
else:
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SlicedAttnProcessor:
|
||||
r"""
|
||||
Processor for implementing sliced attention.
|
||||
@@ -1639,6 +1750,7 @@ AttentionProcessor = Union[
|
||||
XFormersAttnAddedKVProcessor,
|
||||
CustomDiffusionAttnProcessor,
|
||||
CustomDiffusionXFormersAttnProcessor,
|
||||
CustomDiffusionAttnProcessor2_0,
|
||||
# depraceted
|
||||
LoRAAttnProcessor,
|
||||
LoRAAttnProcessor2_0,
|
||||
|
||||
Reference in New Issue
Block a user