1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[FlaxAutoencoderKL] rename weights to align with PT (#584)

* rename weights to align with PT

* DiagonalGaussianDistribution => FlaxDiagonalGaussianDistribution

* fix name
This commit is contained in:
Suraj Patil
2022-09-20 13:04:16 +02:00
committed by GitHub
parent 0902449ef8
commit c01ec2d119

View File

@@ -34,15 +34,15 @@ class FlaxAutoencoderKLOutput(BaseOutput):
Output of AutoencoderKL encoding method.
Args:
latent_dist (`DiagonalGaussianDistribution`):
Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
`DiagonalGaussianDistribution` allows for sampling latents from the distribution.
latent_dist (`FlaxDiagonalGaussianDistribution`):
Encoded outputs of `Encoder` represented as the mean and logvar of `FlaxDiagonalGaussianDistribution`.
`FlaxDiagonalGaussianDistribution` allows for sampling latents from the distribution.
"""
latent_dist: "DiagonalGaussianDistribution"
latent_dist: "FlaxDiagonalGaussianDistribution"
class Upsample2D(nn.Module):
class FlaxUpsample2D(nn.Module):
in_channels: int
dtype: jnp.dtype = jnp.float32
@@ -66,7 +66,7 @@ class Upsample2D(nn.Module):
return hidden_states
class Downsample2D(nn.Module):
class FlaxDownsample2D(nn.Module):
in_channels: int
dtype: jnp.dtype = jnp.float32
@@ -86,7 +86,7 @@ class Downsample2D(nn.Module):
return hidden_states
class ResnetBlock2D(nn.Module):
class FlaxResnetBlock2D(nn.Module):
in_channels: int
out_channels: int = None
dropout_prob: float = 0.0
@@ -144,7 +144,7 @@ class ResnetBlock2D(nn.Module):
return hidden_states + residual
class AttentionBlock(nn.Module):
class FlaxAttentionBlock(nn.Module):
channels: int
num_head_channels: int = None
dtype: jnp.dtype = jnp.float32
@@ -201,7 +201,7 @@ class AttentionBlock(nn.Module):
return hidden_states
class DownEncoderBlock2D(nn.Module):
class FlaxDownEncoderBlock2D(nn.Module):
in_channels: int
out_channels: int
dropout: float = 0.0
@@ -214,7 +214,7 @@ class DownEncoderBlock2D(nn.Module):
for i in range(self.num_layers):
in_channels = self.in_channels if i == 0 else self.out_channels
res_block = ResnetBlock2D(
res_block = FlaxResnetBlock2D(
in_channels=in_channels,
out_channels=self.out_channels,
dropout_prob=self.dropout,
@@ -224,19 +224,19 @@ class DownEncoderBlock2D(nn.Module):
self.resnets = resnets
if self.add_downsample:
self.downsample = Downsample2D(self.out_channels, dtype=self.dtype)
self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
def __call__(self, hidden_states, deterministic=True):
for resnet in self.resnets:
hidden_states = resnet(hidden_states, deterministic=deterministic)
if self.add_downsample:
hidden_states = self.downsample(hidden_states)
hidden_states = self.downsamplers_0(hidden_states)
return hidden_states
class UpEncoderBlock2D(nn.Module):
class FlaxUpEncoderBlock2D(nn.Module):
in_channels: int
out_channels: int
dropout: float = 0.0
@@ -248,7 +248,7 @@ class UpEncoderBlock2D(nn.Module):
resnets = []
for i in range(self.num_layers):
in_channels = self.in_channels if i == 0 else self.out_channels
res_block = ResnetBlock2D(
res_block = FlaxResnetBlock2D(
in_channels=in_channels,
out_channels=self.out_channels,
dropout_prob=self.dropout,
@@ -259,19 +259,19 @@ class UpEncoderBlock2D(nn.Module):
self.resnets = resnets
if self.add_upsample:
self.upsample = Upsample2D(self.out_channels, dtype=self.dtype)
self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
def __call__(self, hidden_states, deterministic=True):
for resnet in self.resnets:
hidden_states = resnet(hidden_states, deterministic=deterministic)
if self.add_upsample:
hidden_states = self.upsample(hidden_states)
hidden_states = self.upsamplers_0(hidden_states)
return hidden_states
class UNetMidBlock2D(nn.Module):
class FlaxUNetMidBlock2D(nn.Module):
in_channels: int
dropout: float = 0.0
num_layers: int = 1
@@ -281,7 +281,7 @@ class UNetMidBlock2D(nn.Module):
def setup(self):
# there is always at least one resnet
resnets = [
ResnetBlock2D(
FlaxResnetBlock2D(
in_channels=self.in_channels,
out_channels=self.in_channels,
dropout_prob=self.dropout,
@@ -292,12 +292,12 @@ class UNetMidBlock2D(nn.Module):
attentions = []
for _ in range(self.num_layers):
attn_block = AttentionBlock(
attn_block = FlaxAttentionBlock(
channels=self.in_channels, num_head_channels=self.attn_num_head_channels, dtype=self.dtype
)
attentions.append(attn_block)
res_block = ResnetBlock2D(
res_block = FlaxResnetBlock2D(
in_channels=self.in_channels,
out_channels=self.in_channels,
dropout_prob=self.dropout,
@@ -317,7 +317,7 @@ class UNetMidBlock2D(nn.Module):
return hidden_states
class Encoder(nn.Module):
class FlaxEncoder(nn.Module):
in_channels: int = 3
out_channels: int = 3
down_block_types: Tuple[str] = ("DownEncoderBlock2D",)
@@ -347,7 +347,7 @@ class Encoder(nn.Module):
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = DownEncoderBlock2D(
down_block = FlaxDownEncoderBlock2D(
in_channels=input_channel,
out_channels=output_channel,
num_layers=self.layers_per_block,
@@ -358,7 +358,7 @@ class Encoder(nn.Module):
self.down_blocks = down_blocks
# middle
self.mid_block = UNetMidBlock2D(
self.mid_block = FlaxUNetMidBlock2D(
in_channels=block_out_channels[-1], attn_num_head_channels=None, dtype=self.dtype
)
@@ -392,7 +392,7 @@ class Encoder(nn.Module):
return sample
class Decoder(nn.Module):
class FlaxDecoder(nn.Module):
dtype: jnp.dtype = jnp.float32
in_channels: int = 3
out_channels: int = 3
@@ -415,7 +415,7 @@ class Decoder(nn.Module):
)
# middle
self.mid_block = UNetMidBlock2D(
self.mid_block = FlaxUNetMidBlock2D(
in_channels=block_out_channels[-1], attn_num_head_channels=None, dtype=self.dtype
)
@@ -429,7 +429,7 @@ class Decoder(nn.Module):
is_final_block = i == len(block_out_channels) - 1
up_block = UpEncoderBlock2D(
up_block = FlaxUpEncoderBlock2D(
in_channels=prev_output_channel,
out_channels=output_channel,
num_layers=self.layers_per_block + 1,
@@ -469,7 +469,7 @@ class Decoder(nn.Module):
return sample
class DiagonalGaussianDistribution(object):
class FlaxDiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False):
# Last axis to account for channels-last
self.mean, self.logvar = jnp.split(parameters, 2, axis=-1)
@@ -521,7 +521,7 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
dtype: jnp.dtype = jnp.float32
def setup(self):
self.encoder = Encoder(
self.encoder = FlaxEncoder(
in_channels=self.config.in_channels,
out_channels=self.config.latent_channels,
down_block_types=self.config.down_block_types,
@@ -532,7 +532,7 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
double_z=True,
dtype=self.dtype,
)
self.decoder = Decoder(
self.decoder = FlaxDecoder(
in_channels=self.config.latent_channels,
out_channels=self.config.out_channels,
up_block_types=self.config.up_block_types,
@@ -572,7 +572,7 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
hidden_states = self.encoder(sample, deterministic=deterministic)
moments = self.quant_conv(hidden_states)
posterior = DiagonalGaussianDistribution(moments)
posterior = FlaxDiagonalGaussianDistribution(moments)
if not return_dict:
return (posterior,)