mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
attention refactor: the trilogy (#3387)
* Replace `AttentionBlock` with `Attention` * use _from_deprecated_attn_block check re: @patrickvonplaten
This commit is contained in:
@@ -11,189 +11,17 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
from typing import Callable, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..utils import maybe_allow_in_graph
|
||||
from ..utils.import_utils import is_xformers_available
|
||||
from .attention_processor import Attention
|
||||
from .embeddings import CombinedTimestepLabelEmbeddings
|
||||
|
||||
|
||||
if is_xformers_available():
|
||||
import xformers
|
||||
import xformers.ops
|
||||
else:
|
||||
xformers = None
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
"""
|
||||
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
|
||||
to the N-d case.
|
||||
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
||||
Uses three q, k, v linear layers to compute attention.
|
||||
|
||||
Parameters:
|
||||
channels (`int`): The number of channels in the input and output.
|
||||
num_head_channels (`int`, *optional*):
|
||||
The number of channels in each head. If None, then `num_heads` = 1.
|
||||
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm.
|
||||
rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
|
||||
eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
|
||||
"""
|
||||
|
||||
# IMPORTANT;TODO(Patrick, William) - this class will be deprecated soon. Do not use it anymore
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
num_head_channels: Optional[int] = None,
|
||||
norm_num_groups: int = 32,
|
||||
rescale_output_factor: float = 1.0,
|
||||
eps: float = 1e-5,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
|
||||
self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
|
||||
self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True)
|
||||
|
||||
# define q,k,v as linear layers
|
||||
self.query = nn.Linear(channels, channels)
|
||||
self.key = nn.Linear(channels, channels)
|
||||
self.value = nn.Linear(channels, channels)
|
||||
|
||||
self.rescale_output_factor = rescale_output_factor
|
||||
self.proj_attn = nn.Linear(channels, channels, bias=True)
|
||||
|
||||
self._use_memory_efficient_attention_xformers = False
|
||||
self._use_2_0_attn = True
|
||||
self._attention_op = None
|
||||
|
||||
def reshape_heads_to_batch_dim(self, tensor, merge_head_and_batch=True):
|
||||
batch_size, seq_len, dim = tensor.shape
|
||||
head_size = self.num_heads
|
||||
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
||||
tensor = tensor.permute(0, 2, 1, 3)
|
||||
if merge_head_and_batch:
|
||||
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
|
||||
return tensor
|
||||
|
||||
def reshape_batch_dim_to_heads(self, tensor, unmerge_head_and_batch=True):
|
||||
head_size = self.num_heads
|
||||
|
||||
if unmerge_head_and_batch:
|
||||
batch_head_size, seq_len, dim = tensor.shape
|
||||
batch_size = batch_head_size // head_size
|
||||
|
||||
tensor = tensor.reshape(batch_size, head_size, seq_len, dim)
|
||||
else:
|
||||
batch_size, _, seq_len, dim = tensor.shape
|
||||
|
||||
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size, seq_len, dim * head_size)
|
||||
return tensor
|
||||
|
||||
def set_use_memory_efficient_attention_xformers(
|
||||
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
|
||||
):
|
||||
if use_memory_efficient_attention_xformers:
|
||||
if not is_xformers_available():
|
||||
raise ModuleNotFoundError(
|
||||
(
|
||||
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
|
||||
" xformers"
|
||||
),
|
||||
name="xformers",
|
||||
)
|
||||
elif not torch.cuda.is_available():
|
||||
raise ValueError(
|
||||
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
|
||||
" only available for GPU "
|
||||
)
|
||||
else:
|
||||
try:
|
||||
# Make sure we can run the memory efficient attention
|
||||
_ = xformers.ops.memory_efficient_attention(
|
||||
torch.randn((1, 2, 40), device="cuda"),
|
||||
torch.randn((1, 2, 40), device="cuda"),
|
||||
torch.randn((1, 2, 40), device="cuda"),
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
||||
self._attention_op = attention_op
|
||||
|
||||
def forward(self, hidden_states):
|
||||
residual = hidden_states
|
||||
batch, channel, height, width = hidden_states.shape
|
||||
|
||||
# norm
|
||||
hidden_states = self.group_norm(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
|
||||
|
||||
# proj to q, k, v
|
||||
query_proj = self.query(hidden_states)
|
||||
key_proj = self.key(hidden_states)
|
||||
value_proj = self.value(hidden_states)
|
||||
|
||||
scale = 1 / math.sqrt(self.channels / self.num_heads)
|
||||
|
||||
_use_2_0_attn = self._use_2_0_attn and not self._use_memory_efficient_attention_xformers
|
||||
use_torch_2_0_attn = hasattr(F, "scaled_dot_product_attention") and _use_2_0_attn
|
||||
|
||||
query_proj = self.reshape_heads_to_batch_dim(query_proj, merge_head_and_batch=not use_torch_2_0_attn)
|
||||
key_proj = self.reshape_heads_to_batch_dim(key_proj, merge_head_and_batch=not use_torch_2_0_attn)
|
||||
value_proj = self.reshape_heads_to_batch_dim(value_proj, merge_head_and_batch=not use_torch_2_0_attn)
|
||||
|
||||
if self._use_memory_efficient_attention_xformers:
|
||||
# Memory efficient attention
|
||||
hidden_states = xformers.ops.memory_efficient_attention(
|
||||
query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op, scale=scale
|
||||
)
|
||||
hidden_states = hidden_states.to(query_proj.dtype)
|
||||
elif use_torch_2_0_attn:
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query_proj, key_proj, value_proj, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
hidden_states = hidden_states.to(query_proj.dtype)
|
||||
else:
|
||||
attention_scores = torch.baddbmm(
|
||||
torch.empty(
|
||||
query_proj.shape[0],
|
||||
query_proj.shape[1],
|
||||
key_proj.shape[1],
|
||||
dtype=query_proj.dtype,
|
||||
device=query_proj.device,
|
||||
),
|
||||
query_proj,
|
||||
key_proj.transpose(-1, -2),
|
||||
beta=0,
|
||||
alpha=scale,
|
||||
)
|
||||
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
|
||||
hidden_states = torch.bmm(attention_probs, value_proj)
|
||||
|
||||
# reshape hidden_states
|
||||
hidden_states = self.reshape_batch_dim_to_heads(hidden_states, unmerge_head_and_batch=not use_torch_2_0_attn)
|
||||
|
||||
# compute next hidden_states
|
||||
hidden_states = self.proj_attn(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
|
||||
|
||||
# res connect and rescale
|
||||
hidden_states = (hidden_states + residual) / self.rescale_output_factor
|
||||
return hidden_states
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
r"""
|
||||
|
||||
@@ -65,6 +65,10 @@ class Attention(nn.Module):
|
||||
out_bias: bool = True,
|
||||
scale_qk: bool = True,
|
||||
only_cross_attention: bool = False,
|
||||
eps: float = 1e-5,
|
||||
rescale_output_factor: float = 1.0,
|
||||
residual_connection: bool = False,
|
||||
_from_deprecated_attn_block=False,
|
||||
processor: Optional["AttnProcessor"] = None,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -72,6 +76,12 @@ class Attention(nn.Module):
|
||||
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
||||
self.upcast_attention = upcast_attention
|
||||
self.upcast_softmax = upcast_softmax
|
||||
self.rescale_output_factor = rescale_output_factor
|
||||
self.residual_connection = residual_connection
|
||||
|
||||
# we make use of this private variable to know whether this class is loaded
|
||||
# with an deprecated state dict so that we can convert it on the fly
|
||||
self._from_deprecated_attn_block = _from_deprecated_attn_block
|
||||
|
||||
self.scale_qk = scale_qk
|
||||
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
|
||||
@@ -91,7 +101,7 @@ class Attention(nn.Module):
|
||||
)
|
||||
|
||||
if norm_num_groups is not None:
|
||||
self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
|
||||
self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
|
||||
else:
|
||||
self.group_norm = None
|
||||
|
||||
@@ -407,10 +417,22 @@ class AttnProcessor:
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
):
|
||||
residual = hidden_states
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
@@ -434,6 +456,14 @@ class AttnProcessor:
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -474,11 +504,22 @@ class LoRAAttnProcessor(nn.Module):
|
||||
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
|
||||
|
||||
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
|
||||
residual = hidden_states
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
|
||||
query = attn.head_to_batch_dim(query)
|
||||
|
||||
@@ -502,6 +543,14 @@ class LoRAAttnProcessor(nn.Module):
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -762,12 +811,23 @@ class XFormersAttnProcessor:
|
||||
self.attention_op = attention_op
|
||||
|
||||
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
residual = hidden_states
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
@@ -792,6 +852,15 @@ class XFormersAttnProcessor:
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -801,6 +870,14 @@ class AttnProcessor2_0:
|
||||
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
residual = hidden_states
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
@@ -812,6 +889,9 @@ class AttnProcessor2_0:
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
@@ -840,6 +920,15 @@ class AttnProcessor2_0:
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -858,11 +947,22 @@ class LoRAXFormersAttnProcessor(nn.Module):
|
||||
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
|
||||
|
||||
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
|
||||
residual = hidden_states
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
|
||||
query = attn.head_to_batch_dim(query).contiguous()
|
||||
|
||||
@@ -887,6 +987,14 @@ class LoRAXFormersAttnProcessor(nn.Module):
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -980,11 +1088,22 @@ class SlicedAttnProcessor:
|
||||
self.slice_size = slice_size
|
||||
|
||||
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
residual = hidden_states
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
dim = query.shape[-1]
|
||||
query = attn.head_to_batch_dim(query)
|
||||
@@ -1025,6 +1144,14 @@ class SlicedAttnProcessor:
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
@@ -583,6 +583,7 @@ class ModelMixin(torch.nn.Module):
|
||||
if device_map is None:
|
||||
param_device = "cpu"
|
||||
state_dict = load_state_dict(model_file, variant=variant)
|
||||
model._convert_deprecated_attention_blocks(state_dict)
|
||||
# move the params from meta device to cpu
|
||||
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
|
||||
if len(missing_keys) > 0:
|
||||
@@ -625,6 +626,7 @@ class ModelMixin(torch.nn.Module):
|
||||
model = cls.from_config(config, **unused_kwargs)
|
||||
|
||||
state_dict = load_state_dict(model_file, variant=variant)
|
||||
model._convert_deprecated_attention_blocks(state_dict)
|
||||
|
||||
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
||||
model,
|
||||
@@ -803,3 +805,47 @@ class ModelMixin(torch.nn.Module):
|
||||
return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
|
||||
else:
|
||||
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
|
||||
|
||||
def _convert_deprecated_attention_blocks(self, state_dict):
|
||||
deprecated_attention_block_paths = []
|
||||
|
||||
def recursive_find_attn_block(name, module):
|
||||
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
|
||||
deprecated_attention_block_paths.append(name)
|
||||
|
||||
for sub_name, sub_module in module.named_children():
|
||||
sub_name = sub_name if name == "" else f"{name}.{sub_name}"
|
||||
recursive_find_attn_block(sub_name, sub_module)
|
||||
|
||||
recursive_find_attn_block("", self)
|
||||
|
||||
# NOTE: we have to check if the deprecated parameters are in the state dict
|
||||
# because it is possible we are loading from a state dict that was already
|
||||
# converted
|
||||
|
||||
for path in deprecated_attention_block_paths:
|
||||
# group_norm path stays the same
|
||||
|
||||
# query -> to_q
|
||||
if f"{path}.query.weight" in state_dict:
|
||||
state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight")
|
||||
if f"{path}.query.bias" in state_dict:
|
||||
state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias")
|
||||
|
||||
# key -> to_k
|
||||
if f"{path}.key.weight" in state_dict:
|
||||
state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight")
|
||||
if f"{path}.key.bias" in state_dict:
|
||||
state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias")
|
||||
|
||||
# value -> to_v
|
||||
if f"{path}.value.weight" in state_dict:
|
||||
state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight")
|
||||
if f"{path}.value.bias" in state_dict:
|
||||
state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias")
|
||||
|
||||
# proj_attn -> to_out.0
|
||||
if f"{path}.proj_attn.weight" in state_dict:
|
||||
state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight")
|
||||
if f"{path}.proj_attn.bias" in state_dict:
|
||||
state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
|
||||
|
||||
@@ -18,7 +18,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from .attention import AdaGroupNorm, AttentionBlock
|
||||
from .attention import AdaGroupNorm
|
||||
from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
|
||||
from .dual_transformer_2d import DualTransformer2DModel
|
||||
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D
|
||||
@@ -427,12 +427,17 @@ class UNetMidBlock2D(nn.Module):
|
||||
for _ in range(num_layers):
|
||||
if self.add_attention:
|
||||
attentions.append(
|
||||
AttentionBlock(
|
||||
Attention(
|
||||
in_channels,
|
||||
num_head_channels=attn_num_head_channels,
|
||||
heads=in_channels // attn_num_head_channels if attn_num_head_channels is not None else 1,
|
||||
dim_head=attn_num_head_channels if attn_num_head_channels is not None else in_channels,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
norm_num_groups=resnet_groups,
|
||||
residual_connection=True,
|
||||
bias=True,
|
||||
upcast_softmax=True,
|
||||
_from_deprecated_attn_block=True,
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -711,12 +716,17 @@ class AttnDownBlock2D(nn.Module):
|
||||
)
|
||||
)
|
||||
attentions.append(
|
||||
AttentionBlock(
|
||||
Attention(
|
||||
out_channels,
|
||||
num_head_channels=attn_num_head_channels,
|
||||
heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1,
|
||||
dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
norm_num_groups=resnet_groups,
|
||||
residual_connection=True,
|
||||
bias=True,
|
||||
upcast_softmax=True,
|
||||
_from_deprecated_attn_block=True,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1060,12 +1070,17 @@ class AttnDownEncoderBlock2D(nn.Module):
|
||||
)
|
||||
)
|
||||
attentions.append(
|
||||
AttentionBlock(
|
||||
Attention(
|
||||
out_channels,
|
||||
num_head_channels=attn_num_head_channels,
|
||||
heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1,
|
||||
dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
norm_num_groups=resnet_groups,
|
||||
residual_connection=True,
|
||||
bias=True,
|
||||
upcast_softmax=True,
|
||||
_from_deprecated_attn_block=True,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1134,11 +1149,17 @@ class AttnSkipDownBlock2D(nn.Module):
|
||||
)
|
||||
)
|
||||
self.attentions.append(
|
||||
AttentionBlock(
|
||||
Attention(
|
||||
out_channels,
|
||||
num_head_channels=attn_num_head_channels,
|
||||
heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1,
|
||||
dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
norm_num_groups=32,
|
||||
residual_connection=True,
|
||||
bias=True,
|
||||
upcast_softmax=True,
|
||||
_from_deprecated_attn_block=True,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1703,12 +1724,17 @@ class AttnUpBlock2D(nn.Module):
|
||||
)
|
||||
)
|
||||
attentions.append(
|
||||
AttentionBlock(
|
||||
Attention(
|
||||
out_channels,
|
||||
num_head_channels=attn_num_head_channels,
|
||||
heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1,
|
||||
dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
norm_num_groups=resnet_groups,
|
||||
residual_connection=True,
|
||||
bias=True,
|
||||
upcast_softmax=True,
|
||||
_from_deprecated_attn_block=True,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -2037,12 +2063,17 @@ class AttnUpDecoderBlock2D(nn.Module):
|
||||
)
|
||||
)
|
||||
attentions.append(
|
||||
AttentionBlock(
|
||||
Attention(
|
||||
out_channels,
|
||||
num_head_channels=attn_num_head_channels,
|
||||
heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1,
|
||||
dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
norm_num_groups=resnet_groups,
|
||||
residual_connection=True,
|
||||
bias=True,
|
||||
upcast_softmax=True,
|
||||
_from_deprecated_attn_block=True,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -2109,11 +2140,17 @@ class AttnSkipUpBlock2D(nn.Module):
|
||||
)
|
||||
|
||||
self.attentions.append(
|
||||
AttentionBlock(
|
||||
Attention(
|
||||
out_channels,
|
||||
num_head_channels=attn_num_head_channels,
|
||||
heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1,
|
||||
dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
norm_num_groups=32,
|
||||
residual_connection=True,
|
||||
bias=True,
|
||||
upcast_softmax=True,
|
||||
_from_deprecated_attn_block=True,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -19,11 +19,11 @@ from typing import Any, Callable, List, Optional, Union
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...loaders import TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...models.attention_processor import AttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor
|
||||
from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers
|
||||
from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
@@ -709,12 +709,14 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
self.vae.to(dtype=torch.float32)
|
||||
|
||||
# TODO(Patrick, William) - clean up when attention is refactored
|
||||
use_torch_2_0_attn = hasattr(F, "scaled_dot_product_attention")
|
||||
use_xformers = self.vae.decoder.mid_block.attentions[0]._use_memory_efficient_attention_xformers
|
||||
use_torch_2_0_or_xformers = self.vae.decoder.mid_block.attentions[0].processor in [
|
||||
AttnProcessor2_0,
|
||||
XFormersAttnProcessor,
|
||||
LoRAXFormersAttnProcessor,
|
||||
]
|
||||
# if xformers or torch_2_0 is used attention block does not need
|
||||
# to be in float32 which can save lots of memory
|
||||
if not use_torch_2_0_attn and not use_xformers:
|
||||
if not use_torch_2_0_or_xformers:
|
||||
self.vae.post_quant_conv.to(latents.dtype)
|
||||
self.vae.decoder.conv_in.to(latents.dtype)
|
||||
self.vae.decoder.mid_block.to(latents.dtype)
|
||||
|
||||
@@ -20,7 +20,7 @@ import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU, AttentionBlock
|
||||
from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU
|
||||
from diffusers.models.embeddings import get_timestep_embedding
|
||||
from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
|
||||
from diffusers.models.transformer_2d import Transformer2DModel
|
||||
@@ -314,59 +314,6 @@ class ResnetBlock2DTests(unittest.TestCase):
|
||||
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||
|
||||
|
||||
class AttentionBlockTests(unittest.TestCase):
|
||||
@unittest.skipIf(
|
||||
torch_device == "mps", "Matmul crashes on MPS, see https://github.com/pytorch/pytorch/issues/84039"
|
||||
)
|
||||
def test_attention_block_default(self):
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
sample = torch.randn(1, 32, 64, 64).to(torch_device)
|
||||
attentionBlock = AttentionBlock(
|
||||
channels=32,
|
||||
num_head_channels=1,
|
||||
rescale_output_factor=1.0,
|
||||
eps=1e-6,
|
||||
norm_num_groups=32,
|
||||
).to(torch_device)
|
||||
with torch.no_grad():
|
||||
attention_scores = attentionBlock(sample)
|
||||
|
||||
assert attention_scores.shape == (1, 32, 64, 64)
|
||||
output_slice = attention_scores[0, -1, -3:, -3:]
|
||||
|
||||
expected_slice = torch.tensor(
|
||||
[-1.4975, -0.0038, -0.7847, -1.4567, 1.1220, -0.8962, -1.7394, 1.1319, -0.5427], device=torch_device
|
||||
)
|
||||
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||
|
||||
def test_attention_block_sd(self):
|
||||
# This version uses SD params and is compatible with mps
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
sample = torch.randn(1, 512, 64, 64).to(torch_device)
|
||||
attentionBlock = AttentionBlock(
|
||||
channels=512,
|
||||
rescale_output_factor=1.0,
|
||||
eps=1e-6,
|
||||
norm_num_groups=32,
|
||||
).to(torch_device)
|
||||
with torch.no_grad():
|
||||
attention_scores = attentionBlock(sample)
|
||||
|
||||
assert attention_scores.shape == (1, 512, 64, 64)
|
||||
output_slice = attention_scores[0, -1, -3:, -3:]
|
||||
|
||||
expected_slice = torch.tensor(
|
||||
[-0.6621, -0.0156, -3.2766, 0.8025, -0.8609, 0.2820, 0.0905, -1.1179, -3.2126], device=torch_device
|
||||
)
|
||||
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||
|
||||
|
||||
class Transformer2DModelTests(unittest.TestCase):
|
||||
def test_spatial_transformer_default(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
Reference in New Issue
Block a user