1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

attention refactor: the trilogy (#3387)

* Replace `AttentionBlock` with `Attention`

* use _from_deprecated_attn_block check re: @patrickvonplaten
This commit is contained in:
Will Berman
2023-05-12 08:54:09 -06:00
committed by GitHub
parent 28f404349d
commit 909742dbd6
6 changed files with 235 additions and 248 deletions

View File

@@ -11,189 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Callable, Optional
from typing import Optional
import torch
import torch.nn.functional as F
from torch import nn
from ..utils import maybe_allow_in_graph
from ..utils.import_utils import is_xformers_available
from .attention_processor import Attention
from .embeddings import CombinedTimestepLabelEmbeddings
if is_xformers_available():
import xformers
import xformers.ops
else:
xformers = None
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
Uses three q, k, v linear layers to compute attention.
Parameters:
channels (`int`): The number of channels in the input and output.
num_head_channels (`int`, *optional*):
The number of channels in each head. If None, then `num_heads` = 1.
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm.
rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
"""
# IMPORTANT;TODO(Patrick, William) - this class will be deprecated soon. Do not use it anymore
def __init__(
self,
channels: int,
num_head_channels: Optional[int] = None,
norm_num_groups: int = 32,
rescale_output_factor: float = 1.0,
eps: float = 1e-5,
):
super().__init__()
self.channels = channels
self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True)
# define q,k,v as linear layers
self.query = nn.Linear(channels, channels)
self.key = nn.Linear(channels, channels)
self.value = nn.Linear(channels, channels)
self.rescale_output_factor = rescale_output_factor
self.proj_attn = nn.Linear(channels, channels, bias=True)
self._use_memory_efficient_attention_xformers = False
self._use_2_0_attn = True
self._attention_op = None
def reshape_heads_to_batch_dim(self, tensor, merge_head_and_batch=True):
batch_size, seq_len, dim = tensor.shape
head_size = self.num_heads
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3)
if merge_head_and_batch:
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
return tensor
def reshape_batch_dim_to_heads(self, tensor, unmerge_head_and_batch=True):
head_size = self.num_heads
if unmerge_head_and_batch:
batch_head_size, seq_len, dim = tensor.shape
batch_size = batch_head_size // head_size
tensor = tensor.reshape(batch_size, head_size, seq_len, dim)
else:
batch_size, _, seq_len, dim = tensor.shape
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size, seq_len, dim * head_size)
return tensor
def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
):
if use_memory_efficient_attention_xformers:
if not is_xformers_available():
raise ModuleNotFoundError(
(
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
" xformers"
),
name="xformers",
)
elif not torch.cuda.is_available():
raise ValueError(
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
" only available for GPU "
)
else:
try:
# Make sure we can run the memory efficient attention
_ = xformers.ops.memory_efficient_attention(
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
)
except Exception as e:
raise e
self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
self._attention_op = attention_op
def forward(self, hidden_states):
residual = hidden_states
batch, channel, height, width = hidden_states.shape
# norm
hidden_states = self.group_norm(hidden_states)
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
# proj to q, k, v
query_proj = self.query(hidden_states)
key_proj = self.key(hidden_states)
value_proj = self.value(hidden_states)
scale = 1 / math.sqrt(self.channels / self.num_heads)
_use_2_0_attn = self._use_2_0_attn and not self._use_memory_efficient_attention_xformers
use_torch_2_0_attn = hasattr(F, "scaled_dot_product_attention") and _use_2_0_attn
query_proj = self.reshape_heads_to_batch_dim(query_proj, merge_head_and_batch=not use_torch_2_0_attn)
key_proj = self.reshape_heads_to_batch_dim(key_proj, merge_head_and_batch=not use_torch_2_0_attn)
value_proj = self.reshape_heads_to_batch_dim(value_proj, merge_head_and_batch=not use_torch_2_0_attn)
if self._use_memory_efficient_attention_xformers:
# Memory efficient attention
hidden_states = xformers.ops.memory_efficient_attention(
query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op, scale=scale
)
hidden_states = hidden_states.to(query_proj.dtype)
elif use_torch_2_0_attn:
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query_proj, key_proj, value_proj, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.to(query_proj.dtype)
else:
attention_scores = torch.baddbmm(
torch.empty(
query_proj.shape[0],
query_proj.shape[1],
key_proj.shape[1],
dtype=query_proj.dtype,
device=query_proj.device,
),
query_proj,
key_proj.transpose(-1, -2),
beta=0,
alpha=scale,
)
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
hidden_states = torch.bmm(attention_probs, value_proj)
# reshape hidden_states
hidden_states = self.reshape_batch_dim_to_heads(hidden_states, unmerge_head_and_batch=not use_torch_2_0_attn)
# compute next hidden_states
hidden_states = self.proj_attn(hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
# res connect and rescale
hidden_states = (hidden_states + residual) / self.rescale_output_factor
return hidden_states
@maybe_allow_in_graph
class BasicTransformerBlock(nn.Module):
r"""

View File

@@ -65,6 +65,10 @@ class Attention(nn.Module):
out_bias: bool = True,
scale_qk: bool = True,
only_cross_attention: bool = False,
eps: float = 1e-5,
rescale_output_factor: float = 1.0,
residual_connection: bool = False,
_from_deprecated_attn_block=False,
processor: Optional["AttnProcessor"] = None,
):
super().__init__()
@@ -72,6 +76,12 @@ class Attention(nn.Module):
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax
self.rescale_output_factor = rescale_output_factor
self.residual_connection = residual_connection
# we make use of this private variable to know whether this class is loaded
# with an deprecated state dict so that we can convert it on the fly
self._from_deprecated_attn_block = _from_deprecated_attn_block
self.scale_qk = scale_qk
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
@@ -91,7 +101,7 @@ class Attention(nn.Module):
)
if norm_num_groups is not None:
self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
else:
self.group_norm = None
@@ -407,10 +417,22 @@ class AttnProcessor:
encoder_hidden_states=None,
attention_mask=None,
):
residual = hidden_states
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
@@ -434,6 +456,14 @@ class AttnProcessor:
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
@@ -474,11 +504,22 @@ class LoRAAttnProcessor(nn.Module):
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
residual = hidden_states
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
query = attn.head_to_batch_dim(query)
@@ -502,6 +543,14 @@ class LoRAAttnProcessor(nn.Module):
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
@@ -762,12 +811,23 @@ class XFormersAttnProcessor:
self.attention_op = attention_op
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
residual = hidden_states
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
@@ -792,6 +852,15 @@ class XFormersAttnProcessor:
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
@@ -801,6 +870,14 @@ class AttnProcessor2_0:
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
residual = hidden_states
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
@@ -812,6 +889,9 @@ class AttnProcessor2_0:
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
@@ -840,6 +920,15 @@ class AttnProcessor2_0:
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
@@ -858,11 +947,22 @@ class LoRAXFormersAttnProcessor(nn.Module):
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
residual = hidden_states
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
query = attn.head_to_batch_dim(query).contiguous()
@@ -887,6 +987,14 @@ class LoRAXFormersAttnProcessor(nn.Module):
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
@@ -980,11 +1088,22 @@ class SlicedAttnProcessor:
self.slice_size = slice_size
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
residual = hidden_states
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
dim = query.shape[-1]
query = attn.head_to_batch_dim(query)
@@ -1025,6 +1144,14 @@ class SlicedAttnProcessor:
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states

