mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
add tests for AutoencoderKL
This commit is contained in:
@@ -626,11 +626,11 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
|
||||
dec = self.decoder(z)
|
||||
return dec
|
||||
|
||||
def forward(self, input, sample_posterior=True):
|
||||
posterior = self.encode(input)
|
||||
def forward(self, x, sample_posterior=False):
|
||||
posterior = self.encode(x)
|
||||
if sample_posterior:
|
||||
z = posterior.sample()
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z)
|
||||
return dec, posterior
|
||||
return dec
|
||||
|
||||
@@ -46,6 +46,7 @@ from diffusers import (
|
||||
UNetLDMModel,
|
||||
UNetModel,
|
||||
VQModel,
|
||||
AutoencoderKL,
|
||||
)
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
from diffusers.pipeline_utils import DiffusionPipeline
|
||||
@@ -883,6 +884,78 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
|
||||
|
||||
class AutoEncoderKLTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKL
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
|
||||
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
|
||||
return {"x": image}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"ch": 64,
|
||||
"ch_mult": (1,),
|
||||
"embed_dim": 4,
|
||||
"in_channels": 3,
|
||||
"num_res_blocks": 1,
|
||||
"out_ch": 3,
|
||||
"resolution": 32,
|
||||
"z_channels": 4,
|
||||
"attn_resolutions": []
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_forward_signature(self):
|
||||
pass
|
||||
|
||||
def test_training(self):
|
||||
pass
|
||||
|
||||
def test_from_pretrained_hub(self):
|
||||
model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertEqual(len(loading_info["missing_keys"]), 0)
|
||||
|
||||
model.to(torch_device)
|
||||
image = model(**self.dummy_input)
|
||||
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
def test_output_pretrained(self):
|
||||
model = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy")
|
||||
model.eval()
|
||||
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
image = torch.randn(1, model.config.in_channels, model.config.resolution, model.config.resolution)
|
||||
with torch.no_grad():
|
||||
output = model(image, sample_posterior=True)
|
||||
|
||||
output_slice = output[0, -1, -3:, -3:].flatten()
|
||||
# fmt: off
|
||||
expected_output_slice = torch.tensor([-0.0814, -0.0229, -0.1320, -0.4123, -0.0366, -0.3473, 0.0438, -0.1662,
|
||||
0.1750])
|
||||
# fmt: on
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
|
||||
|
||||
class PipelineTesterMixin(unittest.TestCase):
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
# 1. Load models
|
||||
|
||||
Reference in New Issue
Block a user