1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[WIP] implement rest of the test cases (LoRA tests) (#2824)

* inital commit for lora test cases

* help a bit with lora for 3d

* fixed lora tests

* replaced redundant code

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
Andy
2023-04-12 06:02:14 -04:00
committed by Daniel Gu
parent a16860e1eb
commit 8344d99d6f
4 changed files with 206 additions and 105 deletions

View File

@@ -251,7 +251,9 @@ class UNetMidBlock3DCrossAttn(nn.Module):
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
).sample
hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample
hidden_states = temp_attn(
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
).sample
hidden_states = resnet(hidden_states, temb)
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
@@ -376,7 +378,9 @@ class CrossAttnDownBlock3D(nn.Module):
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
).sample
hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample
hidden_states = temp_attn(
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
).sample
output_states += (hidden_states,)
@@ -587,7 +591,9 @@ class CrossAttnUpBlock3D(nn.Module):
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
).sample
hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample
hidden_states = temp_attn(
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
).sample
if self.upsamplers is not None:
for upsampler in self.upsamplers:

View File

@@ -20,6 +20,7 @@ import torch.nn as nn
import torch.utils.checkpoint
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import UNet2DConditionLoadersMixin
from ..utils import BaseOutput, logging
from .attention_processor import AttentionProcessor, AttnProcessor
from .embeddings import TimestepEmbedding, Timesteps
@@ -50,7 +51,7 @@ class UNet3DConditionOutput(BaseOutput):
sample: torch.FloatTensor
class UNet3DConditionModel(ModelMixin, ConfigMixin):
class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
r"""
UNet3DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
and returns sample shaped output.
@@ -465,7 +466,9 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:])
sample = self.conv_in(sample)
sample = self.transformer_in(sample, num_frames=num_frames).sample
sample = self.transformer_in(
sample, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
).sample
# 3. down
down_block_res_samples = (sample,)

View File

@@ -41,7 +41,7 @@ logger = logging.get_logger(__name__)
torch.backends.cuda.matmul.allow_tf32 = False
def create_lora_layers(model):
def create_lora_layers(model, mock_weights: bool = True):
lora_attn_procs = {}
for name in model.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
@@ -57,12 +57,13 @@ def create_lora_layers(model):
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
lora_attn_procs[name] = lora_attn_procs[name].to(model.device)
# add 1 to weights to mock trained weights
with torch.no_grad():
lora_attn_procs[name].to_q_lora.up.weight += 1
lora_attn_procs[name].to_k_lora.up.weight += 1
lora_attn_procs[name].to_v_lora.up.weight += 1
lora_attn_procs[name].to_out_lora.up.weight += 1
if mock_weights:
# add 1 to weights to mock trained weights
with torch.no_grad():
lora_attn_procs[name].to_q_lora.up.weight += 1
lora_attn_procs[name].to_k_lora.up.weight += 1
lora_attn_procs[name].to_v_lora.up.weight += 1
lora_attn_procs[name].to_out_lora.up.weight += 1
return lora_attn_procs
@@ -378,26 +379,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
with torch.no_grad():
sample1 = model(**inputs_dict).sample
lora_attn_procs = {}
for name in model.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = model.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = model.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
# add 1 to weights to mock trained weights
with torch.no_grad():
lora_attn_procs[name].to_q_lora.up.weight += 1
lora_attn_procs[name].to_k_lora.up.weight += 1
lora_attn_procs[name].to_v_lora.up.weight += 1
lora_attn_procs[name].to_out_lora.up.weight += 1
lora_attn_procs = create_lora_layers(model)
# make sure we can set a list of attention processors
model.set_attn_processor(lora_attn_procs)
@@ -465,28 +447,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
with torch.no_grad():
old_sample = model(**inputs_dict).sample
lora_attn_procs = {}
for name in model.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = model.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = model.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
lora_attn_procs[name] = lora_attn_procs[name].to(model.device)
# add 1 to weights to mock trained weights
with torch.no_grad():
lora_attn_procs[name].to_q_lora.up.weight += 1
lora_attn_procs[name].to_k_lora.up.weight += 1
lora_attn_procs[name].to_v_lora.up.weight += 1
lora_attn_procs[name].to_out_lora.up.weight += 1
lora_attn_procs = create_lora_layers(model)
model.set_attn_processor(lora_attn_procs)
with torch.no_grad():
@@ -518,21 +479,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
model = self.model_class(**init_dict)
model.to(torch_device)
lora_attn_procs = {}
for name in model.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = model.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = model.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
lora_attn_procs[name] = lora_attn_procs[name].to(model.device)
lora_attn_procs = create_lora_layers(model, mock_weights=False)
model.set_attn_processor(lora_attn_procs)
# Saving as torch, properly reloads with directly filename
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -553,21 +500,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
model = self.model_class(**init_dict)
model.to(torch_device)
lora_attn_procs = {}
for name in model.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = model.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = model.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
lora_attn_procs[name] = lora_attn_procs[name].to(model.device)
lora_attn_procs = create_lora_layers(model, mock_weights=False)
model.set_attn_processor(lora_attn_procs)
# Saving as torch, properly reloads with directly filename
with tempfile.TemporaryDirectory() as tmpdirname:

View File

@@ -13,13 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
import unittest
import numpy as np
import torch
from diffusers.models import ModelMixin, UNet3DConditionModel
from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.models.attention_processor import AttnProcessor, LoRAAttnProcessor
from diffusers.utils import (
floats_tensor,
logging,
@@ -35,10 +37,13 @@ logger = logging.get_logger(__name__)
torch.backends.cuda.matmul.allow_tf32 = False
def create_lora_layers(model):
def create_lora_layers(model, mock_weights: bool = True):
lora_attn_procs = {}
for name in model.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
has_cross_attention = name.endswith("attn2.processor") and not (
name.startswith("transformer_in") or "temp_attentions" in name.split(".")
)
cross_attention_dim = model.config.cross_attention_dim if has_cross_attention else None
if name.startswith("mid_block"):
hidden_size = model.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
@@ -47,16 +52,20 @@ def create_lora_layers(model):
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = model.config.block_out_channels[block_id]
elif name.startswith("transformer_in"):
# Note that the `8 * ...` comes from: https://github.com/huggingface/diffusers/blob/7139f0e874f10b2463caa8cbd585762a309d12d6/src/diffusers/models/unet_3d_condition.py#L148
hidden_size = 8 * model.config.attention_head_dim
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
lora_attn_procs[name] = lora_attn_procs[name].to(model.device)
# add 1 to weights to mock trained weights
with torch.no_grad():
lora_attn_procs[name].to_q_lora.up.weight += 1
lora_attn_procs[name].to_k_lora.up.weight += 1
lora_attn_procs[name].to_v_lora.up.weight += 1
lora_attn_procs[name].to_out_lora.up.weight += 1
if mock_weights:
# add 1 to weights to mock trained weights
with torch.no_grad():
lora_attn_procs[name].to_q_lora.up.weight += 1
lora_attn_procs[name].to_k_lora.up.weight += 1
lora_attn_procs[name].to_v_lora.up.weight += 1
lora_attn_procs[name].to_out_lora.up.weight += 1
return lora_attn_procs
@@ -190,23 +199,173 @@ class UNet3DConditionModelTests(ModelTesterMixin, unittest.TestCase):
output = model(**inputs_dict)
assert output is not None
# (`attn_processors`) needs to be implemented in this model for this test.
# def test_lora_processors(self):
def test_lora_processors(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
# (`attn_processors`) needs to be implemented in this model for this test.
# def test_lora_save_load(self):
init_dict["attention_head_dim"] = 8
# (`attn_processors`) needs to be implemented for this test in the model.
# def test_lora_save_load_safetensors(self):
model = self.model_class(**init_dict)
model.to(torch_device)
# (`attn_processors`) needs to be implemented for this test in the model.
# def test_lora_save_safetensors_load_torch(self):
with torch.no_grad():
sample1 = model(**inputs_dict).sample
# (`attn_processors`) needs to be implemented for this test.
# def test_lora_save_torch_force_load_safetensors_error(self):
lora_attn_procs = create_lora_layers(model)
# (`attn_processors`) needs to be added for this test.
# def test_lora_on_off(self):
# make sure we can set a list of attention processors
model.set_attn_processor(lora_attn_procs)
model.to(torch_device)
# test that attn processors can be set to itself
model.set_attn_processor(model.attn_processors)
with torch.no_grad():
sample2 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample
sample3 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
sample4 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
assert (sample1 - sample2).abs().max() < 1e-4
assert (sample3 - sample4).abs().max() < 1e-4
# sample 2 and sample 3 should be different
assert (sample2 - sample3).abs().max() > 1e-4
def test_lora_save_load(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = 8
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)
with torch.no_grad():
old_sample = model(**inputs_dict).sample
lora_attn_procs = create_lora_layers(model)
model.set_attn_processor(lora_attn_procs)
with torch.no_grad():
sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
torch.manual_seed(0)
new_model = self.model_class(**init_dict)
new_model.to(torch_device)
new_model.load_attn_procs(tmpdirname)
with torch.no_grad():
new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
assert (sample - new_sample).abs().max() < 1e-4
# LoRA and no LoRA should NOT be the same
assert (sample - old_sample).abs().max() > 1e-4
def test_lora_save_load_safetensors(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = 8
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)
with torch.no_grad():
old_sample = model(**inputs_dict).sample
lora_attn_procs = create_lora_layers(model)
model.set_attn_processor(lora_attn_procs)
with torch.no_grad():
sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname, safe_serialization=True)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
torch.manual_seed(0)
new_model = self.model_class(**init_dict)
new_model.to(torch_device)
new_model.load_attn_procs(tmpdirname)
with torch.no_grad():
new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
assert (sample - new_sample).abs().max() < 1e-4
# LoRA and no LoRA should NOT be the same
assert (sample - old_sample).abs().max() > 1e-4
def test_lora_save_safetensors_load_torch(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = 8
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)
lora_attn_procs = create_lora_layers(model, mock_weights=False)
model.set_attn_processor(lora_attn_procs)
# Saving as torch, properly reloads with directly filename
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
torch.manual_seed(0)
new_model = self.model_class(**init_dict)
new_model.to(torch_device)
new_model.load_attn_procs(tmpdirname, weight_name="pytorch_lora_weights.bin")
def test_lora_save_torch_force_load_safetensors_error(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = 8
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)
lora_attn_procs = create_lora_layers(model, mock_weights=False)
model.set_attn_processor(lora_attn_procs)
# Saving as torch, properly reloads with directly filename
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
torch.manual_seed(0)
new_model = self.model_class(**init_dict)
new_model.to(torch_device)
with self.assertRaises(IOError) as e:
new_model.load_attn_procs(tmpdirname, use_safetensors=True)
self.assertIn("Error no file named pytorch_lora_weights.safetensors", str(e.exception))
def test_lora_on_off(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = 8
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)
with torch.no_grad():
old_sample = model(**inputs_dict).sample
lora_attn_procs = create_lora_layers(model)
model.set_attn_processor(lora_attn_procs)
with torch.no_grad():
sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample
model.set_attn_processor(AttnProcessor())
with torch.no_grad():
new_sample = model(**inputs_dict).sample
assert (sample - new_sample).abs().max() < 1e-4
assert (sample - old_sample).abs().max() < 1e-4
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),