mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[FlaxAutoencoderKL] rename weights to align with PT (#584)
* rename weights to align with PT * DiagonalGaussianDistribution => FlaxDiagonalGaussianDistribution * fix name
This commit is contained in:
@@ -34,15 +34,15 @@ class FlaxAutoencoderKLOutput(BaseOutput):
|
||||
Output of AutoencoderKL encoding method.
|
||||
|
||||
Args:
|
||||
latent_dist (`DiagonalGaussianDistribution`):
|
||||
Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
|
||||
`DiagonalGaussianDistribution` allows for sampling latents from the distribution.
|
||||
latent_dist (`FlaxDiagonalGaussianDistribution`):
|
||||
Encoded outputs of `Encoder` represented as the mean and logvar of `FlaxDiagonalGaussianDistribution`.
|
||||
`FlaxDiagonalGaussianDistribution` allows for sampling latents from the distribution.
|
||||
"""
|
||||
|
||||
latent_dist: "DiagonalGaussianDistribution"
|
||||
latent_dist: "FlaxDiagonalGaussianDistribution"
|
||||
|
||||
|
||||
class Upsample2D(nn.Module):
|
||||
class FlaxUpsample2D(nn.Module):
|
||||
in_channels: int
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
@@ -66,7 +66,7 @@ class Upsample2D(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Downsample2D(nn.Module):
|
||||
class FlaxDownsample2D(nn.Module):
|
||||
in_channels: int
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
@@ -86,7 +86,7 @@ class Downsample2D(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ResnetBlock2D(nn.Module):
|
||||
class FlaxResnetBlock2D(nn.Module):
|
||||
in_channels: int
|
||||
out_channels: int = None
|
||||
dropout_prob: float = 0.0
|
||||
@@ -144,7 +144,7 @@ class ResnetBlock2D(nn.Module):
|
||||
return hidden_states + residual
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
class FlaxAttentionBlock(nn.Module):
|
||||
channels: int
|
||||
num_head_channels: int = None
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
@@ -201,7 +201,7 @@ class AttentionBlock(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DownEncoderBlock2D(nn.Module):
|
||||
class FlaxDownEncoderBlock2D(nn.Module):
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
dropout: float = 0.0
|
||||
@@ -214,7 +214,7 @@ class DownEncoderBlock2D(nn.Module):
|
||||
for i in range(self.num_layers):
|
||||
in_channels = self.in_channels if i == 0 else self.out_channels
|
||||
|
||||
res_block = ResnetBlock2D(
|
||||
res_block = FlaxResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=self.out_channels,
|
||||
dropout_prob=self.dropout,
|
||||
@@ -224,19 +224,19 @@ class DownEncoderBlock2D(nn.Module):
|
||||
self.resnets = resnets
|
||||
|
||||
if self.add_downsample:
|
||||
self.downsample = Downsample2D(self.out_channels, dtype=self.dtype)
|
||||
self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states, deterministic=True):
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states, deterministic=deterministic)
|
||||
|
||||
if self.add_downsample:
|
||||
hidden_states = self.downsample(hidden_states)
|
||||
hidden_states = self.downsamplers_0(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class UpEncoderBlock2D(nn.Module):
|
||||
class FlaxUpEncoderBlock2D(nn.Module):
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
dropout: float = 0.0
|
||||
@@ -248,7 +248,7 @@ class UpEncoderBlock2D(nn.Module):
|
||||
resnets = []
|
||||
for i in range(self.num_layers):
|
||||
in_channels = self.in_channels if i == 0 else self.out_channels
|
||||
res_block = ResnetBlock2D(
|
||||
res_block = FlaxResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=self.out_channels,
|
||||
dropout_prob=self.dropout,
|
||||
@@ -259,19 +259,19 @@ class UpEncoderBlock2D(nn.Module):
|
||||
self.resnets = resnets
|
||||
|
||||
if self.add_upsample:
|
||||
self.upsample = Upsample2D(self.out_channels, dtype=self.dtype)
|
||||
self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states, deterministic=True):
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states, deterministic=deterministic)
|
||||
|
||||
if self.add_upsample:
|
||||
hidden_states = self.upsample(hidden_states)
|
||||
hidden_states = self.upsamplers_0(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class UNetMidBlock2D(nn.Module):
|
||||
class FlaxUNetMidBlock2D(nn.Module):
|
||||
in_channels: int
|
||||
dropout: float = 0.0
|
||||
num_layers: int = 1
|
||||
@@ -281,7 +281,7 @@ class UNetMidBlock2D(nn.Module):
|
||||
def setup(self):
|
||||
# there is always at least one resnet
|
||||
resnets = [
|
||||
ResnetBlock2D(
|
||||
FlaxResnetBlock2D(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.in_channels,
|
||||
dropout_prob=self.dropout,
|
||||
@@ -292,12 +292,12 @@ class UNetMidBlock2D(nn.Module):
|
||||
attentions = []
|
||||
|
||||
for _ in range(self.num_layers):
|
||||
attn_block = AttentionBlock(
|
||||
attn_block = FlaxAttentionBlock(
|
||||
channels=self.in_channels, num_head_channels=self.attn_num_head_channels, dtype=self.dtype
|
||||
)
|
||||
attentions.append(attn_block)
|
||||
|
||||
res_block = ResnetBlock2D(
|
||||
res_block = FlaxResnetBlock2D(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.in_channels,
|
||||
dropout_prob=self.dropout,
|
||||
@@ -317,7 +317,7 @@ class UNetMidBlock2D(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
class FlaxEncoder(nn.Module):
|
||||
in_channels: int = 3
|
||||
out_channels: int = 3
|
||||
down_block_types: Tuple[str] = ("DownEncoderBlock2D",)
|
||||
@@ -347,7 +347,7 @@ class Encoder(nn.Module):
|
||||
output_channel = block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
down_block = DownEncoderBlock2D(
|
||||
down_block = FlaxDownEncoderBlock2D(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
num_layers=self.layers_per_block,
|
||||
@@ -358,7 +358,7 @@ class Encoder(nn.Module):
|
||||
self.down_blocks = down_blocks
|
||||
|
||||
# middle
|
||||
self.mid_block = UNetMidBlock2D(
|
||||
self.mid_block = FlaxUNetMidBlock2D(
|
||||
in_channels=block_out_channels[-1], attn_num_head_channels=None, dtype=self.dtype
|
||||
)
|
||||
|
||||
@@ -392,7 +392,7 @@ class Encoder(nn.Module):
|
||||
return sample
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
class FlaxDecoder(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
in_channels: int = 3
|
||||
out_channels: int = 3
|
||||
@@ -415,7 +415,7 @@ class Decoder(nn.Module):
|
||||
)
|
||||
|
||||
# middle
|
||||
self.mid_block = UNetMidBlock2D(
|
||||
self.mid_block = FlaxUNetMidBlock2D(
|
||||
in_channels=block_out_channels[-1], attn_num_head_channels=None, dtype=self.dtype
|
||||
)
|
||||
|
||||
@@ -429,7 +429,7 @@ class Decoder(nn.Module):
|
||||
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
up_block = UpEncoderBlock2D(
|
||||
up_block = FlaxUpEncoderBlock2D(
|
||||
in_channels=prev_output_channel,
|
||||
out_channels=output_channel,
|
||||
num_layers=self.layers_per_block + 1,
|
||||
@@ -469,7 +469,7 @@ class Decoder(nn.Module):
|
||||
return sample
|
||||
|
||||
|
||||
class DiagonalGaussianDistribution(object):
|
||||
class FlaxDiagonalGaussianDistribution(object):
|
||||
def __init__(self, parameters, deterministic=False):
|
||||
# Last axis to account for channels-last
|
||||
self.mean, self.logvar = jnp.split(parameters, 2, axis=-1)
|
||||
@@ -521,7 +521,7 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.encoder = Encoder(
|
||||
self.encoder = FlaxEncoder(
|
||||
in_channels=self.config.in_channels,
|
||||
out_channels=self.config.latent_channels,
|
||||
down_block_types=self.config.down_block_types,
|
||||
@@ -532,7 +532,7 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
double_z=True,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.decoder = Decoder(
|
||||
self.decoder = FlaxDecoder(
|
||||
in_channels=self.config.latent_channels,
|
||||
out_channels=self.config.out_channels,
|
||||
up_block_types=self.config.up_block_types,
|
||||
@@ -572,7 +572,7 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
|
||||
hidden_states = self.encoder(sample, deterministic=deterministic)
|
||||
moments = self.quant_conv(hidden_states)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
posterior = FlaxDiagonalGaussianDistribution(moments)
|
||||
|
||||
if not return_dict:
|
||||
return (posterior,)
|
||||
|
||||
Reference in New Issue
Block a user