1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Add support for XFormers in SD3 (#8583)

* Add support for XFormers in SD3

* sd3 xformers test

* sd3 xformers quality

* sd3 xformers update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
Canva
2024-12-12 14:35:39 +08:00
committed by GitHub
parent 26e80e0143
commit 7db9463e52
2 changed files with 124 additions and 0 deletions

View File

@@ -358,6 +358,14 @@ class Attention(nn.Module):
self.processor,
(IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor),
)
is_joint_processor = hasattr(self, "processor") and isinstance(
self.processor,
(
JointAttnProcessor2_0,
XFormersJointAttnProcessor,
),
)
if use_memory_efficient_attention_xformers:
if is_added_kv_processor and is_custom_diffusion:
raise NotImplementedError(
@@ -420,6 +428,8 @@ class Attention(nn.Module):
processor.to(
device=self.processor.to_k_ip[0].weight.device, dtype=self.processor.to_k_ip[0].weight.dtype
)
elif is_joint_processor:
processor = XFormersJointAttnProcessor(attention_op=attention_op)
else:
processor = XFormersAttnProcessor(attention_op=attention_op)
else:
@@ -1685,6 +1695,91 @@ class FusedJointAttnProcessor2_0:
return hidden_states, encoder_hidden_states
class XFormersJointAttnProcessor:
r"""
Processor for implementing memory efficient attention using xFormers.
Args:
attention_op (`Callable`, *optional*, defaults to `None`):
The base
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
operator.
"""
def __init__(self, attention_op: Optional[Callable] = None):
self.attention_op = attention_op
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
residual = hidden_states
# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# `context` projections.
if encoder_hidden_states is not None:
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = attn.head_to_batch_dim(encoder_hidden_states_query_proj).contiguous()
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj).contiguous()
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj).contiguous()
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states)
if encoder_hidden_states is not None:
# Split the attention outputs.
hidden_states, encoder_hidden_states = (
hidden_states[:, : residual.shape[1]],
hidden_states[:, residual.shape[1] :],
)
if not attn.context_pre_only:
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if encoder_hidden_states is not None:
return hidden_states, encoder_hidden_states
else:
return hidden_states
class AllegroAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is

View File

@@ -18,6 +18,7 @@ import unittest
import torch
from diffusers import SD3Transformer2DModel
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import (
enable_full_determinism,
torch_device,
@@ -80,6 +81,20 @@ class SD3TransformerTests(ModelTesterMixin, unittest.TestCase):
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_enable_works(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.enable_xformers_memory_efficient_attention()
assert (
model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor"
), "xformers is not enabled"
@unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply")
def test_set_attn_processor_for_determinism(self):
pass
@@ -140,6 +155,20 @@ class SD35TransformerTests(ModelTesterMixin, unittest.TestCase):
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_enable_works(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.enable_xformers_memory_efficient_attention()
assert (
model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor"
), "xformers is not enabled"
@unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply")
def test_set_attn_processor_for_determinism(self):
pass