mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Wuerstchen] fix for when USE_PEFT_BACKEND is True (#5704)
* fix for when USE_PEFT_BACKEND is True * Update modeling_wuerstchen_prior.py * revert change * add lora tests
This commit is contained in:
@@ -17,6 +17,8 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...models.attention_processor import Attention
|
||||
from ...models.lora import LoRACompatibleConv, LoRACompatibleLinear
|
||||
from ...utils import USE_PEFT_BACKEND
|
||||
|
||||
|
||||
class WuerstchenLayerNorm(nn.LayerNorm):
|
||||
@@ -32,7 +34,8 @@ class WuerstchenLayerNorm(nn.LayerNorm):
|
||||
class TimestepBlock(nn.Module):
|
||||
def __init__(self, c, c_timestep):
|
||||
super().__init__()
|
||||
self.mapper = nn.Linear(c_timestep, c * 2)
|
||||
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
||||
self.mapper = linear_cls(c_timestep, c * 2)
|
||||
|
||||
def forward(self, x, t):
|
||||
a, b = self.mapper(t)[:, :, None, None].chunk(2, dim=1)
|
||||
@@ -42,10 +45,14 @@ class TimestepBlock(nn.Module):
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0):
|
||||
super().__init__()
|
||||
self.depthwise = nn.Conv2d(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
|
||||
|
||||
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
||||
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
||||
|
||||
self.depthwise = conv_cls(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
|
||||
self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
|
||||
self.channelwise = nn.Sequential(
|
||||
nn.Linear(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), nn.Linear(c * 4, c)
|
||||
linear_cls(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), linear_cls(c * 4, c)
|
||||
)
|
||||
|
||||
def forward(self, x, x_skip=None):
|
||||
@@ -73,10 +80,13 @@ class GlobalResponseNorm(nn.Module):
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
|
||||
super().__init__()
|
||||
|
||||
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
||||
|
||||
self.self_attn = self_attn
|
||||
self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
|
||||
self.attention = Attention(query_dim=c, heads=nhead, dim_head=c // nhead, dropout=dropout, bias=True)
|
||||
self.kv_mapper = nn.Sequential(nn.SiLU(), nn.Linear(c_cond, c))
|
||||
self.kv_mapper = nn.Sequential(nn.SiLU(), linear_cls(c_cond, c))
|
||||
|
||||
def forward(self, x, kv):
|
||||
kv = self.kv_mapper(kv)
|
||||
|
||||
@@ -28,8 +28,9 @@ from ...models.attention_processor import (
|
||||
AttnAddedKVProcessor,
|
||||
AttnProcessor,
|
||||
)
|
||||
from ...models.lora import LoRACompatibleConv, LoRACompatibleLinear
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...utils import is_torch_version
|
||||
from ...utils import USE_PEFT_BACKEND, is_torch_version
|
||||
from .modeling_wuerstchen_common import AttnBlock, ResBlock, TimestepBlock, WuerstchenLayerNorm
|
||||
|
||||
|
||||
@@ -40,12 +41,15 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
@register_to_config
|
||||
def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dropout=0.1):
|
||||
super().__init__()
|
||||
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
||||
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
||||
|
||||
self.c_r = c_r
|
||||
self.projection = nn.Conv2d(c_in, c, kernel_size=1)
|
||||
self.projection = conv_cls(c_in, c, kernel_size=1)
|
||||
self.cond_mapper = nn.Sequential(
|
||||
nn.Linear(c_cond, c),
|
||||
linear_cls(c_cond, c),
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.Linear(c, c),
|
||||
linear_cls(c, c),
|
||||
)
|
||||
|
||||
self.blocks = nn.ModuleList()
|
||||
@@ -55,7 +59,7 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
self.blocks.append(AttnBlock(c, c, nhead, self_attn=True, dropout=dropout))
|
||||
self.out = nn.Sequential(
|
||||
WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6),
|
||||
nn.Conv2d(c, c_in * 2, kernel_size=1),
|
||||
conv_cls(c, c_in * 2, kernel_size=1),
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@@ -17,11 +17,24 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import DDPMWuerstchenScheduler, WuerstchenPriorPipeline
|
||||
from diffusers.loaders import AttnProcsLayers
|
||||
from diffusers.models.attention_processor import (
|
||||
LoRAAttnProcessor,
|
||||
LoRAAttnProcessor2_0,
|
||||
)
|
||||
from diffusers.pipelines.wuerstchen import WuerstchenPrior
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, skip_mps, torch_device
|
||||
from diffusers.utils.import_utils import is_peft_available
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, require_peft_backend, skip_mps, torch_device
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
from peft import LoraConfig
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
@@ -29,6 +42,19 @@ from ..test_pipelines_common import PipelineTesterMixin
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
def create_prior_lora_layers(unet: nn.Module):
|
||||
lora_attn_procs = {}
|
||||
for name in unet.attn_processors.keys():
|
||||
lora_attn_processor_class = (
|
||||
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
|
||||
)
|
||||
lora_attn_procs[name] = lora_attn_processor_class(
|
||||
hidden_size=unet.config.c,
|
||||
)
|
||||
unet_lora_layers = AttnProcsLayers(lora_attn_procs)
|
||||
return lora_attn_procs, unet_lora_layers
|
||||
|
||||
|
||||
class WuerstchenPriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = WuerstchenPriorPipeline
|
||||
params = ["prompt"]
|
||||
@@ -219,3 +245,52 @@ class WuerstchenPriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
output = pipe(**inputs)[0]
|
||||
assert output.abs().sum() == 0
|
||||
|
||||
def check_if_lora_correctly_set(self, model) -> bool:
|
||||
"""
|
||||
Checks if the LoRA layers are correctly set with peft
|
||||
"""
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_lora_components(self):
|
||||
prior = self.dummy_prior
|
||||
|
||||
prior_lora_config = LoraConfig(
|
||||
r=4, lora_alpha=4, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False
|
||||
)
|
||||
|
||||
prior_lora_attn_procs, prior_lora_layers = create_prior_lora_layers(prior)
|
||||
|
||||
lora_components = {
|
||||
"prior_lora_layers": prior_lora_layers,
|
||||
"prior_lora_attn_procs": prior_lora_attn_procs,
|
||||
}
|
||||
|
||||
return prior, prior_lora_config, lora_components
|
||||
|
||||
@require_peft_backend
|
||||
def test_inference_with_prior_lora(self):
|
||||
_, prior_lora_config, _ = self.get_lora_components()
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(device)
|
||||
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
output_no_lora = pipe(**self.get_dummy_inputs(device))
|
||||
image_embed = output_no_lora.image_embeddings
|
||||
self.assertTrue(image_embed.shape == (1, 2, 24, 24))
|
||||
|
||||
pipe.prior.add_adapter(prior_lora_config)
|
||||
self.assertTrue(self.check_if_lora_correctly_set(pipe.prior), "Lora not correctly set in prior")
|
||||
|
||||
output_lora = pipe(**self.get_dummy_inputs(device))
|
||||
lora_image_embed = output_lora.image_embeddings
|
||||
|
||||
self.assertTrue(image_embed.shape == lora_image_embed.shape)
|
||||
|
||||
Reference in New Issue
Block a user