1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

all glide passes

This commit is contained in:
Patrick von Platen
2022-06-30 22:09:49 +00:00
parent db934c6750
commit fd6f93b2b1
2 changed files with 117 additions and 62 deletions

View File

@@ -378,10 +378,7 @@ class ResBlock(TimestepBlock):
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
x = self.nin_shortcut(x)
return x + h
@@ -426,7 +423,7 @@ class ResnetBlock(nn.Module):
if time_embedding_norm == "default":
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
if time_embedding_norm == "scale_shift":
elif time_embedding_norm == "scale_shift":
self.temb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
self.norm2 = Normalize(out_channels, num_groups=groups, eps=eps)
@@ -489,7 +486,7 @@ class ResnetBlock(nn.Module):
nn.SiLU(),
linear(
emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
2 * self.out_channels if self.time_embedding_norm == "scale_shift" else self.out_channels,
),
)
self.out_layers = nn.Sequential(
@@ -551,9 +548,6 @@ class ResnetBlock(nn.Module):
self.set_weights_ldm()
self.is_overwritten = True
if self.up or self.down:
x = self.x_upd(x)
h = x
h = h * mask
if self.pre_norm:
@@ -561,6 +555,7 @@ class ResnetBlock(nn.Module):
h = self.nonlinearity(h)
if self.up or self.down:
x = self.x_upd(x)
h = self.h_upd(h)
h = self.conv1(h)
@@ -571,7 +566,6 @@ class ResnetBlock(nn.Module):
h = h * mask
temb = self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
if self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
@@ -595,10 +589,10 @@ class ResnetBlock(nn.Module):
x = x * mask
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
# if self.use_conv_shortcut:
# x = self.conv_shortcut(x)
# else:
x = self.nin_shortcut(x)
return x + h

View File

@@ -7,6 +7,7 @@ from ..modeling_utils import ModelMixin
from .attention import AttentionBlock
from .embeddings import get_timestep_embedding
from .resnet import Downsample, ResBlock, TimestepBlock, Upsample
from .resnet import ResnetBlock
def convert_module_to_f16(l):
@@ -101,7 +102,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
def forward(self, x, emb, encoder_out=None):
for layer in self:
if isinstance(layer, TimestepBlock):
if isinstance(layer, TimestepBlock) or isinstance(layer, ResnetBlock):
x = layer(x, emb)
elif isinstance(layer, AttentionBlock):
x = layer(x, encoder_out)
@@ -190,14 +191,24 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
layers = [
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=int(mult * model_channels),
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
# ResBlock(
# ch,
# time_embed_dim,
# dropout,
# out_channels=int(mult * model_channels),
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# )
ResnetBlock(
in_channels=ch,
out_channels=mult * model_channels,
dropout=dropout,
temb_channels=time_embed_dim,
eps=1e-5,
non_linearity="silu",
time_embedding_norm="scale_shift",
overwrite_for_glide=True,
)
]
ch = int(mult * model_channels)
@@ -218,15 +229,26 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
# ResBlock(
# ch,
# time_embed_dim,
# dropout,
# out_channels=out_ch,
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# down=True,
# )
ResnetBlock(
in_channels=ch,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
dropout=dropout,
temb_channels=time_embed_dim,
eps=1e-5,
non_linearity="silu",
time_embedding_norm="scale_shift",
overwrite_for_glide=True,
down=True
)
if resblock_updown
else Downsample(
@@ -240,13 +262,22 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
self._feature_size += ch
self.middle_block = TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
# ResBlock(
# ch,
# time_embed_dim,
# dropout,
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# ),
ResnetBlock(
in_channels=ch,
dropout=dropout,
temb_channels=time_embed_dim,
eps=1e-5,
non_linearity="silu",
time_embedding_norm="scale_shift",
overwrite_for_glide=True,
),
AttentionBlock(
ch,
@@ -255,14 +286,23 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
num_head_channels=num_head_channels,
encoder_channels=transformer_dim,
),
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
# ResBlock(
# ch,
# time_embed_dim,
# dropout,
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# ),
ResnetBlock(
in_channels=ch,
dropout=dropout,
temb_channels=time_embed_dim,
eps=1e-5,
non_linearity="silu",
time_embedding_norm="scale_shift",
overwrite_for_glide=True,
)
)
self._feature_size += ch
@@ -271,15 +311,25 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
for i in range(num_res_blocks + 1):
ich = input_block_chans.pop()
layers = [
ResBlock(
ch + ich,
time_embed_dim,
dropout,
out_channels=int(model_channels * mult),
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
# ResBlock(
# ch + ich,
# time_embed_dim,
# dropout,
# out_channels=int(model_channels * mult),
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# )
ResnetBlock(
in_channels=ch + ich,
out_channels=model_channels * mult,
dropout=dropout,
temb_channels=time_embed_dim,
eps=1e-5,
non_linearity="silu",
time_embedding_norm="scale_shift",
overwrite_for_glide=True,
),
]
ch = int(model_channels * mult)
if ds in attention_resolutions:
@@ -295,14 +345,25 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
if level and i == num_res_blocks:
out_ch = ch
layers.append(
ResBlock(
ch,
time_embed_dim,
dropout,
# ResBlock(
# ch,
# time_embed_dim,
# dropout,
# out_channels=out_ch,
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# up=True,
# )
ResnetBlock(
in_channels=ch,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
dropout=dropout,
temb_channels=time_embed_dim,
eps=1e-5,
non_linearity="silu",
time_embedding_norm="scale_shift",
overwrite_for_glide=True,
up=True,
)
if resblock_updown