mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
* first add a script for DC-AE; * DC-AE init * replace triton with custom implementation * 1. rename file and remove un-used codes; * no longer rely on omegaconf and dataclass * replace custom activation with diffuers activation * remove dc_ae attention in attention_processor.py * iinherit from ModelMixin * inherit from ConfigMixin * dc-ae reduce to one file * update downsample and upsample * clean code * support DecoderOutput * remove get_same_padding and val2tuple * remove autocast and some assert * update ResBlock * remove contents within super().__init__ * Update src/diffusers/models/autoencoders/dc_ae.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * remove opsequential * update other blocks to support the removal of build_norm * remove build encoder/decoder project in/out * remove inheritance of RMSNorm2d from LayerNorm * remove reset_parameters for RMSNorm2d Co-authored-by: YiYi Xu <yixu310@gmail.com> * remove device and dtype in RMSNorm2d __init__ Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/models/autoencoders/dc_ae.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/models/autoencoders/dc_ae.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/models/autoencoders/dc_ae.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * remove op_list & build_block * remove build_stage_main * change file name to autoencoder_dc * move LiteMLA to attention.py * align with other vae decode output; * add DC-AE into init files; * update * make quality && make style; * quick push before dgx disappears again * update * make style * update * update * fix * refactor * refactor * refactor * update * possibly change to nn.Linear * refactor * make fix-copies * replace vae with ae * replace get_block_from_block_type to get_block * replace downsample_block_type from Conv to conv for consistency * add scaling factors * incorporate changes for all checkpoints * make style * move mla to attention processor file; split qkv conv to linears * refactor * add tests * from original file loader * add docs * add standard autoencoder methods * combine attention processor * fix tests * update * minor fix * minor fix * minor fix & in/out shortcut rename * minor fix * make style * fix paper link * update docs * update single file loading * make style * remove single file loading support; todo for DN6 * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * add abstract --------- Co-authored-by: Junyu Chen <chenjydl2003@gmail.com> Co-authored-by: YiYi Xu <yixu310@gmail.com> Co-authored-by: chenjy2003 <70215701+chenjy2003@users.noreply.github.com> Co-authored-by: Aryan <aryan@huggingface.co> Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
88 lines
2.6 KiB
Python
88 lines
2.6 KiB
Python
# coding=utf-8
|
|
# Copyright 2024 HuggingFace Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import unittest
|
|
|
|
from diffusers import AutoencoderDC
|
|
from diffusers.utils.testing_utils import (
|
|
enable_full_determinism,
|
|
floats_tensor,
|
|
torch_device,
|
|
)
|
|
|
|
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
|
|
|
|
|
enable_full_determinism()
|
|
|
|
|
|
class AutoencoderDCTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
|
model_class = AutoencoderDC
|
|
main_input_name = "sample"
|
|
base_precision = 1e-2
|
|
|
|
def get_autoencoder_dc_config(self):
|
|
return {
|
|
"in_channels": 3,
|
|
"latent_channels": 4,
|
|
"attention_head_dim": 2,
|
|
"encoder_block_types": (
|
|
"ResBlock",
|
|
"EfficientViTBlock",
|
|
),
|
|
"decoder_block_types": (
|
|
"ResBlock",
|
|
"EfficientViTBlock",
|
|
),
|
|
"encoder_block_out_channels": (8, 8),
|
|
"decoder_block_out_channels": (8, 8),
|
|
"encoder_qkv_multiscales": ((), (5,)),
|
|
"decoder_qkv_multiscales": ((), (5,)),
|
|
"encoder_layers_per_block": (1, 1),
|
|
"decoder_layers_per_block": [1, 1],
|
|
"downsample_block_type": "conv",
|
|
"upsample_block_type": "interpolate",
|
|
"decoder_norm_types": "rms_norm",
|
|
"decoder_act_fns": "silu",
|
|
"scaling_factor": 0.41407,
|
|
}
|
|
|
|
@property
|
|
def dummy_input(self):
|
|
batch_size = 4
|
|
num_channels = 3
|
|
sizes = (32, 32)
|
|
|
|
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
|
|
|
return {"sample": image}
|
|
|
|
@property
|
|
def input_shape(self):
|
|
return (3, 32, 32)
|
|
|
|
@property
|
|
def output_shape(self):
|
|
return (3, 32, 32)
|
|
|
|
def prepare_init_args_and_inputs_for_common(self):
|
|
init_dict = self.get_autoencoder_dc_config()
|
|
inputs_dict = self.dummy_input
|
|
return init_dict, inputs_dict
|
|
|
|
@unittest.skip("AutoencoderDC does not support `norm_num_groups` because it does not use GroupNorm.")
|
|
def test_forward_with_norm_groups(self):
|
|
pass
|