mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
all glide passes
This commit is contained in:
@@ -378,10 +378,7 @@ class ResBlock(TimestepBlock):
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
@@ -426,7 +423,7 @@ class ResnetBlock(nn.Module):
|
||||
|
||||
if time_embedding_norm == "default":
|
||||
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
||||
if time_embedding_norm == "scale_shift":
|
||||
elif time_embedding_norm == "scale_shift":
|
||||
self.temb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
|
||||
|
||||
self.norm2 = Normalize(out_channels, num_groups=groups, eps=eps)
|
||||
@@ -489,7 +486,7 @@ class ResnetBlock(nn.Module):
|
||||
nn.SiLU(),
|
||||
linear(
|
||||
emb_channels,
|
||||
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
||||
2 * self.out_channels if self.time_embedding_norm == "scale_shift" else self.out_channels,
|
||||
),
|
||||
)
|
||||
self.out_layers = nn.Sequential(
|
||||
@@ -551,9 +548,6 @@ class ResnetBlock(nn.Module):
|
||||
self.set_weights_ldm()
|
||||
self.is_overwritten = True
|
||||
|
||||
if self.up or self.down:
|
||||
x = self.x_upd(x)
|
||||
|
||||
h = x
|
||||
h = h * mask
|
||||
if self.pre_norm:
|
||||
@@ -561,6 +555,7 @@ class ResnetBlock(nn.Module):
|
||||
h = self.nonlinearity(h)
|
||||
|
||||
if self.up or self.down:
|
||||
x = self.x_upd(x)
|
||||
h = self.h_upd(h)
|
||||
|
||||
h = self.conv1(h)
|
||||
@@ -571,7 +566,6 @@ class ResnetBlock(nn.Module):
|
||||
h = h * mask
|
||||
|
||||
temb = self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
|
||||
|
||||
if self.time_embedding_norm == "scale_shift":
|
||||
scale, shift = torch.chunk(temb, 2, dim=1)
|
||||
|
||||
@@ -595,10 +589,10 @@ class ResnetBlock(nn.Module):
|
||||
|
||||
x = x * mask
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
# if self.use_conv_shortcut:
|
||||
# x = self.conv_shortcut(x)
|
||||
# else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from ..modeling_utils import ModelMixin
|
||||
from .attention import AttentionBlock
|
||||
from .embeddings import get_timestep_embedding
|
||||
from .resnet import Downsample, ResBlock, TimestepBlock, Upsample
|
||||
from .resnet import ResnetBlock
|
||||
|
||||
|
||||
def convert_module_to_f16(l):
|
||||
@@ -101,7 +102,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
||||
|
||||
def forward(self, x, emb, encoder_out=None):
|
||||
for layer in self:
|
||||
if isinstance(layer, TimestepBlock):
|
||||
if isinstance(layer, TimestepBlock) or isinstance(layer, ResnetBlock):
|
||||
x = layer(x, emb)
|
||||
elif isinstance(layer, AttentionBlock):
|
||||
x = layer(x, encoder_out)
|
||||
@@ -190,14 +191,24 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
|
||||
for level, mult in enumerate(channel_mult):
|
||||
for _ in range(num_res_blocks):
|
||||
layers = [
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=int(mult * model_channels),
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
# ResBlock(
|
||||
# ch,
|
||||
# time_embed_dim,
|
||||
# dropout,
|
||||
# out_channels=int(mult * model_channels),
|
||||
# dims=dims,
|
||||
# use_checkpoint=use_checkpoint,
|
||||
# use_scale_shift_norm=use_scale_shift_norm,
|
||||
# )
|
||||
ResnetBlock(
|
||||
in_channels=ch,
|
||||
out_channels=mult * model_channels,
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
eps=1e-5,
|
||||
non_linearity="silu",
|
||||
time_embedding_norm="scale_shift",
|
||||
overwrite_for_glide=True,
|
||||
)
|
||||
]
|
||||
ch = int(mult * model_channels)
|
||||
@@ -218,15 +229,26 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
|
||||
out_ch = ch
|
||||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
# ResBlock(
|
||||
# ch,
|
||||
# time_embed_dim,
|
||||
# dropout,
|
||||
# out_channels=out_ch,
|
||||
# dims=dims,
|
||||
# use_checkpoint=use_checkpoint,
|
||||
# use_scale_shift_norm=use_scale_shift_norm,
|
||||
# down=True,
|
||||
# )
|
||||
ResnetBlock(
|
||||
in_channels=ch,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=True,
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
eps=1e-5,
|
||||
non_linearity="silu",
|
||||
time_embedding_norm="scale_shift",
|
||||
overwrite_for_glide=True,
|
||||
down=True
|
||||
)
|
||||
if resblock_updown
|
||||
else Downsample(
|
||||
@@ -240,13 +262,22 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
|
||||
self._feature_size += ch
|
||||
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
# ResBlock(
|
||||
# ch,
|
||||
# time_embed_dim,
|
||||
# dropout,
|
||||
# dims=dims,
|
||||
# use_checkpoint=use_checkpoint,
|
||||
# use_scale_shift_norm=use_scale_shift_norm,
|
||||
# ),
|
||||
ResnetBlock(
|
||||
in_channels=ch,
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
eps=1e-5,
|
||||
non_linearity="silu",
|
||||
time_embedding_norm="scale_shift",
|
||||
overwrite_for_glide=True,
|
||||
),
|
||||
AttentionBlock(
|
||||
ch,
|
||||
@@ -255,14 +286,23 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
|
||||
num_head_channels=num_head_channels,
|
||||
encoder_channels=transformer_dim,
|
||||
),
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
),
|
||||
# ResBlock(
|
||||
# ch,
|
||||
# time_embed_dim,
|
||||
# dropout,
|
||||
# dims=dims,
|
||||
# use_checkpoint=use_checkpoint,
|
||||
# use_scale_shift_norm=use_scale_shift_norm,
|
||||
# ),
|
||||
ResnetBlock(
|
||||
in_channels=ch,
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
eps=1e-5,
|
||||
non_linearity="silu",
|
||||
time_embedding_norm="scale_shift",
|
||||
overwrite_for_glide=True,
|
||||
)
|
||||
)
|
||||
self._feature_size += ch
|
||||
|
||||
@@ -271,15 +311,25 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
|
||||
for i in range(num_res_blocks + 1):
|
||||
ich = input_block_chans.pop()
|
||||
layers = [
|
||||
ResBlock(
|
||||
ch + ich,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=int(model_channels * mult),
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
)
|
||||
# ResBlock(
|
||||
# ch + ich,
|
||||
# time_embed_dim,
|
||||
# dropout,
|
||||
# out_channels=int(model_channels * mult),
|
||||
# dims=dims,
|
||||
# use_checkpoint=use_checkpoint,
|
||||
# use_scale_shift_norm=use_scale_shift_norm,
|
||||
# )
|
||||
ResnetBlock(
|
||||
in_channels=ch + ich,
|
||||
out_channels=model_channels * mult,
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
eps=1e-5,
|
||||
non_linearity="silu",
|
||||
time_embedding_norm="scale_shift",
|
||||
overwrite_for_glide=True,
|
||||
),
|
||||
]
|
||||
ch = int(model_channels * mult)
|
||||
if ds in attention_resolutions:
|
||||
@@ -295,14 +345,25 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
|
||||
if level and i == num_res_blocks:
|
||||
out_ch = ch
|
||||
layers.append(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
# ResBlock(
|
||||
# ch,
|
||||
# time_embed_dim,
|
||||
# dropout,
|
||||
# out_channels=out_ch,
|
||||
# dims=dims,
|
||||
# use_checkpoint=use_checkpoint,
|
||||
# use_scale_shift_norm=use_scale_shift_norm,
|
||||
# up=True,
|
||||
# )
|
||||
ResnetBlock(
|
||||
in_channels=ch,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
eps=1e-5,
|
||||
non_linearity="silu",
|
||||
time_embedding_norm="scale_shift",
|
||||
overwrite_for_glide=True,
|
||||
up=True,
|
||||
)
|
||||
if resblock_updown
|
||||
|
||||
Reference in New Issue
Block a user