From c9f847a70f7d44f6a856fd61d4fa03dbbab72fdc Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 13 Nov 2023 18:37:55 +0100 Subject: [PATCH] [Wuerstchen] fix for when USE_PEFT_BACKEND is True (#5704) * fix for when USE_PEFT_BACKEND is True * Update modeling_wuerstchen_prior.py * revert change * add lora tests --- .../wuerstchen/modeling_wuerstchen_common.py | 18 ++++- .../wuerstchen/modeling_wuerstchen_prior.py | 14 ++-- .../wuerstchen/test_wuerstchen_prior.py | 77 ++++++++++++++++++- 3 files changed, 99 insertions(+), 10 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py index b3aac39386..00d6f01bec 100644 --- a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +++ b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py @@ -17,6 +17,8 @@ import torch import torch.nn as nn from ...models.attention_processor import Attention +from ...models.lora import LoRACompatibleConv, LoRACompatibleLinear +from ...utils import USE_PEFT_BACKEND class WuerstchenLayerNorm(nn.LayerNorm): @@ -32,7 +34,8 @@ class WuerstchenLayerNorm(nn.LayerNorm): class TimestepBlock(nn.Module): def __init__(self, c, c_timestep): super().__init__() - self.mapper = nn.Linear(c_timestep, c * 2) + linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear + self.mapper = linear_cls(c_timestep, c * 2) def forward(self, x, t): a, b = self.mapper(t)[:, :, None, None].chunk(2, dim=1) @@ -42,10 +45,14 @@ class TimestepBlock(nn.Module): class ResBlock(nn.Module): def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0): super().__init__() - self.depthwise = nn.Conv2d(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c) + + conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv + linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear + + self.depthwise = conv_cls(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c) self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6) self.channelwise = nn.Sequential( - nn.Linear(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), nn.Linear(c * 4, c) + linear_cls(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), linear_cls(c * 4, c) ) def forward(self, x, x_skip=None): @@ -73,10 +80,13 @@ class GlobalResponseNorm(nn.Module): class AttnBlock(nn.Module): def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0): super().__init__() + + linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear + self.self_attn = self_attn self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6) self.attention = Attention(query_dim=c, heads=nhead, dim_head=c // nhead, dropout=dropout, bias=True) - self.kv_mapper = nn.Sequential(nn.SiLU(), nn.Linear(c_cond, c)) + self.kv_mapper = nn.Sequential(nn.SiLU(), linear_cls(c_cond, c)) def forward(self, x, kv): kv = self.kv_mapper(kv) diff --git a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py index ca72ce581f..a7d9e32fb6 100644 --- a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py @@ -28,8 +28,9 @@ from ...models.attention_processor import ( AttnAddedKVProcessor, AttnProcessor, ) +from ...models.lora import LoRACompatibleConv, LoRACompatibleLinear from ...models.modeling_utils import ModelMixin -from ...utils import is_torch_version +from ...utils import USE_PEFT_BACKEND, is_torch_version from .modeling_wuerstchen_common import AttnBlock, ResBlock, TimestepBlock, WuerstchenLayerNorm @@ -40,12 +41,15 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): @register_to_config def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dropout=0.1): super().__init__() + conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv + linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear + self.c_r = c_r - self.projection = nn.Conv2d(c_in, c, kernel_size=1) + self.projection = conv_cls(c_in, c, kernel_size=1) self.cond_mapper = nn.Sequential( - nn.Linear(c_cond, c), + linear_cls(c_cond, c), nn.LeakyReLU(0.2), - nn.Linear(c, c), + linear_cls(c, c), ) self.blocks = nn.ModuleList() @@ -55,7 +59,7 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): self.blocks.append(AttnBlock(c, c, nhead, self_attn=True, dropout=dropout)) self.out = nn.Sequential( WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6), - nn.Conv2d(c, c_in * 2, kernel_size=1), + conv_cls(c, c_in * 2, kernel_size=1), ) self.gradient_checkpointing = False diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_prior.py b/tests/pipelines/wuerstchen/test_wuerstchen_prior.py index 59dbc90b98..5e1b89c0d2 100644 --- a/tests/pipelines/wuerstchen/test_wuerstchen_prior.py +++ b/tests/pipelines/wuerstchen/test_wuerstchen_prior.py @@ -17,11 +17,24 @@ import unittest import numpy as np import torch +import torch.nn as nn +import torch.nn.functional as F from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from diffusers import DDPMWuerstchenScheduler, WuerstchenPriorPipeline +from diffusers.loaders import AttnProcsLayers +from diffusers.models.attention_processor import ( + LoRAAttnProcessor, + LoRAAttnProcessor2_0, +) from diffusers.pipelines.wuerstchen import WuerstchenPrior -from diffusers.utils.testing_utils import enable_full_determinism, skip_mps, torch_device +from diffusers.utils.import_utils import is_peft_available +from diffusers.utils.testing_utils import enable_full_determinism, require_peft_backend, skip_mps, torch_device + + +if is_peft_available(): + from peft import LoraConfig + from peft.tuners.tuners_utils import BaseTunerLayer from ..test_pipelines_common import PipelineTesterMixin @@ -29,6 +42,19 @@ from ..test_pipelines_common import PipelineTesterMixin enable_full_determinism() +def create_prior_lora_layers(unet: nn.Module): + lora_attn_procs = {} + for name in unet.attn_processors.keys(): + lora_attn_processor_class = ( + LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor + ) + lora_attn_procs[name] = lora_attn_processor_class( + hidden_size=unet.config.c, + ) + unet_lora_layers = AttnProcsLayers(lora_attn_procs) + return lora_attn_procs, unet_lora_layers + + class WuerstchenPriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = WuerstchenPriorPipeline params = ["prompt"] @@ -219,3 +245,52 @@ class WuerstchenPriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase): output = pipe(**inputs)[0] assert output.abs().sum() == 0 + + def check_if_lora_correctly_set(self, model) -> bool: + """ + Checks if the LoRA layers are correctly set with peft + """ + for module in model.modules(): + if isinstance(module, BaseTunerLayer): + return True + return False + + def get_lora_components(self): + prior = self.dummy_prior + + prior_lora_config = LoraConfig( + r=4, lora_alpha=4, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False + ) + + prior_lora_attn_procs, prior_lora_layers = create_prior_lora_layers(prior) + + lora_components = { + "prior_lora_layers": prior_lora_layers, + "prior_lora_attn_procs": prior_lora_attn_procs, + } + + return prior, prior_lora_config, lora_components + + @require_peft_backend + def test_inference_with_prior_lora(self): + _, prior_lora_config, _ = self.get_lora_components() + device = "cpu" + + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + + pipe.set_progress_bar_config(disable=None) + + output_no_lora = pipe(**self.get_dummy_inputs(device)) + image_embed = output_no_lora.image_embeddings + self.assertTrue(image_embed.shape == (1, 2, 24, 24)) + + pipe.prior.add_adapter(prior_lora_config) + self.assertTrue(self.check_if_lora_correctly_set(pipe.prior), "Lora not correctly set in prior") + + output_lora = pipe(**self.get_dummy_inputs(device)) + lora_image_embed = output_lora.image_embeddings + + self.assertTrue(image_embed.shape == lora_image_embed.shape)