View File

@@ -583,6 +583,7 @@ class ModelMixin(torch.nn.Module):
if device_map is None:
param_device = "cpu"
state_dict = load_state_dict(model_file, variant=variant)
model._convert_deprecated_attention_blocks(state_dict)
# move the params from meta device to cpu
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
if len(missing_keys) > 0:
@@ -625,6 +626,7 @@ class ModelMixin(torch.nn.Module):
model = cls.from_config(config, **unused_kwargs)
state_dict = load_state_dict(model_file, variant=variant)
model._convert_deprecated_attention_blocks(state_dict)
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
model,
@@ -803,3 +805,47 @@ class ModelMixin(torch.nn.Module):
return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
else:
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
def _convert_deprecated_attention_blocks(self, state_dict):
deprecated_attention_block_paths = []
def recursive_find_attn_block(name, module):
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
deprecated_attention_block_paths.append(name)
for sub_name, sub_module in module.named_children():
sub_name = sub_name if name == "" else f"{name}.{sub_name}"
recursive_find_attn_block(sub_name, sub_module)
recursive_find_attn_block("", self)
# NOTE: we have to check if the deprecated parameters are in the state dict
# because it is possible we are loading from a state dict that was already
# converted
for path in deprecated_attention_block_paths:
# group_norm path stays the same
# query -> to_q
if f"{path}.query.weight" in state_dict:
state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight")
if f"{path}.query.bias" in state_dict:
state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias")
# key -> to_k
if f"{path}.key.weight" in state_dict:
state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight")
if f"{path}.key.bias" in state_dict:
state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias")
# value -> to_v
if f"{path}.value.weight" in state_dict:
state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight")
if f"{path}.value.bias" in state_dict:
state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias")
# proj_attn -> to_out.0
if f"{path}.proj_attn.weight" in state_dict:
state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight")
if f"{path}.proj_attn.bias" in state_dict:
state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")

