mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Complete set_attn_processor for prior and vae (#3796)
* relax tolerance slightly * Add more tests * upload readme * upload readme * Apply suggestions from code review * Improve API Autoencoder KL * finalize * finalize tests * finalize tests * Apply suggestions from code review Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * up --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
committed by
GitHub
parent
958d9ec723
commit
ea8ae8c639
@@ -12,13 +12,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput, apply_forward_hook
|
||||
from .attention_processor import AttentionProcessor, AttnProcessor
|
||||
from .modeling_utils import ModelMixin
|
||||
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
|
||||
|
||||
@@ -156,6 +157,69 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
self.use_slicing = False
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "set_processor"):
|
||||
processors[f"{name}.processor"] = module.processor
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Parameters:
|
||||
`processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
of **all** `Attention` layers.
|
||||
In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.:
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
"""
|
||||
self.set_attn_processor(AttnProcessor())
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
||||
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -8,6 +8,7 @@ from torch import nn
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput
|
||||
from .attention import BasicTransformerBlock
|
||||
from .attention_processor import AttentionProcessor, AttnProcessor
|
||||
from .embeddings import TimestepEmbedding, Timesteps
|
||||
from .modeling_utils import ModelMixin
|
||||
|
||||
@@ -104,6 +105,69 @@ class PriorTransformer(ModelMixin, ConfigMixin):
|
||||
self.clip_mean = nn.Parameter(torch.zeros(1, embedding_dim))
|
||||
self.clip_std = nn.Parameter(torch.zeros(1, embedding_dim))
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "set_processor"):
|
||||
processors[f"{name}.processor"] = module.processor
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Parameters:
|
||||
`processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
of **all** `Attention` layers.
|
||||
In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.:
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
"""
|
||||
self.set_attn_processor(AttnProcessor())
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
|
||||
@@ -26,9 +26,10 @@ import torch
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import AttnProcessor, AttnProcessor2_0, XFormersAttnProcessor
|
||||
from diffusers.training_utils import EMAModel
|
||||
from diffusers.utils import logging, torch_device
|
||||
from diffusers.utils.testing_utils import CaptureLogger, require_torch_2, run_test_in_subprocess
|
||||
from diffusers.utils.testing_utils import CaptureLogger, require_torch_2, require_torch_gpu, run_test_in_subprocess
|
||||
|
||||
|
||||
# Will be run via run_test_in_subprocess
|
||||
@@ -150,7 +151,43 @@ class ModelUtilsTest(unittest.TestCase):
|
||||
assert model.config.in_channels == 9
|
||||
|
||||
|
||||
class UNetTesterMixin:
|
||||
def test_forward_signature(self):
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
signature = inspect.signature(model.forward)
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
expected_arg_names = ["sample", "timestep"]
|
||||
self.assertListEqual(arg_names[:2], expected_arg_names)
|
||||
|
||||
def test_forward_with_norm_groups(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["norm_num_groups"] = 16
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.to_tuple()[0]
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
|
||||
class ModelTesterMixin:
|
||||
main_input_name = None # overwrite in model specific tester class
|
||||
base_precision = 1e-3
|
||||
|
||||
def test_from_save_pretrained(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
@@ -170,12 +207,12 @@ class ModelTesterMixin:
|
||||
with torch.no_grad():
|
||||
image = model(**inputs_dict)
|
||||
if isinstance(image, dict):
|
||||
image = image.sample
|
||||
image = image.to_tuple()[0]
|
||||
|
||||
new_image = new_model(**inputs_dict)
|
||||
|
||||
if isinstance(new_image, dict):
|
||||
new_image = new_image.sample
|
||||
new_image = new_image.to_tuple()[0]
|
||||
|
||||
max_diff = (image - new_image).abs().sum().item()
|
||||
self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
|
||||
@@ -223,12 +260,62 @@ class ModelTesterMixin:
|
||||
|
||||
assert str(error.exception) == f"'{type(model).__name__}' object has no attribute 'does_not_exist'"
|
||||
|
||||
@require_torch_gpu
|
||||
def test_set_attn_processor_for_determinism(self):
|
||||
torch.use_deterministic_algorithms(False)
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
if not hasattr(model, "set_attn_processor"):
|
||||
# If not has `set_attn_processor`, skip test
|
||||
return
|
||||
|
||||
assert all(type(proc) == AttnProcessor2_0 for proc in model.attn_processors.values())
|
||||
with torch.no_grad():
|
||||
output_1 = model(**inputs_dict)[0]
|
||||
|
||||
model.set_default_attn_processor()
|
||||
assert all(type(proc) == AttnProcessor for proc in model.attn_processors.values())
|
||||
with torch.no_grad():
|
||||
output_2 = model(**inputs_dict)[0]
|
||||
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
assert all(type(proc) == XFormersAttnProcessor for proc in model.attn_processors.values())
|
||||
with torch.no_grad():
|
||||
output_3 = model(**inputs_dict)[0]
|
||||
|
||||
model.set_attn_processor(AttnProcessor2_0())
|
||||
assert all(type(proc) == AttnProcessor2_0 for proc in model.attn_processors.values())
|
||||
with torch.no_grad():
|
||||
output_4 = model(**inputs_dict)[0]
|
||||
|
||||
model.set_attn_processor(AttnProcessor())
|
||||
assert all(type(proc) == AttnProcessor for proc in model.attn_processors.values())
|
||||
with torch.no_grad():
|
||||
output_5 = model(**inputs_dict)[0]
|
||||
|
||||
model.set_attn_processor(XFormersAttnProcessor())
|
||||
assert all(type(proc) == XFormersAttnProcessor for proc in model.attn_processors.values())
|
||||
with torch.no_grad():
|
||||
output_6 = model(**inputs_dict)[0]
|
||||
|
||||
torch.use_deterministic_algorithms(True)
|
||||
|
||||
# make sure that outputs match
|
||||
assert torch.allclose(output_2, output_1, atol=self.base_precision)
|
||||
assert torch.allclose(output_2, output_3, atol=self.base_precision)
|
||||
assert torch.allclose(output_2, output_4, atol=self.base_precision)
|
||||
assert torch.allclose(output_2, output_5, atol=self.base_precision)
|
||||
assert torch.allclose(output_2, output_6, atol=self.base_precision)
|
||||
|
||||
def test_from_save_pretrained_variant(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
if hasattr(model, "set_default_attn_processor"):
|
||||
model.set_default_attn_processor()
|
||||
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
@@ -250,12 +337,12 @@ class ModelTesterMixin:
|
||||
with torch.no_grad():
|
||||
image = model(**inputs_dict)
|
||||
if isinstance(image, dict):
|
||||
image = image.sample
|
||||
image = image.to_tuple()[0]
|
||||
|
||||
new_image = new_model(**inputs_dict)
|
||||
|
||||
if isinstance(new_image, dict):
|
||||
new_image = new_image.sample
|
||||
new_image = new_image.to_tuple()[0]
|
||||
|
||||
max_diff = (image - new_image).abs().sum().item()
|
||||
self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
|
||||
@@ -293,11 +380,11 @@ class ModelTesterMixin:
|
||||
with torch.no_grad():
|
||||
first = model(**inputs_dict)
|
||||
if isinstance(first, dict):
|
||||
first = first.sample
|
||||
first = first.to_tuple()[0]
|
||||
|
||||
second = model(**inputs_dict)
|
||||
if isinstance(second, dict):
|
||||
second = second.sample
|
||||
second = second.to_tuple()[0]
|
||||
|
||||
out_1 = first.cpu().numpy()
|
||||
out_2 = second.cpu().numpy()
|
||||
@@ -316,43 +403,15 @@ class ModelTesterMixin:
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
output = output.to_tuple()[0]
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
|
||||
# input & output have to have the same shape
|
||||
input_tensor = inputs_dict[self.main_input_name]
|
||||
expected_shape = input_tensor.shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
def test_forward_with_norm_groups(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["norm_num_groups"] = 16
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
def test_forward_signature(self):
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
signature = inspect.signature(model.forward)
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
expected_arg_names = ["sample", "timestep"]
|
||||
self.assertListEqual(arg_names[:2], expected_arg_names)
|
||||
|
||||
def test_model_from_pretrained(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
@@ -378,12 +437,12 @@ class ModelTesterMixin:
|
||||
output_1 = model(**inputs_dict)
|
||||
|
||||
if isinstance(output_1, dict):
|
||||
output_1 = output_1.sample
|
||||
output_1 = output_1.to_tuple()[0]
|
||||
|
||||
output_2 = new_model(**inputs_dict)
|
||||
|
||||
if isinstance(output_2, dict):
|
||||
output_2 = output_2.sample
|
||||
output_2 = output_2.to_tuple()[0]
|
||||
|
||||
self.assertEqual(output_1.shape, output_2.shape)
|
||||
|
||||
@@ -397,9 +456,10 @@ class ModelTesterMixin:
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
output = output.to_tuple()[0]
|
||||
|
||||
noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device)
|
||||
input_tensor = inputs_dict[self.main_input_name]
|
||||
noise = torch.randn((input_tensor.shape[0],) + self.output_shape).to(torch_device)
|
||||
loss = torch.nn.functional.mse_loss(output, noise)
|
||||
loss.backward()
|
||||
|
||||
@@ -415,9 +475,10 @@ class ModelTesterMixin:
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
output = output.to_tuple()[0]
|
||||
|
||||
noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device)
|
||||
input_tensor = inputs_dict[self.main_input_name]
|
||||
noise = torch.randn((input_tensor.shape[0],) + self.output_shape).to(torch_device)
|
||||
loss = torch.nn.functional.mse_loss(output, noise)
|
||||
loss.backward()
|
||||
ema_model.step(model.parameters())
|
||||
|
||||
185
tests/models/test_models_prior.py
Normal file
185
tests/models/test_models_prior.py
Normal file
@@ -0,0 +1,185 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import inspect
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
|
||||
from diffusers import PriorTransformer
|
||||
from diffusers.utils import floats_tensor, slow, torch_all_close, torch_device
|
||||
from diffusers.utils.testing_utils import enable_full_determinism
|
||||
|
||||
from .test_modeling_common import ModelTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class PriorTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = PriorTransformer
|
||||
main_input_name = "hidden_states"
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
embedding_dim = 8
|
||||
num_embeddings = 7
|
||||
|
||||
hidden_states = floats_tensor((batch_size, embedding_dim)).to(torch_device)
|
||||
|
||||
proj_embedding = floats_tensor((batch_size, embedding_dim)).to(torch_device)
|
||||
encoder_hidden_states = floats_tensor((batch_size, num_embeddings, embedding_dim)).to(torch_device)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"timestep": 2,
|
||||
"proj_embedding": proj_embedding,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
}
|
||||
|
||||
def get_dummy_seed_input(self, seed=0):
|
||||
torch.manual_seed(seed)
|
||||
batch_size = 4
|
||||
embedding_dim = 8
|
||||
num_embeddings = 7
|
||||
|
||||
hidden_states = torch.randn((batch_size, embedding_dim)).to(torch_device)
|
||||
|
||||
proj_embedding = torch.randn((batch_size, embedding_dim)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, num_embeddings, embedding_dim)).to(torch_device)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"timestep": 2,
|
||||
"proj_embedding": proj_embedding,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 8)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 8)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"num_attention_heads": 2,
|
||||
"attention_head_dim": 4,
|
||||
"num_layers": 2,
|
||||
"embedding_dim": 8,
|
||||
"num_embeddings": 7,
|
||||
"additional_embeddings": 4,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_from_pretrained_hub(self):
|
||||
model, loading_info = PriorTransformer.from_pretrained(
|
||||
"hf-internal-testing/prior-dummy", output_loading_info=True
|
||||
)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertEqual(len(loading_info["missing_keys"]), 0)
|
||||
|
||||
model.to(torch_device)
|
||||
hidden_states = model(**self.dummy_input)[0]
|
||||
|
||||
assert hidden_states is not None, "Make sure output is not None"
|
||||
|
||||
def test_forward_signature(self):
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
signature = inspect.signature(model.forward)
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
expected_arg_names = ["hidden_states", "timestep"]
|
||||
self.assertListEqual(arg_names[:2], expected_arg_names)
|
||||
|
||||
def test_output_pretrained(self):
|
||||
model = PriorTransformer.from_pretrained("hf-internal-testing/prior-dummy")
|
||||
model = model.to(torch_device)
|
||||
|
||||
if hasattr(model, "set_default_attn_processor"):
|
||||
model.set_default_attn_processor()
|
||||
|
||||
input = self.get_dummy_seed_input()
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**input)[0]
|
||||
|
||||
output_slice = output[0, :5].flatten().cpu()
|
||||
print(output_slice)
|
||||
|
||||
# Since the VAE Gaussian prior's generator is seeded on the appropriate device,
|
||||
# the expected output slices are not the same for CPU and GPU.
|
||||
expected_output_slice = torch.tensor([-1.3436, -0.2870, 0.7538, 0.4368, -0.0239])
|
||||
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
|
||||
|
||||
|
||||
@slow
|
||||
class PriorTransformerIntegrationTests(unittest.TestCase):
|
||||
def get_dummy_seed_input(self, batch_size=1, embedding_dim=768, num_embeddings=77, seed=0):
|
||||
torch.manual_seed(seed)
|
||||
batch_size = batch_size
|
||||
embedding_dim = embedding_dim
|
||||
num_embeddings = num_embeddings
|
||||
|
||||
hidden_states = torch.randn((batch_size, embedding_dim)).to(torch_device)
|
||||
|
||||
proj_embedding = torch.randn((batch_size, embedding_dim)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, num_embeddings, embedding_dim)).to(torch_device)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"timestep": 2,
|
||||
"proj_embedding": proj_embedding,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
}
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
# fmt: off
|
||||
[13, [-0.5861, 0.1283, -0.0931, 0.0882, 0.4476, 0.1329, -0.0498, 0.0640]],
|
||||
[37, [-0.4913, 0.0110, -0.0483, 0.0541, 0.4954, -0.0170, 0.0354, 0.1651]],
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
def test_kandinsky_prior(self, seed, expected_slice):
|
||||
model = PriorTransformer.from_pretrained("kandinsky-community/kandinsky-2-1-prior", subfolder="prior")
|
||||
model.to(torch_device)
|
||||
input = self.get_dummy_seed_input(seed=seed)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model(**input)[0]
|
||||
|
||||
assert list(sample.shape) == [1, 768]
|
||||
|
||||
output_slice = sample[0, :8].flatten().cpu()
|
||||
print(output_slice)
|
||||
expected_output_slice = torch.tensor(expected_slice)
|
||||
|
||||
assert torch_all_close(output_slice, expected_output_slice, atol=1e-3)
|
||||
@@ -20,11 +20,12 @@ import torch
|
||||
from diffusers import UNet1DModel
|
||||
from diffusers.utils import floats_tensor, slow, torch_device
|
||||
|
||||
from .test_modeling_common import ModelTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
class UNet1DModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNet1DModel
|
||||
main_input_name = "sample"
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
@@ -153,8 +154,9 @@ class UNet1DModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
assert (output_max - 0.0607).abs() < 4e-4
|
||||
|
||||
|
||||
class UNetRLModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNet1DModel
|
||||
main_input_name = "sample"
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
|
||||
@@ -23,7 +23,7 @@ from diffusers import UNet2DModel
|
||||
from diffusers.utils import floats_tensor, logging, slow, torch_all_close, torch_device
|
||||
from diffusers.utils.testing_utils import enable_full_determinism
|
||||
|
||||
from .test_modeling_common import ModelTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -31,8 +31,9 @@ logger = logging.get_logger(__name__)
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class Unet2DModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNet2DModel
|
||||
main_input_name = "sample"
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
@@ -68,8 +69,9 @@ class Unet2DModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
return init_dict, inputs_dict
|
||||
|
||||
|
||||
class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNet2DModel
|
||||
main_input_name = "sample"
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
@@ -182,8 +184,9 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-3))
|
||||
|
||||
|
||||
class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNet2DModel
|
||||
main_input_name = "sample"
|
||||
|
||||
@property
|
||||
def dummy_input(self, sizes=(32, 32)):
|
||||
|
||||
@@ -36,7 +36,7 @@ from diffusers.utils import (
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import enable_full_determinism
|
||||
|
||||
from .test_modeling_common import ModelTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -120,8 +120,9 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True):
|
||||
return custom_diffusion_attn_procs
|
||||
|
||||
|
||||
class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNet2DConditionModel
|
||||
main_input_name = "sample"
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
|
||||
@@ -31,7 +31,7 @@ from diffusers.utils import (
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import enable_full_determinism
|
||||
|
||||
from .test_modeling_common import ModelTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
@@ -73,8 +73,9 @@ def create_lora_layers(model, mock_weights: bool = True):
|
||||
|
||||
|
||||
@skip_mps
|
||||
class UNet3DConditionModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNet3DConditionModel
|
||||
main_input_name = "sample"
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
|
||||
@@ -24,14 +24,16 @@ from diffusers.utils import floats_tensor, load_hf_numpy, require_torch_gpu, slo
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import enable_full_determinism
|
||||
|
||||
from .test_modeling_common import ModelTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
|
||||
class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKL
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
|
||||
@@ -21,14 +21,15 @@ from diffusers import VQModel
|
||||
from diffusers.utils import floats_tensor, torch_device
|
||||
from diffusers.utils.testing_utils import enable_full_determinism
|
||||
|
||||
from .test_modeling_common import ModelTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class VQModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
class VQModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = VQModel
|
||||
main_input_name = "sample"
|
||||
|
||||
@property
|
||||
def dummy_input(self, sizes=(32, 32)):
|
||||
|
||||
Reference in New Issue
Block a user