mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Feat] add tiny Autoencoder for (almost) instant decoding (#4384)
* add: model implementation of tiny autoencoder. * add: inits. * push the latest devs. * add: conversion script and finish. * add: scaling factor args. * debugging * fix denormalization. * fix: positional argument. * handle use_torch_2_0_or_xformers. * handle post_quant_conv * handle dtype * fix: sdxl image processor for tiny ae. * fix: sdxl image processor for tiny ae. * unify upcasting logic. * copied from madness. * remove trailing whitespace. * set is_tiny_vae = False * address PR comments. * change to AutoencoderTiny * make act_fn an str throughout * fix: apply_forward_hook decorator call * get rid of the special is_tiny_vae flag. * directly scale the output. * fix dummies? * fix: act_fn. * get rid of the Clamp() layer. * bring back copied from. * movement of the blocks to appropriate modules. * add: docstrings to AutoencoderTiny * add: documentation. * changes to the conversion script. * add doc entry. * settle tests. * style * add one slow test. * fix * fix 2 * fix 2 * fix: 4 * fix: 5 * finish integration tests * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * style --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
This commit is contained in:
77
scripts/convert_tiny_autoencoder_to_diffusers.py
Normal file
77
scripts/convert_tiny_autoencoder_to_diffusers.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import argparse
|
||||
|
||||
from diffusers.utils import is_safetensors_available
|
||||
|
||||
|
||||
if is_safetensors_available():
|
||||
import safetensors.torch
|
||||
else:
|
||||
raise ImportError("Please install `safetensors`.")
|
||||
|
||||
from diffusers import AutoencoderTiny
|
||||
|
||||
|
||||
"""
|
||||
Example - From the diffusers root directory:
|
||||
|
||||
Download the weights:
|
||||
```sh
|
||||
$ wget -q https://huggingface.co/madebyollin/taesd/resolve/main/taesd_encoder.safetensors
|
||||
$ wget -q https://huggingface.co/madebyollin/taesd/resolve/main/taesd_decoder.safetensors
|
||||
```
|
||||
|
||||
Convert the model:
|
||||
```sh
|
||||
$ python scripts/convert_tiny_autoencoder_to_diffusers.py \
|
||||
--encoder_ckpt_path taesd_encoder.safetensors \
|
||||
--decoder_ckpt_path taesd_decoder.safetensors \
|
||||
--dump_path taesd-diffusers
|
||||
```
|
||||
"""
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
||||
parser.add_argument(
|
||||
"--encoder_ckpt_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the encoder ckpt.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decoder_ckpt_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the decoder ckpt.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_safetensors", action="store_true", help="Whether to serialize in the safetensors format."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("Loading the original state_dicts of the encoder and the decoder...")
|
||||
encoder_state_dict = safetensors.torch.load_file(args.encoder_ckpt_path)
|
||||
decoder_state_dict = safetensors.torch.load_file(args.decoder_ckpt_path)
|
||||
|
||||
print("Populating the state_dicts in the diffusers format...")
|
||||
tiny_autoencoder = AutoencoderTiny()
|
||||
new_state_dict = {}
|
||||
|
||||
# Modify the encoder state dict.
|
||||
for k in encoder_state_dict:
|
||||
new_state_dict.update({f"encoder.layers.{k}": encoder_state_dict[k]})
|
||||
|
||||
# Modify the decoder state dict.
|
||||
for k in decoder_state_dict:
|
||||
layer_id = int(k.split(".")[0]) - 1
|
||||
new_k = str(layer_id) + "." + ".".join(k.split(".")[1:])
|
||||
new_state_dict.update({f"decoder.layers.{new_k}": decoder_state_dict[k]})
|
||||
|
||||
# Assertion tests with the original implementation can be found here:
|
||||
# https://gist.github.com/sayakpaul/337b0988f08bd2cf2b248206f760e28f
|
||||
tiny_autoencoder.load_state_dict(new_state_dict)
|
||||
print("Population successful, serializing...")
|
||||
tiny_autoencoder.save_pretrained(args.dump_path, safe_serialization=args.use_safetensors)
|
||||
Reference in New Issue
Block a user