mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
* update * remove unused fn * apply suggestions based on review * update + cleanup 🧹 * more cleanup 🧹 * make fix-copies * update test
292 lines
12 KiB
Python
292 lines
12 KiB
Python
from typing import Any, Dict, Optional
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
|
from diffusers.models import PixArtTransformer2DModel
|
|
from diffusers.models.attention import BasicTransformerBlock
|
|
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
|
from diffusers.models.modeling_utils import ModelMixin
|
|
|
|
|
|
class PixArtControlNetAdapterBlock(nn.Module):
|
|
def __init__(
|
|
self,
|
|
block_index,
|
|
# taken from PixArtTransformer2DModel
|
|
num_attention_heads: int = 16,
|
|
attention_head_dim: int = 72,
|
|
dropout: float = 0.0,
|
|
cross_attention_dim: Optional[int] = 1152,
|
|
attention_bias: bool = True,
|
|
activation_fn: str = "gelu-approximate",
|
|
num_embeds_ada_norm: Optional[int] = 1000,
|
|
upcast_attention: bool = False,
|
|
norm_type: str = "ada_norm_single",
|
|
norm_elementwise_affine: bool = False,
|
|
norm_eps: float = 1e-6,
|
|
attention_type: Optional[str] = "default",
|
|
):
|
|
super().__init__()
|
|
|
|
self.block_index = block_index
|
|
self.inner_dim = num_attention_heads * attention_head_dim
|
|
|
|
# the first block has a zero before layer
|
|
if self.block_index == 0:
|
|
self.before_proj = nn.Linear(self.inner_dim, self.inner_dim)
|
|
nn.init.zeros_(self.before_proj.weight)
|
|
nn.init.zeros_(self.before_proj.bias)
|
|
|
|
self.transformer_block = BasicTransformerBlock(
|
|
self.inner_dim,
|
|
num_attention_heads,
|
|
attention_head_dim,
|
|
dropout=dropout,
|
|
cross_attention_dim=cross_attention_dim,
|
|
activation_fn=activation_fn,
|
|
num_embeds_ada_norm=num_embeds_ada_norm,
|
|
attention_bias=attention_bias,
|
|
upcast_attention=upcast_attention,
|
|
norm_type=norm_type,
|
|
norm_elementwise_affine=norm_elementwise_affine,
|
|
norm_eps=norm_eps,
|
|
attention_type=attention_type,
|
|
)
|
|
|
|
self.after_proj = nn.Linear(self.inner_dim, self.inner_dim)
|
|
nn.init.zeros_(self.after_proj.weight)
|
|
nn.init.zeros_(self.after_proj.bias)
|
|
|
|
def train(self, mode: bool = True):
|
|
self.transformer_block.train(mode)
|
|
|
|
if self.block_index == 0:
|
|
self.before_proj.train(mode)
|
|
|
|
self.after_proj.train(mode)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
controlnet_states: torch.Tensor,
|
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
timestep: Optional[torch.LongTensor] = None,
|
|
added_cond_kwargs: Dict[str, torch.Tensor] = None,
|
|
cross_attention_kwargs: Dict[str, Any] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
|
):
|
|
if self.block_index == 0:
|
|
controlnet_states = self.before_proj(controlnet_states)
|
|
controlnet_states = hidden_states + controlnet_states
|
|
|
|
controlnet_states_down = self.transformer_block(
|
|
hidden_states=controlnet_states,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
timestep=timestep,
|
|
added_cond_kwargs=added_cond_kwargs,
|
|
cross_attention_kwargs=cross_attention_kwargs,
|
|
attention_mask=attention_mask,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
class_labels=None,
|
|
)
|
|
|
|
controlnet_states_left = self.after_proj(controlnet_states_down)
|
|
|
|
return controlnet_states_left, controlnet_states_down
|
|
|
|
|
|
class PixArtControlNetAdapterModel(ModelMixin, ConfigMixin):
|
|
# N=13, as specified in the paper https://arxiv.org/html/2401.05252v1/#S4 ControlNet-Transformer
|
|
@register_to_config
|
|
def __init__(self, num_layers=13) -> None:
|
|
super().__init__()
|
|
|
|
self.num_layers = num_layers
|
|
|
|
self.controlnet_blocks = nn.ModuleList(
|
|
[PixArtControlNetAdapterBlock(block_index=i) for i in range(num_layers)]
|
|
)
|
|
|
|
@classmethod
|
|
def from_transformer(cls, transformer: PixArtTransformer2DModel):
|
|
control_net = PixArtControlNetAdapterModel()
|
|
|
|
# copied the specified number of blocks from the transformer
|
|
for depth in range(control_net.num_layers):
|
|
control_net.controlnet_blocks[depth].transformer_block.load_state_dict(
|
|
transformer.transformer_blocks[depth].state_dict()
|
|
)
|
|
|
|
return control_net
|
|
|
|
def train(self, mode: bool = True):
|
|
for block in self.controlnet_blocks:
|
|
block.train(mode)
|
|
|
|
|
|
class PixArtControlNetTransformerModel(ModelMixin, ConfigMixin):
|
|
def __init__(
|
|
self,
|
|
transformer: PixArtTransformer2DModel,
|
|
controlnet: PixArtControlNetAdapterModel,
|
|
blocks_num=13,
|
|
init_from_transformer=False,
|
|
training=False,
|
|
):
|
|
super().__init__()
|
|
|
|
self.blocks_num = blocks_num
|
|
self.gradient_checkpointing = False
|
|
self.register_to_config(**transformer.config)
|
|
self.training = training
|
|
|
|
if init_from_transformer:
|
|
# copies the specified number of blocks from the transformer
|
|
controlnet.from_transformer(transformer, self.blocks_num)
|
|
|
|
self.transformer = transformer
|
|
self.controlnet = controlnet
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
timestep: Optional[torch.LongTensor] = None,
|
|
controlnet_cond: Optional[torch.Tensor] = None,
|
|
added_cond_kwargs: Dict[str, torch.Tensor] = None,
|
|
cross_attention_kwargs: Dict[str, Any] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
|
return_dict: bool = True,
|
|
):
|
|
if self.transformer.use_additional_conditions and added_cond_kwargs is None:
|
|
raise ValueError("`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`.")
|
|
|
|
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
|
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
|
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
|
# expects mask of shape:
|
|
# [batch, key_tokens]
|
|
# adds singleton query_tokens dimension:
|
|
# [batch, 1, key_tokens]
|
|
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
|
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
|
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
|
if attention_mask is not None and attention_mask.ndim == 2:
|
|
# assume that mask is expressed as:
|
|
# (1 = keep, 0 = discard)
|
|
# convert mask into a bias that can be added to attention scores:
|
|
# (keep = +0, discard = -10000.0)
|
|
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
|
attention_mask = attention_mask.unsqueeze(1)
|
|
|
|
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
|
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
|
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
|
|
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
|
|
|
# 1. Input
|
|
batch_size = hidden_states.shape[0]
|
|
height, width = (
|
|
hidden_states.shape[-2] // self.transformer.config.patch_size,
|
|
hidden_states.shape[-1] // self.transformer.config.patch_size,
|
|
)
|
|
hidden_states = self.transformer.pos_embed(hidden_states)
|
|
|
|
timestep, embedded_timestep = self.transformer.adaln_single(
|
|
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
|
)
|
|
|
|
if self.transformer.caption_projection is not None:
|
|
encoder_hidden_states = self.transformer.caption_projection(encoder_hidden_states)
|
|
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
|
|
|
|
controlnet_states_down = None
|
|
if controlnet_cond is not None:
|
|
controlnet_states_down = self.transformer.pos_embed(controlnet_cond)
|
|
|
|
# 2. Blocks
|
|
for block_index, block in enumerate(self.transformer.transformer_blocks):
|
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
|
# rc todo: for training and gradient checkpointing
|
|
print("Gradient checkpointing is not supported for the controlnet transformer model, yet.")
|
|
exit(1)
|
|
|
|
hidden_states = self._gradient_checkpointing_func(
|
|
block,
|
|
hidden_states,
|
|
attention_mask,
|
|
encoder_hidden_states,
|
|
encoder_attention_mask,
|
|
timestep,
|
|
cross_attention_kwargs,
|
|
None,
|
|
)
|
|
else:
|
|
# the control nets are only used for the blocks 1 to self.blocks_num
|
|
if block_index > 0 and block_index <= self.blocks_num and controlnet_states_down is not None:
|
|
controlnet_states_left, controlnet_states_down = self.controlnet.controlnet_blocks[
|
|
block_index - 1
|
|
](
|
|
hidden_states=hidden_states, # used only in the first block
|
|
controlnet_states=controlnet_states_down,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
timestep=timestep,
|
|
added_cond_kwargs=added_cond_kwargs,
|
|
cross_attention_kwargs=cross_attention_kwargs,
|
|
attention_mask=attention_mask,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
)
|
|
|
|
hidden_states = hidden_states + controlnet_states_left
|
|
|
|
hidden_states = block(
|
|
hidden_states,
|
|
attention_mask=attention_mask,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
timestep=timestep,
|
|
cross_attention_kwargs=cross_attention_kwargs,
|
|
class_labels=None,
|
|
)
|
|
|
|
# 3. Output
|
|
shift, scale = (
|
|
self.transformer.scale_shift_table[None]
|
|
+ embedded_timestep[:, None].to(self.transformer.scale_shift_table.device)
|
|
).chunk(2, dim=1)
|
|
hidden_states = self.transformer.norm_out(hidden_states)
|
|
# Modulation
|
|
hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device)
|
|
hidden_states = self.transformer.proj_out(hidden_states)
|
|
hidden_states = hidden_states.squeeze(1)
|
|
|
|
# unpatchify
|
|
hidden_states = hidden_states.reshape(
|
|
shape=(
|
|
-1,
|
|
height,
|
|
width,
|
|
self.transformer.config.patch_size,
|
|
self.transformer.config.patch_size,
|
|
self.transformer.out_channels,
|
|
)
|
|
)
|
|
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
|
output = hidden_states.reshape(
|
|
shape=(
|
|
-1,
|
|
self.transformer.out_channels,
|
|
height * self.transformer.config.patch_size,
|
|
width * self.transformer.config.patch_size,
|
|
)
|
|
)
|
|
|
|
if not return_dict:
|
|
return (output,)
|
|
|
|
return Transformer2DModelOutput(sample=output)
|