mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[docs] Add AttnProcessor to docs (#3474)
* add attnprocessor to docs * fix path to class * create separate page for attnprocessors * fix path * fix path for real * fill in docstrings * apply feedback * apply feedback
This commit is contained in:
@@ -132,6 +132,8 @@
|
||||
- sections:
|
||||
- local: api/models
|
||||
title: Models
|
||||
- local: api/attnprocessor
|
||||
title: Attention Processor
|
||||
- local: api/diffusion_pipeline
|
||||
title: Diffusion Pipeline
|
||||
- local: api/logging
|
||||
|
||||
39
docs/source/en/api/attnprocessor.mdx
Normal file
39
docs/source/en/api/attnprocessor.mdx
Normal file
@@ -0,0 +1,39 @@
|
||||
# Attention Processor
|
||||
|
||||
An attention processor is a class for applying different types of attention mechanisms.
|
||||
|
||||
## AttnProcessor
|
||||
[[autodoc]] models.attention_processor.AttnProcessor
|
||||
|
||||
## AttnProcessor2_0
|
||||
[[autodoc]] models.attention_processor.AttnProcessor2_0
|
||||
|
||||
## LoRAAttnProcessor
|
||||
[[autodoc]] models.attention_processor.LoRAAttnProcessor
|
||||
|
||||
## CustomDiffusionAttnProcessor
|
||||
[[autodoc]] models.attention_processor.CustomDiffusionAttnProcessor
|
||||
|
||||
## AttnAddedKVProcessor
|
||||
[[autodoc]] models.attention_processor.AttnAddedKVProcessor
|
||||
|
||||
## AttnAddedKVProcessor2_0
|
||||
[[autodoc]] models.attention_processor.AttnAddedKVProcessor2_0
|
||||
|
||||
## LoRAAttnAddedKVProcessor
|
||||
[[autodoc]] models.attention_processor.LoRAAttnAddedKVProcessor
|
||||
|
||||
## XFormersAttnProcessor
|
||||
[[autodoc]] models.attention_processor.XFormersAttnProcessor
|
||||
|
||||
## LoRAXFormersAttnProcessor
|
||||
[[autodoc]] models.attention_processor.LoRAXFormersAttnProcessor
|
||||
|
||||
## CustomDiffusionXFormersAttnProcessor
|
||||
[[autodoc]] models.attention_processor.CustomDiffusionXFormersAttnProcessor
|
||||
|
||||
## SlicedAttnProcessor
|
||||
[[autodoc]] models.attention_processor.SlicedAttnProcessor
|
||||
|
||||
## SlicedAttnAddedKVProcessor
|
||||
[[autodoc]] models.attention_processor.SlicedAttnAddedKVProcessor
|
||||
@@ -431,6 +431,10 @@ class Attention(nn.Module):
|
||||
|
||||
|
||||
class AttnProcessor:
|
||||
r"""
|
||||
Default processor for performing attention-related computations.
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
@@ -516,6 +520,18 @@ class LoRALinearLayer(nn.Module):
|
||||
|
||||
|
||||
class LoRAAttnProcessor(nn.Module):
|
||||
r"""
|
||||
Processor for implementing the LoRA attention mechanism.
|
||||
|
||||
Args:
|
||||
hidden_size (`int`, *optional*):
|
||||
The hidden size of the attention layer.
|
||||
cross_attention_dim (`int`, *optional*):
|
||||
The number of channels in the `encoder_hidden_states`.
|
||||
rank (`int`, defaults to 4):
|
||||
The dimension of the LoRA update matrices.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
|
||||
super().__init__()
|
||||
|
||||
@@ -580,6 +596,24 @@ class LoRAAttnProcessor(nn.Module):
|
||||
|
||||
|
||||
class CustomDiffusionAttnProcessor(nn.Module):
|
||||
r"""
|
||||
Processor for implementing attention for the Custom Diffusion method.
|
||||
|
||||
Args:
|
||||
train_kv (`bool`, defaults to `True`):
|
||||
Whether to newly train the key and value matrices corresponding to the text features.
|
||||
train_q_out (`bool`, defaults to `True`):
|
||||
Whether to newly train query matrices corresponding to the latent image features.
|
||||
hidden_size (`int`, *optional*, defaults to `None`):
|
||||
The hidden size of the attention layer.
|
||||
cross_attention_dim (`int`, *optional*, defaults to `None`):
|
||||
The number of channels in the `encoder_hidden_states`.
|
||||
out_bias (`bool`, defaults to `True`):
|
||||
Whether to include the bias parameter in `train_q_out`.
|
||||
dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probability to use.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
train_kv=True,
|
||||
@@ -658,6 +692,11 @@ class CustomDiffusionAttnProcessor(nn.Module):
|
||||
|
||||
|
||||
class AttnAddedKVProcessor:
|
||||
r"""
|
||||
Processor for performing attention-related computations with extra learnable key and value matrices for the text
|
||||
encoder.
|
||||
"""
|
||||
|
||||
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
residual = hidden_states
|
||||
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
||||
@@ -707,6 +746,11 @@ class AttnAddedKVProcessor:
|
||||
|
||||
|
||||
class AttnAddedKVProcessor2_0:
|
||||
r"""
|
||||
Processor for performing scaled dot-product attention (enabled by default if you're using PyTorch 2.0), with extra
|
||||
learnable key and value matrices for the text encoder.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
@@ -765,6 +809,19 @@ class AttnAddedKVProcessor2_0:
|
||||
|
||||
|
||||
class LoRAAttnAddedKVProcessor(nn.Module):
|
||||
r"""
|
||||
Processor for implementing the LoRA attention mechanism with extra learnable key and value matrices for the text
|
||||
encoder.
|
||||
|
||||
Args:
|
||||
hidden_size (`int`, *optional*):
|
||||
The hidden size of the attention layer.
|
||||
cross_attention_dim (`int`, *optional*, defaults to `None`):
|
||||
The number of channels in the `encoder_hidden_states`.
|
||||
rank (`int`, defaults to 4):
|
||||
The dimension of the LoRA update matrices.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
|
||||
super().__init__()
|
||||
|
||||
@@ -832,6 +889,17 @@ class LoRAAttnAddedKVProcessor(nn.Module):
|
||||
|
||||
|
||||
class XFormersAttnProcessor:
|
||||
r"""
|
||||
Processor for implementing memory efficient attention using xFormers.
|
||||
|
||||
Args:
|
||||
attention_op (`Callable`, *optional*, defaults to `None`):
|
||||
The base
|
||||
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
|
||||
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
|
||||
operator.
|
||||
"""
|
||||
|
||||
def __init__(self, attention_op: Optional[Callable] = None):
|
||||
self.attention_op = attention_op
|
||||
|
||||
@@ -905,6 +973,10 @@ class XFormersAttnProcessor:
|
||||
|
||||
|
||||
class AttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
@@ -983,6 +1055,23 @@ class AttnProcessor2_0:
|
||||
|
||||
|
||||
class LoRAXFormersAttnProcessor(nn.Module):
|
||||
r"""
|
||||
Processor for implementing the LoRA attention mechanism with memory efficient attention using xFormers.
|
||||
|
||||
Args:
|
||||
hidden_size (`int`, *optional*):
|
||||
The hidden size of the attention layer.
|
||||
cross_attention_dim (`int`, *optional*):
|
||||
The number of channels in the `encoder_hidden_states`.
|
||||
rank (`int`, defaults to 4):
|
||||
The dimension of the LoRA update matrices.
|
||||
attention_op (`Callable`, *optional*, defaults to `None`):
|
||||
The base
|
||||
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
|
||||
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
|
||||
operator.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None):
|
||||
super().__init__()
|
||||
|
||||
@@ -1049,6 +1138,28 @@ class LoRAXFormersAttnProcessor(nn.Module):
|
||||
|
||||
|
||||
class CustomDiffusionXFormersAttnProcessor(nn.Module):
|
||||
r"""
|
||||
Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
|
||||
|
||||
Args:
|
||||
train_kv (`bool`, defaults to `True`):
|
||||
Whether to newly train the key and value matrices corresponding to the text features.
|
||||
train_q_out (`bool`, defaults to `True`):
|
||||
Whether to newly train query matrices corresponding to the latent image features.
|
||||
hidden_size (`int`, *optional*, defaults to `None`):
|
||||
The hidden size of the attention layer.
|
||||
cross_attention_dim (`int`, *optional*, defaults to `None`):
|
||||
The number of channels in the `encoder_hidden_states`.
|
||||
out_bias (`bool`, defaults to `True`):
|
||||
Whether to include the bias parameter in `train_q_out`.
|
||||
dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probability to use.
|
||||
attention_op (`Callable`, *optional*, defaults to `None`):
|
||||
The base
|
||||
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use
|
||||
as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
train_kv=True,
|
||||
@@ -1134,6 +1245,15 @@ class CustomDiffusionXFormersAttnProcessor(nn.Module):
|
||||
|
||||
|
||||
class SlicedAttnProcessor:
|
||||
r"""
|
||||
Processor for implementing sliced attention.
|
||||
|
||||
Args:
|
||||
slice_size (`int`, *optional*):
|
||||
The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
|
||||
`attention_head_dim` must be a multiple of the `slice_size`.
|
||||
"""
|
||||
|
||||
def __init__(self, slice_size):
|
||||
self.slice_size = slice_size
|
||||
|
||||
@@ -1206,6 +1326,15 @@ class SlicedAttnProcessor:
|
||||
|
||||
|
||||
class SlicedAttnAddedKVProcessor:
|
||||
r"""
|
||||
Processor for implementing sliced attention with extra learnable key and value matrices for the text encoder.
|
||||
|
||||
Args:
|
||||
slice_size (`int`, *optional*):
|
||||
The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
|
||||
`attention_head_dim` must be a multiple of the `slice_size`.
|
||||
"""
|
||||
|
||||
def __init__(self, slice_size):
|
||||
self.slice_size = slice_size
|
||||
|
||||
|
||||
Reference in New Issue
Block a user