mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[tests] refactor vae tests (#9808)
* add: autoencoderkl tests * autoencodertiny. * fix * asymmetric autoencoder. * more * integration tests for stable audio decoder. * consistency decoder vae tests * remove grad check from consistency decoder. * cog * bye test_models_vae.py * fix * fix * remove allegro * fixes * fixes * fixes --------- Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
@@ -433,7 +433,7 @@ class CogVideoXDownBlock3D(nn.Module):
|
||||
hidden_states,
|
||||
temb,
|
||||
zq,
|
||||
conv_cache=conv_cache.get(conv_cache_key),
|
||||
conv_cache.get(conv_cache_key),
|
||||
)
|
||||
else:
|
||||
hidden_states, new_conv_cache[conv_cache_key] = resnet(
|
||||
@@ -531,7 +531,7 @@ class CogVideoXMidBlock3D(nn.Module):
|
||||
return create_forward
|
||||
|
||||
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
|
||||
create_custom_forward(resnet), hidden_states, temb, zq, conv_cache.get(conv_cache_key)
|
||||
)
|
||||
else:
|
||||
hidden_states, new_conv_cache[conv_cache_key] = resnet(
|
||||
@@ -649,7 +649,7 @@ class CogVideoXUpBlock3D(nn.Module):
|
||||
hidden_states,
|
||||
temb,
|
||||
zq,
|
||||
conv_cache=conv_cache.get(conv_cache_key),
|
||||
conv_cache.get(conv_cache_key),
|
||||
)
|
||||
else:
|
||||
hidden_states, new_conv_cache[conv_cache_key] = resnet(
|
||||
@@ -789,7 +789,7 @@ class CogVideoXEncoder3D(nn.Module):
|
||||
hidden_states,
|
||||
temb,
|
||||
None,
|
||||
conv_cache=conv_cache.get(conv_cache_key),
|
||||
conv_cache.get(conv_cache_key),
|
||||
)
|
||||
|
||||
# 2. Mid
|
||||
@@ -798,14 +798,14 @@ class CogVideoXEncoder3D(nn.Module):
|
||||
hidden_states,
|
||||
temb,
|
||||
None,
|
||||
conv_cache=conv_cache.get("mid_block"),
|
||||
conv_cache.get("mid_block"),
|
||||
)
|
||||
else:
|
||||
# 1. Down
|
||||
for i, down_block in enumerate(self.down_blocks):
|
||||
conv_cache_key = f"down_block_{i}"
|
||||
hidden_states, new_conv_cache[conv_cache_key] = down_block(
|
||||
hidden_states, temb, None, conv_cache=conv_cache.get(conv_cache_key)
|
||||
hidden_states, temb, None, conv_cache.get(conv_cache_key)
|
||||
)
|
||||
|
||||
# 2. Mid
|
||||
@@ -953,7 +953,7 @@ class CogVideoXDecoder3D(nn.Module):
|
||||
hidden_states,
|
||||
temb,
|
||||
sample,
|
||||
conv_cache=conv_cache.get("mid_block"),
|
||||
conv_cache.get("mid_block"),
|
||||
)
|
||||
|
||||
# 2. Up
|
||||
@@ -964,7 +964,7 @@ class CogVideoXDecoder3D(nn.Module):
|
||||
hidden_states,
|
||||
temb,
|
||||
sample,
|
||||
conv_cache=conv_cache.get(conv_cache_key),
|
||||
conv_cache.get(conv_cache_key),
|
||||
)
|
||||
else:
|
||||
# 1. Mid
|
||||
@@ -1476,7 +1476,7 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
z = posterior.sample(generator=generator)
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z)
|
||||
dec = self.decode(z).sample
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
return dec
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
@@ -229,14 +229,6 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
|
||||
|
||||
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
|
||||
|
||||
sample_size = (
|
||||
self.config.sample_size[0]
|
||||
if isinstance(self.config.sample_size, (list, tuple))
|
||||
else self.config.sample_size
|
||||
)
|
||||
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
|
||||
self.tile_overlap_factor = 0.25
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, (Encoder, TemporalDecoder)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
@@ -310,7 +310,9 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
|
||||
self, x: torch.Tensor, generator: Optional[torch.Generator] = None, return_dict: bool = True
|
||||
) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
|
||||
if self.use_slicing and x.shape[0] > 1:
|
||||
output = [self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x) for x_slice in x.split(1)]
|
||||
output = [
|
||||
self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x_slice) for x_slice in x.split(1)
|
||||
]
|
||||
output = torch.cat(output)
|
||||
else:
|
||||
output = self._tiled_decode(x) if self.use_tiling else self.decoder(x)
|
||||
@@ -341,7 +343,7 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
|
||||
# as if we were loading the latents from an RGBA uint8 image.
|
||||
unscaled_enc = self.unscale_latents(scaled_enc / 255.0)
|
||||
|
||||
dec = self.decode(unscaled_enc)
|
||||
dec = self.decode(unscaled_enc).sample
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
@@ -0,0 +1,261 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
|
||||
from diffusers import AsymmetricAutoencoderKL
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
load_hf_numpy,
|
||||
require_torch_accelerator,
|
||||
require_torch_gpu,
|
||||
skip_mps,
|
||||
slow,
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = AsymmetricAutoencoderKL
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
|
||||
def get_asym_autoencoder_kl_config(self, block_out_channels=None, norm_num_groups=None):
|
||||
block_out_channels = block_out_channels or [2, 4]
|
||||
norm_num_groups = norm_num_groups or 2
|
||||
init_dict = {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"down_block_types": ["DownEncoderBlock2D"] * len(block_out_channels),
|
||||
"down_block_out_channels": block_out_channels,
|
||||
"layers_per_down_block": 1,
|
||||
"up_block_types": ["UpDecoderBlock2D"] * len(block_out_channels),
|
||||
"up_block_out_channels": block_out_channels,
|
||||
"layers_per_up_block": 1,
|
||||
"act_fn": "silu",
|
||||
"latent_channels": 4,
|
||||
"norm_num_groups": norm_num_groups,
|
||||
"sample_size": 32,
|
||||
"scaling_factor": 0.18215,
|
||||
}
|
||||
return init_dict
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
|
||||
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
mask = torch.ones((batch_size, 1) + sizes).to(torch_device)
|
||||
|
||||
return {"sample": image, "mask": mask}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = self.get_asym_autoencoder_kl_config()
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
@unittest.skip("Unsupported test.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
|
||||
|
||||
@slow
|
||||
class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
def get_file_format(self, seed, shape):
|
||||
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
|
||||
dtype = torch.float16 if fp16 else torch.float32
|
||||
image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
|
||||
return image
|
||||
|
||||
def get_sd_vae_model(self, model_id="cross-attention/asymmetric-autoencoder-kl-x-1-5", fp16=False):
|
||||
revision = "main"
|
||||
torch_dtype = torch.float32
|
||||
|
||||
model = AsymmetricAutoencoderKL.from_pretrained(
|
||||
model_id,
|
||||
torch_dtype=torch_dtype,
|
||||
revision=revision,
|
||||
)
|
||||
model.to(torch_device).eval()
|
||||
|
||||
return model
|
||||
|
||||
def get_generator(self, seed=0):
|
||||
generator_device = "cpu" if not torch_device.startswith("cuda") else "cuda"
|
||||
if torch_device != "mps":
|
||||
return torch.Generator(device=generator_device).manual_seed(seed)
|
||||
return torch.manual_seed(seed)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
# fmt: off
|
||||
[
|
||||
33,
|
||||
[-0.0336, 0.3011, 0.1764, 0.0087, -0.3401, 0.3645, -0.1247, 0.1205],
|
||||
[-0.1603, 0.9878, -0.0495, -0.0790, -0.2709, 0.8375, -0.2060, -0.0824],
|
||||
],
|
||||
[
|
||||
47,
|
||||
[0.4400, 0.0543, 0.2873, 0.2946, 0.0553, 0.0839, -0.1585, 0.2529],
|
||||
[-0.2376, 0.1168, 0.1332, -0.4840, -0.2508, -0.0791, -0.0493, -0.4089],
|
||||
],
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
def test_stable_diffusion(self, seed, expected_slice, expected_slice_mps):
|
||||
model = self.get_sd_vae_model()
|
||||
image = self.get_sd_image(seed)
|
||||
generator = self.get_generator(seed)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model(image, generator=generator, sample_posterior=True).sample
|
||||
|
||||
assert sample.shape == image.shape
|
||||
|
||||
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
|
||||
expected_output_slice = torch.tensor(expected_slice_mps if torch_device == "mps" else expected_slice)
|
||||
|
||||
assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
# fmt: off
|
||||
[
|
||||
33,
|
||||
[-0.0340, 0.2870, 0.1698, -0.0105, -0.3448, 0.3529, -0.1321, 0.1097],
|
||||
[-0.0344, 0.2912, 0.1687, -0.0137, -0.3462, 0.3552, -0.1337, 0.1078],
|
||||
],
|
||||
[
|
||||
47,
|
||||
[0.4397, 0.0550, 0.2873, 0.2946, 0.0567, 0.0855, -0.1580, 0.2531],
|
||||
[0.4397, 0.0550, 0.2873, 0.2946, 0.0567, 0.0855, -0.1580, 0.2531],
|
||||
],
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
def test_stable_diffusion_mode(self, seed, expected_slice, expected_slice_mps):
|
||||
model = self.get_sd_vae_model()
|
||||
image = self.get_sd_image(seed)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model(image).sample
|
||||
|
||||
assert sample.shape == image.shape
|
||||
|
||||
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
|
||||
expected_output_slice = torch.tensor(expected_slice_mps if torch_device == "mps" else expected_slice)
|
||||
|
||||
assert torch_all_close(output_slice, expected_output_slice, atol=3e-3)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
# fmt: off
|
||||
[13, [-0.0521, -0.2939, 0.1540, -0.1855, -0.5936, -0.3138, -0.4579, -0.2275]],
|
||||
[37, [-0.1820, -0.4345, -0.0455, -0.2923, -0.8035, -0.5089, -0.4795, -0.3106]],
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
@require_torch_accelerator
|
||||
@skip_mps
|
||||
def test_stable_diffusion_decode(self, seed, expected_slice):
|
||||
model = self.get_sd_vae_model()
|
||||
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model.decode(encoding).sample
|
||||
|
||||
assert list(sample.shape) == [3, 3, 512, 512]
|
||||
|
||||
output_slice = sample[-1, -2:, :2, -2:].flatten().cpu()
|
||||
expected_output_slice = torch.tensor(expected_slice)
|
||||
|
||||
assert torch_all_close(output_slice, expected_output_slice, atol=2e-3)
|
||||
|
||||
@parameterized.expand([(13,), (16,), (37,)])
|
||||
@require_torch_gpu
|
||||
@unittest.skipIf(
|
||||
not is_xformers_available(),
|
||||
reason="xformers is not required when using PyTorch 2.0.",
|
||||
)
|
||||
def test_stable_diffusion_decode_xformers_vs_2_0(self, seed):
|
||||
model = self.get_sd_vae_model()
|
||||
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model.decode(encoding).sample
|
||||
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
with torch.no_grad():
|
||||
sample_2 = model.decode(encoding).sample
|
||||
|
||||
assert list(sample.shape) == [3, 3, 512, 512]
|
||||
|
||||
assert torch_all_close(sample, sample_2, atol=5e-2)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
# fmt: off
|
||||
[33, [-0.3001, 0.0918, -2.6984, -3.9720, -3.2099, -5.0353, 1.7338, -0.2065, 3.4267]],
|
||||
[47, [-1.5030, -4.3871, -6.0355, -9.1157, -1.6661, -2.7853, 2.1607, -5.0823, 2.5633]],
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
def test_stable_diffusion_encode_sample(self, seed, expected_slice):
|
||||
model = self.get_sd_vae_model()
|
||||
image = self.get_sd_image(seed)
|
||||
generator = self.get_generator(seed)
|
||||
|
||||
with torch.no_grad():
|
||||
dist = model.encode(image).latent_dist
|
||||
sample = dist.sample(generator=generator)
|
||||
|
||||
assert list(sample.shape) == [image.shape[0], 4] + [i // 8 for i in image.shape[2:]]
|
||||
|
||||
output_slice = sample[0, -1, -3:, -3:].flatten().cpu()
|
||||
expected_output_slice = torch.tensor(expected_slice)
|
||||
|
||||
tolerance = 3e-3 if torch_device != "mps" else 1e-2
|
||||
assert torch_all_close(output_slice, expected_output_slice, atol=tolerance)
|
||||
468
tests/models/autoencoders/test_models_autoencoder_kl.py
Normal file
468
tests/models/autoencoders/test_models_autoencoder_kl.py
Normal file
@@ -0,0 +1,468 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
|
||||
from diffusers import AutoencoderKL
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
load_hf_numpy,
|
||||
require_torch_accelerator,
|
||||
require_torch_accelerator_with_fp16,
|
||||
require_torch_gpu,
|
||||
skip_mps,
|
||||
slow,
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKL
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
|
||||
def get_autoencoder_kl_config(self, block_out_channels=None, norm_num_groups=None):
|
||||
block_out_channels = block_out_channels or [2, 4]
|
||||
norm_num_groups = norm_num_groups or 2
|
||||
init_dict = {
|
||||
"block_out_channels": block_out_channels,
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"down_block_types": ["DownEncoderBlock2D"] * len(block_out_channels),
|
||||
"up_block_types": ["UpDecoderBlock2D"] * len(block_out_channels),
|
||||
"latent_channels": 4,
|
||||
"norm_num_groups": norm_num_groups,
|
||||
}
|
||||
return init_dict
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
|
||||
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
|
||||
return {"sample": image}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = self.get_autoencoder_kl_config()
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_enable_disable_tiling(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
inputs_dict.update({"return_dict": False})
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.enable_tiling()
|
||||
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertLess(
|
||||
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
|
||||
0.5,
|
||||
"VAE tiling should not affect the inference results",
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.disable_tiling()
|
||||
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertEqual(
|
||||
output_without_tiling.detach().cpu().numpy().all(),
|
||||
output_without_tiling_2.detach().cpu().numpy().all(),
|
||||
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
|
||||
)
|
||||
|
||||
def test_enable_disable_slicing(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
inputs_dict.update({"return_dict": False})
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.enable_slicing()
|
||||
output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertLess(
|
||||
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
|
||||
0.5,
|
||||
"VAE slicing should not affect the inference results",
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.disable_slicing()
|
||||
output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertEqual(
|
||||
output_without_slicing.detach().cpu().numpy().all(),
|
||||
output_without_slicing_2.detach().cpu().numpy().all(),
|
||||
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
|
||||
)
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"Decoder", "Encoder"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
def test_from_pretrained_hub(self):
|
||||
model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertEqual(len(loading_info["missing_keys"]), 0)
|
||||
|
||||
model.to(torch_device)
|
||||
image = model(**self.dummy_input)
|
||||
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
def test_output_pretrained(self):
|
||||
model = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy")
|
||||
model = model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# Keep generator on CPU for non-CUDA devices to compare outputs with CPU result tensors
|
||||
generator_device = "cpu" if not torch_device.startswith("cuda") else "cuda"
|
||||
if torch_device != "mps":
|
||||
generator = torch.Generator(device=generator_device).manual_seed(0)
|
||||
else:
|
||||
generator = torch.manual_seed(0)
|
||||
|
||||
image = torch.randn(
|
||||
1,
|
||||
model.config.in_channels,
|
||||
model.config.sample_size,
|
||||
model.config.sample_size,
|
||||
generator=torch.manual_seed(0),
|
||||
)
|
||||
image = image.to(torch_device)
|
||||
with torch.no_grad():
|
||||
output = model(image, sample_posterior=True, generator=generator).sample
|
||||
|
||||
output_slice = output[0, -1, -3:, -3:].flatten().cpu()
|
||||
|
||||
# Since the VAE Gaussian prior's generator is seeded on the appropriate device,
|
||||
# the expected output slices are not the same for CPU and GPU.
|
||||
if torch_device == "mps":
|
||||
expected_output_slice = torch.tensor(
|
||||
[
|
||||
-4.0078e-01,
|
||||
-3.8323e-04,
|
||||
-1.2681e-01,
|
||||
-1.1462e-01,
|
||||
2.0095e-01,
|
||||
1.0893e-01,
|
||||
-8.8247e-02,
|
||||
-3.0361e-01,
|
||||
-9.8644e-03,
|
||||
]
|
||||
)
|
||||
elif generator_device == "cpu":
|
||||
expected_output_slice = torch.tensor(
|
||||
[
|
||||
-0.1352,
|
||||
0.0878,
|
||||
0.0419,
|
||||
-0.0818,
|
||||
-0.1069,
|
||||
0.0688,
|
||||
-0.1458,
|
||||
-0.4446,
|
||||
-0.0026,
|
||||
]
|
||||
)
|
||||
else:
|
||||
expected_output_slice = torch.tensor(
|
||||
[
|
||||
-0.2421,
|
||||
0.4642,
|
||||
0.2507,
|
||||
-0.0438,
|
||||
0.0682,
|
||||
0.3160,
|
||||
-0.2018,
|
||||
-0.0727,
|
||||
0.2485,
|
||||
]
|
||||
)
|
||||
|
||||
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
|
||||
|
||||
|
||||
@slow
|
||||
class AutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
def get_file_format(self, seed, shape):
|
||||
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
|
||||
dtype = torch.float16 if fp16 else torch.float32
|
||||
image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
|
||||
return image
|
||||
|
||||
def get_sd_vae_model(self, model_id="CompVis/stable-diffusion-v1-4", fp16=False):
|
||||
revision = "fp16" if fp16 else None
|
||||
torch_dtype = torch.float16 if fp16 else torch.float32
|
||||
|
||||
model = AutoencoderKL.from_pretrained(
|
||||
model_id,
|
||||
subfolder="vae",
|
||||
torch_dtype=torch_dtype,
|
||||
revision=revision,
|
||||
)
|
||||
model.to(torch_device)
|
||||
|
||||
return model
|
||||
|
||||
def get_generator(self, seed=0):
|
||||
generator_device = "cpu" if not torch_device.startswith("cuda") else "cuda"
|
||||
if torch_device != "mps":
|
||||
return torch.Generator(device=generator_device).manual_seed(seed)
|
||||
return torch.manual_seed(seed)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
# fmt: off
|
||||
[
|
||||
33,
|
||||
[-0.1556, 0.9848, -0.0410, -0.0642, -0.2685, 0.8381, -0.2004, -0.0700],
|
||||
[-0.2395, 0.0098, 0.0102, -0.0709, -0.2840, -0.0274, -0.0718, -0.1824],
|
||||
],
|
||||
[
|
||||
47,
|
||||
[-0.2376, 0.1200, 0.1337, -0.4830, -0.2504, -0.0759, -0.0486, -0.4077],
|
||||
[0.0350, 0.0847, 0.0467, 0.0344, -0.0842, -0.0547, -0.0633, -0.1131],
|
||||
],
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
def test_stable_diffusion(self, seed, expected_slice, expected_slice_mps):
|
||||
model = self.get_sd_vae_model()
|
||||
image = self.get_sd_image(seed)
|
||||
generator = self.get_generator(seed)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model(image, generator=generator, sample_posterior=True).sample
|
||||
|
||||
assert sample.shape == image.shape
|
||||
|
||||
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
|
||||
expected_output_slice = torch.tensor(expected_slice_mps if torch_device == "mps" else expected_slice)
|
||||
|
||||
assert torch_all_close(output_slice, expected_output_slice, atol=3e-3)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
# fmt: off
|
||||
[33, [-0.0513, 0.0289, 1.3799, 0.2166, -0.2573, -0.0871, 0.5103, -0.0999]],
|
||||
[47, [-0.4128, -0.1320, -0.3704, 0.1965, -0.4116, -0.2332, -0.3340, 0.2247]],
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
@require_torch_accelerator_with_fp16
|
||||
def test_stable_diffusion_fp16(self, seed, expected_slice):
|
||||
model = self.get_sd_vae_model(fp16=True)
|
||||
image = self.get_sd_image(seed, fp16=True)
|
||||
generator = self.get_generator(seed)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model(image, generator=generator, sample_posterior=True).sample
|
||||
|
||||
assert sample.shape == image.shape
|
||||
|
||||
output_slice = sample[-1, -2:, :2, -2:].flatten().float().cpu()
|
||||
expected_output_slice = torch.tensor(expected_slice)
|
||||
|
||||
assert torch_all_close(output_slice, expected_output_slice, atol=1e-2)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
# fmt: off
|
||||
[
|
||||
33,
|
||||
[-0.1609, 0.9866, -0.0487, -0.0777, -0.2716, 0.8368, -0.2055, -0.0814],
|
||||
[-0.2395, 0.0098, 0.0102, -0.0709, -0.2840, -0.0274, -0.0718, -0.1824],
|
||||
],
|
||||
[
|
||||
47,
|
||||
[-0.2377, 0.1147, 0.1333, -0.4841, -0.2506, -0.0805, -0.0491, -0.4085],
|
||||
[0.0350, 0.0847, 0.0467, 0.0344, -0.0842, -0.0547, -0.0633, -0.1131],
|
||||
],
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
def test_stable_diffusion_mode(self, seed, expected_slice, expected_slice_mps):
|
||||
model = self.get_sd_vae_model()
|
||||
image = self.get_sd_image(seed)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model(image).sample
|
||||
|
||||
assert sample.shape == image.shape
|
||||
|
||||
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
|
||||
expected_output_slice = torch.tensor(expected_slice_mps if torch_device == "mps" else expected_slice)
|
||||
|
||||
assert torch_all_close(output_slice, expected_output_slice, atol=3e-3)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
# fmt: off
|
||||
[13, [-0.2051, -0.1803, -0.2311, -0.2114, -0.3292, -0.3574, -0.2953, -0.3323]],
|
||||
[37, [-0.2632, -0.2625, -0.2199, -0.2741, -0.4539, -0.4990, -0.3720, -0.4925]],
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
@require_torch_accelerator
|
||||
@skip_mps
|
||||
def test_stable_diffusion_decode(self, seed, expected_slice):
|
||||
model = self.get_sd_vae_model()
|
||||
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model.decode(encoding).sample
|
||||
|
||||
assert list(sample.shape) == [3, 3, 512, 512]
|
||||
|
||||
output_slice = sample[-1, -2:, :2, -2:].flatten().cpu()
|
||||
expected_output_slice = torch.tensor(expected_slice)
|
||||
|
||||
assert torch_all_close(output_slice, expected_output_slice, atol=1e-3)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
# fmt: off
|
||||
[27, [-0.0369, 0.0207, -0.0776, -0.0682, -0.1747, -0.1930, -0.1465, -0.2039]],
|
||||
[16, [-0.1628, -0.2134, -0.2747, -0.2642, -0.3774, -0.4404, -0.3687, -0.4277]],
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
@require_torch_accelerator_with_fp16
|
||||
def test_stable_diffusion_decode_fp16(self, seed, expected_slice):
|
||||
model = self.get_sd_vae_model(fp16=True)
|
||||
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64), fp16=True)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model.decode(encoding).sample
|
||||
|
||||
assert list(sample.shape) == [3, 3, 512, 512]
|
||||
|
||||
output_slice = sample[-1, -2:, :2, -2:].flatten().float().cpu()
|
||||
expected_output_slice = torch.tensor(expected_slice)
|
||||
|
||||
assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)
|
||||
|
||||
@parameterized.expand([(13,), (16,), (27,)])
|
||||
@require_torch_gpu
|
||||
@unittest.skipIf(
|
||||
not is_xformers_available(),
|
||||
reason="xformers is not required when using PyTorch 2.0.",
|
||||
)
|
||||
def test_stable_diffusion_decode_xformers_vs_2_0_fp16(self, seed):
|
||||
model = self.get_sd_vae_model(fp16=True)
|
||||
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64), fp16=True)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model.decode(encoding).sample
|
||||
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
with torch.no_grad():
|
||||
sample_2 = model.decode(encoding).sample
|
||||
|
||||
assert list(sample.shape) == [3, 3, 512, 512]
|
||||
|
||||
assert torch_all_close(sample, sample_2, atol=1e-1)
|
||||
|
||||
@parameterized.expand([(13,), (16,), (37,)])
|
||||
@require_torch_gpu
|
||||
@unittest.skipIf(
|
||||
not is_xformers_available(),
|
||||
reason="xformers is not required when using PyTorch 2.0.",
|
||||
)
|
||||
def test_stable_diffusion_decode_xformers_vs_2_0(self, seed):
|
||||
model = self.get_sd_vae_model()
|
||||
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model.decode(encoding).sample
|
||||
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
with torch.no_grad():
|
||||
sample_2 = model.decode(encoding).sample
|
||||
|
||||
assert list(sample.shape) == [3, 3, 512, 512]
|
||||
|
||||
assert torch_all_close(sample, sample_2, atol=1e-2)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
# fmt: off
|
||||
[33, [-0.3001, 0.0918, -2.6984, -3.9720, -3.2099, -5.0353, 1.7338, -0.2065, 3.4267]],
|
||||
[47, [-1.5030, -4.3871, -6.0355, -9.1157, -1.6661, -2.7853, 2.1607, -5.0823, 2.5633]],
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
def test_stable_diffusion_encode_sample(self, seed, expected_slice):
|
||||
model = self.get_sd_vae_model()
|
||||
image = self.get_sd_image(seed)
|
||||
generator = self.get_generator(seed)
|
||||
|
||||
with torch.no_grad():
|
||||
dist = model.encode(image).latent_dist
|
||||
sample = dist.sample(generator=generator)
|
||||
|
||||
assert list(sample.shape) == [image.shape[0], 4] + [i // 8 for i in image.shape[2:]]
|
||||
|
||||
output_slice = sample[0, -1, -3:, -3:].flatten().cpu()
|
||||
expected_output_slice = torch.tensor(expected_slice)
|
||||
|
||||
tolerance = 3e-3 if torch_device != "mps" else 1e-2
|
||||
assert torch_all_close(output_slice, expected_output_slice, atol=tolerance)
|
||||
@@ -0,0 +1,179 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import AutoencoderKLCogVideoX
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderKLCogVideoXTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKLCogVideoX
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
|
||||
def get_autoencoder_kl_cogvideox_config(self):
|
||||
return {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"down_block_types": (
|
||||
"CogVideoXDownBlock3D",
|
||||
"CogVideoXDownBlock3D",
|
||||
"CogVideoXDownBlock3D",
|
||||
"CogVideoXDownBlock3D",
|
||||
),
|
||||
"up_block_types": (
|
||||
"CogVideoXUpBlock3D",
|
||||
"CogVideoXUpBlock3D",
|
||||
"CogVideoXUpBlock3D",
|
||||
"CogVideoXUpBlock3D",
|
||||
),
|
||||
"block_out_channels": (8, 8, 8, 8),
|
||||
"latent_channels": 4,
|
||||
"layers_per_block": 1,
|
||||
"norm_num_groups": 2,
|
||||
"temporal_compression_ratio": 4,
|
||||
}
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_frames = 8
|
||||
num_channels = 3
|
||||
sizes = (16, 16)
|
||||
|
||||
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
|
||||
|
||||
return {"sample": image}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 8, 16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 8, 16, 16)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = self.get_autoencoder_kl_cogvideox_config()
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_enable_disable_tiling(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
inputs_dict.update({"return_dict": False})
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.enable_tiling()
|
||||
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertLess(
|
||||
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
|
||||
0.5,
|
||||
"VAE tiling should not affect the inference results",
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.disable_tiling()
|
||||
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertEqual(
|
||||
output_without_tiling.detach().cpu().numpy().all(),
|
||||
output_without_tiling_2.detach().cpu().numpy().all(),
|
||||
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
|
||||
)
|
||||
|
||||
def test_enable_disable_slicing(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
inputs_dict.update({"return_dict": False})
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.enable_slicing()
|
||||
output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertLess(
|
||||
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
|
||||
0.5,
|
||||
"VAE slicing should not affect the inference results",
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.disable_slicing()
|
||||
output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertEqual(
|
||||
output_without_slicing.detach().cpu().numpy().all(),
|
||||
output_without_slicing_2.detach().cpu().numpy().all(),
|
||||
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
|
||||
)
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"CogVideoXDownBlock3D",
|
||||
"CogVideoXDecoder3D",
|
||||
"CogVideoXEncoder3D",
|
||||
"CogVideoXUpBlock3D",
|
||||
"CogVideoXMidBlock3D",
|
||||
}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
def test_forward_with_norm_groups(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["norm_num_groups"] = 16
|
||||
init_dict["block_out_channels"] = (16, 32, 32, 32)
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.to_tuple()[0]
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
@unittest.skip("Unsupported test.")
|
||||
def test_outputs_equivalence(self):
|
||||
pass
|
||||
@@ -0,0 +1,73 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from diffusers import AutoencoderKLTemporalDecoder
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderKLTemporalDecoderTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKLTemporalDecoder
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 3
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
|
||||
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
num_frames = 3
|
||||
|
||||
return {"sample": image, "num_frames": num_frames}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"block_out_channels": [32, 64],
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
"latent_channels": 4,
|
||||
"layers_per_block": 2,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"Encoder", "TemporalDecoder"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@unittest.skip("Test unsupported.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
228
tests/models/autoencoders/test_models_autoencoder_oobleck.py
Normal file
228
tests/models/autoencoders/test_models_autoencoder_oobleck.py
Normal file
@@ -0,0 +1,228 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from parameterized import parameterized
|
||||
|
||||
from diffusers import AutoencoderOobleck
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
slow,
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderOobleckTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderOobleck
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
|
||||
def get_autoencoder_oobleck_config(self, block_out_channels=None):
|
||||
init_dict = {
|
||||
"encoder_hidden_size": 12,
|
||||
"decoder_channels": 12,
|
||||
"decoder_input_channels": 6,
|
||||
"audio_channels": 2,
|
||||
"downsampling_ratios": [2, 4],
|
||||
"channel_multiples": [1, 2],
|
||||
}
|
||||
return init_dict
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_channels = 2
|
||||
seq_len = 24
|
||||
|
||||
waveform = floats_tensor((batch_size, num_channels, seq_len)).to(torch_device)
|
||||
|
||||
return {"sample": waveform, "sample_posterior": False}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (2, 24)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (2, 24)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = self.get_autoencoder_oobleck_config()
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_enable_disable_slicing(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
inputs_dict.update({"return_dict": False})
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.enable_slicing()
|
||||
output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertLess(
|
||||
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
|
||||
0.5,
|
||||
"VAE slicing should not affect the inference results",
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.disable_slicing()
|
||||
output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertEqual(
|
||||
output_without_slicing.detach().cpu().numpy().all(),
|
||||
output_without_slicing_2.detach().cpu().numpy().all(),
|
||||
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
|
||||
)
|
||||
|
||||
@unittest.skip("Test unsupported.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("No attention module used in this model")
|
||||
def test_set_attn_processor_for_determinism(self):
|
||||
return
|
||||
|
||||
|
||||
@slow
|
||||
class AutoencoderOobleckIntegrationTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def _load_datasamples(self, num_samples):
|
||||
ds = load_dataset(
|
||||
"hf-internal-testing/librispeech_asr_dummy", "clean", split="validation", trust_remote_code=True
|
||||
)
|
||||
# automatic decoding with librispeech
|
||||
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
|
||||
|
||||
return torch.nn.utils.rnn.pad_sequence(
|
||||
[torch.from_numpy(x["array"]) for x in speech_samples], batch_first=True
|
||||
)
|
||||
|
||||
def get_audio(self, audio_sample_size=2097152, fp16=False):
|
||||
dtype = torch.float16 if fp16 else torch.float32
|
||||
audio = self._load_datasamples(2).to(torch_device).to(dtype)
|
||||
|
||||
# pad / crop to audio_sample_size
|
||||
audio = torch.nn.functional.pad(audio[:, :audio_sample_size], pad=(0, audio_sample_size - audio.shape[-1]))
|
||||
|
||||
# todo channel
|
||||
audio = audio.unsqueeze(1).repeat(1, 2, 1).to(torch_device)
|
||||
|
||||
return audio
|
||||
|
||||
def get_oobleck_vae_model(self, model_id="stabilityai/stable-audio-open-1.0", fp16=False):
|
||||
torch_dtype = torch.float16 if fp16 else torch.float32
|
||||
|
||||
model = AutoencoderOobleck.from_pretrained(
|
||||
model_id,
|
||||
subfolder="vae",
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
model.to(torch_device)
|
||||
|
||||
return model
|
||||
|
||||
def get_generator(self, seed=0):
|
||||
generator_device = "cpu" if not torch_device.startswith("cuda") else "cuda"
|
||||
if torch_device != "mps":
|
||||
return torch.Generator(device=generator_device).manual_seed(seed)
|
||||
return torch.manual_seed(seed)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
# fmt: off
|
||||
[33, [1.193e-4, 6.56e-05, 1.314e-4, 3.80e-05, -4.01e-06], 0.001192],
|
||||
[44, [2.77e-05, -2.65e-05, 1.18e-05, -6.94e-05, -9.57e-05], 0.001196],
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
def test_stable_diffusion(self, seed, expected_slice, expected_mean_absolute_diff):
|
||||
model = self.get_oobleck_vae_model()
|
||||
audio = self.get_audio()
|
||||
generator = self.get_generator(seed)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model(audio, generator=generator, sample_posterior=True).sample
|
||||
|
||||
assert sample.shape == audio.shape
|
||||
assert ((sample - audio).abs().mean() - expected_mean_absolute_diff).abs() <= 1e-6
|
||||
|
||||
output_slice = sample[-1, 1, 5:10].cpu()
|
||||
expected_output_slice = torch.tensor(expected_slice)
|
||||
|
||||
assert torch_all_close(output_slice, expected_output_slice, atol=1e-5)
|
||||
|
||||
def test_stable_diffusion_mode(self):
|
||||
model = self.get_oobleck_vae_model()
|
||||
audio = self.get_audio()
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model(audio, sample_posterior=False).sample
|
||||
|
||||
assert sample.shape == audio.shape
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
# fmt: off
|
||||
[33, [1.193e-4, 6.56e-05, 1.314e-4, 3.80e-05, -4.01e-06], 0.001192],
|
||||
[44, [2.77e-05, -2.65e-05, 1.18e-05, -6.94e-05, -9.57e-05], 0.001196],
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
def test_stable_diffusion_encode_decode(self, seed, expected_slice, expected_mean_absolute_diff):
|
||||
model = self.get_oobleck_vae_model()
|
||||
audio = self.get_audio()
|
||||
generator = self.get_generator(seed)
|
||||
|
||||
with torch.no_grad():
|
||||
x = audio
|
||||
posterior = model.encode(x).latent_dist
|
||||
z = posterior.sample(generator=generator)
|
||||
sample = model.decode(z).sample
|
||||
|
||||
# (batch_size, latent_dim, sequence_length)
|
||||
assert posterior.mean.shape == (audio.shape[0], model.config.decoder_input_channels, 1024)
|
||||
|
||||
assert sample.shape == audio.shape
|
||||
assert ((sample - audio).abs().mean() - expected_mean_absolute_diff).abs() <= 1e-6
|
||||
|
||||
output_slice = sample[-1, 1, 5:10].cpu()
|
||||
expected_output_slice = torch.tensor(expected_slice)
|
||||
|
||||
assert torch_all_close(output_slice, expected_output_slice, atol=1e-5)
|
||||
251
tests/models/autoencoders/test_models_autoencoder_tiny.py
Normal file
251
tests/models/autoencoders/test_models_autoencoder_tiny.py
Normal file
@@ -0,0 +1,251 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
|
||||
from diffusers import AutoencoderTiny
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
load_hf_numpy,
|
||||
slow,
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderTinyTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderTiny
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
|
||||
def get_autoencoder_tiny_config(self, block_out_channels=None):
|
||||
block_out_channels = (len(block_out_channels) * [32]) if block_out_channels is not None else [32, 32]
|
||||
init_dict = {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"encoder_block_out_channels": block_out_channels,
|
||||
"decoder_block_out_channels": block_out_channels,
|
||||
"num_encoder_blocks": [b // min(block_out_channels) for b in block_out_channels],
|
||||
"num_decoder_blocks": [b // min(block_out_channels) for b in reversed(block_out_channels)],
|
||||
}
|
||||
return init_dict
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
|
||||
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
|
||||
return {"sample": image}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = self.get_autoencoder_tiny_config()
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
@unittest.skip("Model doesn't yet support smaller resolution.")
|
||||
def test_enable_disable_tiling(self):
|
||||
pass
|
||||
|
||||
def test_enable_disable_slicing(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
inputs_dict.update({"return_dict": False})
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_without_slicing = model(**inputs_dict)[0]
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.enable_slicing()
|
||||
output_with_slicing = model(**inputs_dict)[0]
|
||||
|
||||
self.assertLess(
|
||||
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
|
||||
0.5,
|
||||
"VAE slicing should not affect the inference results",
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.disable_slicing()
|
||||
output_without_slicing_2 = model(**inputs_dict)[0]
|
||||
|
||||
self.assertEqual(
|
||||
output_without_slicing.detach().cpu().numpy().all(),
|
||||
output_without_slicing_2.detach().cpu().numpy().all(),
|
||||
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
|
||||
)
|
||||
|
||||
@unittest.skip("Test not supported.")
|
||||
def test_outputs_equivalence(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Test not supported.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"DecoderTiny", "EncoderTiny"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
def test_effective_gradient_checkpointing(self):
|
||||
if not self.model_class._supports_gradient_checkpointing:
|
||||
return # Skip test if model does not support gradient checkpointing
|
||||
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
inputs_dict_copy = copy.deepcopy(inputs_dict)
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
assert not model.is_gradient_checkpointing and model.training
|
||||
|
||||
out = model(**inputs_dict).sample
|
||||
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
|
||||
# we won't calculate the loss and rather backprop on out.sum()
|
||||
model.zero_grad()
|
||||
|
||||
labels = torch.randn_like(out)
|
||||
loss = (out - labels).mean()
|
||||
loss.backward()
|
||||
|
||||
# re-instantiate the model now enabling gradient checkpointing
|
||||
torch.manual_seed(0)
|
||||
model_2 = self.model_class(**init_dict)
|
||||
# clone model
|
||||
model_2.load_state_dict(model.state_dict())
|
||||
model_2.to(torch_device)
|
||||
model_2.enable_gradient_checkpointing()
|
||||
|
||||
assert model_2.is_gradient_checkpointing and model_2.training
|
||||
|
||||
out_2 = model_2(**inputs_dict_copy).sample
|
||||
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
|
||||
# we won't calculate the loss and rather backprop on out.sum()
|
||||
model_2.zero_grad()
|
||||
loss_2 = (out_2 - labels).mean()
|
||||
loss_2.backward()
|
||||
|
||||
# compare the output and parameters gradients
|
||||
self.assertTrue((loss - loss_2).abs() < 1e-3)
|
||||
named_params = dict(model.named_parameters())
|
||||
named_params_2 = dict(model_2.named_parameters())
|
||||
|
||||
for name, param in named_params.items():
|
||||
if "encoder.layers" in name:
|
||||
continue
|
||||
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=3e-2))
|
||||
|
||||
|
||||
@slow
|
||||
class AutoencoderTinyIntegrationTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def get_file_format(self, seed, shape):
|
||||
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
|
||||
|
||||
def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
|
||||
dtype = torch.float16 if fp16 else torch.float32
|
||||
image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
|
||||
return image
|
||||
|
||||
def get_sd_vae_model(self, model_id="hf-internal-testing/taesd-diffusers", fp16=False):
|
||||
torch_dtype = torch.float16 if fp16 else torch.float32
|
||||
|
||||
model = AutoencoderTiny.from_pretrained(model_id, torch_dtype=torch_dtype)
|
||||
model.to(torch_device).eval()
|
||||
return model
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
[(1, 4, 73, 97), (1, 3, 584, 776)],
|
||||
[(1, 4, 97, 73), (1, 3, 776, 584)],
|
||||
[(1, 4, 49, 65), (1, 3, 392, 520)],
|
||||
[(1, 4, 65, 49), (1, 3, 520, 392)],
|
||||
[(1, 4, 49, 49), (1, 3, 392, 392)],
|
||||
]
|
||||
)
|
||||
def test_tae_tiling(self, in_shape, out_shape):
|
||||
model = self.get_sd_vae_model()
|
||||
model.enable_tiling()
|
||||
with torch.no_grad():
|
||||
zeros = torch.zeros(in_shape).to(torch_device)
|
||||
dec = model.decode(zeros).sample
|
||||
assert dec.shape == out_shape
|
||||
|
||||
def test_stable_diffusion(self):
|
||||
model = self.get_sd_vae_model()
|
||||
image = self.get_sd_image(seed=33)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model(image).sample
|
||||
|
||||
assert sample.shape == image.shape
|
||||
|
||||
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
|
||||
expected_output_slice = torch.tensor([0.0093, 0.6385, -0.1274, 0.1631, -0.1762, 0.5232, -0.3108, -0.0382])
|
||||
|
||||
assert torch_all_close(output_slice, expected_output_slice, atol=3e-3)
|
||||
|
||||
@parameterized.expand([(True,), (False,)])
|
||||
def test_tae_roundtrip(self, enable_tiling):
|
||||
# load the autoencoder
|
||||
model = self.get_sd_vae_model()
|
||||
if enable_tiling:
|
||||
model.enable_tiling()
|
||||
|
||||
# make a black image with a white square in the middle,
|
||||
# which is large enough to split across multiple tiles
|
||||
image = -torch.ones(1, 3, 1024, 1024, device=torch_device)
|
||||
image[..., 256:768, 256:768] = 1.0
|
||||
|
||||
# round-trip the image through the autoencoder
|
||||
with torch.no_grad():
|
||||
sample = model(image).sample
|
||||
|
||||
# the autoencoder reconstruction should match original image, sorta
|
||||
def downscale(x):
|
||||
return torch.nn.functional.avg_pool2d(x, model.spatial_scale_factor)
|
||||
|
||||
assert torch_all_close(downscale(sample), downscale(image), atol=0.125)
|
||||
300
tests/models/autoencoders/test_models_consistency_decoder_vae.py
Normal file
300
tests/models/autoencoders/test_models_consistency_decoder_vae.py
Normal file
@@ -0,0 +1,300 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import ConsistencyDecoderVAE, StableDiffusionPipeline
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
load_image,
|
||||
slow,
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = ConsistencyDecoderVAE
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
forward_requires_fresh_args = True
|
||||
|
||||
def get_consistency_vae_config(self, block_out_channels=None, norm_num_groups=None):
|
||||
block_out_channels = block_out_channels or [2, 4]
|
||||
norm_num_groups = norm_num_groups or 2
|
||||
return {
|
||||
"encoder_block_out_channels": block_out_channels,
|
||||
"encoder_in_channels": 3,
|
||||
"encoder_out_channels": 4,
|
||||
"encoder_down_block_types": ["DownEncoderBlock2D"] * len(block_out_channels),
|
||||
"decoder_add_attention": False,
|
||||
"decoder_block_out_channels": block_out_channels,
|
||||
"decoder_down_block_types": ["ResnetDownsampleBlock2D"] * len(block_out_channels),
|
||||
"decoder_downsample_padding": 1,
|
||||
"decoder_in_channels": 7,
|
||||
"decoder_layers_per_block": 1,
|
||||
"decoder_norm_eps": 1e-05,
|
||||
"decoder_norm_num_groups": norm_num_groups,
|
||||
"encoder_norm_num_groups": norm_num_groups,
|
||||
"decoder_num_train_timesteps": 1024,
|
||||
"decoder_out_channels": 6,
|
||||
"decoder_resnet_time_scale_shift": "scale_shift",
|
||||
"decoder_time_embedding_type": "learned",
|
||||
"decoder_up_block_types": ["ResnetUpsampleBlock2D"] * len(block_out_channels),
|
||||
"scaling_factor": 1,
|
||||
"latent_channels": 4,
|
||||
}
|
||||
|
||||
def inputs_dict(self, seed=None):
|
||||
if seed is None:
|
||||
generator = torch.Generator("cpu").manual_seed(0)
|
||||
else:
|
||||
generator = torch.Generator("cpu").manual_seed(seed)
|
||||
image = randn_tensor((4, 3, 32, 32), generator=generator, device=torch.device(torch_device))
|
||||
|
||||
return {"sample": image, "generator": generator}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
@property
|
||||
def init_dict(self):
|
||||
return self.get_consistency_vae_config()
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return self.init_dict, self.inputs_dict()
|
||||
|
||||
def test_enable_disable_tiling(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
inputs_dict.update({"return_dict": False})
|
||||
_ = inputs_dict.pop("generator")
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.enable_tiling()
|
||||
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertLess(
|
||||
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
|
||||
0.5,
|
||||
"VAE tiling should not affect the inference results",
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.disable_tiling()
|
||||
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertEqual(
|
||||
output_without_tiling.detach().cpu().numpy().all(),
|
||||
output_without_tiling_2.detach().cpu().numpy().all(),
|
||||
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
|
||||
)
|
||||
|
||||
def test_enable_disable_slicing(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
inputs_dict.update({"return_dict": False})
|
||||
_ = inputs_dict.pop("generator")
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.enable_slicing()
|
||||
output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertLess(
|
||||
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
|
||||
0.5,
|
||||
"VAE slicing should not affect the inference results",
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.disable_slicing()
|
||||
output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertEqual(
|
||||
output_without_slicing.detach().cpu().numpy().all(),
|
||||
output_without_slicing_2.detach().cpu().numpy().all(),
|
||||
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
|
||||
)
|
||||
|
||||
|
||||
@slow
|
||||
class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# clean up the VRAM before each test
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@torch.no_grad()
|
||||
def test_encode_decode(self):
|
||||
vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder") # TODO - update
|
||||
vae.to(torch_device)
|
||||
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/img2img/sketch-mountains-input.jpg"
|
||||
).resize((256, 256))
|
||||
image = torch.from_numpy(np.array(image).transpose(2, 0, 1).astype(np.float32) / 127.5 - 1)[None, :, :, :].to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
latent = vae.encode(image).latent_dist.mean
|
||||
|
||||
sample = vae.decode(latent, generator=torch.Generator("cpu").manual_seed(0)).sample
|
||||
|
||||
actual_output = sample[0, :2, :2, :2].flatten().cpu()
|
||||
expected_output = torch.tensor([-0.0141, -0.0014, 0.0115, 0.0086, 0.1051, 0.1053, 0.1031, 0.1024])
|
||||
|
||||
assert torch_all_close(actual_output, expected_output, atol=5e-3)
|
||||
|
||||
def test_sd(self):
|
||||
vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder") # TODO - update
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5", vae=vae, safety_checker=None
|
||||
)
|
||||
pipe.to(torch_device)
|
||||
|
||||
out = pipe(
|
||||
"horse",
|
||||
num_inference_steps=2,
|
||||
output_type="pt",
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
).images[0]
|
||||
|
||||
actual_output = out[:2, :2, :2].flatten().cpu()
|
||||
expected_output = torch.tensor([0.7686, 0.8228, 0.6489, 0.7455, 0.8661, 0.8797, 0.8241, 0.8759])
|
||||
|
||||
assert torch_all_close(actual_output, expected_output, atol=5e-3)
|
||||
|
||||
def test_encode_decode_f16(self):
|
||||
vae = ConsistencyDecoderVAE.from_pretrained(
|
||||
"openai/consistency-decoder", torch_dtype=torch.float16
|
||||
) # TODO - update
|
||||
vae.to(torch_device)
|
||||
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/img2img/sketch-mountains-input.jpg"
|
||||
).resize((256, 256))
|
||||
image = (
|
||||
torch.from_numpy(np.array(image).transpose(2, 0, 1).astype(np.float32) / 127.5 - 1)[None, :, :, :]
|
||||
.half()
|
||||
.to(torch_device)
|
||||
)
|
||||
|
||||
latent = vae.encode(image).latent_dist.mean
|
||||
|
||||
sample = vae.decode(latent, generator=torch.Generator("cpu").manual_seed(0)).sample
|
||||
|
||||
actual_output = sample[0, :2, :2, :2].flatten().cpu()
|
||||
expected_output = torch.tensor(
|
||||
[-0.0111, -0.0125, -0.0017, -0.0007, 0.1257, 0.1465, 0.1450, 0.1471],
|
||||
dtype=torch.float16,
|
||||
)
|
||||
|
||||
assert torch_all_close(actual_output, expected_output, atol=5e-3)
|
||||
|
||||
def test_sd_f16(self):
|
||||
vae = ConsistencyDecoderVAE.from_pretrained(
|
||||
"openai/consistency-decoder", torch_dtype=torch.float16
|
||||
) # TODO - update
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float16,
|
||||
vae=vae,
|
||||
safety_checker=None,
|
||||
)
|
||||
pipe.to(torch_device)
|
||||
|
||||
out = pipe(
|
||||
"horse",
|
||||
num_inference_steps=2,
|
||||
output_type="pt",
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
).images[0]
|
||||
|
||||
actual_output = out[:2, :2, :2].flatten().cpu()
|
||||
expected_output = torch.tensor(
|
||||
[0.0000, 0.0249, 0.0000, 0.0000, 0.1709, 0.2773, 0.0471, 0.1035],
|
||||
dtype=torch.float16,
|
||||
)
|
||||
|
||||
assert torch_all_close(actual_output, expected_output, atol=5e-3)
|
||||
|
||||
def test_vae_tiling(self):
|
||||
vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16)
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5", vae=vae, safety_checker=None, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
out_1 = pipe(
|
||||
"horse",
|
||||
num_inference_steps=2,
|
||||
output_type="pt",
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
).images[0]
|
||||
|
||||
# make sure tiled vae decode yields the same result
|
||||
pipe.enable_vae_tiling()
|
||||
out_2 = pipe(
|
||||
"horse",
|
||||
num_inference_steps=2,
|
||||
output_type="pt",
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
).images[0]
|
||||
|
||||
assert torch_all_close(out_1, out_2, atol=5e-3)
|
||||
|
||||
# test that tiled decode works with various shapes
|
||||
shapes = [(1, 4, 73, 97), (1, 4, 97, 73), (1, 4, 49, 65), (1, 4, 65, 49)]
|
||||
with torch.no_grad():
|
||||
for shape in shapes:
|
||||
image = torch.zeros(shape, device=torch_device, dtype=pipe.vae.dtype)
|
||||
pipe.vae.decode(image)
|
||||
File diff suppressed because it is too large
Load Diff
86
tests/models/autoencoders/vae.py
Normal file
86
tests/models/autoencoders/vae.py
Normal file
@@ -0,0 +1,86 @@
|
||||
def get_autoencoder_kl_config(block_out_channels=None, norm_num_groups=None):
|
||||
block_out_channels = block_out_channels or [2, 4]
|
||||
norm_num_groups = norm_num_groups or 2
|
||||
init_dict = {
|
||||
"block_out_channels": block_out_channels,
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"down_block_types": ["DownEncoderBlock2D"] * len(block_out_channels),
|
||||
"up_block_types": ["UpDecoderBlock2D"] * len(block_out_channels),
|
||||
"latent_channels": 4,
|
||||
"norm_num_groups": norm_num_groups,
|
||||
}
|
||||
return init_dict
|
||||
|
||||
|
||||
def get_asym_autoencoder_kl_config(block_out_channels=None, norm_num_groups=None):
|
||||
block_out_channels = block_out_channels or [2, 4]
|
||||
norm_num_groups = norm_num_groups or 2
|
||||
init_dict = {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"down_block_types": ["DownEncoderBlock2D"] * len(block_out_channels),
|
||||
"down_block_out_channels": block_out_channels,
|
||||
"layers_per_down_block": 1,
|
||||
"up_block_types": ["UpDecoderBlock2D"] * len(block_out_channels),
|
||||
"up_block_out_channels": block_out_channels,
|
||||
"layers_per_up_block": 1,
|
||||
"act_fn": "silu",
|
||||
"latent_channels": 4,
|
||||
"norm_num_groups": norm_num_groups,
|
||||
"sample_size": 32,
|
||||
"scaling_factor": 0.18215,
|
||||
}
|
||||
return init_dict
|
||||
|
||||
|
||||
def get_autoencoder_tiny_config(block_out_channels=None):
|
||||
block_out_channels = (len(block_out_channels) * [32]) if block_out_channels is not None else [32, 32]
|
||||
init_dict = {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"encoder_block_out_channels": block_out_channels,
|
||||
"decoder_block_out_channels": block_out_channels,
|
||||
"num_encoder_blocks": [b // min(block_out_channels) for b in block_out_channels],
|
||||
"num_decoder_blocks": [b // min(block_out_channels) for b in reversed(block_out_channels)],
|
||||
}
|
||||
return init_dict
|
||||
|
||||
|
||||
def get_consistency_vae_config(block_out_channels=None, norm_num_groups=None):
|
||||
block_out_channels = block_out_channels or [2, 4]
|
||||
norm_num_groups = norm_num_groups or 2
|
||||
return {
|
||||
"encoder_block_out_channels": block_out_channels,
|
||||
"encoder_in_channels": 3,
|
||||
"encoder_out_channels": 4,
|
||||
"encoder_down_block_types": ["DownEncoderBlock2D"] * len(block_out_channels),
|
||||
"decoder_add_attention": False,
|
||||
"decoder_block_out_channels": block_out_channels,
|
||||
"decoder_down_block_types": ["ResnetDownsampleBlock2D"] * len(block_out_channels),
|
||||
"decoder_downsample_padding": 1,
|
||||
"decoder_in_channels": 7,
|
||||
"decoder_layers_per_block": 1,
|
||||
"decoder_norm_eps": 1e-05,
|
||||
"decoder_norm_num_groups": norm_num_groups,
|
||||
"encoder_norm_num_groups": norm_num_groups,
|
||||
"decoder_num_train_timesteps": 1024,
|
||||
"decoder_out_channels": 6,
|
||||
"decoder_resnet_time_scale_shift": "scale_shift",
|
||||
"decoder_time_embedding_type": "learned",
|
||||
"decoder_up_block_types": ["ResnetUpsampleBlock2D"] * len(block_out_channels),
|
||||
"scaling_factor": 1,
|
||||
"latent_channels": 4,
|
||||
}
|
||||
|
||||
|
||||
def get_autoencoder_oobleck_config(block_out_channels=None):
|
||||
init_dict = {
|
||||
"encoder_hidden_size": 12,
|
||||
"decoder_channels": 12,
|
||||
"decoder_input_channels": 6,
|
||||
"audio_channels": 2,
|
||||
"downsampling_ratios": [2, 4],
|
||||
"channel_multiples": [1, 2],
|
||||
}
|
||||
return init_dict
|
||||
@@ -858,11 +858,6 @@ class ModelTesterMixin:
|
||||
):
|
||||
if not self.model_class._supports_gradient_checkpointing:
|
||||
return # Skip test if model does not support gradient checkpointing
|
||||
if self.model_class.__name__ in [
|
||||
"UNetSpatioTemporalConditionModel",
|
||||
"AutoencoderKLTemporalDecoder",
|
||||
]:
|
||||
return
|
||||
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ from diffusers.utils.testing_utils import (
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...models.autoencoders.test_models_vae import (
|
||||
from ...models.autoencoders.vae import (
|
||||
get_asym_autoencoder_kl_config,
|
||||
get_autoencoder_kl_config,
|
||||
get_autoencoder_tiny_config,
|
||||
|
||||
@@ -34,7 +34,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, load_image, require_torch_gpu, slow, torch_device
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...models.autoencoders.test_models_vae import (
|
||||
from ...models.autoencoders.vae import (
|
||||
get_asym_autoencoder_kl_config,
|
||||
get_autoencoder_kl_config,
|
||||
get_autoencoder_tiny_config,
|
||||
|
||||
@@ -48,7 +48,7 @@ from diffusers.utils.testing_utils import (
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..models.autoencoders.test_models_vae import (
|
||||
from ..models.autoencoders.vae import (
|
||||
get_asym_autoencoder_kl_config,
|
||||
get_autoencoder_kl_config,
|
||||
get_autoencoder_tiny_config,
|
||||
|
||||
Reference in New Issue
Block a user