mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
committed by
GitHub
parent
c3d78cd306
commit
e7fe901e5e
94
conversion.py
Executable file
94
conversion.py
Executable file
@@ -0,0 +1,94 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import inspect
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDIMPipeline,
|
||||
DDIMScheduler,
|
||||
DDPMPipeline,
|
||||
DDPMScheduler,
|
||||
GlidePipeline,
|
||||
GlideSuperResUNetModel,
|
||||
GlideTextToImageUNetModel,
|
||||
LatentDiffusionPipeline,
|
||||
LatentDiffusionUncondPipeline,
|
||||
NCSNpp,
|
||||
PNDMPipeline,
|
||||
PNDMScheduler,
|
||||
ScoreSdeVePipeline,
|
||||
ScoreSdeVeScheduler,
|
||||
ScoreSdeVpPipeline,
|
||||
ScoreSdeVpScheduler,
|
||||
UNetLDMModel,
|
||||
UNetModel,
|
||||
UNetUnconditionalModel,
|
||||
VQModel,
|
||||
)
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
from diffusers.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.testing_utils import floats_tensor, slow, torch_device
|
||||
from diffusers.training_utils import EMAModel
|
||||
|
||||
|
||||
def test_output_pretrained_ldm_dummy():
|
||||
model = UNetUnconditionalModel.from_pretrained("fusing/unet-ldm-dummy", ldm=True)
|
||||
model.eval()
|
||||
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
|
||||
time_step = torch.tensor([10] * noise.shape[0])
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(noise, time_step)
|
||||
|
||||
print(model)
|
||||
import ipdb; ipdb.set_trace()
|
||||
|
||||
|
||||
def test_output_pretrained_ldm():
|
||||
model = UNetUnconditionalModel.from_pretrained("fusing/latent-diffusion-celeba-256", subfolder="unet", ldm=True)
|
||||
model.eval()
|
||||
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
|
||||
time_step = torch.tensor([10] * noise.shape[0])
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(noise, time_step)
|
||||
|
||||
print(model)
|
||||
import ipdb; ipdb.set_trace()
|
||||
|
||||
# To see the how the final model should look like
|
||||
test_output_pretrained_ldm_dummy()
|
||||
test_output_pretrained_ldm()
|
||||
# => this is the architecture in which the model should be saved in the new format
|
||||
# -> verify new repo with the following tests (in `test_modeling_utils.py`)
|
||||
# - test_ldm_uncond (in PipelineTesterMixin)
|
||||
# - test_output_pretrained ( in UNetLDMModelTests)
|
||||
@@ -111,7 +111,7 @@ prompt = "A painting of a squirrel eating a burger"
|
||||
image = ldm([prompt], generator=generator, eta=0.3, guidance_scale=6.0, num_inference_steps=50)
|
||||
|
||||
image_processed = image.cpu().permute(0, 2, 3, 1)
|
||||
image_processed = image_processed * 255.
|
||||
image_processed = image_processed * 255.0
|
||||
image_processed = image_processed.numpy().astype(np.uint8)
|
||||
image_pil = PIL.Image.fromarray(image_processed[0])
|
||||
|
||||
@@ -143,6 +143,7 @@ audio = bddm(mel_spec, generator, torch_device=torch_device)
|
||||
|
||||
# save generated audio
|
||||
from scipy.io.wavfile import write as wavwrite
|
||||
|
||||
sampling_rate = 22050
|
||||
wavwrite("generated_audio.wav", sampling_rate, audio.squeeze().cpu().numpy())
|
||||
```
|
||||
|
||||
@@ -116,6 +116,7 @@ class ConfigMixin:
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
|
||||
user_agent = {"file_type": "config"}
|
||||
|
||||
@@ -150,6 +151,7 @@ class ConfigMixin:
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
subfolder=subfolder,
|
||||
)
|
||||
|
||||
except RepositoryNotFoundError:
|
||||
|
||||
@@ -321,6 +321,7 @@ class ModelMixin(torch.nn.Module):
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
from_auto_class = kwargs.pop("_from_auto", False)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
|
||||
user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
|
||||
|
||||
@@ -336,6 +337,7 @@ class ModelMixin(torch.nn.Module):
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
**kwargs,
|
||||
)
|
||||
model.register_to_config(name_or_path=pretrained_model_name_or_path)
|
||||
@@ -363,6 +365,7 @@ class ModelMixin(torch.nn.Module):
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
subfolder=subfolder,
|
||||
)
|
||||
|
||||
except RepositoryNotFoundError:
|
||||
|
||||
@@ -51,6 +51,7 @@ class AttentionBlock(nn.Module):
|
||||
overwrite_qkv=False,
|
||||
overwrite_linear=False,
|
||||
rescale_output_factor=1.0,
|
||||
eps=1e-5,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
@@ -62,7 +63,7 @@ class AttentionBlock(nn.Module):
|
||||
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
||||
self.num_heads = channels // num_head_channels
|
||||
|
||||
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-5, affine=True)
|
||||
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
|
||||
self.qkv = nn.Conv1d(channels, channels * 3, 1)
|
||||
self.n_heads = self.num_heads
|
||||
self.rescale_output_factor = rescale_output_factor
|
||||
@@ -165,7 +166,7 @@ class AttentionBlock(nn.Module):
|
||||
return result
|
||||
|
||||
|
||||
class AttentionBlockNew(nn.Module):
|
||||
class AttentionBlockNew_2(nn.Module):
|
||||
"""
|
||||
An attention block that allows spatial positions to attend to each other.
|
||||
|
||||
@@ -180,11 +181,14 @@ class AttentionBlockNew(nn.Module):
|
||||
num_groups=32,
|
||||
encoder_channels=None,
|
||||
rescale_output_factor=1.0,
|
||||
eps=1e-5,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-5, affine=True)
|
||||
self.channels = channels
|
||||
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
|
||||
self.qkv = nn.Conv1d(channels, channels * 3, 1)
|
||||
self.n_heads = channels // num_head_channels
|
||||
self.num_head_size = num_head_channels
|
||||
self.rescale_output_factor = rescale_output_factor
|
||||
|
||||
if encoder_channels is not None:
|
||||
@@ -192,6 +196,28 @@ class AttentionBlockNew(nn.Module):
|
||||
|
||||
self.proj = zero_module(nn.Conv1d(channels, channels, 1))
|
||||
|
||||
# ------------------------- new -----------------------
|
||||
num_heads = self.n_heads
|
||||
self.channels = channels
|
||||
if num_head_channels is None:
|
||||
self.num_heads = num_heads
|
||||
else:
|
||||
assert (
|
||||
channels % num_head_channels == 0
|
||||
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
||||
self.num_heads = channels // num_head_channels
|
||||
|
||||
self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
|
||||
|
||||
# define q,k,v as linear layers
|
||||
self.query = nn.Linear(channels, channels)
|
||||
self.key = nn.Linear(channels, channels)
|
||||
self.value = nn.Linear(channels, channels)
|
||||
|
||||
self.rescale_output_factor = rescale_output_factor
|
||||
self.proj_attn = zero_module(nn.Linear(channels, channels, 1))
|
||||
# ------------------------- new -----------------------
|
||||
|
||||
def set_weight(self, attn_layer):
|
||||
self.norm.weight.data = attn_layer.norm.weight.data
|
||||
self.norm.bias.data = attn_layer.norm.bias.data
|
||||
@@ -202,6 +228,89 @@ class AttentionBlockNew(nn.Module):
|
||||
self.proj.weight.data = attn_layer.proj.weight.data
|
||||
self.proj.bias.data = attn_layer.proj.bias.data
|
||||
|
||||
if hasattr(attn_layer, "q"):
|
||||
module = attn_layer
|
||||
qkv_weight = torch.cat([module.q.weight.data, module.k.weight.data, module.v.weight.data], dim=0)[
|
||||
:, :, :, 0
|
||||
]
|
||||
qkv_bias = torch.cat([module.q.bias.data, module.k.bias.data, module.v.bias.data], dim=0)
|
||||
|
||||
self.qkv.weight.data = qkv_weight
|
||||
self.qkv.bias.data = qkv_bias
|
||||
|
||||
proj_out = zero_module(nn.Conv1d(self.channels, self.channels, 1))
|
||||
proj_out.weight.data = module.proj_out.weight.data[:, :, :, 0]
|
||||
proj_out.bias.data = module.proj_out.bias.data
|
||||
|
||||
self.proj = proj_out
|
||||
|
||||
self.set_weights_2(attn_layer)
|
||||
|
||||
def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
|
||||
new_projection_shape = projection.size()[:-1] + (self.n_heads, self.num_head_size)
|
||||
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
|
||||
new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
|
||||
return new_projection
|
||||
|
||||
def set_weights_2(self, attn_layer):
|
||||
self.group_norm.weight.data = attn_layer.norm.weight.data
|
||||
self.group_norm.bias.data = attn_layer.norm.bias.data
|
||||
|
||||
qkv_weight = attn_layer.qkv.weight.data.reshape(self.n_heads, 3 * self.channels // self.n_heads, self.channels)
|
||||
qkv_bias = attn_layer.qkv.bias.data.reshape(self.n_heads, 3 * self.channels // self.n_heads)
|
||||
|
||||
q_w, k_w, v_w = qkv_weight.split(self.channels // self.n_heads, dim=1)
|
||||
q_b, k_b, v_b = qkv_bias.split(self.channels // self.n_heads, dim=1)
|
||||
|
||||
self.query.weight.data = q_w.reshape(-1, self.channels)
|
||||
self.key.weight.data = k_w.reshape(-1, self.channels)
|
||||
self.value.weight.data = v_w.reshape(-1, self.channels)
|
||||
|
||||
self.query.bias.data = q_b.reshape(-1)
|
||||
self.key.bias.data = k_b.reshape(-1)
|
||||
self.value.bias.data = v_b.reshape(-1)
|
||||
|
||||
self.proj_attn.weight.data = attn_layer.proj.weight.data[:, :, 0]
|
||||
self.proj_attn.bias.data = attn_layer.proj.bias.data
|
||||
|
||||
def forward_2(self, hidden_states):
|
||||
residual = hidden_states
|
||||
batch, channel, height, width = hidden_states.shape
|
||||
|
||||
# norm
|
||||
hidden_states = self.group_norm(hidden_states)
|
||||
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
|
||||
|
||||
# proj to q, k, v
|
||||
query_proj = self.query(hidden_states)
|
||||
key_proj = self.key(hidden_states)
|
||||
value_proj = self.value(hidden_states)
|
||||
|
||||
# transpose
|
||||
query_states = self.transpose_for_scores(query_proj)
|
||||
key_states = self.transpose_for_scores(key_proj)
|
||||
value_states = self.transpose_for_scores(value_proj)
|
||||
|
||||
# get scores
|
||||
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
|
||||
attention_scores = attention_scores / math.sqrt(self.channels // self.n_heads)
|
||||
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
||||
|
||||
# compute attention output
|
||||
context_states = torch.matmul(attention_probs, value_states)
|
||||
|
||||
context_states = context_states.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_states_shape = context_states.size()[:-2] + (self.channels,)
|
||||
context_states = context_states.view(new_context_states_shape)
|
||||
|
||||
# compute next hidden_states
|
||||
hidden_states = self.proj_attn(context_states)
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
|
||||
|
||||
# res connect and rescale
|
||||
hidden_states = (hidden_states + residual) / self.rescale_output_factor
|
||||
return hidden_states
|
||||
|
||||
def forward(self, x, encoder_out=None):
|
||||
b, c, *spatial = x.shape
|
||||
hid_states = self.norm(x).view(b, c, -1)
|
||||
@@ -230,10 +339,119 @@ class AttentionBlockNew(nn.Module):
|
||||
h = h.reshape(b, c, *spatial)
|
||||
|
||||
result = x + h
|
||||
|
||||
result = result / self.rescale_output_factor
|
||||
|
||||
return result
|
||||
result_2 = self.forward_2(x)
|
||||
|
||||
print((result - result_2).abs().sum())
|
||||
|
||||
return result_2
|
||||
|
||||
|
||||
class AttentionBlockNew(nn.Module):
|
||||
"""
|
||||
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
|
||||
to the N-d case.
|
||||
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
||||
Uses three q, k, v linear layers to compute attention
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
num_heads=1,
|
||||
num_head_channels=None,
|
||||
num_groups=32,
|
||||
rescale_output_factor=1.0,
|
||||
eps=1e-5,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
if num_head_channels is None:
|
||||
self.num_heads = num_heads
|
||||
else:
|
||||
assert (
|
||||
channels % num_head_channels == 0
|
||||
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
||||
self.num_heads = channels // num_head_channels
|
||||
|
||||
self.num_head_size = num_head_channels
|
||||
self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
|
||||
|
||||
# define q,k,v as linear layers
|
||||
self.query = nn.Linear(channels, channels)
|
||||
self.key = nn.Linear(channels, channels)
|
||||
self.value = nn.Linear(channels, channels)
|
||||
|
||||
self.rescale_output_factor = rescale_output_factor
|
||||
self.proj_attn = zero_module(nn.Linear(channels, channels, 1))
|
||||
|
||||
def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
|
||||
new_projection_shape = projection.size()[:-1] + (self.num_heads, self.num_head_size)
|
||||
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
|
||||
new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
|
||||
return new_projection
|
||||
|
||||
def forward(self, hidden_states):
|
||||
residual = hidden_states
|
||||
batch, channel, height, width = hidden_states.shape
|
||||
|
||||
# norm
|
||||
hidden_states = self.group_norm(hidden_states)
|
||||
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
|
||||
|
||||
# proj to q, k, v
|
||||
query_proj = self.query(hidden_states)
|
||||
key_proj = self.key(hidden_states)
|
||||
value_proj = self.value(hidden_states)
|
||||
|
||||
# transpose
|
||||
query_states = self.transpose_for_scores(query_proj)
|
||||
key_states = self.transpose_for_scores(key_proj)
|
||||
value_states = self.transpose_for_scores(value_proj)
|
||||
|
||||
# get scores
|
||||
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
|
||||
attention_scores = attention_scores / math.sqrt(self.channels // self.num_heads)
|
||||
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
||||
|
||||
# compute attention output
|
||||
context_states = torch.matmul(attention_probs, value_states)
|
||||
|
||||
context_states = context_states.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_states_shape = context_states.size()[:-2] + (self.channels,)
|
||||
context_states = context_states.view(new_context_states_shape)
|
||||
|
||||
# compute next hidden_states
|
||||
hidden_states = self.proj_attn(context_states)
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
|
||||
|
||||
# res connect and rescale
|
||||
hidden_states = (hidden_states + residual) / self.rescale_output_factor
|
||||
return hidden_states
|
||||
|
||||
def set_weight(self, attn_layer):
|
||||
self.group_norm.weight.data = attn_layer.norm.weight.data
|
||||
self.group_norm.bias.data = attn_layer.norm.bias.data
|
||||
|
||||
qkv_weight = attn_layer.qkv.weight.data.reshape(
|
||||
self.num_heads, 3 * self.channels // self.num_heads, self.channels
|
||||
)
|
||||
qkv_bias = attn_layer.qkv.bias.data.reshape(self.num_heads, 3 * self.channels // self.num_heads)
|
||||
|
||||
q_w, k_w, v_w = qkv_weight.split(self.channels // self.num_heads, dim=1)
|
||||
q_b, k_b, v_b = qkv_bias.split(self.channels // self.num_heads, dim=1)
|
||||
|
||||
self.query.weight.data = q_w.reshape(-1, self.channels)
|
||||
self.key.weight.data = k_w.reshape(-1, self.channels)
|
||||
self.value.weight.data = v_w.reshape(-1, self.channels)
|
||||
|
||||
self.query.bias.data = q_b.reshape(-1)
|
||||
self.key.bias.data = k_b.reshape(-1)
|
||||
self.value.bias.data = v_b.reshape(-1)
|
||||
|
||||
self.proj_attn.weight.data = attn_layer.proj.weight.data[:, :, 0]
|
||||
self.proj_attn.bias.data = attn_layer.proj.bias.data
|
||||
|
||||
|
||||
class SpatialTransformer(nn.Module):
|
||||
|
||||
@@ -81,8 +81,10 @@ class Downsample2D(nn.Module):
|
||||
self.conv = conv
|
||||
elif name == "Conv2d_0":
|
||||
self.Conv2d_0 = conv
|
||||
self.conv = conv
|
||||
else:
|
||||
self.op = conv
|
||||
self.conv = conv
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
@@ -90,13 +92,16 @@ class Downsample2D(nn.Module):
|
||||
pad = (0, 1, 0, 1)
|
||||
x = F.pad(x, pad, mode="constant", value=0)
|
||||
|
||||
return self.conv(x)
|
||||
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
||||
if self.name == "conv":
|
||||
return self.conv(x)
|
||||
elif self.name == "Conv2d_0":
|
||||
return self.Conv2d_0(x)
|
||||
else:
|
||||
return self.op(x)
|
||||
|
||||
|
||||
# if self.name == "conv":
|
||||
# return self.conv(x)
|
||||
# elif self.name == "Conv2d_0":
|
||||
# return self.Conv2d_0(x)
|
||||
# else:
|
||||
# return self.op(x)
|
||||
|
||||
|
||||
class Upsample1D(nn.Module):
|
||||
@@ -656,9 +661,9 @@ class ResnetBlock(nn.Module):
|
||||
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
if time_embedding_norm == "default" and temb_channels > 0:
|
||||
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
||||
self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
|
||||
elif time_embedding_norm == "scale_shift" and temb_channels > 0:
|
||||
self.temb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
|
||||
self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
|
||||
|
||||
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
@@ -691,9 +696,9 @@ class ResnetBlock(nn.Module):
|
||||
|
||||
self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut
|
||||
|
||||
self.nin_shortcut = None
|
||||
self.conv_shortcut = None
|
||||
if self.use_nin_shortcut:
|
||||
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x, temb):
|
||||
h = x
|
||||
@@ -715,7 +720,7 @@ class ResnetBlock(nn.Module):
|
||||
h = self.nonlinearity(h)
|
||||
|
||||
if temb is not None:
|
||||
temb = self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
|
||||
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
|
||||
else:
|
||||
temb = 0
|
||||
|
||||
@@ -738,8 +743,8 @@ class ResnetBlock(nn.Module):
|
||||
h = self.norm2(h)
|
||||
h = self.nonlinearity(h)
|
||||
|
||||
if self.nin_shortcut is not None:
|
||||
x = self.nin_shortcut(x)
|
||||
if self.conv_shortcut is not None:
|
||||
x = self.conv_shortcut(x)
|
||||
|
||||
return (x + h) / self.output_scale_factor
|
||||
|
||||
@@ -750,8 +755,8 @@ class ResnetBlock(nn.Module):
|
||||
self.conv1.weight.data = resnet.conv1.weight.data
|
||||
self.conv1.bias.data = resnet.conv1.bias.data
|
||||
|
||||
self.temb_proj.weight.data = resnet.temb_proj.weight.data
|
||||
self.temb_proj.bias.data = resnet.temb_proj.bias.data
|
||||
self.time_emb_proj.weight.data = resnet.temb_proj.weight.data
|
||||
self.time_emb_proj.bias.data = resnet.temb_proj.bias.data
|
||||
|
||||
self.norm2.weight.data = resnet.norm2.weight.data
|
||||
self.norm2.bias.data = resnet.norm2.bias.data
|
||||
@@ -760,8 +765,8 @@ class ResnetBlock(nn.Module):
|
||||
self.conv2.bias.data = resnet.conv2.bias.data
|
||||
|
||||
if self.use_nin_shortcut:
|
||||
self.nin_shortcut.weight.data = resnet.nin_shortcut.weight.data
|
||||
self.nin_shortcut.bias.data = resnet.nin_shortcut.bias.data
|
||||
self.conv_shortcut.weight.data = resnet.nin_shortcut.weight.data
|
||||
self.conv_shortcut.bias.data = resnet.nin_shortcut.bias.data
|
||||
|
||||
|
||||
# TODO(Patrick) - just there to convert the weights; can delete afterward
|
||||
|
||||
@@ -177,7 +177,9 @@ class UNetModel(ModelMixin, ConfigMixin):
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
print("hs", hs[-1].abs().sum())
|
||||
h = self.mid_new(hs[-1], temb)
|
||||
print("h", h.abs().sum())
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
|
||||
@@ -29,9 +29,10 @@ def get_down_block(
|
||||
resnet_eps,
|
||||
resnet_act_fn,
|
||||
attn_num_head_channels,
|
||||
downsample_padding=None,
|
||||
):
|
||||
if down_block_type == "UNetResDownBlock2D":
|
||||
return UNetResAttnDownBlock2D(
|
||||
return UNetResDownBlock2D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
@@ -39,6 +40,7 @@ def get_down_block(
|
||||
add_downsample=add_downsample,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
downsample_padding=downsample_padding,
|
||||
)
|
||||
elif down_block_type == "UNetResAttnDownBlock2D":
|
||||
return UNetResAttnDownBlock2D(
|
||||
@@ -57,7 +59,8 @@ def get_up_block(
|
||||
up_block_type,
|
||||
num_layers,
|
||||
in_channels,
|
||||
next_channels,
|
||||
out_channels,
|
||||
prev_output_channel,
|
||||
temb_channels,
|
||||
add_upsample,
|
||||
resnet_eps,
|
||||
@@ -68,7 +71,8 @@ def get_up_block(
|
||||
return UNetResUpBlock2D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
next_channels=next_channels,
|
||||
out_channels=out_channels,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=temb_channels,
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=resnet_eps,
|
||||
@@ -78,7 +82,8 @@ def get_up_block(
|
||||
return UNetResAttnUpBlock2D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
next_channels=next_channels,
|
||||
out_channels=out_channels,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=temb_channels,
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=resnet_eps,
|
||||
@@ -100,11 +105,14 @@ class UNetMidBlock2D(nn.Module):
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
attn_num_head_channels=1,
|
||||
attention_type="default",
|
||||
output_scale_factor=1.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attention_type = attention_type
|
||||
|
||||
# there is always at least one resnet
|
||||
resnets = [
|
||||
ResnetBlock(
|
||||
@@ -128,6 +136,7 @@ class UNetMidBlock2D(nn.Module):
|
||||
in_channels,
|
||||
num_head_channels=attn_num_head_channels,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
)
|
||||
)
|
||||
resnets.append(
|
||||
@@ -148,18 +157,15 @@ class UNetMidBlock2D(nn.Module):
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_states=None, mask=None):
|
||||
if mask is not None:
|
||||
hidden_states = self.resnets[0](hidden_states, temb, mask=mask)
|
||||
else:
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
def forward(self, hidden_states, temb=None, encoder_states=None):
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
hidden_states = attn(hidden_states, encoder_states)
|
||||
if mask is not None:
|
||||
hidden_states = resnet(hidden_states, temb, mask=mask)
|
||||
if self.attention_type == "default":
|
||||
hidden_states = attn(hidden_states)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(hidden_states, encoder_states)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -178,6 +184,7 @@ class UNetResAttnDownBlock2D(nn.Module):
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
attn_num_head_channels=1,
|
||||
attention_type="default",
|
||||
output_scale_factor=1.0,
|
||||
add_downsample=True,
|
||||
):
|
||||
@@ -185,6 +192,8 @@ class UNetResAttnDownBlock2D(nn.Module):
|
||||
resnets = []
|
||||
attentions = []
|
||||
|
||||
self.attention_type = attention_type
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else out_channels
|
||||
resnets.append(
|
||||
@@ -206,6 +215,7 @@ class UNetResAttnDownBlock2D(nn.Module):
|
||||
out_channels,
|
||||
num_head_channels=attn_num_head_channels,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -251,6 +261,7 @@ class UNetResDownBlock2D(nn.Module):
|
||||
resnet_pre_norm: bool = True,
|
||||
output_scale_factor=1.0,
|
||||
add_downsample=True,
|
||||
downsample_padding=1,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
@@ -276,7 +287,11 @@ class UNetResDownBlock2D(nn.Module):
|
||||
|
||||
if add_downsample:
|
||||
self.downsamplers = nn.ModuleList(
|
||||
[Downsample2D(in_channels, use_conv=True, out_channels=out_channels, padding=1, name="op")]
|
||||
[
|
||||
Downsample2D(
|
||||
in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.downsamplers = None
|
||||
@@ -301,7 +316,8 @@ class UNetResAttnUpBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
next_channels: int,
|
||||
prev_output_channel: int,
|
||||
out_channels: int,
|
||||
temb_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
@@ -310,7 +326,7 @@ class UNetResAttnUpBlock2D(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
attention_layer_type: str = "self",
|
||||
attention_type="default",
|
||||
attn_num_head_channels=1,
|
||||
output_scale_factor=1.0,
|
||||
add_upsample=True,
|
||||
@@ -319,12 +335,16 @@ class UNetResAttnUpBlock2D(nn.Module):
|
||||
resnets = []
|
||||
attentions = []
|
||||
|
||||
self.attention_type = attention_type
|
||||
|
||||
for i in range(num_layers):
|
||||
resnet_channels = in_channels if i < num_layers - 1 else next_channels
|
||||
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
||||
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
||||
|
||||
resnets.append(
|
||||
ResnetBlock(
|
||||
in_channels=in_channels + resnet_channels,
|
||||
out_channels=in_channels,
|
||||
in_channels=resnet_in_channels + res_skip_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
@@ -337,9 +357,10 @@ class UNetResAttnUpBlock2D(nn.Module):
|
||||
)
|
||||
attentions.append(
|
||||
AttentionBlockNew(
|
||||
in_channels,
|
||||
out_channels,
|
||||
num_head_channels=attn_num_head_channels,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -347,7 +368,7 @@ class UNetResAttnUpBlock2D(nn.Module):
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if add_upsample:
|
||||
self.upsamplers = nn.ModuleList([Upsample2D(in_channels, use_conv=True, out_channels=in_channels)])
|
||||
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
||||
else:
|
||||
self.upsamplers = None
|
||||
|
||||
@@ -373,7 +394,8 @@ class UNetResUpBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
next_channels: int,
|
||||
prev_output_channel: int,
|
||||
out_channels: int,
|
||||
temb_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
@@ -382,7 +404,6 @@ class UNetResUpBlock2D(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
attention_layer_type: str = "self",
|
||||
output_scale_factor=1.0,
|
||||
add_upsample=True,
|
||||
):
|
||||
@@ -390,11 +411,13 @@ class UNetResUpBlock2D(nn.Module):
|
||||
resnets = []
|
||||
|
||||
for i in range(num_layers):
|
||||
resnet_channels = in_channels if i < num_layers - 1 else next_channels
|
||||
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
||||
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
||||
|
||||
resnets.append(
|
||||
ResnetBlock(
|
||||
in_channels=in_channels + resnet_channels,
|
||||
out_channels=in_channels,
|
||||
in_channels=resnet_in_channels + res_skip_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
@@ -409,7 +432,7 @@ class UNetResUpBlock2D(nn.Module):
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if add_upsample:
|
||||
self.upsamplers = nn.ModuleList([Upsample2D(in_channels, use_conv=True, out_channels=in_channels)])
|
||||
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
||||
else:
|
||||
self.upsamplers = None
|
||||
|
||||
|
||||
@@ -9,6 +9,30 @@ from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
|
||||
from .unet_new import UNetMidBlock2D, get_down_block, get_up_block
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
def Normalize(in_channels):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(self, channel, time_embed_dim):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = nn.Linear(channel, time_embed_dim)
|
||||
self.act = nn.SiLU()
|
||||
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
|
||||
|
||||
def forward(self, sample):
|
||||
sample = self.linear_1(sample)
|
||||
sample = self.act(sample)
|
||||
sample = self.linear_2(sample)
|
||||
return sample
|
||||
|
||||
|
||||
class UNetUnconditionalModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param
|
||||
@@ -35,35 +59,66 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size,
|
||||
in_channels,
|
||||
out_channels,
|
||||
num_res_blocks,
|
||||
image_size=None,
|
||||
in_channels=None,
|
||||
out_channels=None,
|
||||
num_res_blocks=None,
|
||||
dropout=0,
|
||||
block_input_channels=(224, 224, 448, 672),
|
||||
block_output_channels=(224, 448, 672, 896),
|
||||
block_channels=(224, 448, 672, 896),
|
||||
down_blocks=(
|
||||
"UNetResDownBlock2D",
|
||||
"UNetResAttnDownBlock2D",
|
||||
"UNetResAttnDownBlock2D",
|
||||
"UNetResAttnDownBlock2D",
|
||||
),
|
||||
downsample_padding=1,
|
||||
up_blocks=("UNetResAttnUpBlock2D", "UNetResAttnUpBlock2D", "UNetResAttnUpBlock2D", "UNetResUpBlock2D"),
|
||||
resnet_act_fn="silu",
|
||||
resnet_eps=1e-5,
|
||||
conv_resample=True,
|
||||
num_head_channels=32,
|
||||
flip_sin_to_cos=True,
|
||||
downscale_freq_shift=0,
|
||||
# To delete once weights are converted
|
||||
# LDM
|
||||
attention_resolutions=(8, 4, 2),
|
||||
ldm=False,
|
||||
# DDPM
|
||||
out_ch=None,
|
||||
resolution=None,
|
||||
attn_resolutions=None,
|
||||
resamp_with_conv=None,
|
||||
ch_mult=None,
|
||||
ch=None,
|
||||
ddpm=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# DELETE if statements if not necessary anymore
|
||||
# DDPM
|
||||
if ddpm:
|
||||
out_channels = out_ch
|
||||
image_size = resolution
|
||||
block_channels = [x * ch for x in ch_mult]
|
||||
conv_resample = resamp_with_conv
|
||||
flip_sin_to_cos = False
|
||||
downscale_freq_shift = 1
|
||||
resnet_eps = 1e-6
|
||||
block_channels = (32, 64)
|
||||
down_blocks = (
|
||||
"UNetResDownBlock2D",
|
||||
"UNetResAttnDownBlock2D",
|
||||
)
|
||||
up_blocks = ("UNetResUpBlock2D", "UNetResAttnUpBlock2D")
|
||||
downsample_padding = 0
|
||||
num_head_channels = 64
|
||||
|
||||
# register all __init__ params with self.register
|
||||
self.register_to_config(
|
||||
image_size=image_size,
|
||||
in_channels=in_channels,
|
||||
block_input_channels=block_input_channels,
|
||||
block_output_channels=block_output_channels,
|
||||
block_channels=block_channels,
|
||||
downsample_padding=downsample_padding,
|
||||
out_channels=out_channels,
|
||||
num_res_blocks=num_res_blocks,
|
||||
down_blocks=down_blocks,
|
||||
@@ -71,37 +126,34 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
|
||||
dropout=dropout,
|
||||
conv_resample=conv_resample,
|
||||
num_head_channels=num_head_channels,
|
||||
flip_sin_to_cos=flip_sin_to_cos,
|
||||
downscale_freq_shift=downscale_freq_shift,
|
||||
# (TODO(PVP) - To delete once weights are converted
|
||||
attention_resolutions=attention_resolutions,
|
||||
ldm=ldm,
|
||||
ddpm=ddpm,
|
||||
)
|
||||
|
||||
# To delete - replace with config values
|
||||
self.image_size = image_size
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.dropout = dropout
|
||||
time_embed_dim = block_channels[0] * 4
|
||||
|
||||
time_embed_dim = block_input_channels[0] * 4
|
||||
# # input
|
||||
self.conv_in = nn.Conv2d(in_channels, block_channels[0], kernel_size=3, padding=(1, 1))
|
||||
|
||||
# ======================== Input ===================
|
||||
self.conv_in = nn.Conv2d(in_channels, block_input_channels[0], kernel_size=3, padding=(1, 1))
|
||||
|
||||
# ======================== Time ====================
|
||||
self.time_embed = nn.Sequential(
|
||||
nn.Linear(block_input_channels[0], time_embed_dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
|
||||
# ======================== Down ====================
|
||||
input_channels = list(block_input_channels)
|
||||
output_channels = list(block_output_channels)
|
||||
# # time
|
||||
self.time_embedding = TimestepEmbedding(block_channels[0], time_embed_dim)
|
||||
|
||||
self.downsample_blocks = nn.ModuleList([])
|
||||
for i, (input_channel, output_channel) in enumerate(zip(input_channels, output_channels)):
|
||||
down_block_type = down_blocks[i]
|
||||
is_final_block = i == len(input_channels) - 1
|
||||
self.mid = None
|
||||
self.upsample_blocks = nn.ModuleList([])
|
||||
|
||||
# down
|
||||
output_channel = block_channels[0]
|
||||
for i, down_block_type in enumerate(down_blocks):
|
||||
input_channel = output_channel
|
||||
output_channel = block_channels[i]
|
||||
is_final_block = i == len(block_channels) - 1
|
||||
|
||||
down_block = get_down_block(
|
||||
down_block_type,
|
||||
@@ -113,30 +165,48 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
attn_num_head_channels=num_head_channels,
|
||||
downsample_padding=downsample_padding,
|
||||
)
|
||||
self.downsample_blocks.append(down_block)
|
||||
|
||||
# ======================== Mid ====================
|
||||
self.mid = UNetMidBlock2D(
|
||||
in_channels=output_channels[-1],
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_time_scale_shift="default",
|
||||
attn_num_head_channels=num_head_channels,
|
||||
)
|
||||
# mid
|
||||
if self.config.ddpm:
|
||||
self.mid_new_2 = UNetMidBlock2D(
|
||||
in_channels=block_channels[-1],
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_time_scale_shift="default",
|
||||
attn_num_head_channels=num_head_channels,
|
||||
)
|
||||
else:
|
||||
self.mid = UNetMidBlock2D(
|
||||
in_channels=block_channels[-1],
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_time_scale_shift="default",
|
||||
attn_num_head_channels=num_head_channels,
|
||||
)
|
||||
|
||||
self.upsample_blocks = nn.ModuleList([])
|
||||
for i, (input_channel, output_channel) in enumerate(zip(reversed(input_channels), reversed(output_channels))):
|
||||
up_block_type = up_blocks[i]
|
||||
is_final_block = i == len(input_channels) - 1
|
||||
# up
|
||||
reversed_block_channels = list(reversed(block_channels))
|
||||
output_channel = reversed_block_channels[0]
|
||||
for i, up_block_type in enumerate(up_blocks):
|
||||
prev_output_channel = output_channel
|
||||
output_channel = reversed_block_channels[i]
|
||||
input_channel = reversed_block_channels[min(i + 1, len(block_channels) - 1)]
|
||||
|
||||
is_final_block = i == len(block_channels) - 1
|
||||
|
||||
up_block = get_up_block(
|
||||
up_block_type,
|
||||
num_layers=num_res_blocks + 1,
|
||||
in_channels=output_channel,
|
||||
next_channels=input_channel,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
add_upsample=not is_final_block,
|
||||
resnet_eps=resnet_eps,
|
||||
@@ -144,50 +214,72 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
|
||||
attn_num_head_channels=num_head_channels,
|
||||
)
|
||||
self.upsample_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
# out
|
||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_channels[0], num_groups=32, eps=1e-5)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = nn.Conv2d(block_channels[0], out_channels, 3, padding=1)
|
||||
|
||||
# ======================== Out ====================
|
||||
self.out = nn.Sequential(
|
||||
nn.GroupNorm(num_channels=output_channels[0], num_groups=32, eps=1e-5),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(block_input_channels[0], out_channels, 3, padding=1),
|
||||
)
|
||||
|
||||
# =========== TO DELETE AFTER CONVERSION ==========
|
||||
transformer_depth = 1
|
||||
context_dim = None
|
||||
legacy = True
|
||||
num_heads = -1
|
||||
model_channels = block_input_channels[0]
|
||||
channel_mult = tuple([x // model_channels for x in block_output_channels])
|
||||
self.init_for_ldm(
|
||||
in_channels,
|
||||
model_channels,
|
||||
channel_mult,
|
||||
num_res_blocks,
|
||||
dropout,
|
||||
time_embed_dim,
|
||||
attention_resolutions,
|
||||
num_head_channels,
|
||||
num_heads,
|
||||
legacy,
|
||||
False,
|
||||
transformer_depth,
|
||||
context_dim,
|
||||
conv_resample,
|
||||
out_channels,
|
||||
)
|
||||
self.is_overwritten = False
|
||||
if ldm:
|
||||
# =========== TO DELETE AFTER CONVERSION ==========
|
||||
transformer_depth = 1
|
||||
context_dim = None
|
||||
legacy = True
|
||||
num_heads = -1
|
||||
model_channels = block_channels[0]
|
||||
channel_mult = tuple([x // model_channels for x in block_channels])
|
||||
self.init_for_ldm(
|
||||
in_channels,
|
||||
model_channels,
|
||||
channel_mult,
|
||||
num_res_blocks,
|
||||
dropout,
|
||||
time_embed_dim,
|
||||
attention_resolutions,
|
||||
num_head_channels,
|
||||
num_heads,
|
||||
legacy,
|
||||
False,
|
||||
transformer_depth,
|
||||
context_dim,
|
||||
conv_resample,
|
||||
out_channels,
|
||||
)
|
||||
if ddpm:
|
||||
self.init_for_ddpm(
|
||||
ch_mult,
|
||||
ch,
|
||||
num_res_blocks,
|
||||
resolution,
|
||||
in_channels,
|
||||
resamp_with_conv,
|
||||
attn_resolutions,
|
||||
out_ch,
|
||||
dropout=0.1,
|
||||
)
|
||||
|
||||
def forward(self, sample, timesteps=None):
|
||||
# TODO(PVP) - to delete later
|
||||
if not self.is_overwritten:
|
||||
self.set_weights()
|
||||
|
||||
# 1. time step embeddings
|
||||
if not torch.is_tensor(timesteps):
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
||||
|
||||
t_emb = get_timestep_embedding(
|
||||
timesteps, self.config.block_input_channels[0], flip_sin_to_cos=True, downscale_freq_shift=0
|
||||
timesteps,
|
||||
self.config.block_channels[0],
|
||||
flip_sin_to_cos=self.config.flip_sin_to_cos,
|
||||
downscale_freq_shift=self.config.downscale_freq_shift,
|
||||
)
|
||||
emb = self.time_embed(t_emb)
|
||||
emb = self.time_embedding(t_emb)
|
||||
|
||||
# 2. pre-process sample
|
||||
# sample = sample.type(self.dtype_)
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
# 3. down blocks
|
||||
@@ -198,8 +290,13 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
|
||||
# append to tuple
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
print("sample", sample.abs().sum())
|
||||
# 4. mid block
|
||||
sample = self.mid(sample, emb)
|
||||
if self.config.ddpm:
|
||||
sample = self.mid_new_2(sample, emb)
|
||||
else:
|
||||
sample = self.mid(sample, emb)
|
||||
print("sample", sample.abs().sum())
|
||||
|
||||
# 5. up blocks
|
||||
for upsample_block in self.upsample_blocks:
|
||||
@@ -211,10 +308,192 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
|
||||
sample = upsample_block(sample, res_samples, emb)
|
||||
|
||||
# 6. post-process sample
|
||||
sample = self.out(sample)
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
return sample
|
||||
|
||||
def set_weights(self):
|
||||
self.is_overwritten = True
|
||||
if self.config.ldm:
|
||||
|
||||
self.time_embedding.linear_1.weight.data = self.time_embed[0].weight.data
|
||||
self.time_embedding.linear_1.bias.data = self.time_embed[0].bias.data
|
||||
self.time_embedding.linear_2.weight.data = self.time_embed[2].weight.data
|
||||
self.time_embedding.linear_2.bias.data = self.time_embed[2].bias.data
|
||||
|
||||
# ================ SET WEIGHTS OF ALL WEIGHTS ==================
|
||||
for i, input_layer in enumerate(self.input_blocks[1:]):
|
||||
block_id = i // (self.config.num_res_blocks + 1)
|
||||
layer_in_block_id = i % (self.config.num_res_blocks + 1)
|
||||
|
||||
if layer_in_block_id == 2:
|
||||
self.downsample_blocks[block_id].downsamplers[0].conv.weight.data = input_layer[0].op.weight.data
|
||||
self.downsample_blocks[block_id].downsamplers[0].conv.bias.data = input_layer[0].op.bias.data
|
||||
elif len(input_layer) > 1:
|
||||
self.downsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
|
||||
self.downsample_blocks[block_id].attentions[layer_in_block_id].set_weight(input_layer[1])
|
||||
else:
|
||||
self.downsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
|
||||
|
||||
self.mid.resnets[0].set_weight(self.middle_block[0])
|
||||
self.mid.resnets[1].set_weight(self.middle_block[2])
|
||||
self.mid.attentions[0].set_weight(self.middle_block[1])
|
||||
|
||||
for i, input_layer in enumerate(self.output_blocks):
|
||||
block_id = i // (self.config.num_res_blocks + 1)
|
||||
layer_in_block_id = i % (self.config.num_res_blocks + 1)
|
||||
|
||||
if len(input_layer) > 2:
|
||||
self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
|
||||
self.upsample_blocks[block_id].attentions[layer_in_block_id].set_weight(input_layer[1])
|
||||
self.upsample_blocks[block_id].upsamplers[0].conv.weight.data = input_layer[2].conv.weight.data
|
||||
self.upsample_blocks[block_id].upsamplers[0].conv.bias.data = input_layer[2].conv.bias.data
|
||||
elif len(input_layer) > 1 and "Upsample2D" in input_layer[1].__class__.__name__:
|
||||
self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
|
||||
self.upsample_blocks[block_id].upsamplers[0].conv.weight.data = input_layer[1].conv.weight.data
|
||||
self.upsample_blocks[block_id].upsamplers[0].conv.bias.data = input_layer[1].conv.bias.data
|
||||
elif len(input_layer) > 1:
|
||||
self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
|
||||
self.upsample_blocks[block_id].attentions[layer_in_block_id].set_weight(input_layer[1])
|
||||
else:
|
||||
self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
|
||||
|
||||
self.conv_in.weight.data = self.input_blocks[0][0].weight.data
|
||||
self.conv_in.bias.data = self.input_blocks[0][0].bias.data
|
||||
|
||||
self.conv_norm_out.weight.data = self.out[0].weight.data
|
||||
self.conv_norm_out.bias.data = self.out[0].bias.data
|
||||
self.conv_out.weight.data = self.out[2].weight.data
|
||||
self.conv_out.bias.data = self.out[2].bias.data
|
||||
|
||||
self.remove_ldm()
|
||||
|
||||
elif self.config.ddpm:
|
||||
# =============== SET WEIGHTS ===============
|
||||
# =============== TIME ======================
|
||||
self.time_embed[0] = self.temb.dense[0]
|
||||
self.time_embed[2] = self.temb.dense[1]
|
||||
|
||||
for i, block in enumerate(self.down):
|
||||
if hasattr(block, "downsample"):
|
||||
self.downsample_blocks[i].downsamplers[0].conv.weight.data = block.downsample.conv.weight.data
|
||||
self.downsample_blocks[i].downsamplers[0].conv.bias.data = block.downsample.conv.bias.data
|
||||
if hasattr(block, "block") and len(block.block) > 0:
|
||||
for j in range(self.num_res_blocks):
|
||||
self.downsample_blocks[i].resnets[j].set_weight(block.block[j])
|
||||
if hasattr(block, "attn") and len(block.attn) > 0:
|
||||
for j in range(self.num_res_blocks):
|
||||
self.downsample_blocks[i].attentions[j].set_weight(block.attn[j])
|
||||
|
||||
self.mid_new_2.resnets[0].set_weight(self.mid.block_1)
|
||||
self.mid_new_2.resnets[1].set_weight(self.mid.block_2)
|
||||
self.mid_new_2.attentions[0].set_weight(self.mid.attn_1)
|
||||
|
||||
def init_for_ddpm(
|
||||
self,
|
||||
ch_mult,
|
||||
ch,
|
||||
num_res_blocks,
|
||||
resolution,
|
||||
in_channels,
|
||||
resamp_with_conv,
|
||||
attn_resolutions,
|
||||
out_ch,
|
||||
dropout=0.1,
|
||||
):
|
||||
ch_mult = tuple(ch_mult)
|
||||
self.ch = ch
|
||||
self.temb_ch = self.ch * 4
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
|
||||
# timestep embedding
|
||||
self.temb = nn.Module()
|
||||
self.temb.dense = nn.ModuleList(
|
||||
[
|
||||
torch.nn.Linear(self.ch, self.temb_ch),
|
||||
torch.nn.Linear(self.temb_ch, self.temb_ch),
|
||||
]
|
||||
)
|
||||
|
||||
# downsampling
|
||||
self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,) + ch_mult
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(AttentionBlock(block_in, overwrite_qkv=True))
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down.downsample = Downsample2D(block_in, use_conv=resamp_with_conv, padding=0)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock2D(
|
||||
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
|
||||
)
|
||||
self.mid.attn_1 = AttentionBlock(block_in, overwrite_qkv=True)
|
||||
self.mid.block_2 = ResnetBlock2D(
|
||||
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
|
||||
)
|
||||
self.mid_new = UNetMidBlock2D(in_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
|
||||
self.mid_new.resnets[0] = self.mid.block_1
|
||||
self.mid_new.attentions[0] = self.mid.attn_1
|
||||
self.mid_new.resnets[1] = self.mid.block_2
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
skip_in = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
if i_block == self.num_res_blocks:
|
||||
skip_in = ch * in_ch_mult[i_level]
|
||||
block.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=block_in + skip_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(AttentionBlock(block_in, overwrite_qkv=True))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample2D(block_in, use_conv=resamp_with_conv)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def init_for_ldm(
|
||||
self,
|
||||
in_channels,
|
||||
@@ -234,7 +513,6 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
|
||||
out_channels,
|
||||
):
|
||||
# TODO(PVP) - delete after weight conversion
|
||||
|
||||
class TimestepEmbedSequential(nn.Sequential):
|
||||
"""
|
||||
A sequential module that passes timestep embeddings to the children that support it as an extra input.
|
||||
@@ -255,6 +533,12 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
|
||||
return nn.Conv3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
self.time_embed = nn.Sequential(
|
||||
nn.Linear(model_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
|
||||
dims = 2
|
||||
self.input_blocks = nn.ModuleList(
|
||||
[TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))]
|
||||
@@ -389,42 +673,15 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
|
||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
|
||||
# ================ SET WEIGHTS OF ALL WEIGHTS ==================
|
||||
for i, input_layer in enumerate(self.input_blocks[1:]):
|
||||
block_id = i // (num_res_blocks + 1)
|
||||
layer_in_block_id = i % (num_res_blocks + 1)
|
||||
self.out = nn.Sequential(
|
||||
nn.GroupNorm(num_channels=model_channels, num_groups=32, eps=1e-5),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(model_channels, out_channels, 3, padding=1),
|
||||
)
|
||||
|
||||
if layer_in_block_id == 2:
|
||||
self.downsample_blocks[block_id].downsamplers[0].op.weight.data = input_layer[0].op.weight.data
|
||||
self.downsample_blocks[block_id].downsamplers[0].op.bias.data = input_layer[0].op.bias.data
|
||||
elif len(input_layer) > 1:
|
||||
self.downsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
|
||||
self.downsample_blocks[block_id].attentions[layer_in_block_id].set_weight(input_layer[1])
|
||||
else:
|
||||
self.downsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
|
||||
|
||||
self.mid.resnets[0].set_weight(self.middle_block[0])
|
||||
self.mid.resnets[1].set_weight(self.middle_block[2])
|
||||
self.mid.attentions[0].set_weight(self.middle_block[1])
|
||||
|
||||
for i, input_layer in enumerate(self.output_blocks):
|
||||
block_id = i // (num_res_blocks + 1)
|
||||
layer_in_block_id = i % (num_res_blocks + 1)
|
||||
|
||||
if len(input_layer) > 2:
|
||||
self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
|
||||
self.upsample_blocks[block_id].attentions[layer_in_block_id].set_weight(input_layer[1])
|
||||
self.upsample_blocks[block_id].upsamplers[0].conv.weight.data = input_layer[2].conv.weight.data
|
||||
self.upsample_blocks[block_id].upsamplers[0].conv.bias.data = input_layer[2].conv.bias.data
|
||||
elif len(input_layer) > 1 and "Upsample2D" in input_layer[1].__class__.__name__:
|
||||
self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
|
||||
self.upsample_blocks[block_id].upsamplers[0].conv.weight.data = input_layer[1].conv.weight.data
|
||||
self.upsample_blocks[block_id].upsamplers[0].conv.bias.data = input_layer[1].conv.bias.data
|
||||
elif len(input_layer) > 1:
|
||||
self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
|
||||
self.upsample_blocks[block_id].attentions[layer_in_block_id].set_weight(input_layer[1])
|
||||
else:
|
||||
self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
|
||||
|
||||
self.conv_in.weight.data = self.input_blocks[0][0].weight.data
|
||||
self.conv_in.bias.data = self.input_blocks[0][0].bias.data
|
||||
def remove_ldm(self):
|
||||
del self.time_embed
|
||||
del self.input_blocks
|
||||
del self.middle_block
|
||||
del self.output_blocks
|
||||
del self.out
|
||||
|
||||
@@ -271,6 +271,27 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
# fmt: on
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
|
||||
|
||||
print("Original success!!!")
|
||||
|
||||
model = UNetUnconditionalModel.from_pretrained("fusing/ddpm_dummy", ddpm=True)
|
||||
model.eval()
|
||||
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
|
||||
time_step = torch.tensor([10])
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(noise, time_step)
|
||||
|
||||
output_slice = output[0, -1, -3:, -3:].flatten()
|
||||
# fmt: off
|
||||
expected_output_slice = torch.tensor([0.2891, -0.1899, 0.2595, -0.6214, 0.0968, -0.2622, 0.4688, 0.1311, 0.0053])
|
||||
# fmt: on
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
|
||||
|
||||
|
||||
class GlideSuperResUNetTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = GlideSuperResUNetModel
|
||||
@@ -486,18 +507,20 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
"out_channels": 4,
|
||||
"num_res_blocks": 2,
|
||||
"attention_resolutions": (16,),
|
||||
"block_input_channels": [32, 32],
|
||||
"block_output_channels": [32, 64],
|
||||
"block_channels": (32, 64),
|
||||
"num_head_channels": 32,
|
||||
"conv_resample": True,
|
||||
"down_blocks": ("UNetResDownBlock2D", "UNetResDownBlock2D"),
|
||||
"up_blocks": ("UNetResUpBlock2D", "UNetResUpBlock2D"),
|
||||
"ldm": True,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_from_pretrained_hub(self):
|
||||
model, loading_info = UNetUnconditionalModel.from_pretrained("fusing/unet-ldm-dummy", output_loading_info=True)
|
||||
model, loading_info = UNetUnconditionalModel.from_pretrained(
|
||||
"fusing/unet-ldm-dummy", output_loading_info=True, ldm=True
|
||||
)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertEqual(len(loading_info["missing_keys"]), 0)
|
||||
|
||||
@@ -507,7 +530,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
def test_output_pretrained(self):
|
||||
model = UNetUnconditionalModel.from_pretrained("fusing/unet-ldm-dummy")
|
||||
model = UNetUnconditionalModel.from_pretrained("fusing/unet-ldm-dummy", ldm=True)
|
||||
model.eval()
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
Reference in New Issue
Block a user