1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

save intermediate (#87)

* save intermediate

* up

* up
This commit is contained in:
Patrick von Platen
2022-07-14 12:29:06 +02:00
committed by GitHub
parent c3d78cd306
commit e7fe901e5e
10 changed files with 800 additions and 172 deletions

94
conversion.py Executable file
View File

@@ -0,0 +1,94 @@
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import tempfile
import unittest
import numpy as np
import torch
from diffusers import (
AutoencoderKL,
DDIMPipeline,
DDIMScheduler,
DDPMPipeline,
DDPMScheduler,
GlidePipeline,
GlideSuperResUNetModel,
GlideTextToImageUNetModel,
LatentDiffusionPipeline,
LatentDiffusionUncondPipeline,
NCSNpp,
PNDMPipeline,
PNDMScheduler,
ScoreSdeVePipeline,
ScoreSdeVeScheduler,
ScoreSdeVpPipeline,
ScoreSdeVpScheduler,
UNetLDMModel,
UNetModel,
UNetUnconditionalModel,
VQModel,
)
from diffusers.configuration_utils import ConfigMixin
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.testing_utils import floats_tensor, slow, torch_device
from diffusers.training_utils import EMAModel
def test_output_pretrained_ldm_dummy():
model = UNetUnconditionalModel.from_pretrained("fusing/unet-ldm-dummy", ldm=True)
model.eval()
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
time_step = torch.tensor([10] * noise.shape[0])
with torch.no_grad():
output = model(noise, time_step)
print(model)
import ipdb; ipdb.set_trace()
def test_output_pretrained_ldm():
model = UNetUnconditionalModel.from_pretrained("fusing/latent-diffusion-celeba-256", subfolder="unet", ldm=True)
model.eval()
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
time_step = torch.tensor([10] * noise.shape[0])
with torch.no_grad():
output = model(noise, time_step)
print(model)
import ipdb; ipdb.set_trace()
# To see the how the final model should look like
test_output_pretrained_ldm_dummy()
test_output_pretrained_ldm()
# => this is the architecture in which the model should be saved in the new format
# -> verify new repo with the following tests (in `test_modeling_utils.py`)
# - test_ldm_uncond (in PipelineTesterMixin)
# - test_output_pretrained ( in UNetLDMModelTests)

View File

@@ -111,7 +111,7 @@ prompt = "A painting of a squirrel eating a burger"
image = ldm([prompt], generator=generator, eta=0.3, guidance_scale=6.0, num_inference_steps=50)
image_processed = image.cpu().permute(0, 2, 3, 1)
image_processed = image_processed * 255.
image_processed = image_processed * 255.0
image_processed = image_processed.numpy().astype(np.uint8)
image_pil = PIL.Image.fromarray(image_processed[0])
@@ -143,6 +143,7 @@ audio = bddm(mel_spec, generator, torch_device=torch_device)
# save generated audio
from scipy.io.wavfile import write as wavwrite
sampling_rate = 22050
wavwrite("generated_audio.wav", sampling_rate, audio.squeeze().cpu().numpy())
```

View File

@@ -116,6 +116,7 @@ class ConfigMixin:
use_auth_token = kwargs.pop("use_auth_token", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
user_agent = {"file_type": "config"}
@@ -150,6 +151,7 @@ class ConfigMixin:
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
subfolder=subfolder,
)
except RepositoryNotFoundError:

View File

@@ -321,6 +321,7 @@ class ModelMixin(torch.nn.Module):
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
from_auto_class = kwargs.pop("_from_auto", False)
subfolder = kwargs.pop("subfolder", None)
user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
@@ -336,6 +337,7 @@ class ModelMixin(torch.nn.Module):
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
**kwargs,
)
model.register_to_config(name_or_path=pretrained_model_name_or_path)
@@ -363,6 +365,7 @@ class ModelMixin(torch.nn.Module):
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
subfolder=subfolder,
)
except RepositoryNotFoundError:

View File

@@ -51,6 +51,7 @@ class AttentionBlock(nn.Module):
overwrite_qkv=False,
overwrite_linear=False,
rescale_output_factor=1.0,
eps=1e-5,
):
super().__init__()
self.channels = channels
@@ -62,7 +63,7 @@ class AttentionBlock(nn.Module):
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-5, affine=True)
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
self.qkv = nn.Conv1d(channels, channels * 3, 1)
self.n_heads = self.num_heads
self.rescale_output_factor = rescale_output_factor
@@ -165,7 +166,7 @@ class AttentionBlock(nn.Module):
return result
class AttentionBlockNew(nn.Module):
class AttentionBlockNew_2(nn.Module):
"""
An attention block that allows spatial positions to attend to each other.
@@ -180,11 +181,14 @@ class AttentionBlockNew(nn.Module):
num_groups=32,
encoder_channels=None,
rescale_output_factor=1.0,
eps=1e-5,
):
super().__init__()
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-5, affine=True)
self.channels = channels
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
self.qkv = nn.Conv1d(channels, channels * 3, 1)
self.n_heads = channels // num_head_channels
self.num_head_size = num_head_channels
self.rescale_output_factor = rescale_output_factor
if encoder_channels is not None:
@@ -192,6 +196,28 @@ class AttentionBlockNew(nn.Module):
self.proj = zero_module(nn.Conv1d(channels, channels, 1))
# ------------------------- new -----------------------
num_heads = self.n_heads
self.channels = channels
if num_head_channels is None:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
# define q,k,v as linear layers
self.query = nn.Linear(channels, channels)
self.key = nn.Linear(channels, channels)
self.value = nn.Linear(channels, channels)
self.rescale_output_factor = rescale_output_factor
self.proj_attn = zero_module(nn.Linear(channels, channels, 1))
# ------------------------- new -----------------------
def set_weight(self, attn_layer):
self.norm.weight.data = attn_layer.norm.weight.data
self.norm.bias.data = attn_layer.norm.bias.data
@@ -202,6 +228,89 @@ class AttentionBlockNew(nn.Module):
self.proj.weight.data = attn_layer.proj.weight.data
self.proj.bias.data = attn_layer.proj.bias.data
if hasattr(attn_layer, "q"):
module = attn_layer
qkv_weight = torch.cat([module.q.weight.data, module.k.weight.data, module.v.weight.data], dim=0)[
:, :, :, 0
]
qkv_bias = torch.cat([module.q.bias.data, module.k.bias.data, module.v.bias.data], dim=0)
self.qkv.weight.data = qkv_weight
self.qkv.bias.data = qkv_bias
proj_out = zero_module(nn.Conv1d(self.channels, self.channels, 1))
proj_out.weight.data = module.proj_out.weight.data[:, :, :, 0]
proj_out.bias.data = module.proj_out.bias.data
self.proj = proj_out
self.set_weights_2(attn_layer)
def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
new_projection_shape = projection.size()[:-1] + (self.n_heads, self.num_head_size)
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
return new_projection
def set_weights_2(self, attn_layer):
self.group_norm.weight.data = attn_layer.norm.weight.data
self.group_norm.bias.data = attn_layer.norm.bias.data
qkv_weight = attn_layer.qkv.weight.data.reshape(self.n_heads, 3 * self.channels // self.n_heads, self.channels)
qkv_bias = attn_layer.qkv.bias.data.reshape(self.n_heads, 3 * self.channels // self.n_heads)
q_w, k_w, v_w = qkv_weight.split(self.channels // self.n_heads, dim=1)
q_b, k_b, v_b = qkv_bias.split(self.channels // self.n_heads, dim=1)
self.query.weight.data = q_w.reshape(-1, self.channels)
self.key.weight.data = k_w.reshape(-1, self.channels)
self.value.weight.data = v_w.reshape(-1, self.channels)
self.query.bias.data = q_b.reshape(-1)
self.key.bias.data = k_b.reshape(-1)
self.value.bias.data = v_b.reshape(-1)
self.proj_attn.weight.data = attn_layer.proj.weight.data[:, :, 0]
self.proj_attn.bias.data = attn_layer.proj.bias.data
def forward_2(self, hidden_states):
residual = hidden_states
batch, channel, height, width = hidden_states.shape
# norm
hidden_states = self.group_norm(hidden_states)
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
# proj to q, k, v
query_proj = self.query(hidden_states)
key_proj = self.key(hidden_states)
value_proj = self.value(hidden_states)
# transpose
query_states = self.transpose_for_scores(query_proj)
key_states = self.transpose_for_scores(key_proj)
value_states = self.transpose_for_scores(value_proj)
# get scores
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.channels // self.n_heads)
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# compute attention output
context_states = torch.matmul(attention_probs, value_states)
context_states = context_states.permute(0, 2, 1, 3).contiguous()
new_context_states_shape = context_states.size()[:-2] + (self.channels,)
context_states = context_states.view(new_context_states_shape)
# compute next hidden_states
hidden_states = self.proj_attn(context_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
# res connect and rescale
hidden_states = (hidden_states + residual) / self.rescale_output_factor
return hidden_states
def forward(self, x, encoder_out=None):
b, c, *spatial = x.shape
hid_states = self.norm(x).view(b, c, -1)
@@ -230,10 +339,119 @@ class AttentionBlockNew(nn.Module):
h = h.reshape(b, c, *spatial)
result = x + h
result = result / self.rescale_output_factor
return result
result_2 = self.forward_2(x)
print((result - result_2).abs().sum())
return result_2
class AttentionBlockNew(nn.Module):
"""
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
Uses three q, k, v linear layers to compute attention
"""
def __init__(
self,
channels,
num_heads=1,
num_head_channels=None,
num_groups=32,
rescale_output_factor=1.0,
eps=1e-5,
):
super().__init__()
self.channels = channels
if num_head_channels is None:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.num_head_size = num_head_channels
self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
# define q,k,v as linear layers
self.query = nn.Linear(channels, channels)
self.key = nn.Linear(channels, channels)
self.value = nn.Linear(channels, channels)
self.rescale_output_factor = rescale_output_factor
self.proj_attn = zero_module(nn.Linear(channels, channels, 1))
def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
new_projection_shape = projection.size()[:-1] + (self.num_heads, self.num_head_size)
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
return new_projection
def forward(self, hidden_states):
residual = hidden_states
batch, channel, height, width = hidden_states.shape
# norm
hidden_states = self.group_norm(hidden_states)
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
# proj to q, k, v
query_proj = self.query(hidden_states)
key_proj = self.key(hidden_states)
value_proj = self.value(hidden_states)
# transpose
query_states = self.transpose_for_scores(query_proj)
key_states = self.transpose_for_scores(key_proj)
value_states = self.transpose_for_scores(value_proj)
# get scores
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.channels // self.num_heads)
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# compute attention output
context_states = torch.matmul(attention_probs, value_states)
context_states = context_states.permute(0, 2, 1, 3).contiguous()
new_context_states_shape = context_states.size()[:-2] + (self.channels,)
context_states = context_states.view(new_context_states_shape)
# compute next hidden_states
hidden_states = self.proj_attn(context_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
# res connect and rescale
hidden_states = (hidden_states + residual) / self.rescale_output_factor
return hidden_states
def set_weight(self, attn_layer):
self.group_norm.weight.data = attn_layer.norm.weight.data
self.group_norm.bias.data = attn_layer.norm.bias.data
qkv_weight = attn_layer.qkv.weight.data.reshape(
self.num_heads, 3 * self.channels // self.num_heads, self.channels
)
qkv_bias = attn_layer.qkv.bias.data.reshape(self.num_heads, 3 * self.channels // self.num_heads)
q_w, k_w, v_w = qkv_weight.split(self.channels // self.num_heads, dim=1)
q_b, k_b, v_b = qkv_bias.split(self.channels // self.num_heads, dim=1)
self.query.weight.data = q_w.reshape(-1, self.channels)
self.key.weight.data = k_w.reshape(-1, self.channels)
self.value.weight.data = v_w.reshape(-1, self.channels)
self.query.bias.data = q_b.reshape(-1)
self.key.bias.data = k_b.reshape(-1)
self.value.bias.data = v_b.reshape(-1)
self.proj_attn.weight.data = attn_layer.proj.weight.data[:, :, 0]
self.proj_attn.bias.data = attn_layer.proj.bias.data
class SpatialTransformer(nn.Module):

View File

@@ -81,8 +81,10 @@ class Downsample2D(nn.Module):
self.conv = conv
elif name == "Conv2d_0":
self.Conv2d_0 = conv
self.conv = conv
else:
self.op = conv
self.conv = conv
def forward(self, x):
assert x.shape[1] == self.channels
@@ -90,13 +92,16 @@ class Downsample2D(nn.Module):
pad = (0, 1, 0, 1)
x = F.pad(x, pad, mode="constant", value=0)
return self.conv(x)
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if self.name == "conv":
return self.conv(x)
elif self.name == "Conv2d_0":
return self.Conv2d_0(x)
else:
return self.op(x)
# if self.name == "conv":
# return self.conv(x)
# elif self.name == "Conv2d_0":
# return self.Conv2d_0(x)
# else:
# return self.op(x)
class Upsample1D(nn.Module):
@@ -656,9 +661,9 @@ class ResnetBlock(nn.Module):
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if time_embedding_norm == "default" and temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
elif time_embedding_norm == "scale_shift" and temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
self.dropout = torch.nn.Dropout(dropout)
@@ -691,9 +696,9 @@ class ResnetBlock(nn.Module):
self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut
self.nin_shortcut = None
self.conv_shortcut = None
if self.use_nin_shortcut:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x, temb):
h = x
@@ -715,7 +720,7 @@ class ResnetBlock(nn.Module):
h = self.nonlinearity(h)
if temb is not None:
temb = self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
else:
temb = 0
@@ -738,8 +743,8 @@ class ResnetBlock(nn.Module):
h = self.norm2(h)
h = self.nonlinearity(h)
if self.nin_shortcut is not None:
x = self.nin_shortcut(x)
if self.conv_shortcut is not None:
x = self.conv_shortcut(x)
return (x + h) / self.output_scale_factor
@@ -750,8 +755,8 @@ class ResnetBlock(nn.Module):
self.conv1.weight.data = resnet.conv1.weight.data
self.conv1.bias.data = resnet.conv1.bias.data
self.temb_proj.weight.data = resnet.temb_proj.weight.data
self.temb_proj.bias.data = resnet.temb_proj.bias.data
self.time_emb_proj.weight.data = resnet.temb_proj.weight.data
self.time_emb_proj.bias.data = resnet.temb_proj.bias.data
self.norm2.weight.data = resnet.norm2.weight.data
self.norm2.bias.data = resnet.norm2.bias.data
@@ -760,8 +765,8 @@ class ResnetBlock(nn.Module):
self.conv2.bias.data = resnet.conv2.bias.data
if self.use_nin_shortcut:
self.nin_shortcut.weight.data = resnet.nin_shortcut.weight.data
self.nin_shortcut.bias.data = resnet.nin_shortcut.bias.data
self.conv_shortcut.weight.data = resnet.nin_shortcut.weight.data
self.conv_shortcut.bias.data = resnet.nin_shortcut.bias.data
# TODO(Patrick) - just there to convert the weights; can delete afterward

View File

@@ -177,7 +177,9 @@ class UNetModel(ModelMixin, ConfigMixin):
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
print("hs", hs[-1].abs().sum())
h = self.mid_new(hs[-1], temb)
print("h", h.abs().sum())
# upsampling
for i_level in reversed(range(self.num_resolutions)):

View File

@@ -29,9 +29,10 @@ def get_down_block(
resnet_eps,
resnet_act_fn,
attn_num_head_channels,
downsample_padding=None,
):
if down_block_type == "UNetResDownBlock2D":
return UNetResAttnDownBlock2D(
return UNetResDownBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
@@ -39,6 +40,7 @@ def get_down_block(
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
downsample_padding=downsample_padding,
)
elif down_block_type == "UNetResAttnDownBlock2D":
return UNetResAttnDownBlock2D(
@@ -57,7 +59,8 @@ def get_up_block(
up_block_type,
num_layers,
in_channels,
next_channels,
out_channels,
prev_output_channel,
temb_channels,
add_upsample,
resnet_eps,
@@ -68,7 +71,8 @@ def get_up_block(
return UNetResUpBlock2D(
num_layers=num_layers,
in_channels=in_channels,
next_channels=next_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
@@ -78,7 +82,8 @@ def get_up_block(
return UNetResAttnUpBlock2D(
num_layers=num_layers,
in_channels=in_channels,
next_channels=next_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
@@ -100,11 +105,14 @@ class UNetMidBlock2D(nn.Module):
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
attention_type="default",
output_scale_factor=1.0,
**kwargs,
):
super().__init__()
self.attention_type = attention_type
# there is always at least one resnet
resnets = [
ResnetBlock(
@@ -128,6 +136,7 @@ class UNetMidBlock2D(nn.Module):
in_channels,
num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
)
)
resnets.append(
@@ -148,18 +157,15 @@ class UNetMidBlock2D(nn.Module):
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None, encoder_states=None, mask=None):
if mask is not None:
hidden_states = self.resnets[0](hidden_states, temb, mask=mask)
else:
hidden_states = self.resnets[0](hidden_states, temb)
def forward(self, hidden_states, temb=None, encoder_states=None):
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
hidden_states = attn(hidden_states, encoder_states)
if mask is not None:
hidden_states = resnet(hidden_states, temb, mask=mask)
if self.attention_type == "default":
hidden_states = attn(hidden_states)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, encoder_states)
hidden_states = resnet(hidden_states, temb)
return hidden_states
@@ -178,6 +184,7 @@ class UNetResAttnDownBlock2D(nn.Module):
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
attention_type="default",
output_scale_factor=1.0,
add_downsample=True,
):
@@ -185,6 +192,8 @@ class UNetResAttnDownBlock2D(nn.Module):
resnets = []
attentions = []
self.attention_type = attention_type
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
@@ -206,6 +215,7 @@ class UNetResAttnDownBlock2D(nn.Module):
out_channels,
num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
)
)
@@ -251,6 +261,7 @@ class UNetResDownBlock2D(nn.Module):
resnet_pre_norm: bool = True,
output_scale_factor=1.0,
add_downsample=True,
downsample_padding=1,
):
super().__init__()
resnets = []
@@ -276,7 +287,11 @@ class UNetResDownBlock2D(nn.Module):
if add_downsample:
self.downsamplers = nn.ModuleList(
[Downsample2D(in_channels, use_conv=True, out_channels=out_channels, padding=1, name="op")]
[
Downsample2D(
in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
)
]
)
else:
self.downsamplers = None
@@ -301,7 +316,8 @@ class UNetResAttnUpBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
next_channels: int,
prev_output_channel: int,
out_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
@@ -310,7 +326,7 @@ class UNetResAttnUpBlock2D(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attention_layer_type: str = "self",
attention_type="default",
attn_num_head_channels=1,
output_scale_factor=1.0,
add_upsample=True,
@@ -319,12 +335,16 @@ class UNetResAttnUpBlock2D(nn.Module):
resnets = []
attentions = []
self.attention_type = attention_type
for i in range(num_layers):
resnet_channels = in_channels if i < num_layers - 1 else next_channels
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append(
ResnetBlock(
in_channels=in_channels + resnet_channels,
out_channels=in_channels,
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
@@ -337,9 +357,10 @@ class UNetResAttnUpBlock2D(nn.Module):
)
attentions.append(
AttentionBlockNew(
in_channels,
out_channels,
num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
)
)
@@ -347,7 +368,7 @@ class UNetResAttnUpBlock2D(nn.Module):
self.resnets = nn.ModuleList(resnets)
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(in_channels, use_conv=True, out_channels=in_channels)])
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
else:
self.upsamplers = None
@@ -373,7 +394,8 @@ class UNetResUpBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
next_channels: int,
prev_output_channel: int,
out_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
@@ -382,7 +404,6 @@ class UNetResUpBlock2D(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attention_layer_type: str = "self",
output_scale_factor=1.0,
add_upsample=True,
):
@@ -390,11 +411,13 @@ class UNetResUpBlock2D(nn.Module):
resnets = []
for i in range(num_layers):
resnet_channels = in_channels if i < num_layers - 1 else next_channels
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append(
ResnetBlock(
in_channels=in_channels + resnet_channels,
out_channels=in_channels,
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
@@ -409,7 +432,7 @@ class UNetResUpBlock2D(nn.Module):
self.resnets = nn.ModuleList(resnets)
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(in_channels, use_conv=True, out_channels=in_channels)])
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
else:
self.upsamplers = None

View File

@@ -9,6 +9,30 @@ from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
from .unet_new import UNetMidBlock2D, get_down_block, get_up_block
def nonlinearity(x):
# swish
return x * torch.sigmoid(x)
def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
class TimestepEmbedding(nn.Module):
def __init__(self, channel, time_embed_dim):
super().__init__()
self.linear_1 = nn.Linear(channel, time_embed_dim)
self.act = nn.SiLU()
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
def forward(self, sample):
sample = self.linear_1(sample)
sample = self.act(sample)
sample = self.linear_2(sample)
return sample
class UNetUnconditionalModel(ModelMixin, ConfigMixin):
"""
The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param
@@ -35,35 +59,66 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
def __init__(
self,
image_size,
in_channels,
out_channels,
num_res_blocks,
image_size=None,
in_channels=None,
out_channels=None,
num_res_blocks=None,
dropout=0,
block_input_channels=(224, 224, 448, 672),
block_output_channels=(224, 448, 672, 896),
block_channels=(224, 448, 672, 896),
down_blocks=(
"UNetResDownBlock2D",
"UNetResAttnDownBlock2D",
"UNetResAttnDownBlock2D",
"UNetResAttnDownBlock2D",
),
downsample_padding=1,
up_blocks=("UNetResAttnUpBlock2D", "UNetResAttnUpBlock2D", "UNetResAttnUpBlock2D", "UNetResUpBlock2D"),
resnet_act_fn="silu",
resnet_eps=1e-5,
conv_resample=True,
num_head_channels=32,
flip_sin_to_cos=True,
downscale_freq_shift=0,
# To delete once weights are converted
# LDM
attention_resolutions=(8, 4, 2),
ldm=False,
# DDPM
out_ch=None,
resolution=None,
attn_resolutions=None,
resamp_with_conv=None,
ch_mult=None,
ch=None,
ddpm=False,
):
super().__init__()
# DELETE if statements if not necessary anymore
# DDPM
if ddpm:
out_channels = out_ch
image_size = resolution
block_channels = [x * ch for x in ch_mult]
conv_resample = resamp_with_conv
flip_sin_to_cos = False
downscale_freq_shift = 1
resnet_eps = 1e-6
block_channels = (32, 64)
down_blocks = (
"UNetResDownBlock2D",
"UNetResAttnDownBlock2D",
)
up_blocks = ("UNetResUpBlock2D", "UNetResAttnUpBlock2D")
downsample_padding = 0
num_head_channels = 64
# register all __init__ params with self.register
self.register_to_config(
image_size=image_size,
in_channels=in_channels,
block_input_channels=block_input_channels,
block_output_channels=block_output_channels,
block_channels=block_channels,
downsample_padding=downsample_padding,
out_channels=out_channels,
num_res_blocks=num_res_blocks,
down_blocks=down_blocks,
@@ -71,37 +126,34 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
dropout=dropout,
conv_resample=conv_resample,
num_head_channels=num_head_channels,
flip_sin_to_cos=flip_sin_to_cos,
downscale_freq_shift=downscale_freq_shift,
# (TODO(PVP) - To delete once weights are converted
attention_resolutions=attention_resolutions,
ldm=ldm,
ddpm=ddpm,
)
# To delete - replace with config values
self.image_size = image_size
self.in_channels = in_channels
self.out_channels = out_channels
self.num_res_blocks = num_res_blocks
self.dropout = dropout
time_embed_dim = block_channels[0] * 4
time_embed_dim = block_input_channels[0] * 4
# # input
self.conv_in = nn.Conv2d(in_channels, block_channels[0], kernel_size=3, padding=(1, 1))
# ======================== Input ===================
self.conv_in = nn.Conv2d(in_channels, block_input_channels[0], kernel_size=3, padding=(1, 1))
# ======================== Time ====================
self.time_embed = nn.Sequential(
nn.Linear(block_input_channels[0], time_embed_dim),
nn.SiLU(),
nn.Linear(time_embed_dim, time_embed_dim),
)
# ======================== Down ====================
input_channels = list(block_input_channels)
output_channels = list(block_output_channels)
# # time
self.time_embedding = TimestepEmbedding(block_channels[0], time_embed_dim)
self.downsample_blocks = nn.ModuleList([])
for i, (input_channel, output_channel) in enumerate(zip(input_channels, output_channels)):
down_block_type = down_blocks[i]
is_final_block = i == len(input_channels) - 1
self.mid = None
self.upsample_blocks = nn.ModuleList([])
# down
output_channel = block_channels[0]
for i, down_block_type in enumerate(down_blocks):
input_channel = output_channel
output_channel = block_channels[i]
is_final_block = i == len(block_channels) - 1
down_block = get_down_block(
down_block_type,
@@ -113,30 +165,48 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
attn_num_head_channels=num_head_channels,
downsample_padding=downsample_padding,
)
self.downsample_blocks.append(down_block)
# ======================== Mid ====================
self.mid = UNetMidBlock2D(
in_channels=output_channels[-1],
dropout=dropout,
temb_channels=time_embed_dim,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_time_scale_shift="default",
attn_num_head_channels=num_head_channels,
)
# mid
if self.config.ddpm:
self.mid_new_2 = UNetMidBlock2D(
in_channels=block_channels[-1],
dropout=dropout,
temb_channels=time_embed_dim,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_time_scale_shift="default",
attn_num_head_channels=num_head_channels,
)
else:
self.mid = UNetMidBlock2D(
in_channels=block_channels[-1],
dropout=dropout,
temb_channels=time_embed_dim,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_time_scale_shift="default",
attn_num_head_channels=num_head_channels,
)
self.upsample_blocks = nn.ModuleList([])
for i, (input_channel, output_channel) in enumerate(zip(reversed(input_channels), reversed(output_channels))):
up_block_type = up_blocks[i]
is_final_block = i == len(input_channels) - 1
# up
reversed_block_channels = list(reversed(block_channels))
output_channel = reversed_block_channels[0]
for i, up_block_type in enumerate(up_blocks):
prev_output_channel = output_channel
output_channel = reversed_block_channels[i]
input_channel = reversed_block_channels[min(i + 1, len(block_channels) - 1)]
is_final_block = i == len(block_channels) - 1
up_block = get_up_block(
up_block_type,
num_layers=num_res_blocks + 1,
in_channels=output_channel,
next_channels=input_channel,
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
add_upsample=not is_final_block,
resnet_eps=resnet_eps,
@@ -144,50 +214,72 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
attn_num_head_channels=num_head_channels,
)
self.upsample_blocks.append(up_block)
prev_output_channel = output_channel
# out
self.conv_norm_out = nn.GroupNorm(num_channels=block_channels[0], num_groups=32, eps=1e-5)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_channels[0], out_channels, 3, padding=1)
# ======================== Out ====================
self.out = nn.Sequential(
nn.GroupNorm(num_channels=output_channels[0], num_groups=32, eps=1e-5),
nn.SiLU(),
nn.Conv2d(block_input_channels[0], out_channels, 3, padding=1),
)
# =========== TO DELETE AFTER CONVERSION ==========
transformer_depth = 1
context_dim = None
legacy = True
num_heads = -1
model_channels = block_input_channels[0]
channel_mult = tuple([x // model_channels for x in block_output_channels])
self.init_for_ldm(
in_channels,
model_channels,
channel_mult,
num_res_blocks,
dropout,
time_embed_dim,
attention_resolutions,
num_head_channels,
num_heads,
legacy,
False,
transformer_depth,
context_dim,
conv_resample,
out_channels,
)
self.is_overwritten = False
if ldm:
# =========== TO DELETE AFTER CONVERSION ==========
transformer_depth = 1
context_dim = None
legacy = True
num_heads = -1
model_channels = block_channels[0]
channel_mult = tuple([x // model_channels for x in block_channels])
self.init_for_ldm(
in_channels,
model_channels,
channel_mult,
num_res_blocks,
dropout,
time_embed_dim,
attention_resolutions,
num_head_channels,
num_heads,
legacy,
False,
transformer_depth,
context_dim,
conv_resample,
out_channels,
)
if ddpm:
self.init_for_ddpm(
ch_mult,
ch,
num_res_blocks,
resolution,
in_channels,
resamp_with_conv,
attn_resolutions,
out_ch,
dropout=0.1,
)
def forward(self, sample, timesteps=None):
# TODO(PVP) - to delete later
if not self.is_overwritten:
self.set_weights()
# 1. time step embeddings
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
t_emb = get_timestep_embedding(
timesteps, self.config.block_input_channels[0], flip_sin_to_cos=True, downscale_freq_shift=0
timesteps,
self.config.block_channels[0],
flip_sin_to_cos=self.config.flip_sin_to_cos,
downscale_freq_shift=self.config.downscale_freq_shift,
)
emb = self.time_embed(t_emb)
emb = self.time_embedding(t_emb)
# 2. pre-process sample
# sample = sample.type(self.dtype_)
sample = self.conv_in(sample)
# 3. down blocks
@@ -198,8 +290,13 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
# append to tuple
down_block_res_samples += res_samples
print("sample", sample.abs().sum())
# 4. mid block
sample = self.mid(sample, emb)
if self.config.ddpm:
sample = self.mid_new_2(sample, emb)
else:
sample = self.mid(sample, emb)
print("sample", sample.abs().sum())
# 5. up blocks
for upsample_block in self.upsample_blocks:
@@ -211,10 +308,192 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
sample = upsample_block(sample, res_samples, emb)
# 6. post-process sample
sample = self.out(sample)
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
return sample
def set_weights(self):
self.is_overwritten = True
if self.config.ldm:
self.time_embedding.linear_1.weight.data = self.time_embed[0].weight.data
self.time_embedding.linear_1.bias.data = self.time_embed[0].bias.data
self.time_embedding.linear_2.weight.data = self.time_embed[2].weight.data
self.time_embedding.linear_2.bias.data = self.time_embed[2].bias.data
# ================ SET WEIGHTS OF ALL WEIGHTS ==================
for i, input_layer in enumerate(self.input_blocks[1:]):
block_id = i // (self.config.num_res_blocks + 1)
layer_in_block_id = i % (self.config.num_res_blocks + 1)
if layer_in_block_id == 2:
self.downsample_blocks[block_id].downsamplers[0].conv.weight.data = input_layer[0].op.weight.data
self.downsample_blocks[block_id].downsamplers[0].conv.bias.data = input_layer[0].op.bias.data
elif len(input_layer) > 1:
self.downsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
self.downsample_blocks[block_id].attentions[layer_in_block_id].set_weight(input_layer[1])
else:
self.downsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
self.mid.resnets[0].set_weight(self.middle_block[0])
self.mid.resnets[1].set_weight(self.middle_block[2])
self.mid.attentions[0].set_weight(self.middle_block[1])
for i, input_layer in enumerate(self.output_blocks):
block_id = i // (self.config.num_res_blocks + 1)
layer_in_block_id = i % (self.config.num_res_blocks + 1)
if len(input_layer) > 2:
self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
self.upsample_blocks[block_id].attentions[layer_in_block_id].set_weight(input_layer[1])
self.upsample_blocks[block_id].upsamplers[0].conv.weight.data = input_layer[2].conv.weight.data
self.upsample_blocks[block_id].upsamplers[0].conv.bias.data = input_layer[2].conv.bias.data
elif len(input_layer) > 1 and "Upsample2D" in input_layer[1].__class__.__name__:
self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
self.upsample_blocks[block_id].upsamplers[0].conv.weight.data = input_layer[1].conv.weight.data
self.upsample_blocks[block_id].upsamplers[0].conv.bias.data = input_layer[1].conv.bias.data
elif len(input_layer) > 1:
self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
self.upsample_blocks[block_id].attentions[layer_in_block_id].set_weight(input_layer[1])
else:
self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
self.conv_in.weight.data = self.input_blocks[0][0].weight.data
self.conv_in.bias.data = self.input_blocks[0][0].bias.data
self.conv_norm_out.weight.data = self.out[0].weight.data
self.conv_norm_out.bias.data = self.out[0].bias.data
self.conv_out.weight.data = self.out[2].weight.data
self.conv_out.bias.data = self.out[2].bias.data
self.remove_ldm()
elif self.config.ddpm:
# =============== SET WEIGHTS ===============
# =============== TIME ======================
self.time_embed[0] = self.temb.dense[0]
self.time_embed[2] = self.temb.dense[1]
for i, block in enumerate(self.down):
if hasattr(block, "downsample"):
self.downsample_blocks[i].downsamplers[0].conv.weight.data = block.downsample.conv.weight.data
self.downsample_blocks[i].downsamplers[0].conv.bias.data = block.downsample.conv.bias.data
if hasattr(block, "block") and len(block.block) > 0:
for j in range(self.num_res_blocks):
self.downsample_blocks[i].resnets[j].set_weight(block.block[j])
if hasattr(block, "attn") and len(block.attn) > 0:
for j in range(self.num_res_blocks):
self.downsample_blocks[i].attentions[j].set_weight(block.attn[j])
self.mid_new_2.resnets[0].set_weight(self.mid.block_1)
self.mid_new_2.resnets[1].set_weight(self.mid.block_2)
self.mid_new_2.attentions[0].set_weight(self.mid.attn_1)
def init_for_ddpm(
self,
ch_mult,
ch,
num_res_blocks,
resolution,
in_channels,
resamp_with_conv,
attn_resolutions,
out_ch,
dropout=0.1,
):
ch_mult = tuple(ch_mult)
self.ch = ch
self.temb_ch = self.ch * 4
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
# timestep embedding
self.temb = nn.Module()
self.temb.dense = nn.ModuleList(
[
torch.nn.Linear(self.ch, self.temb_ch),
torch.nn.Linear(self.temb_ch, self.temb_ch),
]
)
# downsampling
self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
curr_res = resolution
in_ch_mult = (1,) + ch_mult
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
ResnetBlock2D(
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(AttentionBlock(block_in, overwrite_qkv=True))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample2D(block_in, use_conv=resamp_with_conv, padding=0)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock2D(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
)
self.mid.attn_1 = AttentionBlock(block_in, overwrite_qkv=True)
self.mid.block_2 = ResnetBlock2D(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
)
self.mid_new = UNetMidBlock2D(in_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
self.mid_new.resnets[0] = self.mid.block_1
self.mid_new.attentions[0] = self.mid.attn_1
self.mid_new.resnets[1] = self.mid.block_2
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
skip_in = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
if i_block == self.num_res_blocks:
skip_in = ch * in_ch_mult[i_level]
block.append(
ResnetBlock2D(
in_channels=block_in + skip_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(AttentionBlock(block_in, overwrite_qkv=True))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample2D(block_in, use_conv=resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
def init_for_ldm(
self,
in_channels,
@@ -234,7 +513,6 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
out_channels,
):
# TODO(PVP) - delete after weight conversion
class TimestepEmbedSequential(nn.Sequential):
"""
A sequential module that passes timestep embeddings to the children that support it as an extra input.
@@ -255,6 +533,12 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
self.time_embed = nn.Sequential(
nn.Linear(model_channels, time_embed_dim),
nn.SiLU(),
nn.Linear(time_embed_dim, time_embed_dim),
)
dims = 2
self.input_blocks = nn.ModuleList(
[TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))]
@@ -389,42 +673,15 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
self.output_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
# ================ SET WEIGHTS OF ALL WEIGHTS ==================
for i, input_layer in enumerate(self.input_blocks[1:]):
block_id = i // (num_res_blocks + 1)
layer_in_block_id = i % (num_res_blocks + 1)
self.out = nn.Sequential(
nn.GroupNorm(num_channels=model_channels, num_groups=32, eps=1e-5),
nn.SiLU(),
nn.Conv2d(model_channels, out_channels, 3, padding=1),
)
if layer_in_block_id == 2:
self.downsample_blocks[block_id].downsamplers[0].op.weight.data = input_layer[0].op.weight.data
self.downsample_blocks[block_id].downsamplers[0].op.bias.data = input_layer[0].op.bias.data
elif len(input_layer) > 1:
self.downsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
self.downsample_blocks[block_id].attentions[layer_in_block_id].set_weight(input_layer[1])
else:
self.downsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
self.mid.resnets[0].set_weight(self.middle_block[0])
self.mid.resnets[1].set_weight(self.middle_block[2])
self.mid.attentions[0].set_weight(self.middle_block[1])
for i, input_layer in enumerate(self.output_blocks):
block_id = i // (num_res_blocks + 1)
layer_in_block_id = i % (num_res_blocks + 1)
if len(input_layer) > 2:
self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
self.upsample_blocks[block_id].attentions[layer_in_block_id].set_weight(input_layer[1])
self.upsample_blocks[block_id].upsamplers[0].conv.weight.data = input_layer[2].conv.weight.data
self.upsample_blocks[block_id].upsamplers[0].conv.bias.data = input_layer[2].conv.bias.data
elif len(input_layer) > 1 and "Upsample2D" in input_layer[1].__class__.__name__:
self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
self.upsample_blocks[block_id].upsamplers[0].conv.weight.data = input_layer[1].conv.weight.data
self.upsample_blocks[block_id].upsamplers[0].conv.bias.data = input_layer[1].conv.bias.data
elif len(input_layer) > 1:
self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
self.upsample_blocks[block_id].attentions[layer_in_block_id].set_weight(input_layer[1])
else:
self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
self.conv_in.weight.data = self.input_blocks[0][0].weight.data
self.conv_in.bias.data = self.input_blocks[0][0].bias.data
def remove_ldm(self):
del self.time_embed
del self.input_blocks
del self.middle_block
del self.output_blocks
del self.out

View File

@@ -271,6 +271,27 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
# fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
print("Original success!!!")
model = UNetUnconditionalModel.from_pretrained("fusing/ddpm_dummy", ddpm=True)
model.eval()
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
time_step = torch.tensor([10])
with torch.no_grad():
output = model(noise, time_step)
output_slice = output[0, -1, -3:, -3:].flatten()
# fmt: off
expected_output_slice = torch.tensor([0.2891, -0.1899, 0.2595, -0.6214, 0.0968, -0.2622, 0.4688, 0.1311, 0.0053])
# fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
class GlideSuperResUNetTests(ModelTesterMixin, unittest.TestCase):
model_class = GlideSuperResUNetModel
@@ -486,18 +507,20 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
"out_channels": 4,
"num_res_blocks": 2,
"attention_resolutions": (16,),
"block_input_channels": [32, 32],
"block_output_channels": [32, 64],
"block_channels": (32, 64),
"num_head_channels": 32,
"conv_resample": True,
"down_blocks": ("UNetResDownBlock2D", "UNetResDownBlock2D"),
"up_blocks": ("UNetResUpBlock2D", "UNetResUpBlock2D"),
"ldm": True,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_from_pretrained_hub(self):
model, loading_info = UNetUnconditionalModel.from_pretrained("fusing/unet-ldm-dummy", output_loading_info=True)
model, loading_info = UNetUnconditionalModel.from_pretrained(
"fusing/unet-ldm-dummy", output_loading_info=True, ldm=True
)
self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0)
@@ -507,7 +530,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
assert image is not None, "Make sure output is not None"
def test_output_pretrained(self):
model = UNetUnconditionalModel.from_pretrained("fusing/unet-ldm-dummy")
model = UNetUnconditionalModel.from_pretrained("fusing/unet-ldm-dummy", ldm=True)
model.eval()
torch.manual_seed(0)