mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Feat] add tiny Autoencoder for (almost) instant decoding (#4384)
* add: model implementation of tiny autoencoder. * add: inits. * push the latest devs. * add: conversion script and finish. * add: scaling factor args. * debugging * fix denormalization. * fix: positional argument. * handle use_torch_2_0_or_xformers. * handle post_quant_conv * handle dtype * fix: sdxl image processor for tiny ae. * fix: sdxl image processor for tiny ae. * unify upcasting logic. * copied from madness. * remove trailing whitespace. * set is_tiny_vae = False * address PR comments. * change to AutoencoderTiny * make act_fn an str throughout * fix: apply_forward_hook decorator call * get rid of the special is_tiny_vae flag. * directly scale the output. * fix dummies? * fix: act_fn. * get rid of the Clamp() layer. * bring back copied from. * movement of the blocks to appropriate modules. * add: docstrings to AutoencoderTiny * add: documentation. * changes to the conversion script. * add doc entry. * settle tests. * style * add one slow test. * fix * fix 2 * fix 2 * fix: 4 * fix: 5 * finish integration tests * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * style --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
This commit is contained in:
@@ -164,6 +164,8 @@
|
||||
title: AutoencoderKL
|
||||
- local: api/models/asymmetricautoencoderkl
|
||||
title: AsymmetricAutoencoderKL
|
||||
- local: api/models/autoencoder_tiny
|
||||
title: Tiny AutoEncoder
|
||||
- local: api/models/transformer2d
|
||||
title: Transformer2D
|
||||
- local: api/models/transformer_temporal
|
||||
|
||||
45
docs/source/en/api/models/autoencoder_tiny.md
Normal file
45
docs/source/en/api/models/autoencoder_tiny.md
Normal file
@@ -0,0 +1,45 @@
|
||||
# Tiny AutoEncoder
|
||||
|
||||
Tiny AutoEncoder for Stable Diffusion (TAESD) was introduced in [madebyollin/taesd](https://github.com/madebyollin/taesd) by Ollin Boer Bohan. It is a tiny distilled version of Stable Diffusion's VAE that can quickly decode the latents in a [`StableDiffusionPipeline`] or [`StableDiffusionXLPipeline`] almost instantly.
|
||||
|
||||
To use with Stable Diffusion v-2.1:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline, AutoencoderTiny
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1-base", torch_dtype=torch.float16
|
||||
)
|
||||
pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesd", torch_dtype=torch.float16)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
prompt = "slice of delicious New York-style berry cheesecake"
|
||||
image = pipe(prompt, num_inference_steps=25).images[0]
|
||||
image.save("cheesecake.png")
|
||||
```
|
||||
|
||||
To use with Stable Diffusion XL 1.0
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline, AutoencoderTiny
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
||||
)
|
||||
pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.float16)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
prompt = "slice of delicious New York-style berry cheesecake"
|
||||
image = pipe(prompt, num_inference_steps=25).images[0]
|
||||
image.save("cheesecake_sdxl.png")
|
||||
```
|
||||
|
||||
## AutoencoderTiny
|
||||
|
||||
[[autodoc]] AutoencoderTiny
|
||||
|
||||
## AutoencoderTinyOutput
|
||||
|
||||
[[autodoc]] models.autoencoder_tiny.AutoencoderTinyOutput
|
||||
77
scripts/convert_tiny_autoencoder_to_diffusers.py
Normal file
77
scripts/convert_tiny_autoencoder_to_diffusers.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import argparse
|
||||
|
||||
from diffusers.utils import is_safetensors_available
|
||||
|
||||
|
||||
if is_safetensors_available():
|
||||
import safetensors.torch
|
||||
else:
|
||||
raise ImportError("Please install `safetensors`.")
|
||||
|
||||
from diffusers import AutoencoderTiny
|
||||
|
||||
|
||||
"""
|
||||
Example - From the diffusers root directory:
|
||||
|
||||
Download the weights:
|
||||
```sh
|
||||
$ wget -q https://huggingface.co/madebyollin/taesd/resolve/main/taesd_encoder.safetensors
|
||||
$ wget -q https://huggingface.co/madebyollin/taesd/resolve/main/taesd_decoder.safetensors
|
||||
```
|
||||
|
||||
Convert the model:
|
||||
```sh
|
||||
$ python scripts/convert_tiny_autoencoder_to_diffusers.py \
|
||||
--encoder_ckpt_path taesd_encoder.safetensors \
|
||||
--decoder_ckpt_path taesd_decoder.safetensors \
|
||||
--dump_path taesd-diffusers
|
||||
```
|
||||
"""
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
||||
parser.add_argument(
|
||||
"--encoder_ckpt_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the encoder ckpt.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decoder_ckpt_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the decoder ckpt.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_safetensors", action="store_true", help="Whether to serialize in the safetensors format."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("Loading the original state_dicts of the encoder and the decoder...")
|
||||
encoder_state_dict = safetensors.torch.load_file(args.encoder_ckpt_path)
|
||||
decoder_state_dict = safetensors.torch.load_file(args.decoder_ckpt_path)
|
||||
|
||||
print("Populating the state_dicts in the diffusers format...")
|
||||
tiny_autoencoder = AutoencoderTiny()
|
||||
new_state_dict = {}
|
||||
|
||||
# Modify the encoder state dict.
|
||||
for k in encoder_state_dict:
|
||||
new_state_dict.update({f"encoder.layers.{k}": encoder_state_dict[k]})
|
||||
|
||||
# Modify the decoder state dict.
|
||||
for k in decoder_state_dict:
|
||||
layer_id = int(k.split(".")[0]) - 1
|
||||
new_k = str(layer_id) + "." + ".".join(k.split(".")[1:])
|
||||
new_state_dict.update({f"decoder.layers.{new_k}": decoder_state_dict[k]})
|
||||
|
||||
# Assertion tests with the original implementation can be found here:
|
||||
# https://gist.github.com/sayakpaul/337b0988f08bd2cf2b248206f760e28f
|
||||
tiny_autoencoder.load_state_dict(new_state_dict)
|
||||
print("Population successful, serializing...")
|
||||
tiny_autoencoder.save_pretrained(args.dump_path, safe_serialization=args.use_safetensors)
|
||||
@@ -38,6 +38,7 @@ else:
|
||||
from .models import (
|
||||
AsymmetricAutoencoderKL,
|
||||
AutoencoderKL,
|
||||
AutoencoderTiny,
|
||||
ControlNetModel,
|
||||
ModelMixin,
|
||||
MultiAdapter,
|
||||
|
||||
@@ -19,6 +19,7 @@ if is_torch_available():
|
||||
from .adapter import MultiAdapter, T2IAdapter
|
||||
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
|
||||
from .autoencoder_kl import AutoencoderKL
|
||||
from .autoencoder_tiny import AutoencoderTiny
|
||||
from .controlnet import ControlNetModel
|
||||
from .dual_transformer_2d import DualTransformer2DModel
|
||||
from .modeling_utils import ModelMixin
|
||||
|
||||
@@ -8,5 +8,7 @@ def get_activation(act_fn):
|
||||
return nn.Mish()
|
||||
elif act_fn == "gelu":
|
||||
return nn.GELU()
|
||||
elif act_fn == "relu":
|
||||
return nn.ReLU()
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation function: {act_fn}")
|
||||
|
||||
193
src/diffusers/models/autoencoder_tiny.py
Normal file
193
src/diffusers/models/autoencoder_tiny.py
Normal file
@@ -0,0 +1,193 @@
|
||||
# Copyright 2023 Ollin Boer Bohan and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput, apply_forward_hook
|
||||
from .modeling_utils import ModelMixin
|
||||
from .vae import DecoderOutput, DecoderTiny, EncoderTiny
|
||||
|
||||
|
||||
@dataclass
|
||||
class AutoencoderTinyOutput(BaseOutput):
|
||||
"""
|
||||
Output of AutoencoderTiny encoding method.
|
||||
|
||||
Args:
|
||||
latents (`torch.Tensor`): Encoded outputs of the `Encoder`.
|
||||
|
||||
"""
|
||||
|
||||
latents: torch.Tensor
|
||||
|
||||
|
||||
class AutoencoderTiny(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
A tiny distilled VAE model for encoding images into latents and decoding latent representations into images.
|
||||
|
||||
[`AutoencoderTiny`] is a wrapper around the original implementation of `TAESD`.
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for
|
||||
all models (such as downloading or saving).
|
||||
|
||||
Parameters:
|
||||
in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image.
|
||||
out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
|
||||
encoder_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`):
|
||||
Tuple of integers representing the number of output channels for each encoder block. The length of the
|
||||
tuple should be equal to the number of encoder blocks.
|
||||
decoder_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`):
|
||||
Tuple of integers representing the number of output channels for each decoder block. The length of the
|
||||
tuple should be equal to the number of decoder blocks.
|
||||
act_fn (`str`, *optional*, defaults to `"relu"`):
|
||||
Activation function to be used throughout the model.
|
||||
latent_channels (`int`, *optional*, defaults to 4):
|
||||
Number of channels in the latent representation. The latent space acts as a compressed representation of
|
||||
the input image.
|
||||
upsampling_scaling_factor (`int`, *optional*, defaults to 2):
|
||||
Scaling factor for upsampling in the decoder. It determines the size of the output image during the
|
||||
upsampling process.
|
||||
num_encoder_blocks (`Tuple[int]`, *optional*, defaults to `(1, 3, 3, 3)`):
|
||||
Tuple of integers representing the number of encoder blocks at each stage of the encoding process. The
|
||||
length of the tuple should be equal to the number of stages in the encoder. Each stage has a different
|
||||
number of encoder blocks.
|
||||
num_decoder_blocks (`Tuple[int]`, *optional*, defaults to `(3, 3, 3, 1)`):
|
||||
Tuple of integers representing the number of decoder blocks at each stage of the decoding process. The
|
||||
length of the tuple should be equal to the number of stages in the decoder. Each stage has a different
|
||||
number of decoder blocks.
|
||||
latent_magnitude (`float`, *optional*, defaults to 3.0):
|
||||
Magnitude of the latent representation. This parameter scales the latent representation values to control
|
||||
the extent of information preservation.
|
||||
latent_shift (float, *optional*, defaults to 0.5):
|
||||
Shift applied to the latent representation. This parameter controls the center of the latent space.
|
||||
scaling_factor (`float`, *optional*, defaults to 1.0):
|
||||
The component-wise standard deviation of the trained latent space computed using the first batch of the
|
||||
training set. This is used to scale the latent space to have unit variance when training the diffusion
|
||||
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
||||
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
|
||||
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
|
||||
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. For this Autoencoder,
|
||||
however, no such scaling factor was used, hence the value of 1.0 as the default.
|
||||
force_upcast (`bool`, *optional*, default to `False`):
|
||||
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
|
||||
can be fine-tuned / trained to a lower range without losing too much precision, in which case
|
||||
`force_upcast` can be set to `False` (see this fp16-friendly
|
||||
[AutoEncoder](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
|
||||
"""
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
encoder_block_out_channels: Tuple[int] = (64, 64, 64, 64),
|
||||
decoder_block_out_channels: Tuple[int] = (64, 64, 64, 64),
|
||||
act_fn: str = "relu",
|
||||
latent_channels: int = 4,
|
||||
upsampling_scaling_factor: int = 2,
|
||||
num_encoder_blocks: Tuple[int] = (1, 3, 3, 3),
|
||||
num_decoder_blocks: Tuple[int] = (3, 3, 3, 1),
|
||||
latent_magnitude: int = 3,
|
||||
latent_shift: float = 0.5,
|
||||
force_upcast: float = False,
|
||||
scaling_factor: float = 1.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if len(encoder_block_out_channels) != len(num_encoder_blocks):
|
||||
raise ValueError("`encoder_block_out_channels` should have the same length as `num_encoder_blocks`.")
|
||||
if len(decoder_block_out_channels) != len(num_decoder_blocks):
|
||||
raise ValueError("`decoder_block_out_channels` should have the same length as `num_decoder_blocks`.")
|
||||
|
||||
self.encoder = EncoderTiny(
|
||||
in_channels=in_channels,
|
||||
out_channels=latent_channels,
|
||||
num_blocks=num_encoder_blocks,
|
||||
block_out_channels=encoder_block_out_channels,
|
||||
act_fn=act_fn,
|
||||
)
|
||||
|
||||
self.decoder = DecoderTiny(
|
||||
in_channels=latent_channels,
|
||||
out_channels=out_channels,
|
||||
num_blocks=num_decoder_blocks,
|
||||
block_out_channels=decoder_block_out_channels,
|
||||
upsampling_scaling_factor=upsampling_scaling_factor,
|
||||
act_fn=act_fn,
|
||||
)
|
||||
|
||||
self.latent_magnitude = latent_magnitude
|
||||
self.latent_shift = latent_shift
|
||||
self.scaling_factor = scaling_factor
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, (EncoderTiny, DecoderTiny)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def scale_latents(self, x):
|
||||
"""raw latents -> [0, 1]"""
|
||||
return x.div(2 * self.latent_magnitude).add(self.latent_shift).clamp(0, 1)
|
||||
|
||||
def unscale_latents(self, x):
|
||||
"""[0, 1] -> raw latents"""
|
||||
return x.sub(self.latent_shift).mul(2 * self.latent_magnitude)
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.FloatTensor, return_dict: bool = True
|
||||
) -> Union[AutoencoderTinyOutput, Tuple[torch.FloatTensor]]:
|
||||
output = self.encoder(x)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return AutoencoderTinyOutput(latents=output)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, x: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
|
||||
output = self.decoder(x)
|
||||
# Refer to the following discussion to know why this is needed.
|
||||
# https://github.com/huggingface/diffusers/pull/4384#discussion_r1279401854
|
||||
output = output.mul_(2).sub_(1)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return DecoderOutput(sample=output)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
return_dict: bool = True,
|
||||
) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.FloatTensor`): Input sample.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
||||
"""
|
||||
enc = self.encode(sample).latents
|
||||
scaled_enc = self.scale_latents(enc).mul_(255).round_().byte()
|
||||
unscaled_enc = self.unscale_latents(scaled_enc)
|
||||
dec = self.decode(unscaled_enc)
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
return DecoderOutput(sample=dec)
|
||||
@@ -19,6 +19,7 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..utils import is_torch_version, logging
|
||||
from .activations import get_activation
|
||||
from .attention import AdaGroupNorm
|
||||
from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
|
||||
from .dual_transformer_2d import DualTransformer2DModel
|
||||
@@ -423,6 +424,28 @@ def get_up_block(
|
||||
raise ValueError(f"{up_block_type} does not exist.")
|
||||
|
||||
|
||||
class AutoencoderTinyBlock(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, act_fn: str):
|
||||
super().__init__()
|
||||
act_fn = get_activation(act_fn)
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
||||
act_fn,
|
||||
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
||||
act_fn,
|
||||
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
||||
)
|
||||
self.skip = (
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
|
||||
if in_channels != out_channels
|
||||
else nn.Identity()
|
||||
)
|
||||
self.fuse = nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
return self.fuse(self.conv(x) + self.skip(x))
|
||||
|
||||
|
||||
class UNetMidBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -19,8 +19,9 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..utils import BaseOutput, is_torch_version, randn_tensor
|
||||
from .activations import get_activation
|
||||
from .attention_processor import SpatialNorm
|
||||
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
|
||||
from .unet_2d_blocks import AutoencoderTinyBlock, UNetMidBlock2D, get_down_block, get_up_block
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -686,3 +687,107 @@ class DiagonalGaussianDistribution(object):
|
||||
|
||||
def mode(self):
|
||||
return self.mean
|
||||
|
||||
|
||||
class EncoderTiny(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_blocks: int,
|
||||
block_out_channels: int,
|
||||
act_fn: str,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
layers = []
|
||||
for i, num_block in enumerate(num_blocks):
|
||||
num_channels = block_out_channels[i]
|
||||
|
||||
if i == 0:
|
||||
layers.append(nn.Conv2d(in_channels, num_channels, kernel_size=3, padding=1))
|
||||
else:
|
||||
layers.append(nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, stride=2, bias=False))
|
||||
|
||||
for _ in range(num_block):
|
||||
layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn))
|
||||
|
||||
layers.append(nn.Conv2d(block_out_channels[-1], out_channels, kernel_size=3, padding=1))
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, x):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
if is_torch_version(">=", "1.11.0"):
|
||||
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False)
|
||||
else:
|
||||
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
|
||||
|
||||
else:
|
||||
x = self.layers(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class DecoderTiny(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_blocks: int,
|
||||
block_out_channels: int,
|
||||
upsampling_scaling_factor: int,
|
||||
act_fn: str,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
layers = [
|
||||
nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1),
|
||||
get_activation(act_fn),
|
||||
]
|
||||
|
||||
for i, num_block in enumerate(num_blocks):
|
||||
is_final_block = i == (len(num_blocks) - 1)
|
||||
num_channels = block_out_channels[i]
|
||||
|
||||
for _ in range(num_block):
|
||||
layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn))
|
||||
|
||||
if not is_final_block:
|
||||
layers.append(nn.Upsample(scale_factor=upsampling_scaling_factor))
|
||||
|
||||
conv_out_channel = num_channels if not is_final_block else out_channels
|
||||
layers.append(nn.Conv2d(num_channels, conv_out_channel, kernel_size=3, padding=1, bias=is_final_block))
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, x):
|
||||
# Clamp.
|
||||
x = torch.tanh(x / 3) * 3
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
if is_torch_version(">=", "1.11.0"):
|
||||
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False)
|
||||
else:
|
||||
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
|
||||
|
||||
else:
|
||||
x = self.layers(x)
|
||||
|
||||
return x
|
||||
|
||||
@@ -32,6 +32,21 @@ class AutoencoderKL(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderTiny(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class ControlNetModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ import unittest
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
|
||||
from diffusers import AsymmetricAutoencoderKL, AutoencoderKL
|
||||
from diffusers import AsymmetricAutoencoderKL, AutoencoderKL, AutoencoderTiny
|
||||
from diffusers.utils import floats_tensor, load_hf_numpy, require_torch_gpu, slow, torch_all_close, torch_device
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import enable_full_determinism
|
||||
@@ -223,6 +223,83 @@ class AsymmetricAutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.T
|
||||
pass
|
||||
|
||||
|
||||
class AutoencoderTinyTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderTiny
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
|
||||
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
|
||||
return {"sample": image}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"encoder_block_out_channels": (32, 32),
|
||||
"decoder_block_out_channels": (32, 32),
|
||||
"num_encoder_blocks": (1, 2),
|
||||
"num_decoder_blocks": (2, 1),
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_outputs_equivalence(self):
|
||||
pass
|
||||
|
||||
|
||||
@slow
|
||||
class AutoencoderTinyIntegrationTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_file_format(self, seed, shape):
|
||||
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
|
||||
|
||||
def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
|
||||
dtype = torch.float16 if fp16 else torch.float32
|
||||
image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
|
||||
return image
|
||||
|
||||
def get_sd_vae_model(self, model_id="hf-internal-testing/taesd-diffusers", fp16=False):
|
||||
torch_dtype = torch.float16 if fp16 else torch.float32
|
||||
|
||||
model = AutoencoderTiny.from_pretrained(model_id, torch_dtype=torch_dtype)
|
||||
model.to(torch_device).eval()
|
||||
return model
|
||||
|
||||
def test_stable_diffusion(self):
|
||||
model = self.get_sd_vae_model()
|
||||
image = self.get_sd_image(seed=33)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model(image).sample
|
||||
|
||||
assert sample.shape == image.shape
|
||||
|
||||
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
|
||||
expected_output_slice = torch.tensor([0.9858, 0.9262, 0.8629, 1.0974, -0.091, -0.2485, 0.0936, 0.0604])
|
||||
|
||||
assert torch_all_close(output_slice, expected_output_slice, atol=3e-3)
|
||||
|
||||
|
||||
@slow
|
||||
class AutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
def get_file_format(self, seed, shape):
|
||||
|
||||
Reference in New Issue
Block a user