View File

@@ -18,7 +18,7 @@ import torch
import torch.nn.functional as F
from torch import nn
from .attention import AdaGroupNorm, AttentionBlock
from .attention import AdaGroupNorm
from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
from .dual_transformer_2d import DualTransformer2DModel
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D
@@ -427,12 +427,17 @@ class UNetMidBlock2D(nn.Module):
for _ in range(num_layers):
if self.add_attention:
attentions.append(
AttentionBlock(
Attention(
in_channels,
num_head_channels=attn_num_head_channels,
heads=in_channels // attn_num_head_channels if attn_num_head_channels is not None else 1,
dim_head=attn_num_head_channels if attn_num_head_channels is not None else in_channels,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
norm_num_groups=resnet_groups,
residual_connection=True,
bias=True,
upcast_softmax=True,
_from_deprecated_attn_block=True,
)
)
else:
@@ -711,12 +716,17 @@ class AttnDownBlock2D(nn.Module):
)
)
attentions.append(
AttentionBlock(
Attention(
out_channels,
num_head_channels=attn_num_head_channels,
heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1,
dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
norm_num_groups=resnet_groups,
residual_connection=True,
bias=True,
upcast_softmax=True,
_from_deprecated_attn_block=True,
)
)
@@ -1060,12 +1070,17 @@ class AttnDownEncoderBlock2D(nn.Module):
)
)
attentions.append(
AttentionBlock(
Attention(
out_channels,
num_head_channels=attn_num_head_channels,
heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1,
dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
norm_num_groups=resnet_groups,
residual_connection=True,
bias=True,
upcast_softmax=True,
_from_deprecated_attn_block=True,
)
)
@@ -1134,11 +1149,17 @@ class AttnSkipDownBlock2D(nn.Module):
)
)
self.attentions.append(
AttentionBlock(
Attention(
out_channels,
num_head_channels=attn_num_head_channels,
heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1,
dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
norm_num_groups=32,
residual_connection=True,
bias=True,
upcast_softmax=True,
_from_deprecated_attn_block=True,
)
)
@@ -1703,12 +1724,17 @@ class AttnUpBlock2D(nn.Module):
)
)
attentions.append(
AttentionBlock(
Attention(
out_channels,
num_head_channels=attn_num_head_channels,
heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1,
dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
norm_num_groups=resnet_groups,
residual_connection=True,
bias=True,
upcast_softmax=True,
_from_deprecated_attn_block=True,
)
)
@@ -2037,12 +2063,17 @@ class AttnUpDecoderBlock2D(nn.Module):
)
)
attentions.append(
AttentionBlock(
Attention(
out_channels,
num_head_channels=attn_num_head_channels,
heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1,
dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
norm_num_groups=resnet_groups,
residual_connection=True,
bias=True,
upcast_softmax=True,
_from_deprecated_attn_block=True,
)
)
@@ -2109,11 +2140,17 @@ class AttnSkipUpBlock2D(nn.Module):
)
self.attentions.append(
AttentionBlock(
Attention(
out_channels,
num_head_channels=attn_num_head_channels,
heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1,
dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
norm_num_groups=32,
residual_connection=True,
bias=True,
upcast_softmax=True,
_from_deprecated_attn_block=True,
)
)

View File

@@ -19,11 +19,11 @@ from typing import Any, Callable, List, Optional, Union
import numpy as np
import PIL
import torch
import torch.nn.functional as F
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import AttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor
from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers
from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline
@@ -709,12 +709,14 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
# make sure the VAE is in float32 mode, as it overflows in float16
self.vae.to(dtype=torch.float32)
# TODO(Patrick, William) - clean up when attention is refactored
use_torch_2_0_attn = hasattr(F, "scaled_dot_product_attention")
use_xformers = self.vae.decoder.mid_block.attentions[0]._use_memory_efficient_attention_xformers
use_torch_2_0_or_xformers = self.vae.decoder.mid_block.attentions[0].processor in [
AttnProcessor2_0,
XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
]
# if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory
if not use_torch_2_0_attn and not use_xformers:
if not use_torch_2_0_or_xformers:
self.vae.post_quant_conv.to(latents.dtype)
self.vae.decoder.conv_in.to(latents.dtype)
self.vae.decoder.mid_block.to(latents.dtype)

View File

@@ -20,7 +20,7 @@ import numpy as np
import torch
from torch import nn
from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU, AttentionBlock
from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU
from diffusers.models.embeddings import get_timestep_embedding
from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
from diffusers.models.transformer_2d import Transformer2DModel
@@ -314,59 +314,6 @@ class ResnetBlock2DTests(unittest.TestCase):
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
class AttentionBlockTests(unittest.TestCase):
@unittest.skipIf(
torch_device == "mps", "Matmul crashes on MPS, see https://github.com/pytorch/pytorch/issues/84039"
)
def test_attention_block_default(self):
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
sample = torch.randn(1, 32, 64, 64).to(torch_device)
attentionBlock = AttentionBlock(
channels=32,
num_head_channels=1,
rescale_output_factor=1.0,
eps=1e-6,
norm_num_groups=32,
).to(torch_device)
with torch.no_grad():
attention_scores = attentionBlock(sample)
assert attention_scores.shape == (1, 32, 64, 64)
output_slice = attention_scores[0, -1, -3:, -3:]
expected_slice = torch.tensor(
[-1.4975, -0.0038, -0.7847, -1.4567, 1.1220, -0.8962, -1.7394, 1.1319, -0.5427], device=torch_device
)
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
def test_attention_block_sd(self):
# This version uses SD params and is compatible with mps
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
sample = torch.randn(1, 512, 64, 64).to(torch_device)
attentionBlock = AttentionBlock(
channels=512,
rescale_output_factor=1.0,
eps=1e-6,
norm_num_groups=32,
).to(torch_device)
with torch.no_grad():
attention_scores = attentionBlock(sample)
assert attention_scores.shape == (1, 512, 64, 64)
output_slice = attention_scores[0, -1, -3:, -3:]
expected_slice = torch.tensor(
[-0.6621, -0.0156, -3.2766, 0.8025, -0.8609, 0.2820, 0.0905, -1.1179, -3.2126], device=torch_device
)
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
class Transformer2DModelTests(unittest.TestCase):
def test_spatial_transformer_default(self):
torch.manual_seed(0)