mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Add support for XFormers in SD3 (#8583)
* Add support for XFormers in SD3 * sd3 xformers test * sd3 xformers quality * sd3 xformers update --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
@@ -358,6 +358,14 @@ class Attention(nn.Module):
|
||||
self.processor,
|
||||
(IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor),
|
||||
)
|
||||
is_joint_processor = hasattr(self, "processor") and isinstance(
|
||||
self.processor,
|
||||
(
|
||||
JointAttnProcessor2_0,
|
||||
XFormersJointAttnProcessor,
|
||||
),
|
||||
)
|
||||
|
||||
if use_memory_efficient_attention_xformers:
|
||||
if is_added_kv_processor and is_custom_diffusion:
|
||||
raise NotImplementedError(
|
||||
@@ -420,6 +428,8 @@ class Attention(nn.Module):
|
||||
processor.to(
|
||||
device=self.processor.to_k_ip[0].weight.device, dtype=self.processor.to_k_ip[0].weight.dtype
|
||||
)
|
||||
elif is_joint_processor:
|
||||
processor = XFormersJointAttnProcessor(attention_op=attention_op)
|
||||
else:
|
||||
processor = XFormersAttnProcessor(attention_op=attention_op)
|
||||
else:
|
||||
@@ -1685,6 +1695,91 @@ class FusedJointAttnProcessor2_0:
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class XFormersJointAttnProcessor:
|
||||
r"""
|
||||
Processor for implementing memory efficient attention using xFormers.
|
||||
|
||||
Args:
|
||||
attention_op (`Callable`, *optional*, defaults to `None`):
|
||||
The base
|
||||
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
|
||||
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
|
||||
operator.
|
||||
"""
|
||||
|
||||
def __init__(self, attention_op: Optional[Callable] = None):
|
||||
self.attention_op = attention_op
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
residual = hidden_states
|
||||
|
||||
# `sample` projections.
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
query = attn.head_to_batch_dim(query).contiguous()
|
||||
key = attn.head_to_batch_dim(key).contiguous()
|
||||
value = attn.head_to_batch_dim(value).contiguous()
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# `context` projections.
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
encoder_hidden_states_query_proj = attn.head_to_batch_dim(encoder_hidden_states_query_proj).contiguous()
|
||||
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj).contiguous()
|
||||
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj).contiguous()
|
||||
|
||||
if attn.norm_added_q is not None:
|
||||
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
||||
|
||||
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
|
||||
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
|
||||
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
|
||||
|
||||
hidden_states = xformers.ops.memory_efficient_attention(
|
||||
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
|
||||
)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
# Split the attention outputs.
|
||||
hidden_states, encoder_hidden_states = (
|
||||
hidden_states[:, : residual.shape[1]],
|
||||
hidden_states[:, residual.shape[1] :],
|
||||
)
|
||||
if not attn.context_pre_only:
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
return hidden_states, encoder_hidden_states
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AllegroAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
|
||||
|
||||
@@ -18,6 +18,7 @@ import unittest
|
||||
import torch
|
||||
|
||||
from diffusers import SD3Transformer2DModel
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
torch_device,
|
||||
@@ -80,6 +81,20 @@ class SD3TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_enable_works(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
|
||||
assert (
|
||||
model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor"
|
||||
), "xformers is not enabled"
|
||||
|
||||
@unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply")
|
||||
def test_set_attn_processor_for_determinism(self):
|
||||
pass
|
||||
@@ -140,6 +155,20 @@ class SD35TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_enable_works(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
|
||||
assert (
|
||||
model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor"
|
||||
), "xformers is not enabled"
|
||||
|
||||
@unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply")
|
||||
def test_set_attn_processor_for_determinism(self):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user