From ea8ae8c6397d8333760471e573e4d8ca4646efd0 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 15 Jun 2023 17:42:49 +0200 Subject: [PATCH] Complete set_attn_processor for prior and vae (#3796) * relax tolerance slightly * Add more tests * upload readme * upload readme * Apply suggestions from code review * Improve API Autoencoder KL * finalize * finalize tests * finalize tests * Apply suggestions from code review Co-authored-by: Sayak Paul * up --------- Co-authored-by: Sayak Paul --- src/diffusers/models/autoencoder_kl.py | 66 ++++++- src/diffusers/models/prior_transformer.py | 66 ++++++- tests/models/test_modeling_common.py | 153 ++++++++++----- tests/models/test_models_prior.py | 185 ++++++++++++++++++ tests/models/test_models_unet_1d.py | 8 +- tests/models/test_models_unet_2d.py | 11 +- tests/models/test_models_unet_2d_condition.py | 5 +- tests/models/test_models_unet_3d_condition.py | 5 +- tests/models/test_models_vae.py | 6 +- tests/models/test_models_vq.py | 5 +- 10 files changed, 447 insertions(+), 63 deletions(-) create mode 100644 tests/models/test_models_prior.py diff --git a/src/diffusers/models/autoencoder_kl.py b/src/diffusers/models/autoencoder_kl.py index a4894e78c4..7178543132 100644 --- a/src/diffusers/models/autoencoder_kl.py +++ b/src/diffusers/models/autoencoder_kl.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import Dict, Optional, Tuple, Union import torch import torch.nn as nn from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, apply_forward_hook +from .attention_processor import AttentionProcessor, AttnProcessor from .modeling_utils import ModelMixin from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder @@ -156,6 +157,69 @@ class AutoencoderKL(ModelMixin, ConfigMixin): """ self.use_slicing = False + @property + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "set_processor"): + processors[f"{name}.processor"] = module.processor + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Parameters: + `processor (`dict` of `AttentionProcessor` or `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + of **all** `Attention` layers. + In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.: + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(AttnProcessor()) + @apply_forward_hook def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size): diff --git a/src/diffusers/models/prior_transformer.py b/src/diffusers/models/prior_transformer.py index b245612e6f..58804f2672 100644 --- a/src/diffusers/models/prior_transformer.py +++ b/src/diffusers/models/prior_transformer.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Optional, Union +from typing import Dict, Optional, Union import torch import torch.nn.functional as F @@ -8,6 +8,7 @@ from torch import nn from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput from .attention import BasicTransformerBlock +from .attention_processor import AttentionProcessor, AttnProcessor from .embeddings import TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin @@ -104,6 +105,69 @@ class PriorTransformer(ModelMixin, ConfigMixin): self.clip_mean = nn.Parameter(torch.zeros(1, embedding_dim)) self.clip_std = nn.Parameter(torch.zeros(1, embedding_dim)) + @property + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "set_processor"): + processors[f"{name}.processor"] = module.processor + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Parameters: + `processor (`dict` of `AttentionProcessor` or `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + of **all** `Attention` layers. + In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.: + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(AttnProcessor()) + def forward( self, hidden_states, diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index adc18e003a..ee8e55842f 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -26,9 +26,10 @@ import torch from requests.exceptions import HTTPError from diffusers.models import UNet2DConditionModel +from diffusers.models.attention_processor import AttnProcessor, AttnProcessor2_0, XFormersAttnProcessor from diffusers.training_utils import EMAModel from diffusers.utils import logging, torch_device -from diffusers.utils.testing_utils import CaptureLogger, require_torch_2, run_test_in_subprocess +from diffusers.utils.testing_utils import CaptureLogger, require_torch_2, require_torch_gpu, run_test_in_subprocess # Will be run via run_test_in_subprocess @@ -150,7 +151,43 @@ class ModelUtilsTest(unittest.TestCase): assert model.config.in_channels == 9 +class UNetTesterMixin: + def test_forward_signature(self): + init_dict, _ = self.prepare_init_args_and_inputs_for_common() + + model = self.model_class(**init_dict) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["sample", "timestep"] + self.assertListEqual(arg_names[:2], expected_arg_names) + + def test_forward_with_norm_groups(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["norm_num_groups"] = 16 + init_dict["block_out_channels"] = (16, 32) + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output.to_tuple()[0] + + self.assertIsNotNone(output) + expected_shape = inputs_dict["sample"].shape + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + + class ModelTesterMixin: + main_input_name = None # overwrite in model specific tester class + base_precision = 1e-3 + def test_from_save_pretrained(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -170,12 +207,12 @@ class ModelTesterMixin: with torch.no_grad(): image = model(**inputs_dict) if isinstance(image, dict): - image = image.sample + image = image.to_tuple()[0] new_image = new_model(**inputs_dict) if isinstance(new_image, dict): - new_image = new_image.sample + new_image = new_image.to_tuple()[0] max_diff = (image - new_image).abs().sum().item() self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes") @@ -223,12 +260,62 @@ class ModelTesterMixin: assert str(error.exception) == f"'{type(model).__name__}' object has no attribute 'does_not_exist'" + @require_torch_gpu + def test_set_attn_processor_for_determinism(self): + torch.use_deterministic_algorithms(False) + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + + if not hasattr(model, "set_attn_processor"): + # If not has `set_attn_processor`, skip test + return + + assert all(type(proc) == AttnProcessor2_0 for proc in model.attn_processors.values()) + with torch.no_grad(): + output_1 = model(**inputs_dict)[0] + + model.set_default_attn_processor() + assert all(type(proc) == AttnProcessor for proc in model.attn_processors.values()) + with torch.no_grad(): + output_2 = model(**inputs_dict)[0] + + model.enable_xformers_memory_efficient_attention() + assert all(type(proc) == XFormersAttnProcessor for proc in model.attn_processors.values()) + with torch.no_grad(): + output_3 = model(**inputs_dict)[0] + + model.set_attn_processor(AttnProcessor2_0()) + assert all(type(proc) == AttnProcessor2_0 for proc in model.attn_processors.values()) + with torch.no_grad(): + output_4 = model(**inputs_dict)[0] + + model.set_attn_processor(AttnProcessor()) + assert all(type(proc) == AttnProcessor for proc in model.attn_processors.values()) + with torch.no_grad(): + output_5 = model(**inputs_dict)[0] + + model.set_attn_processor(XFormersAttnProcessor()) + assert all(type(proc) == XFormersAttnProcessor for proc in model.attn_processors.values()) + with torch.no_grad(): + output_6 = model(**inputs_dict)[0] + + torch.use_deterministic_algorithms(True) + + # make sure that outputs match + assert torch.allclose(output_2, output_1, atol=self.base_precision) + assert torch.allclose(output_2, output_3, atol=self.base_precision) + assert torch.allclose(output_2, output_4, atol=self.base_precision) + assert torch.allclose(output_2, output_5, atol=self.base_precision) + assert torch.allclose(output_2, output_6, atol=self.base_precision) + def test_from_save_pretrained_variant(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict) if hasattr(model, "set_default_attn_processor"): model.set_default_attn_processor() + model.to(torch_device) model.eval() @@ -250,12 +337,12 @@ class ModelTesterMixin: with torch.no_grad(): image = model(**inputs_dict) if isinstance(image, dict): - image = image.sample + image = image.to_tuple()[0] new_image = new_model(**inputs_dict) if isinstance(new_image, dict): - new_image = new_image.sample + new_image = new_image.to_tuple()[0] max_diff = (image - new_image).abs().sum().item() self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes") @@ -293,11 +380,11 @@ class ModelTesterMixin: with torch.no_grad(): first = model(**inputs_dict) if isinstance(first, dict): - first = first.sample + first = first.to_tuple()[0] second = model(**inputs_dict) if isinstance(second, dict): - second = second.sample + second = second.to_tuple()[0] out_1 = first.cpu().numpy() out_2 = second.cpu().numpy() @@ -316,43 +403,15 @@ class ModelTesterMixin: output = model(**inputs_dict) if isinstance(output, dict): - output = output.sample + output = output.to_tuple()[0] self.assertIsNotNone(output) - expected_shape = inputs_dict["sample"].shape + + # input & output have to have the same shape + input_tensor = inputs_dict[self.main_input_name] + expected_shape = input_tensor.shape self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") - def test_forward_with_norm_groups(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["norm_num_groups"] = 16 - init_dict["block_out_channels"] = (16, 32) - - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() - - with torch.no_grad(): - output = model(**inputs_dict) - - if isinstance(output, dict): - output = output.sample - - self.assertIsNotNone(output) - expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") - - def test_forward_signature(self): - init_dict, _ = self.prepare_init_args_and_inputs_for_common() - - model = self.model_class(**init_dict) - signature = inspect.signature(model.forward) - # signature.parameters is an OrderedDict => so arg_names order is deterministic - arg_names = [*signature.parameters.keys()] - - expected_arg_names = ["sample", "timestep"] - self.assertListEqual(arg_names[:2], expected_arg_names) - def test_model_from_pretrained(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -378,12 +437,12 @@ class ModelTesterMixin: output_1 = model(**inputs_dict) if isinstance(output_1, dict): - output_1 = output_1.sample + output_1 = output_1.to_tuple()[0] output_2 = new_model(**inputs_dict) if isinstance(output_2, dict): - output_2 = output_2.sample + output_2 = output_2.to_tuple()[0] self.assertEqual(output_1.shape, output_2.shape) @@ -397,9 +456,10 @@ class ModelTesterMixin: output = model(**inputs_dict) if isinstance(output, dict): - output = output.sample + output = output.to_tuple()[0] - noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device) + input_tensor = inputs_dict[self.main_input_name] + noise = torch.randn((input_tensor.shape[0],) + self.output_shape).to(torch_device) loss = torch.nn.functional.mse_loss(output, noise) loss.backward() @@ -415,9 +475,10 @@ class ModelTesterMixin: output = model(**inputs_dict) if isinstance(output, dict): - output = output.sample + output = output.to_tuple()[0] - noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device) + input_tensor = inputs_dict[self.main_input_name] + noise = torch.randn((input_tensor.shape[0],) + self.output_shape).to(torch_device) loss = torch.nn.functional.mse_loss(output, noise) loss.backward() ema_model.step(model.parameters()) diff --git a/tests/models/test_models_prior.py b/tests/models/test_models_prior.py new file mode 100644 index 0000000000..25b9768ee3 --- /dev/null +++ b/tests/models/test_models_prior.py @@ -0,0 +1,185 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import inspect +import unittest + +import torch +from parameterized import parameterized + +from diffusers import PriorTransformer +from diffusers.utils import floats_tensor, slow, torch_all_close, torch_device +from diffusers.utils.testing_utils import enable_full_determinism + +from .test_modeling_common import ModelTesterMixin + + +enable_full_determinism() + + +class PriorTransformerTests(ModelTesterMixin, unittest.TestCase): + model_class = PriorTransformer + main_input_name = "hidden_states" + + @property + def dummy_input(self): + batch_size = 4 + embedding_dim = 8 + num_embeddings = 7 + + hidden_states = floats_tensor((batch_size, embedding_dim)).to(torch_device) + + proj_embedding = floats_tensor((batch_size, embedding_dim)).to(torch_device) + encoder_hidden_states = floats_tensor((batch_size, num_embeddings, embedding_dim)).to(torch_device) + + return { + "hidden_states": hidden_states, + "timestep": 2, + "proj_embedding": proj_embedding, + "encoder_hidden_states": encoder_hidden_states, + } + + def get_dummy_seed_input(self, seed=0): + torch.manual_seed(seed) + batch_size = 4 + embedding_dim = 8 + num_embeddings = 7 + + hidden_states = torch.randn((batch_size, embedding_dim)).to(torch_device) + + proj_embedding = torch.randn((batch_size, embedding_dim)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, num_embeddings, embedding_dim)).to(torch_device) + + return { + "hidden_states": hidden_states, + "timestep": 2, + "proj_embedding": proj_embedding, + "encoder_hidden_states": encoder_hidden_states, + } + + @property + def input_shape(self): + return (4, 8) + + @property + def output_shape(self): + return (4, 8) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "num_attention_heads": 2, + "attention_head_dim": 4, + "num_layers": 2, + "embedding_dim": 8, + "num_embeddings": 7, + "additional_embeddings": 4, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_from_pretrained_hub(self): + model, loading_info = PriorTransformer.from_pretrained( + "hf-internal-testing/prior-dummy", output_loading_info=True + ) + self.assertIsNotNone(model) + self.assertEqual(len(loading_info["missing_keys"]), 0) + + model.to(torch_device) + hidden_states = model(**self.dummy_input)[0] + + assert hidden_states is not None, "Make sure output is not None" + + def test_forward_signature(self): + init_dict, _ = self.prepare_init_args_and_inputs_for_common() + + model = self.model_class(**init_dict) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["hidden_states", "timestep"] + self.assertListEqual(arg_names[:2], expected_arg_names) + + def test_output_pretrained(self): + model = PriorTransformer.from_pretrained("hf-internal-testing/prior-dummy") + model = model.to(torch_device) + + if hasattr(model, "set_default_attn_processor"): + model.set_default_attn_processor() + + input = self.get_dummy_seed_input() + + with torch.no_grad(): + output = model(**input)[0] + + output_slice = output[0, :5].flatten().cpu() + print(output_slice) + + # Since the VAE Gaussian prior's generator is seeded on the appropriate device, + # the expected output slices are not the same for CPU and GPU. + expected_output_slice = torch.tensor([-1.3436, -0.2870, 0.7538, 0.4368, -0.0239]) + self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2)) + + +@slow +class PriorTransformerIntegrationTests(unittest.TestCase): + def get_dummy_seed_input(self, batch_size=1, embedding_dim=768, num_embeddings=77, seed=0): + torch.manual_seed(seed) + batch_size = batch_size + embedding_dim = embedding_dim + num_embeddings = num_embeddings + + hidden_states = torch.randn((batch_size, embedding_dim)).to(torch_device) + + proj_embedding = torch.randn((batch_size, embedding_dim)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, num_embeddings, embedding_dim)).to(torch_device) + + return { + "hidden_states": hidden_states, + "timestep": 2, + "proj_embedding": proj_embedding, + "encoder_hidden_states": encoder_hidden_states, + } + + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + @parameterized.expand( + [ + # fmt: off + [13, [-0.5861, 0.1283, -0.0931, 0.0882, 0.4476, 0.1329, -0.0498, 0.0640]], + [37, [-0.4913, 0.0110, -0.0483, 0.0541, 0.4954, -0.0170, 0.0354, 0.1651]], + # fmt: on + ] + ) + def test_kandinsky_prior(self, seed, expected_slice): + model = PriorTransformer.from_pretrained("kandinsky-community/kandinsky-2-1-prior", subfolder="prior") + model.to(torch_device) + input = self.get_dummy_seed_input(seed=seed) + + with torch.no_grad(): + sample = model(**input)[0] + + assert list(sample.shape) == [1, 768] + + output_slice = sample[0, :8].flatten().cpu() + print(output_slice) + expected_output_slice = torch.tensor(expected_slice) + + assert torch_all_close(output_slice, expected_output_slice, atol=1e-3) diff --git a/tests/models/test_models_unet_1d.py b/tests/models/test_models_unet_1d.py index 9fb1a61011..99a243e911 100644 --- a/tests/models/test_models_unet_1d.py +++ b/tests/models/test_models_unet_1d.py @@ -20,11 +20,12 @@ import torch from diffusers import UNet1DModel from diffusers.utils import floats_tensor, slow, torch_device -from .test_modeling_common import ModelTesterMixin +from .test_modeling_common import ModelTesterMixin, UNetTesterMixin -class UNet1DModelTests(ModelTesterMixin, unittest.TestCase): +class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): model_class = UNet1DModel + main_input_name = "sample" @property def dummy_input(self): @@ -153,8 +154,9 @@ class UNet1DModelTests(ModelTesterMixin, unittest.TestCase): assert (output_max - 0.0607).abs() < 4e-4 -class UNetRLModelTests(ModelTesterMixin, unittest.TestCase): +class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): model_class = UNet1DModel + main_input_name = "sample" @property def dummy_input(self): diff --git a/tests/models/test_models_unet_2d.py b/tests/models/test_models_unet_2d.py index 92a5664daa..4857afb852 100644 --- a/tests/models/test_models_unet_2d.py +++ b/tests/models/test_models_unet_2d.py @@ -23,7 +23,7 @@ from diffusers import UNet2DModel from diffusers.utils import floats_tensor, logging, slow, torch_all_close, torch_device from diffusers.utils.testing_utils import enable_full_determinism -from .test_modeling_common import ModelTesterMixin +from .test_modeling_common import ModelTesterMixin, UNetTesterMixin logger = logging.get_logger(__name__) @@ -31,8 +31,9 @@ logger = logging.get_logger(__name__) enable_full_determinism() -class Unet2DModelTests(ModelTesterMixin, unittest.TestCase): +class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): model_class = UNet2DModel + main_input_name = "sample" @property def dummy_input(self): @@ -68,8 +69,9 @@ class Unet2DModelTests(ModelTesterMixin, unittest.TestCase): return init_dict, inputs_dict -class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): +class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): model_class = UNet2DModel + main_input_name = "sample" @property def dummy_input(self): @@ -182,8 +184,9 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-3)) -class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): +class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): model_class = UNet2DModel + main_input_name = "sample" @property def dummy_input(self, sizes=(32, 32)): diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py index 8a3d9dd16f..24da508227 100644 --- a/tests/models/test_models_unet_2d_condition.py +++ b/tests/models/test_models_unet_2d_condition.py @@ -36,7 +36,7 @@ from diffusers.utils import ( from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import enable_full_determinism -from .test_modeling_common import ModelTesterMixin +from .test_modeling_common import ModelTesterMixin, UNetTesterMixin logger = logging.get_logger(__name__) @@ -120,8 +120,9 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True): return custom_diffusion_attn_procs -class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase): +class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): model_class = UNet2DConditionModel + main_input_name = "sample" @property def dummy_input(self): diff --git a/tests/models/test_models_unet_3d_condition.py b/tests/models/test_models_unet_3d_condition.py index 4193b6e17b..2d3edfffd3 100644 --- a/tests/models/test_models_unet_3d_condition.py +++ b/tests/models/test_models_unet_3d_condition.py @@ -31,7 +31,7 @@ from diffusers.utils import ( from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import enable_full_determinism -from .test_modeling_common import ModelTesterMixin +from .test_modeling_common import ModelTesterMixin, UNetTesterMixin enable_full_determinism() @@ -73,8 +73,9 @@ def create_lora_layers(model, mock_weights: bool = True): @skip_mps -class UNet3DConditionModelTests(ModelTesterMixin, unittest.TestCase): +class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): model_class = UNet3DConditionModel + main_input_name = "sample" @property def dummy_input(self): diff --git a/tests/models/test_models_vae.py b/tests/models/test_models_vae.py index fe27e138f5..08b030bbf9 100644 --- a/tests/models/test_models_vae.py +++ b/tests/models/test_models_vae.py @@ -24,14 +24,16 @@ from diffusers.utils import floats_tensor, load_hf_numpy, require_torch_gpu, slo from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import enable_full_determinism -from .test_modeling_common import ModelTesterMixin +from .test_modeling_common import ModelTesterMixin, UNetTesterMixin enable_full_determinism() -class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase): +class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): model_class = AutoencoderKL + main_input_name = "sample" + base_precision = 1e-2 @property def dummy_input(self): diff --git a/tests/models/test_models_vq.py b/tests/models/test_models_vq.py index 8ea6ef77ce..5706c13a0c 100644 --- a/tests/models/test_models_vq.py +++ b/tests/models/test_models_vq.py @@ -21,14 +21,15 @@ from diffusers import VQModel from diffusers.utils import floats_tensor, torch_device from diffusers.utils.testing_utils import enable_full_determinism -from .test_modeling_common import ModelTesterMixin +from .test_modeling_common import ModelTesterMixin, UNetTesterMixin enable_full_determinism() -class VQModelTests(ModelTesterMixin, unittest.TestCase): +class VQModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): model_class = VQModel + main_input_name = "sample" @property def dummy_input(self, sizes=(32, 32)):