mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Correct consist dec (#5722)
* uP * Update src/diffusers/models/consistency_decoder_vae.py * uP * uP
This commit is contained in:
committed by
GitHub
parent
2fd46405cd
commit
bf406ea886
@@ -138,6 +138,7 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
|
||||
def decode(
|
||||
self,
|
||||
z: torch.FloatTensor,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
image: Optional[torch.FloatTensor] = None,
|
||||
mask: Optional[torch.FloatTensor] = None,
|
||||
return_dict: bool = True,
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, Union
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -307,7 +307,9 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
|
||||
return AutoencoderTinyOutput(latents=output)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, x: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
|
||||
def decode(
|
||||
self, x: torch.FloatTensor, generator: Optional[torch.Generator] = None, return_dict: bool = True
|
||||
) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
|
||||
if self.use_slicing and x.shape[0] > 1:
|
||||
output = [self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x) for x_slice in x.split(1)]
|
||||
output = torch.cat(output)
|
||||
|
||||
@@ -68,11 +68,76 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(self, encoder_args, decoder_args, scaling_factor, block_out_channels, latent_channels):
|
||||
def __init__(
|
||||
self,
|
||||
scaling_factor=0.18215,
|
||||
latent_channels=4,
|
||||
encoder_act_fn="silu",
|
||||
encoder_block_out_channels=(128, 256, 512, 512),
|
||||
encoder_double_z=True,
|
||||
encoder_down_block_types=(
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D",
|
||||
),
|
||||
encoder_in_channels=3,
|
||||
encoder_layers_per_block=2,
|
||||
encoder_norm_num_groups=32,
|
||||
encoder_out_channels=4,
|
||||
decoder_add_attention=False,
|
||||
decoder_block_out_channels=(320, 640, 1024, 1024),
|
||||
decoder_down_block_types=(
|
||||
"ResnetDownsampleBlock2D",
|
||||
"ResnetDownsampleBlock2D",
|
||||
"ResnetDownsampleBlock2D",
|
||||
"ResnetDownsampleBlock2D",
|
||||
),
|
||||
decoder_downsample_padding=1,
|
||||
decoder_in_channels=7,
|
||||
decoder_layers_per_block=3,
|
||||
decoder_norm_eps=1e-05,
|
||||
decoder_norm_num_groups=32,
|
||||
decoder_num_train_timesteps=1024,
|
||||
decoder_out_channels=6,
|
||||
decoder_resnet_time_scale_shift="scale_shift",
|
||||
decoder_time_embedding_type="learned",
|
||||
decoder_up_block_types=(
|
||||
"ResnetUpsampleBlock2D",
|
||||
"ResnetUpsampleBlock2D",
|
||||
"ResnetUpsampleBlock2D",
|
||||
"ResnetUpsampleBlock2D",
|
||||
),
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = Encoder(**encoder_args)
|
||||
self.decoder_unet = UNet2DModel(**decoder_args)
|
||||
self.encoder = Encoder(
|
||||
act_fn=encoder_act_fn,
|
||||
block_out_channels=encoder_block_out_channels,
|
||||
double_z=encoder_double_z,
|
||||
down_block_types=encoder_down_block_types,
|
||||
in_channels=encoder_in_channels,
|
||||
layers_per_block=encoder_layers_per_block,
|
||||
norm_num_groups=encoder_norm_num_groups,
|
||||
out_channels=encoder_out_channels,
|
||||
)
|
||||
|
||||
self.decoder_unet = UNet2DModel(
|
||||
add_attention=decoder_add_attention,
|
||||
block_out_channels=decoder_block_out_channels,
|
||||
down_block_types=decoder_down_block_types,
|
||||
downsample_padding=decoder_downsample_padding,
|
||||
in_channels=decoder_in_channels,
|
||||
layers_per_block=decoder_layers_per_block,
|
||||
norm_eps=decoder_norm_eps,
|
||||
norm_num_groups=decoder_norm_num_groups,
|
||||
num_train_timesteps=decoder_num_train_timesteps,
|
||||
out_channels=decoder_out_channels,
|
||||
resnet_time_scale_shift=decoder_resnet_time_scale_shift,
|
||||
time_embedding_type=decoder_time_embedding_type,
|
||||
up_block_types=decoder_up_block_types,
|
||||
)
|
||||
self.decoder_scheduler = ConsistencyDecoderScheduler()
|
||||
self.register_to_config(block_out_channels=encoder_block_out_channels)
|
||||
self.register_buffer(
|
||||
"means",
|
||||
torch.tensor([0.38862467, 0.02253063, 0.07381133, -0.0171294])[None, :, None, None],
|
||||
|
||||
@@ -303,39 +303,30 @@ class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase):
|
||||
@property
|
||||
def init_dict(self):
|
||||
return {
|
||||
"encoder_args": {
|
||||
"block_out_channels": [32, 64],
|
||||
"in_channels": 3,
|
||||
"out_channels": 4,
|
||||
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
},
|
||||
"decoder_args": {
|
||||
"act_fn": "silu",
|
||||
"add_attention": False,
|
||||
"block_out_channels": [32, 64],
|
||||
"down_block_types": [
|
||||
"ResnetDownsampleBlock2D",
|
||||
"ResnetDownsampleBlock2D",
|
||||
],
|
||||
"downsample_padding": 1,
|
||||
"downsample_type": "conv",
|
||||
"dropout": 0.0,
|
||||
"in_channels": 7,
|
||||
"layers_per_block": 1,
|
||||
"norm_eps": 1e-05,
|
||||
"norm_num_groups": 32,
|
||||
"num_train_timesteps": 1024,
|
||||
"out_channels": 6,
|
||||
"resnet_time_scale_shift": "scale_shift",
|
||||
"time_embedding_type": "learned",
|
||||
"up_block_types": [
|
||||
"ResnetUpsampleBlock2D",
|
||||
"ResnetUpsampleBlock2D",
|
||||
],
|
||||
"upsample_type": "conv",
|
||||
},
|
||||
"encoder_block_out_channels": [32, 64],
|
||||
"encoder_in_channels": 3,
|
||||
"encoder_out_channels": 4,
|
||||
"encoder_down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
"decoder_add_attention": False,
|
||||
"decoder_block_out_channels": [32, 64],
|
||||
"decoder_down_block_types": [
|
||||
"ResnetDownsampleBlock2D",
|
||||
"ResnetDownsampleBlock2D",
|
||||
],
|
||||
"decoder_downsample_padding": 1,
|
||||
"decoder_in_channels": 7,
|
||||
"decoder_layers_per_block": 1,
|
||||
"decoder_norm_eps": 1e-05,
|
||||
"decoder_norm_num_groups": 32,
|
||||
"decoder_num_train_timesteps": 1024,
|
||||
"decoder_out_channels": 6,
|
||||
"decoder_resnet_time_scale_shift": "scale_shift",
|
||||
"decoder_time_embedding_type": "learned",
|
||||
"decoder_up_block_types": [
|
||||
"ResnetUpsampleBlock2D",
|
||||
"ResnetUpsampleBlock2D",
|
||||
],
|
||||
"scaling_factor": 1,
|
||||
"block_out_channels": [32, 64],
|
||||
"latent_channels": 4,
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user