mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[LoRA] feat: add lora attention processor for pt 2.0. (#3594)
* feat: add lora attention processor for pt 2.0. * explicit context manager for SDPA. * switch to flash attention * make shapes compatible to work optimally with SDPA. * fix: circular import problem. * explicitly specify the flash attention kernel in sdpa * fall back to efficient attention context manager. * remove explicit dispatch. * fix: removed processor. * fix: remove optional from type annotation. * feat: make changes regarding LoRAAttnProcessor2_0. * remove confusing warning. * formatting. * relax tolerance for PT 2.0 * fix: loading message. * remove unnecessary logging. * add: entry to the docs. * add: network_alpha argument. * relax tolerance.
This commit is contained in:
@@ -11,6 +11,9 @@ An attention processor is a class for applying different types of attention mech
|
||||
## LoRAAttnProcessor
|
||||
[[autodoc]] models.attention_processor.LoRAAttnProcessor
|
||||
|
||||
## LoRAAttnProcessor2_0
|
||||
[[autodoc]] models.attention_processor.LoRAAttnProcessor2_0
|
||||
|
||||
## CustomDiffusionAttnProcessor
|
||||
[[autodoc]] models.attention_processor.CustomDiffusionAttnProcessor
|
||||
|
||||
|
||||
@@ -55,6 +55,7 @@ from diffusers.models.attention_processor import (
|
||||
AttnAddedKVProcessor2_0,
|
||||
LoRAAttnAddedKVProcessor,
|
||||
LoRAAttnProcessor,
|
||||
LoRAAttnProcessor2_0,
|
||||
SlicedAttnAddedKVProcessor,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
@@ -844,8 +845,9 @@ def main(args):
|
||||
if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)):
|
||||
lora_attn_processor_class = LoRAAttnAddedKVProcessor
|
||||
else:
|
||||
lora_attn_processor_class = LoRAAttnProcessor
|
||||
|
||||
lora_attn_processor_class = (
|
||||
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
|
||||
)
|
||||
unet_lora_attn_procs[name] = lora_attn_processor_class(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
|
||||
)
|
||||
|
||||
@@ -18,6 +18,7 @@ from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from .models.attention_processor import (
|
||||
@@ -27,6 +28,7 @@ from .models.attention_processor import (
|
||||
CustomDiffusionXFormersAttnProcessor,
|
||||
LoRAAttnAddedKVProcessor,
|
||||
LoRAAttnProcessor,
|
||||
LoRAAttnProcessor2_0,
|
||||
LoRAXFormersAttnProcessor,
|
||||
SlicedAttnAddedKVProcessor,
|
||||
XFormersAttnProcessor,
|
||||
@@ -287,7 +289,9 @@ class UNet2DConditionLoadersMixin:
|
||||
if isinstance(attn_processor, (XFormersAttnProcessor, LoRAXFormersAttnProcessor)):
|
||||
attn_processor_class = LoRAXFormersAttnProcessor
|
||||
else:
|
||||
attn_processor_class = LoRAAttnProcessor
|
||||
attn_processor_class = (
|
||||
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
|
||||
)
|
||||
|
||||
attn_processors[key] = attn_processor_class(
|
||||
hidden_size=hidden_size,
|
||||
@@ -927,11 +931,11 @@ class LoraLoaderMixin:
|
||||
|
||||
# Load the layers corresponding to text encoder and make necessary adjustments.
|
||||
text_encoder_keys = [k for k in keys if k.startswith(self.text_encoder_name)]
|
||||
logger.info(f"Loading {self.text_encoder_name}.")
|
||||
text_encoder_lora_state_dict = {
|
||||
k.replace(f"{self.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
|
||||
}
|
||||
if len(text_encoder_lora_state_dict) > 0:
|
||||
logger.info(f"Loading {self.text_encoder_name}.")
|
||||
attn_procs_text_encoder = self._load_text_encoder_attn_procs(
|
||||
text_encoder_lora_state_dict, network_alpha=network_alpha
|
||||
)
|
||||
|
||||
@@ -11,7 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import warnings
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
@@ -166,7 +165,8 @@ class Attention(nn.Module):
|
||||
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
|
||||
):
|
||||
is_lora = hasattr(self, "processor") and isinstance(
|
||||
self.processor, (LoRAAttnProcessor, LoRAXFormersAttnProcessor, LoRAAttnAddedKVProcessor)
|
||||
self.processor,
|
||||
(LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, LoRAAttnAddedKVProcessor),
|
||||
)
|
||||
is_custom_diffusion = hasattr(self, "processor") and isinstance(
|
||||
self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)
|
||||
@@ -200,14 +200,6 @@ class Attention(nn.Module):
|
||||
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
|
||||
" only available for GPU "
|
||||
)
|
||||
elif hasattr(F, "scaled_dot_product_attention") and self.scale_qk:
|
||||
warnings.warn(
|
||||
"You have specified using flash attention using xFormers but you have PyTorch 2.0 already installed. "
|
||||
"We will default to PyTorch's native efficient flash attention implementation (`F.scaled_dot_product_attention`) "
|
||||
"introduced in PyTorch 2.0. In case you are using LoRA or Custom Diffusion, we will fall "
|
||||
"back to their respective attention processors i.e., we will NOT use the PyTorch 2.0 "
|
||||
"native efficient flash attention."
|
||||
)
|
||||
else:
|
||||
try:
|
||||
# Make sure we can run the memory efficient attention
|
||||
@@ -220,6 +212,8 @@ class Attention(nn.Module):
|
||||
raise e
|
||||
|
||||
if is_lora:
|
||||
# TODO (sayakpaul): should we throw a warning if someone wants to use the xformers
|
||||
# variant when using PT 2.0 now that we have LoRAAttnProcessor2_0?
|
||||
processor = LoRAXFormersAttnProcessor(
|
||||
hidden_size=self.processor.hidden_size,
|
||||
cross_attention_dim=self.processor.cross_attention_dim,
|
||||
@@ -252,7 +246,10 @@ class Attention(nn.Module):
|
||||
processor = XFormersAttnProcessor(attention_op=attention_op)
|
||||
else:
|
||||
if is_lora:
|
||||
processor = LoRAAttnProcessor(
|
||||
attn_processor_class = (
|
||||
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
|
||||
)
|
||||
processor = attn_processor_class(
|
||||
hidden_size=self.processor.hidden_size,
|
||||
cross_attention_dim=self.processor.cross_attention_dim,
|
||||
rank=self.processor.rank,
|
||||
@@ -548,6 +545,8 @@ class LoRAAttnProcessor(nn.Module):
|
||||
The number of channels in the `encoder_hidden_states`.
|
||||
rank (`int`, defaults to 4):
|
||||
The dimension of the LoRA update matrices.
|
||||
network_alpha (`int`, *optional*):
|
||||
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
|
||||
@@ -843,6 +842,7 @@ class LoRAAttnAddedKVProcessor(nn.Module):
|
||||
The number of channels in the `encoder_hidden_states`.
|
||||
rank (`int`, defaults to 4):
|
||||
The dimension of the LoRA update matrices.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
|
||||
@@ -1162,6 +1162,9 @@ class LoRAXFormersAttnProcessor(nn.Module):
|
||||
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
|
||||
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
|
||||
operator.
|
||||
network_alpha (`int`, *optional*):
|
||||
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -1236,6 +1239,97 @@ class LoRAXFormersAttnProcessor(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LoRAAttnProcessor2_0(nn.Module):
|
||||
r"""
|
||||
Processor for implementing the LoRA attention mechanism using PyTorch 2.0's memory-efficient scaled dot-product
|
||||
attention.
|
||||
|
||||
Args:
|
||||
hidden_size (`int`):
|
||||
The hidden size of the attention layer.
|
||||
cross_attention_dim (`int`, *optional*):
|
||||
The number of channels in the `encoder_hidden_states`.
|
||||
rank (`int`, defaults to 4):
|
||||
The dimension of the LoRA update matrices.
|
||||
network_alpha (`int`, *optional*):
|
||||
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
|
||||
super().__init__()
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.rank = rank
|
||||
|
||||
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
||||
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
||||
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
||||
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
||||
|
||||
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
|
||||
residual = hidden_states
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
inner_dim = hidden_states.shape[-1]
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
|
||||
|
||||
head_dim = inner_dim // attn.heads
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CustomDiffusionXFormersAttnProcessor(nn.Module):
|
||||
r"""
|
||||
Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
|
||||
@@ -1520,6 +1614,7 @@ AttentionProcessor = Union[
|
||||
XFormersAttnAddedKVProcessor,
|
||||
LoRAAttnProcessor,
|
||||
LoRAXFormersAttnProcessor,
|
||||
LoRAAttnProcessor2_0,
|
||||
LoRAAttnAddedKVProcessor,
|
||||
CustomDiffusionAttnProcessor,
|
||||
CustomDiffusionXFormersAttnProcessor,
|
||||
|
||||
@@ -19,6 +19,7 @@ import unittest
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
||||
@@ -28,6 +29,7 @@ from diffusers.models.attention_processor import (
|
||||
AttnProcessor,
|
||||
AttnProcessor2_0,
|
||||
LoRAAttnProcessor,
|
||||
LoRAAttnProcessor2_0,
|
||||
LoRAXFormersAttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
@@ -46,16 +48,24 @@ def create_unet_lora_layers(unet: nn.Module):
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = unet.config.block_out_channels[block_id]
|
||||
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
|
||||
lora_attn_processor_class = (
|
||||
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
|
||||
)
|
||||
lora_attn_procs[name] = lora_attn_processor_class(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
|
||||
)
|
||||
unet_lora_layers = AttnProcsLayers(lora_attn_procs)
|
||||
return lora_attn_procs, unet_lora_layers
|
||||
|
||||
|
||||
def create_text_encoder_lora_attn_procs(text_encoder: nn.Module):
|
||||
text_lora_attn_procs = {}
|
||||
lora_attn_processor_class = (
|
||||
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
|
||||
)
|
||||
for name, module in text_encoder.named_modules():
|
||||
if name.endswith(TEXT_ENCODER_ATTN_MODULE):
|
||||
text_lora_attn_procs[name] = LoRAAttnProcessor(
|
||||
text_lora_attn_procs[name] = lora_attn_processor_class(
|
||||
hidden_size=module.out_proj.out_features, cross_attention_dim=None
|
||||
)
|
||||
return text_lora_attn_procs
|
||||
@@ -368,7 +378,10 @@ class LoraLoaderMixinTests(unittest.TestCase):
|
||||
# check if lora attention processors are used
|
||||
for _, module in sd_pipe.unet.named_modules():
|
||||
if isinstance(module, Attention):
|
||||
self.assertIsInstance(module.processor, LoRAAttnProcessor)
|
||||
attn_proc_class = (
|
||||
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
|
||||
)
|
||||
self.assertIsInstance(module.processor, attn_proc_class)
|
||||
|
||||
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
|
||||
def test_lora_unet_attn_processors_with_xformers(self):
|
||||
|
||||
@@ -261,7 +261,7 @@ class UNet3DConditionModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
with torch.no_grad():
|
||||
new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
|
||||
|
||||
assert (sample - new_sample).abs().max() < 1e-4
|
||||
assert (sample - new_sample).abs().max() < 5e-4
|
||||
|
||||
# LoRA and no LoRA should NOT be the same
|
||||
assert (sample - old_sample).abs().max() > 1e-4
|
||||
@@ -295,7 +295,7 @@ class UNet3DConditionModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
with torch.no_grad():
|
||||
new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
|
||||
|
||||
assert (sample - new_sample).abs().max() < 1e-4
|
||||
assert (sample - new_sample).abs().max() < 3e-4
|
||||
|
||||
# LoRA and no LoRA should NOT be the same
|
||||
assert (sample - old_sample).abs().max() > 1e-4
|
||||
|
||||
Reference in New Issue
Block a user