diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 08263875d0..b5acd6f4f9 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import Optional +from typing import Callable, Optional import torch import torch.nn.functional as F @@ -72,6 +72,7 @@ class AttentionBlock(nn.Module): self.proj_attn = nn.Linear(channels, channels, 1) self._use_memory_efficient_attention_xformers = False + self._attention_op = None def reshape_heads_to_batch_dim(self, tensor): batch_size, seq_len, dim = tensor.shape @@ -87,7 +88,9 @@ class AttentionBlock(nn.Module): tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) return tensor - def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + def set_use_memory_efficient_attention_xformers( + self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None + ): if use_memory_efficient_attention_xformers: if not is_xformers_available(): raise ModuleNotFoundError( @@ -113,6 +116,7 @@ class AttentionBlock(nn.Module): except Exception as e: raise e self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + self._attention_op = attention_op def forward(self, hidden_states): residual = hidden_states @@ -136,7 +140,9 @@ class AttentionBlock(nn.Module): if self._use_memory_efficient_attention_xformers: # Memory efficient attention - hidden_states = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None) + hidden_states = xformers.ops.memory_efficient_attention( + query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op + ) hidden_states = hidden_states.to(query_proj.dtype) else: attention_scores = torch.baddbmm( diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index d4da50c23f..7dda30fbda 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union +from typing import Callable, Optional, Union import torch import torch.nn.functional as F @@ -93,7 +93,9 @@ class CrossAttention(nn.Module): processor = processor if processor is not None else CrossAttnProcessor() self.set_processor(processor) - def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + def set_use_memory_efficient_attention_xformers( + self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None + ): if use_memory_efficient_attention_xformers: if self.added_kv_proj_dim is not None: # TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP @@ -127,7 +129,7 @@ class CrossAttention(nn.Module): except Exception as e: raise e - processor = XFormersCrossAttnProcessor() + processor = XFormersCrossAttnProcessor(attention_op=attention_op) else: processor = CrossAttnProcessor() @@ -351,6 +353,9 @@ class CrossAttnAddedKVProcessor: class XFormersCrossAttnProcessor: + def __init__(self, attention_op: Optional[Callable] = None): + self.attention_op = attention_op + def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): batch_size, sequence_length, _ = hidden_states.shape @@ -366,7 +371,9 @@ class XFormersCrossAttnProcessor: key = attn.head_to_batch_dim(key).contiguous() value = attn.head_to_batch_dim(value).contiguous() - hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) + hidden_states = xformers.ops.memory_efficient_attention( + query, key, value, attn_bias=attention_mask, op=self.attention_op + ) hidden_states = hidden_states.to(query.dtype) hidden_states = attn.batch_to_head_dim(hidden_states) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index afe5689fdb..ceddf70b2c 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -190,13 +190,15 @@ class ModelMixin(torch.nn.Module): if self._supports_gradient_checkpointing: self.apply(partial(self._set_gradient_checkpointing, value=False)) - def set_use_memory_efficient_attention_xformers(self, valid: bool) -> None: + def set_use_memory_efficient_attention_xformers( + self, valid: bool, attention_op: Optional[Callable] = None + ) -> None: # Recursively walk through all the children. # Any children which exposes the set_use_memory_efficient_attention_xformers method # gets the message def fn_recursive_set_mem_eff(module: torch.nn.Module): if hasattr(module, "set_use_memory_efficient_attention_xformers"): - module.set_use_memory_efficient_attention_xformers(valid) + module.set_use_memory_efficient_attention_xformers(valid, attention_op) for child in module.children(): fn_recursive_set_mem_eff(child) @@ -205,7 +207,7 @@ class ModelMixin(torch.nn.Module): if isinstance(module, torch.nn.Module): fn_recursive_set_mem_eff(module) - def enable_xformers_memory_efficient_attention(self): + def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): r""" Enable memory efficient attention as implemented in xformers. @@ -214,8 +216,28 @@ class ModelMixin(torch.nn.Module): Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention is used. + + Parameters: + attention_op (`Callable`, *optional*): + Override the default `None` operator for use as `op` argument to the + [`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention) + function of xFormers. + + Examples: + + ```py + >>> import torch + >>> from diffusers import UNet2DConditionModel + >>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp + + >>> model = UNet2DConditionModel.from_pretrained( + ... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16 + ... ) + >>> model = model.to("cuda") + >>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp) + ``` """ - self.set_use_memory_efficient_attention_xformers(True) + self.set_use_memory_efficient_attention_xformers(True, attention_op) def disable_xformers_memory_efficient_attention(self): r""" diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 14f0454f6d..1c7d2c41a9 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -19,7 +19,7 @@ import inspect import os from dataclasses import dataclass from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import torch @@ -842,7 +842,7 @@ class DiffusionPipeline(ConfigMixin): def set_progress_bar_config(self, **kwargs): self._progress_bar_config = kwargs - def enable_xformers_memory_efficient_attention(self): + def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): r""" Enable memory efficient attention as implemented in xformers. @@ -851,8 +851,28 @@ class DiffusionPipeline(ConfigMixin): Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention is used. + + Parameters: + attention_op (`Callable`, *optional*): + Override the default `None` operator for use as `op` argument to the + [`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention) + function of xFormers. + + Examples: + + ```py + >>> import torch + >>> from diffusers import DiffusionPipeline + >>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp + + >>> pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + >>> pipe.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp) + >>> # Workaround for not accepting attention shape using VAE for Flash Attention + >>> pipe.vae.enable_xformers_memory_efficient_attention(attention_op=None) + ``` """ - self.set_use_memory_efficient_attention_xformers(True) + self.set_use_memory_efficient_attention_xformers(True, attention_op) def disable_xformers_memory_efficient_attention(self): r""" @@ -860,13 +880,15 @@ class DiffusionPipeline(ConfigMixin): """ self.set_use_memory_efficient_attention_xformers(False) - def set_use_memory_efficient_attention_xformers(self, valid: bool) -> None: + def set_use_memory_efficient_attention_xformers( + self, valid: bool, attention_op: Optional[Callable] = None + ) -> None: # Recursively walk through all the children. # Any children which exposes the set_use_memory_efficient_attention_xformers method # gets the message def fn_recursive_set_mem_eff(module: torch.nn.Module): if hasattr(module, "set_use_memory_efficient_attention_xformers"): - module.set_use_memory_efficient_attention_xformers(valid) + module.set_use_memory_efficient_attention_xformers(valid, attention_op) for child in module.children(): fn_recursive_set_mem_eff(child)