1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Wuerstchen] fix for when USE_PEFT_BACKEND is True (#5704)

* fix for when USE_PEFT_BACKEND is True

* Update modeling_wuerstchen_prior.py

* revert change

* add lora tests
This commit is contained in:
Kashif Rasul
2023-11-13 18:37:55 +01:00
committed by GitHub
parent 8789d0b6c7
commit c9f847a70f
3 changed files with 99 additions and 10 deletions

View File

@@ -17,6 +17,8 @@ import torch
import torch.nn as nn
from ...models.attention_processor import Attention
from ...models.lora import LoRACompatibleConv, LoRACompatibleLinear
from ...utils import USE_PEFT_BACKEND
class WuerstchenLayerNorm(nn.LayerNorm):
@@ -32,7 +34,8 @@ class WuerstchenLayerNorm(nn.LayerNorm):
class TimestepBlock(nn.Module):
def __init__(self, c, c_timestep):
super().__init__()
self.mapper = nn.Linear(c_timestep, c * 2)
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
self.mapper = linear_cls(c_timestep, c * 2)
def forward(self, x, t):
a, b = self.mapper(t)[:, :, None, None].chunk(2, dim=1)
@@ -42,10 +45,14 @@ class TimestepBlock(nn.Module):
class ResBlock(nn.Module):
def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0):
super().__init__()
self.depthwise = nn.Conv2d(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
self.depthwise = conv_cls(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
self.channelwise = nn.Sequential(
nn.Linear(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), nn.Linear(c * 4, c)
linear_cls(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), linear_cls(c * 4, c)
)
def forward(self, x, x_skip=None):
@@ -73,10 +80,13 @@ class GlobalResponseNorm(nn.Module):
class AttnBlock(nn.Module):
def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
super().__init__()
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
self.self_attn = self_attn
self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
self.attention = Attention(query_dim=c, heads=nhead, dim_head=c // nhead, dropout=dropout, bias=True)
self.kv_mapper = nn.Sequential(nn.SiLU(), nn.Linear(c_cond, c))
self.kv_mapper = nn.Sequential(nn.SiLU(), linear_cls(c_cond, c))
def forward(self, x, kv):
kv = self.kv_mapper(kv)

View File

@@ -28,8 +28,9 @@ from ...models.attention_processor import (
AttnAddedKVProcessor,
AttnProcessor,
)
from ...models.lora import LoRACompatibleConv, LoRACompatibleLinear
from ...models.modeling_utils import ModelMixin
from ...utils import is_torch_version
from ...utils import USE_PEFT_BACKEND, is_torch_version
from .modeling_wuerstchen_common import AttnBlock, ResBlock, TimestepBlock, WuerstchenLayerNorm
@@ -40,12 +41,15 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
@register_to_config
def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dropout=0.1):
super().__init__()
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
self.c_r = c_r
self.projection = nn.Conv2d(c_in, c, kernel_size=1)
self.projection = conv_cls(c_in, c, kernel_size=1)
self.cond_mapper = nn.Sequential(
nn.Linear(c_cond, c),
linear_cls(c_cond, c),
nn.LeakyReLU(0.2),
nn.Linear(c, c),
linear_cls(c, c),
)
self.blocks = nn.ModuleList()
@@ -55,7 +59,7 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
self.blocks.append(AttnBlock(c, c, nhead, self_attn=True, dropout=dropout))
self.out = nn.Sequential(
WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6),
nn.Conv2d(c, c_in * 2, kernel_size=1),
conv_cls(c, c_in * 2, kernel_size=1),
)
self.gradient_checkpointing = False

View File

@@ -17,11 +17,24 @@ import unittest
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import DDPMWuerstchenScheduler, WuerstchenPriorPipeline
from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import (
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
)
from diffusers.pipelines.wuerstchen import WuerstchenPrior
from diffusers.utils.testing_utils import enable_full_determinism, skip_mps, torch_device
from diffusers.utils.import_utils import is_peft_available
from diffusers.utils.testing_utils import enable_full_determinism, require_peft_backend, skip_mps, torch_device
if is_peft_available():
from peft import LoraConfig
from peft.tuners.tuners_utils import BaseTunerLayer
from ..test_pipelines_common import PipelineTesterMixin
@@ -29,6 +42,19 @@ from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
def create_prior_lora_layers(unet: nn.Module):
lora_attn_procs = {}
for name in unet.attn_processors.keys():
lora_attn_processor_class = (
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
)
lora_attn_procs[name] = lora_attn_processor_class(
hidden_size=unet.config.c,
)
unet_lora_layers = AttnProcsLayers(lora_attn_procs)
return lora_attn_procs, unet_lora_layers
class WuerstchenPriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = WuerstchenPriorPipeline
params = ["prompt"]
@@ -219,3 +245,52 @@ class WuerstchenPriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
output = pipe(**inputs)[0]
assert output.abs().sum() == 0
def check_if_lora_correctly_set(self, model) -> bool:
"""
Checks if the LoRA layers are correctly set with peft
"""
for module in model.modules():
if isinstance(module, BaseTunerLayer):
return True
return False
def get_lora_components(self):
prior = self.dummy_prior
prior_lora_config = LoraConfig(
r=4, lora_alpha=4, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False
)
prior_lora_attn_procs, prior_lora_layers = create_prior_lora_layers(prior)
lora_components = {
"prior_lora_layers": prior_lora_layers,
"prior_lora_attn_procs": prior_lora_attn_procs,
}
return prior, prior_lora_config, lora_components
@require_peft_backend
def test_inference_with_prior_lora(self):
_, prior_lora_config, _ = self.get_lora_components()
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
output_no_lora = pipe(**self.get_dummy_inputs(device))
image_embed = output_no_lora.image_embeddings
self.assertTrue(image_embed.shape == (1, 2, 24, 24))
pipe.prior.add_adapter(prior_lora_config)
self.assertTrue(self.check_if_lora_correctly_set(pipe.prior), "Lora not correctly set in prior")
output_lora = pipe(**self.get_dummy_inputs(device))
lora_image_embed = output_lora.image_embeddings
self.assertTrue(image_embed.shape == lora_image_embed.shape)