mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[WIP] implement rest of the test cases (LoRA tests) (#2824)
* inital commit for lora test cases * help a bit with lora for 3d * fixed lora tests * replaced redundant code --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -251,7 +251,9 @@ class UNetMidBlock3DCrossAttn(nn.Module):
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
).sample
|
||||
hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample
|
||||
hidden_states = temp_attn(
|
||||
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
|
||||
).sample
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
|
||||
|
||||
@@ -376,7 +378,9 @@ class CrossAttnDownBlock3D(nn.Module):
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
).sample
|
||||
hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample
|
||||
hidden_states = temp_attn(
|
||||
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
|
||||
).sample
|
||||
|
||||
output_states += (hidden_states,)
|
||||
|
||||
@@ -587,7 +591,9 @@ class CrossAttnUpBlock3D(nn.Module):
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
).sample
|
||||
hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample
|
||||
hidden_states = temp_attn(
|
||||
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
|
||||
).sample
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
|
||||
@@ -20,6 +20,7 @@ import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import UNet2DConditionLoadersMixin
|
||||
from ..utils import BaseOutput, logging
|
||||
from .attention_processor import AttentionProcessor, AttnProcessor
|
||||
from .embeddings import TimestepEmbedding, Timesteps
|
||||
@@ -50,7 +51,7 @@ class UNet3DConditionOutput(BaseOutput):
|
||||
sample: torch.FloatTensor
|
||||
|
||||
|
||||
class UNet3DConditionModel(ModelMixin, ConfigMixin):
|
||||
class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
r"""
|
||||
UNet3DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
|
||||
and returns sample shaped output.
|
||||
@@ -465,7 +466,9 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
|
||||
sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:])
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
sample = self.transformer_in(sample, num_frames=num_frames).sample
|
||||
sample = self.transformer_in(
|
||||
sample, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
|
||||
).sample
|
||||
|
||||
# 3. down
|
||||
down_block_res_samples = (sample,)
|
||||
|
||||
@@ -41,7 +41,7 @@ logger = logging.get_logger(__name__)
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
def create_lora_layers(model):
|
||||
def create_lora_layers(model, mock_weights: bool = True):
|
||||
lora_attn_procs = {}
|
||||
for name in model.attn_processors.keys():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
|
||||
@@ -57,12 +57,13 @@ def create_lora_layers(model):
|
||||
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
|
||||
lora_attn_procs[name] = lora_attn_procs[name].to(model.device)
|
||||
|
||||
# add 1 to weights to mock trained weights
|
||||
with torch.no_grad():
|
||||
lora_attn_procs[name].to_q_lora.up.weight += 1
|
||||
lora_attn_procs[name].to_k_lora.up.weight += 1
|
||||
lora_attn_procs[name].to_v_lora.up.weight += 1
|
||||
lora_attn_procs[name].to_out_lora.up.weight += 1
|
||||
if mock_weights:
|
||||
# add 1 to weights to mock trained weights
|
||||
with torch.no_grad():
|
||||
lora_attn_procs[name].to_q_lora.up.weight += 1
|
||||
lora_attn_procs[name].to_k_lora.up.weight += 1
|
||||
lora_attn_procs[name].to_v_lora.up.weight += 1
|
||||
lora_attn_procs[name].to_out_lora.up.weight += 1
|
||||
|
||||
return lora_attn_procs
|
||||
|
||||
@@ -378,26 +379,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
with torch.no_grad():
|
||||
sample1 = model(**inputs_dict).sample
|
||||
|
||||
lora_attn_procs = {}
|
||||
for name in model.attn_processors.keys():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = model.config.block_out_channels[-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
block_id = int(name[len("up_blocks.")])
|
||||
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = model.config.block_out_channels[block_id]
|
||||
|
||||
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
|
||||
|
||||
# add 1 to weights to mock trained weights
|
||||
with torch.no_grad():
|
||||
lora_attn_procs[name].to_q_lora.up.weight += 1
|
||||
lora_attn_procs[name].to_k_lora.up.weight += 1
|
||||
lora_attn_procs[name].to_v_lora.up.weight += 1
|
||||
lora_attn_procs[name].to_out_lora.up.weight += 1
|
||||
lora_attn_procs = create_lora_layers(model)
|
||||
|
||||
# make sure we can set a list of attention processors
|
||||
model.set_attn_processor(lora_attn_procs)
|
||||
@@ -465,28 +447,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
with torch.no_grad():
|
||||
old_sample = model(**inputs_dict).sample
|
||||
|
||||
lora_attn_procs = {}
|
||||
for name in model.attn_processors.keys():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = model.config.block_out_channels[-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
block_id = int(name[len("up_blocks.")])
|
||||
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = model.config.block_out_channels[block_id]
|
||||
|
||||
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
|
||||
lora_attn_procs[name] = lora_attn_procs[name].to(model.device)
|
||||
|
||||
# add 1 to weights to mock trained weights
|
||||
with torch.no_grad():
|
||||
lora_attn_procs[name].to_q_lora.up.weight += 1
|
||||
lora_attn_procs[name].to_k_lora.up.weight += 1
|
||||
lora_attn_procs[name].to_v_lora.up.weight += 1
|
||||
lora_attn_procs[name].to_out_lora.up.weight += 1
|
||||
|
||||
lora_attn_procs = create_lora_layers(model)
|
||||
model.set_attn_processor(lora_attn_procs)
|
||||
|
||||
with torch.no_grad():
|
||||
@@ -518,21 +479,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
lora_attn_procs = {}
|
||||
for name in model.attn_processors.keys():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = model.config.block_out_channels[-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
block_id = int(name[len("up_blocks.")])
|
||||
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = model.config.block_out_channels[block_id]
|
||||
|
||||
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
|
||||
lora_attn_procs[name] = lora_attn_procs[name].to(model.device)
|
||||
|
||||
lora_attn_procs = create_lora_layers(model, mock_weights=False)
|
||||
model.set_attn_processor(lora_attn_procs)
|
||||
# Saving as torch, properly reloads with directly filename
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
@@ -553,21 +500,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
lora_attn_procs = {}
|
||||
for name in model.attn_processors.keys():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = model.config.block_out_channels[-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
block_id = int(name[len("up_blocks.")])
|
||||
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = model.config.block_out_channels[block_id]
|
||||
|
||||
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
|
||||
lora_attn_procs[name] = lora_attn_procs[name].to(model.device)
|
||||
|
||||
lora_attn_procs = create_lora_layers(model, mock_weights=False)
|
||||
model.set_attn_processor(lora_attn_procs)
|
||||
# Saving as torch, properly reloads with directly filename
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
|
||||
@@ -13,13 +13,15 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers.models import ModelMixin, UNet3DConditionModel
|
||||
from diffusers.models.attention_processor import LoRAAttnProcessor
|
||||
from diffusers.models.attention_processor import AttnProcessor, LoRAAttnProcessor
|
||||
from diffusers.utils import (
|
||||
floats_tensor,
|
||||
logging,
|
||||
@@ -35,10 +37,13 @@ logger = logging.get_logger(__name__)
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
def create_lora_layers(model):
|
||||
def create_lora_layers(model, mock_weights: bool = True):
|
||||
lora_attn_procs = {}
|
||||
for name in model.attn_processors.keys():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
|
||||
has_cross_attention = name.endswith("attn2.processor") and not (
|
||||
name.startswith("transformer_in") or "temp_attentions" in name.split(".")
|
||||
)
|
||||
cross_attention_dim = model.config.cross_attention_dim if has_cross_attention else None
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = model.config.block_out_channels[-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
@@ -47,16 +52,20 @@ def create_lora_layers(model):
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = model.config.block_out_channels[block_id]
|
||||
elif name.startswith("transformer_in"):
|
||||
# Note that the `8 * ...` comes from: https://github.com/huggingface/diffusers/blob/7139f0e874f10b2463caa8cbd585762a309d12d6/src/diffusers/models/unet_3d_condition.py#L148
|
||||
hidden_size = 8 * model.config.attention_head_dim
|
||||
|
||||
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
|
||||
lora_attn_procs[name] = lora_attn_procs[name].to(model.device)
|
||||
|
||||
# add 1 to weights to mock trained weights
|
||||
with torch.no_grad():
|
||||
lora_attn_procs[name].to_q_lora.up.weight += 1
|
||||
lora_attn_procs[name].to_k_lora.up.weight += 1
|
||||
lora_attn_procs[name].to_v_lora.up.weight += 1
|
||||
lora_attn_procs[name].to_out_lora.up.weight += 1
|
||||
if mock_weights:
|
||||
# add 1 to weights to mock trained weights
|
||||
with torch.no_grad():
|
||||
lora_attn_procs[name].to_q_lora.up.weight += 1
|
||||
lora_attn_procs[name].to_k_lora.up.weight += 1
|
||||
lora_attn_procs[name].to_v_lora.up.weight += 1
|
||||
lora_attn_procs[name].to_out_lora.up.weight += 1
|
||||
|
||||
return lora_attn_procs
|
||||
|
||||
@@ -190,23 +199,173 @@ class UNet3DConditionModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
output = model(**inputs_dict)
|
||||
assert output is not None
|
||||
|
||||
# (`attn_processors`) needs to be implemented in this model for this test.
|
||||
# def test_lora_processors(self):
|
||||
def test_lora_processors(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
# (`attn_processors`) needs to be implemented in this model for this test.
|
||||
# def test_lora_save_load(self):
|
||||
init_dict["attention_head_dim"] = 8
|
||||
|
||||
# (`attn_processors`) needs to be implemented for this test in the model.
|
||||
# def test_lora_save_load_safetensors(self):
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
# (`attn_processors`) needs to be implemented for this test in the model.
|
||||
# def test_lora_save_safetensors_load_torch(self):
|
||||
with torch.no_grad():
|
||||
sample1 = model(**inputs_dict).sample
|
||||
|
||||
# (`attn_processors`) needs to be implemented for this test.
|
||||
# def test_lora_save_torch_force_load_safetensors_error(self):
|
||||
lora_attn_procs = create_lora_layers(model)
|
||||
|
||||
# (`attn_processors`) needs to be added for this test.
|
||||
# def test_lora_on_off(self):
|
||||
# make sure we can set a list of attention processors
|
||||
model.set_attn_processor(lora_attn_procs)
|
||||
model.to(torch_device)
|
||||
|
||||
# test that attn processors can be set to itself
|
||||
model.set_attn_processor(model.attn_processors)
|
||||
|
||||
with torch.no_grad():
|
||||
sample2 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample
|
||||
sample3 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
|
||||
sample4 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
|
||||
|
||||
assert (sample1 - sample2).abs().max() < 1e-4
|
||||
assert (sample3 - sample4).abs().max() < 1e-4
|
||||
|
||||
# sample 2 and sample 3 should be different
|
||||
assert (sample2 - sample3).abs().max() > 1e-4
|
||||
|
||||
def test_lora_save_load(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["attention_head_dim"] = 8
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
old_sample = model(**inputs_dict).sample
|
||||
|
||||
lora_attn_procs = create_lora_layers(model)
|
||||
model.set_attn_processor(lora_attn_procs)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_attn_procs(tmpdirname)
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
|
||||
torch.manual_seed(0)
|
||||
new_model = self.model_class(**init_dict)
|
||||
new_model.to(torch_device)
|
||||
new_model.load_attn_procs(tmpdirname)
|
||||
|
||||
with torch.no_grad():
|
||||
new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
|
||||
|
||||
assert (sample - new_sample).abs().max() < 1e-4
|
||||
|
||||
# LoRA and no LoRA should NOT be the same
|
||||
assert (sample - old_sample).abs().max() > 1e-4
|
||||
|
||||
def test_lora_save_load_safetensors(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["attention_head_dim"] = 8
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
old_sample = model(**inputs_dict).sample
|
||||
|
||||
lora_attn_procs = create_lora_layers(model)
|
||||
model.set_attn_processor(lora_attn_procs)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_attn_procs(tmpdirname, safe_serialization=True)
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
|
||||
torch.manual_seed(0)
|
||||
new_model = self.model_class(**init_dict)
|
||||
new_model.to(torch_device)
|
||||
new_model.load_attn_procs(tmpdirname)
|
||||
|
||||
with torch.no_grad():
|
||||
new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
|
||||
|
||||
assert (sample - new_sample).abs().max() < 1e-4
|
||||
|
||||
# LoRA and no LoRA should NOT be the same
|
||||
assert (sample - old_sample).abs().max() > 1e-4
|
||||
|
||||
def test_lora_save_safetensors_load_torch(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["attention_head_dim"] = 8
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
lora_attn_procs = create_lora_layers(model, mock_weights=False)
|
||||
model.set_attn_processor(lora_attn_procs)
|
||||
# Saving as torch, properly reloads with directly filename
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_attn_procs(tmpdirname)
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
|
||||
torch.manual_seed(0)
|
||||
new_model = self.model_class(**init_dict)
|
||||
new_model.to(torch_device)
|
||||
new_model.load_attn_procs(tmpdirname, weight_name="pytorch_lora_weights.bin")
|
||||
|
||||
def test_lora_save_torch_force_load_safetensors_error(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["attention_head_dim"] = 8
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
lora_attn_procs = create_lora_layers(model, mock_weights=False)
|
||||
model.set_attn_processor(lora_attn_procs)
|
||||
# Saving as torch, properly reloads with directly filename
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_attn_procs(tmpdirname)
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
|
||||
torch.manual_seed(0)
|
||||
new_model = self.model_class(**init_dict)
|
||||
new_model.to(torch_device)
|
||||
with self.assertRaises(IOError) as e:
|
||||
new_model.load_attn_procs(tmpdirname, use_safetensors=True)
|
||||
self.assertIn("Error no file named pytorch_lora_weights.safetensors", str(e.exception))
|
||||
|
||||
def test_lora_on_off(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["attention_head_dim"] = 8
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
old_sample = model(**inputs_dict).sample
|
||||
|
||||
lora_attn_procs = create_lora_layers(model)
|
||||
model.set_attn_processor(lora_attn_procs)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample
|
||||
|
||||
model.set_attn_processor(AttnProcessor())
|
||||
|
||||
with torch.no_grad():
|
||||
new_sample = model(**inputs_dict).sample
|
||||
|
||||
assert (sample - new_sample).abs().max() < 1e-4
|
||||
assert (sample - old_sample).abs().max() < 1e-4
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
|
||||
Reference in New Issue
Block a user