From 18fc40c169a82da2fca188b5d0083bda6ac044ab Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 2 Aug 2023 23:58:05 +0530 Subject: [PATCH] [Feat] add tiny Autoencoder for (almost) instant decoding (#4384) * add: model implementation of tiny autoencoder. * add: inits. * push the latest devs. * add: conversion script and finish. * add: scaling factor args. * debugging * fix denormalization. * fix: positional argument. * handle use_torch_2_0_or_xformers. * handle post_quant_conv * handle dtype * fix: sdxl image processor for tiny ae. * fix: sdxl image processor for tiny ae. * unify upcasting logic. * copied from madness. * remove trailing whitespace. * set is_tiny_vae = False * address PR comments. * change to AutoencoderTiny * make act_fn an str throughout * fix: apply_forward_hook decorator call * get rid of the special is_tiny_vae flag. * directly scale the output. * fix dummies? * fix: act_fn. * get rid of the Clamp() layer. * bring back copied from. * movement of the blocks to appropriate modules. * add: docstrings to AutoencoderTiny * add: documentation. * changes to the conversion script. * add doc entry. * settle tests. * style * add one slow test. * fix * fix 2 * fix 2 * fix: 4 * fix: 5 * finish integration tests * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * style --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/_toctree.yml | 2 + docs/source/en/api/models/autoencoder_tiny.md | 45 ++++ .../convert_tiny_autoencoder_to_diffusers.py | 77 +++++++ src/diffusers/__init__.py | 1 + src/diffusers/models/__init__.py | 1 + src/diffusers/models/activations.py | 2 + src/diffusers/models/autoencoder_tiny.py | 193 ++++++++++++++++++ src/diffusers/models/unet_2d_blocks.py | 23 +++ src/diffusers/models/vae.py | 107 +++++++++- src/diffusers/utils/dummy_pt_objects.py | 15 ++ tests/models/test_models_vae.py | 79 ++++++- 11 files changed, 543 insertions(+), 2 deletions(-) create mode 100644 docs/source/en/api/models/autoencoder_tiny.md create mode 100644 scripts/convert_tiny_autoencoder_to_diffusers.py create mode 100644 src/diffusers/models/autoencoder_tiny.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 154728696f..795b023809 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -164,6 +164,8 @@ title: AutoencoderKL - local: api/models/asymmetricautoencoderkl title: AsymmetricAutoencoderKL + - local: api/models/autoencoder_tiny + title: Tiny AutoEncoder - local: api/models/transformer2d title: Transformer2D - local: api/models/transformer_temporal diff --git a/docs/source/en/api/models/autoencoder_tiny.md b/docs/source/en/api/models/autoencoder_tiny.md new file mode 100644 index 0000000000..9b97b6e8e9 --- /dev/null +++ b/docs/source/en/api/models/autoencoder_tiny.md @@ -0,0 +1,45 @@ +# Tiny AutoEncoder + +Tiny AutoEncoder for Stable Diffusion (TAESD) was introduced in [madebyollin/taesd](https://github.com/madebyollin/taesd) by Ollin Boer Bohan. It is a tiny distilled version of Stable Diffusion's VAE that can quickly decode the latents in a [`StableDiffusionPipeline`] or [`StableDiffusionXLPipeline`] almost instantly. + +To use with Stable Diffusion v-2.1: + +```python +import torch +from diffusers import DiffusionPipeline, AutoencoderTiny + +pipe = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-2-1-base", torch_dtype=torch.float16 +) +pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesd", torch_dtype=torch.float16) +pipe = pipe.to("cuda") + +prompt = "slice of delicious New York-style berry cheesecake" +image = pipe(prompt, num_inference_steps=25).images[0] +image.save("cheesecake.png") +``` + +To use with Stable Diffusion XL 1.0 + +```python +import torch +from diffusers import DiffusionPipeline, AutoencoderTiny + +pipe = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 +) +pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.float16) +pipe = pipe.to("cuda") + +prompt = "slice of delicious New York-style berry cheesecake" +image = pipe(prompt, num_inference_steps=25).images[0] +image.save("cheesecake_sdxl.png") +``` + +## AutoencoderTiny + +[[autodoc]] AutoencoderTiny + +## AutoencoderTinyOutput + +[[autodoc]] models.autoencoder_tiny.AutoencoderTinyOutput \ No newline at end of file diff --git a/scripts/convert_tiny_autoencoder_to_diffusers.py b/scripts/convert_tiny_autoencoder_to_diffusers.py new file mode 100644 index 0000000000..3cf6713e6e --- /dev/null +++ b/scripts/convert_tiny_autoencoder_to_diffusers.py @@ -0,0 +1,77 @@ +import argparse + +from diffusers.utils import is_safetensors_available + + +if is_safetensors_available(): + import safetensors.torch +else: + raise ImportError("Please install `safetensors`.") + +from diffusers import AutoencoderTiny + + +""" +Example - From the diffusers root directory: + +Download the weights: +```sh +$ wget -q https://huggingface.co/madebyollin/taesd/resolve/main/taesd_encoder.safetensors +$ wget -q https://huggingface.co/madebyollin/taesd/resolve/main/taesd_decoder.safetensors +``` + +Convert the model: +```sh +$ python scripts/convert_tiny_autoencoder_to_diffusers.py \ + --encoder_ckpt_path taesd_encoder.safetensors \ + --decoder_ckpt_path taesd_decoder.safetensors \ + --dump_path taesd-diffusers +``` +""" + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + parser.add_argument( + "--encoder_ckpt_path", + default=None, + type=str, + required=True, + help="Path to the encoder ckpt.", + ) + parser.add_argument( + "--decoder_ckpt_path", + default=None, + type=str, + required=True, + help="Path to the decoder ckpt.", + ) + parser.add_argument( + "--use_safetensors", action="store_true", help="Whether to serialize in the safetensors format." + ) + args = parser.parse_args() + + print("Loading the original state_dicts of the encoder and the decoder...") + encoder_state_dict = safetensors.torch.load_file(args.encoder_ckpt_path) + decoder_state_dict = safetensors.torch.load_file(args.decoder_ckpt_path) + + print("Populating the state_dicts in the diffusers format...") + tiny_autoencoder = AutoencoderTiny() + new_state_dict = {} + + # Modify the encoder state dict. + for k in encoder_state_dict: + new_state_dict.update({f"encoder.layers.{k}": encoder_state_dict[k]}) + + # Modify the decoder state dict. + for k in decoder_state_dict: + layer_id = int(k.split(".")[0]) - 1 + new_k = str(layer_id) + "." + ".".join(k.split(".")[1:]) + new_state_dict.update({f"decoder.layers.{new_k}": decoder_state_dict[k]}) + + # Assertion tests with the original implementation can be found here: + # https://gist.github.com/sayakpaul/337b0988f08bd2cf2b248206f760e28f + tiny_autoencoder.load_state_dict(new_state_dict) + print("Population successful, serializing...") + tiny_autoencoder.save_pretrained(args.dump_path, safe_serialization=args.use_safetensors) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 3db2683640..9080fed4b8 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -38,6 +38,7 @@ else: from .models import ( AsymmetricAutoencoderKL, AutoencoderKL, + AutoencoderTiny, ControlNetModel, ModelMixin, MultiAdapter, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index ce093adea9..54e77df0ff 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -19,6 +19,7 @@ if is_torch_available(): from .adapter import MultiAdapter, T2IAdapter from .autoencoder_asym_kl import AsymmetricAutoencoderKL from .autoencoder_kl import AutoencoderKL + from .autoencoder_tiny import AutoencoderTiny from .controlnet import ControlNetModel from .dual_transformer_2d import DualTransformer2DModel from .modeling_utils import ModelMixin diff --git a/src/diffusers/models/activations.py b/src/diffusers/models/activations.py index 64759b706e..04c978403f 100644 --- a/src/diffusers/models/activations.py +++ b/src/diffusers/models/activations.py @@ -8,5 +8,7 @@ def get_activation(act_fn): return nn.Mish() elif act_fn == "gelu": return nn.GELU() + elif act_fn == "relu": + return nn.ReLU() else: raise ValueError(f"Unsupported activation function: {act_fn}") diff --git a/src/diffusers/models/autoencoder_tiny.py b/src/diffusers/models/autoencoder_tiny.py new file mode 100644 index 0000000000..f6075ef4ce --- /dev/null +++ b/src/diffusers/models/autoencoder_tiny.py @@ -0,0 +1,193 @@ +# Copyright 2023 Ollin Boer Bohan and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dataclasses import dataclass +from typing import Tuple, Union + +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput, apply_forward_hook +from .modeling_utils import ModelMixin +from .vae import DecoderOutput, DecoderTiny, EncoderTiny + + +@dataclass +class AutoencoderTinyOutput(BaseOutput): + """ + Output of AutoencoderTiny encoding method. + + Args: + latents (`torch.Tensor`): Encoded outputs of the `Encoder`. + + """ + + latents: torch.Tensor + + +class AutoencoderTiny(ModelMixin, ConfigMixin): + r""" + A tiny distilled VAE model for encoding images into latents and decoding latent representations into images. + + [`AutoencoderTiny`] is a wrapper around the original implementation of `TAESD`. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for + all models (such as downloading or saving). + + Parameters: + in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image. + out_channels (`int`, *optional*, defaults to 3): Number of channels in the output. + encoder_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`): + Tuple of integers representing the number of output channels for each encoder block. The length of the + tuple should be equal to the number of encoder blocks. + decoder_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`): + Tuple of integers representing the number of output channels for each decoder block. The length of the + tuple should be equal to the number of decoder blocks. + act_fn (`str`, *optional*, defaults to `"relu"`): + Activation function to be used throughout the model. + latent_channels (`int`, *optional*, defaults to 4): + Number of channels in the latent representation. The latent space acts as a compressed representation of + the input image. + upsampling_scaling_factor (`int`, *optional*, defaults to 2): + Scaling factor for upsampling in the decoder. It determines the size of the output image during the + upsampling process. + num_encoder_blocks (`Tuple[int]`, *optional*, defaults to `(1, 3, 3, 3)`): + Tuple of integers representing the number of encoder blocks at each stage of the encoding process. The + length of the tuple should be equal to the number of stages in the encoder. Each stage has a different + number of encoder blocks. + num_decoder_blocks (`Tuple[int]`, *optional*, defaults to `(3, 3, 3, 1)`): + Tuple of integers representing the number of decoder blocks at each stage of the decoding process. The + length of the tuple should be equal to the number of stages in the decoder. Each stage has a different + number of decoder blocks. + latent_magnitude (`float`, *optional*, defaults to 3.0): + Magnitude of the latent representation. This parameter scales the latent representation values to control + the extent of information preservation. + latent_shift (float, *optional*, defaults to 0.5): + Shift applied to the latent representation. This parameter controls the center of the latent space. + scaling_factor (`float`, *optional*, defaults to 1.0): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. For this Autoencoder, + however, no such scaling factor was used, hence the value of 1.0 as the default. + force_upcast (`bool`, *optional*, default to `False`): + If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE + can be fine-tuned / trained to a lower range without losing too much precision, in which case + `force_upcast` can be set to `False` (see this fp16-friendly + [AutoEncoder](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)). + """ + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels=3, + out_channels=3, + encoder_block_out_channels: Tuple[int] = (64, 64, 64, 64), + decoder_block_out_channels: Tuple[int] = (64, 64, 64, 64), + act_fn: str = "relu", + latent_channels: int = 4, + upsampling_scaling_factor: int = 2, + num_encoder_blocks: Tuple[int] = (1, 3, 3, 3), + num_decoder_blocks: Tuple[int] = (3, 3, 3, 1), + latent_magnitude: int = 3, + latent_shift: float = 0.5, + force_upcast: float = False, + scaling_factor: float = 1.0, + ): + super().__init__() + + if len(encoder_block_out_channels) != len(num_encoder_blocks): + raise ValueError("`encoder_block_out_channels` should have the same length as `num_encoder_blocks`.") + if len(decoder_block_out_channels) != len(num_decoder_blocks): + raise ValueError("`decoder_block_out_channels` should have the same length as `num_decoder_blocks`.") + + self.encoder = EncoderTiny( + in_channels=in_channels, + out_channels=latent_channels, + num_blocks=num_encoder_blocks, + block_out_channels=encoder_block_out_channels, + act_fn=act_fn, + ) + + self.decoder = DecoderTiny( + in_channels=latent_channels, + out_channels=out_channels, + num_blocks=num_decoder_blocks, + block_out_channels=decoder_block_out_channels, + upsampling_scaling_factor=upsampling_scaling_factor, + act_fn=act_fn, + ) + + self.latent_magnitude = latent_magnitude + self.latent_shift = latent_shift + self.scaling_factor = scaling_factor + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (EncoderTiny, DecoderTiny)): + module.gradient_checkpointing = value + + def scale_latents(self, x): + """raw latents -> [0, 1]""" + return x.div(2 * self.latent_magnitude).add(self.latent_shift).clamp(0, 1) + + def unscale_latents(self, x): + """[0, 1] -> raw latents""" + return x.sub(self.latent_shift).mul(2 * self.latent_magnitude) + + @apply_forward_hook + def encode( + self, x: torch.FloatTensor, return_dict: bool = True + ) -> Union[AutoencoderTinyOutput, Tuple[torch.FloatTensor]]: + output = self.encoder(x) + + if not return_dict: + return (output,) + + return AutoencoderTinyOutput(latents=output) + + @apply_forward_hook + def decode(self, x: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]: + output = self.decoder(x) + # Refer to the following discussion to know why this is needed. + # https://github.com/huggingface/diffusers/pull/4384#discussion_r1279401854 + output = output.mul_(2).sub_(1) + + if not return_dict: + return (output,) + + return DecoderOutput(sample=output) + + def forward( + self, + sample: torch.FloatTensor, + return_dict: bool = True, + ) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]: + r""" + Args: + sample (`torch.FloatTensor`): Input sample. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + enc = self.encode(sample).latents + scaled_enc = self.scale_latents(enc).mul_(255).round_().byte() + unscaled_enc = self.unscale_latents(scaled_enc) + dec = self.decode(unscaled_enc) + + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 469e501b81..8d7e864dfc 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -19,6 +19,7 @@ import torch.nn.functional as F from torch import nn from ..utils import is_torch_version, logging +from .activations import get_activation from .attention import AdaGroupNorm from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 from .dual_transformer_2d import DualTransformer2DModel @@ -423,6 +424,28 @@ def get_up_block( raise ValueError(f"{up_block_type} does not exist.") +class AutoencoderTinyBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int, act_fn: str): + super().__init__() + act_fn = get_activation(act_fn) + self.conv = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), + act_fn, + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), + act_fn, + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), + ) + self.skip = ( + nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) + if in_channels != out_channels + else nn.Identity() + ) + self.fuse = nn.ReLU() + + def forward(self, x): + return self.fuse(self.conv(x) + self.skip(x)) + + class UNetMidBlock2D(nn.Module): def __init__( self, diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index 9cd8bcda85..bd9562909d 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -19,8 +19,9 @@ import torch import torch.nn as nn from ..utils import BaseOutput, is_torch_version, randn_tensor +from .activations import get_activation from .attention_processor import SpatialNorm -from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block +from .unet_2d_blocks import AutoencoderTinyBlock, UNetMidBlock2D, get_down_block, get_up_block @dataclass @@ -686,3 +687,107 @@ class DiagonalGaussianDistribution(object): def mode(self): return self.mean + + +class EncoderTiny(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + num_blocks: int, + block_out_channels: int, + act_fn: str, + ): + super().__init__() + + layers = [] + for i, num_block in enumerate(num_blocks): + num_channels = block_out_channels[i] + + if i == 0: + layers.append(nn.Conv2d(in_channels, num_channels, kernel_size=3, padding=1)) + else: + layers.append(nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, stride=2, bias=False)) + + for _ in range(num_block): + layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn)) + + layers.append(nn.Conv2d(block_out_channels[-1], out_channels, kernel_size=3, padding=1)) + + self.layers = nn.Sequential(*layers) + self.gradient_checkpointing = False + + def forward(self, x): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False) + else: + x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x) + + else: + x = self.layers(x) + + return x + + +class DecoderTiny(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + num_blocks: int, + block_out_channels: int, + upsampling_scaling_factor: int, + act_fn: str, + ): + super().__init__() + + layers = [ + nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1), + get_activation(act_fn), + ] + + for i, num_block in enumerate(num_blocks): + is_final_block = i == (len(num_blocks) - 1) + num_channels = block_out_channels[i] + + for _ in range(num_block): + layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn)) + + if not is_final_block: + layers.append(nn.Upsample(scale_factor=upsampling_scaling_factor)) + + conv_out_channel = num_channels if not is_final_block else out_channels + layers.append(nn.Conv2d(num_channels, conv_out_channel, kernel_size=3, padding=1, bias=is_final_block)) + + self.layers = nn.Sequential(*layers) + self.gradient_checkpointing = False + + def forward(self, x): + # Clamp. + x = torch.tanh(x / 3) * 3 + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False) + else: + x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x) + + else: + x = self.layers(x) + + return x diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 6c2bb9eed5..f4020b9b5c 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -32,6 +32,21 @@ class AutoencoderKL(metaclass=DummyObject): requires_backends(cls, ["torch"]) +class AutoencoderTiny(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class ControlNetModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/test_models_vae.py b/tests/models/test_models_vae.py index 0abb2f056e..65c7d49f75 100644 --- a/tests/models/test_models_vae.py +++ b/tests/models/test_models_vae.py @@ -19,7 +19,7 @@ import unittest import torch from parameterized import parameterized -from diffusers import AsymmetricAutoencoderKL, AutoencoderKL +from diffusers import AsymmetricAutoencoderKL, AutoencoderKL, AutoencoderTiny from diffusers.utils import floats_tensor, load_hf_numpy, require_torch_gpu, slow, torch_all_close, torch_device from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import enable_full_determinism @@ -223,6 +223,83 @@ class AsymmetricAutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.T pass +class AutoencoderTinyTests(ModelTesterMixin, unittest.TestCase): + model_class = AutoencoderTiny + main_input_name = "sample" + base_precision = 1e-2 + + @property + def dummy_input(self): + batch_size = 4 + num_channels = 3 + sizes = (32, 32) + + image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) + + return {"sample": image} + + @property + def input_shape(self): + return (3, 32, 32) + + @property + def output_shape(self): + return (3, 32, 32) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "in_channels": 3, + "out_channels": 3, + "encoder_block_out_channels": (32, 32), + "decoder_block_out_channels": (32, 32), + "num_encoder_blocks": (1, 2), + "num_decoder_blocks": (2, 1), + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_outputs_equivalence(self): + pass + + +@slow +class AutoencoderTinyIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def get_file_format(self, seed, shape): + return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy" + + def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False): + dtype = torch.float16 if fp16 else torch.float32 + image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype) + return image + + def get_sd_vae_model(self, model_id="hf-internal-testing/taesd-diffusers", fp16=False): + torch_dtype = torch.float16 if fp16 else torch.float32 + + model = AutoencoderTiny.from_pretrained(model_id, torch_dtype=torch_dtype) + model.to(torch_device).eval() + return model + + def test_stable_diffusion(self): + model = self.get_sd_vae_model() + image = self.get_sd_image(seed=33) + + with torch.no_grad(): + sample = model(image).sample + + assert sample.shape == image.shape + + output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu() + expected_output_slice = torch.tensor([0.9858, 0.9262, 0.8629, 1.0974, -0.091, -0.2485, 0.0936, 0.0604]) + + assert torch_all_close(output_slice, expected_output_slice, atol=3e-3) + + @slow class AutoencoderKLIntegrationTests(unittest.TestCase): def get_file_format(self, seed, shape):