diff --git a/conversion.py b/conversion.py new file mode 100755 index 0000000000..2606aad18f --- /dev/null +++ b/conversion.py @@ -0,0 +1,94 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect +import tempfile +import unittest + +import numpy as np +import torch + +from diffusers import ( + AutoencoderKL, + DDIMPipeline, + DDIMScheduler, + DDPMPipeline, + DDPMScheduler, + GlidePipeline, + GlideSuperResUNetModel, + GlideTextToImageUNetModel, + LatentDiffusionPipeline, + LatentDiffusionUncondPipeline, + NCSNpp, + PNDMPipeline, + PNDMScheduler, + ScoreSdeVePipeline, + ScoreSdeVeScheduler, + ScoreSdeVpPipeline, + ScoreSdeVpScheduler, + UNetLDMModel, + UNetModel, + UNetUnconditionalModel, + VQModel, +) +from diffusers.configuration_utils import ConfigMixin +from diffusers.pipeline_utils import DiffusionPipeline +from diffusers.testing_utils import floats_tensor, slow, torch_device +from diffusers.training_utils import EMAModel + + +def test_output_pretrained_ldm_dummy(): + model = UNetUnconditionalModel.from_pretrained("fusing/unet-ldm-dummy", ldm=True) + model.eval() + + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size) + time_step = torch.tensor([10] * noise.shape[0]) + + with torch.no_grad(): + output = model(noise, time_step) + + print(model) + import ipdb; ipdb.set_trace() + + +def test_output_pretrained_ldm(): + model = UNetUnconditionalModel.from_pretrained("fusing/latent-diffusion-celeba-256", subfolder="unet", ldm=True) + model.eval() + + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size) + time_step = torch.tensor([10] * noise.shape[0]) + + with torch.no_grad(): + output = model(noise, time_step) + + print(model) + import ipdb; ipdb.set_trace() + +# To see the how the final model should look like +test_output_pretrained_ldm_dummy() +test_output_pretrained_ldm() +# => this is the architecture in which the model should be saved in the new format +# -> verify new repo with the following tests (in `test_modeling_utils.py`) +# - test_ldm_uncond (in PipelineTesterMixin) +# - test_output_pretrained ( in UNetLDMModelTests) diff --git a/docs/source/examples/diffusers_for_vision.mdx b/docs/source/examples/diffusers_for_vision.mdx index 624938f59d..28a8dcd91d 100644 --- a/docs/source/examples/diffusers_for_vision.mdx +++ b/docs/source/examples/diffusers_for_vision.mdx @@ -111,7 +111,7 @@ prompt = "A painting of a squirrel eating a burger" image = ldm([prompt], generator=generator, eta=0.3, guidance_scale=6.0, num_inference_steps=50) image_processed = image.cpu().permute(0, 2, 3, 1) -image_processed = image_processed * 255. +image_processed = image_processed * 255.0 image_processed = image_processed.numpy().astype(np.uint8) image_pil = PIL.Image.fromarray(image_processed[0]) @@ -143,6 +143,7 @@ audio = bddm(mel_spec, generator, torch_device=torch_device) # save generated audio from scipy.io.wavfile import write as wavwrite + sampling_rate = 22050 wavwrite("generated_audio.wav", sampling_rate, audio.squeeze().cpu().numpy()) ``` diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 05dd00cc10..abc6e094a0 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -116,6 +116,7 @@ class ConfigMixin: use_auth_token = kwargs.pop("use_auth_token", None) local_files_only = kwargs.pop("local_files_only", False) revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) user_agent = {"file_type": "config"} @@ -150,6 +151,7 @@ class ConfigMixin: local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, + subfolder=subfolder, ) except RepositoryNotFoundError: diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py index aa60ffa936..4352627e0d 100644 --- a/src/diffusers/modeling_utils.py +++ b/src/diffusers/modeling_utils.py @@ -321,6 +321,7 @@ class ModelMixin(torch.nn.Module): use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) from_auto_class = kwargs.pop("_from_auto", False) + subfolder = kwargs.pop("subfolder", None) user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class} @@ -336,6 +337,7 @@ class ModelMixin(torch.nn.Module): local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, + subfolder=subfolder, **kwargs, ) model.register_to_config(name_or_path=pretrained_model_name_or_path) @@ -363,6 +365,7 @@ class ModelMixin(torch.nn.Module): local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, + subfolder=subfolder, ) except RepositoryNotFoundError: diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 4b780c7f72..bba5d73577 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -51,6 +51,7 @@ class AttentionBlock(nn.Module): overwrite_qkv=False, overwrite_linear=False, rescale_output_factor=1.0, + eps=1e-5, ): super().__init__() self.channels = channels @@ -62,7 +63,7 @@ class AttentionBlock(nn.Module): ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" self.num_heads = channels // num_head_channels - self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-5, affine=True) + self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True) self.qkv = nn.Conv1d(channels, channels * 3, 1) self.n_heads = self.num_heads self.rescale_output_factor = rescale_output_factor @@ -165,7 +166,7 @@ class AttentionBlock(nn.Module): return result -class AttentionBlockNew(nn.Module): +class AttentionBlockNew_2(nn.Module): """ An attention block that allows spatial positions to attend to each other. @@ -180,11 +181,14 @@ class AttentionBlockNew(nn.Module): num_groups=32, encoder_channels=None, rescale_output_factor=1.0, + eps=1e-5, ): super().__init__() - self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-5, affine=True) + self.channels = channels + self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True) self.qkv = nn.Conv1d(channels, channels * 3, 1) self.n_heads = channels // num_head_channels + self.num_head_size = num_head_channels self.rescale_output_factor = rescale_output_factor if encoder_channels is not None: @@ -192,6 +196,28 @@ class AttentionBlockNew(nn.Module): self.proj = zero_module(nn.Conv1d(channels, channels, 1)) + # ------------------------- new ----------------------- + num_heads = self.n_heads + self.channels = channels + if num_head_channels is None: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + + self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True) + + # define q,k,v as linear layers + self.query = nn.Linear(channels, channels) + self.key = nn.Linear(channels, channels) + self.value = nn.Linear(channels, channels) + + self.rescale_output_factor = rescale_output_factor + self.proj_attn = zero_module(nn.Linear(channels, channels, 1)) + # ------------------------- new ----------------------- + def set_weight(self, attn_layer): self.norm.weight.data = attn_layer.norm.weight.data self.norm.bias.data = attn_layer.norm.bias.data @@ -202,6 +228,89 @@ class AttentionBlockNew(nn.Module): self.proj.weight.data = attn_layer.proj.weight.data self.proj.bias.data = attn_layer.proj.bias.data + if hasattr(attn_layer, "q"): + module = attn_layer + qkv_weight = torch.cat([module.q.weight.data, module.k.weight.data, module.v.weight.data], dim=0)[ + :, :, :, 0 + ] + qkv_bias = torch.cat([module.q.bias.data, module.k.bias.data, module.v.bias.data], dim=0) + + self.qkv.weight.data = qkv_weight + self.qkv.bias.data = qkv_bias + + proj_out = zero_module(nn.Conv1d(self.channels, self.channels, 1)) + proj_out.weight.data = module.proj_out.weight.data[:, :, :, 0] + proj_out.bias.data = module.proj_out.bias.data + + self.proj = proj_out + + self.set_weights_2(attn_layer) + + def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor: + new_projection_shape = projection.size()[:-1] + (self.n_heads, self.num_head_size) + # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D) + new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3) + return new_projection + + def set_weights_2(self, attn_layer): + self.group_norm.weight.data = attn_layer.norm.weight.data + self.group_norm.bias.data = attn_layer.norm.bias.data + + qkv_weight = attn_layer.qkv.weight.data.reshape(self.n_heads, 3 * self.channels // self.n_heads, self.channels) + qkv_bias = attn_layer.qkv.bias.data.reshape(self.n_heads, 3 * self.channels // self.n_heads) + + q_w, k_w, v_w = qkv_weight.split(self.channels // self.n_heads, dim=1) + q_b, k_b, v_b = qkv_bias.split(self.channels // self.n_heads, dim=1) + + self.query.weight.data = q_w.reshape(-1, self.channels) + self.key.weight.data = k_w.reshape(-1, self.channels) + self.value.weight.data = v_w.reshape(-1, self.channels) + + self.query.bias.data = q_b.reshape(-1) + self.key.bias.data = k_b.reshape(-1) + self.value.bias.data = v_b.reshape(-1) + + self.proj_attn.weight.data = attn_layer.proj.weight.data[:, :, 0] + self.proj_attn.bias.data = attn_layer.proj.bias.data + + def forward_2(self, hidden_states): + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = self.group_norm(hidden_states) + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = self.query(hidden_states) + key_proj = self.key(hidden_states) + value_proj = self.value(hidden_states) + + # transpose + query_states = self.transpose_for_scores(query_proj) + key_states = self.transpose_for_scores(key_proj) + value_states = self.transpose_for_scores(value_proj) + + # get scores + attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.channels // self.n_heads) + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # compute attention output + context_states = torch.matmul(attention_probs, value_states) + + context_states = context_states.permute(0, 2, 1, 3).contiguous() + new_context_states_shape = context_states.size()[:-2] + (self.channels,) + context_states = context_states.view(new_context_states_shape) + + # compute next hidden_states + hidden_states = self.proj_attn(context_states) + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states + def forward(self, x, encoder_out=None): b, c, *spatial = x.shape hid_states = self.norm(x).view(b, c, -1) @@ -230,10 +339,119 @@ class AttentionBlockNew(nn.Module): h = h.reshape(b, c, *spatial) result = x + h - result = result / self.rescale_output_factor - return result + result_2 = self.forward_2(x) + + print((result - result_2).abs().sum()) + + return result_2 + + +class AttentionBlockNew(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted + to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + Uses three q, k, v linear layers to compute attention + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=None, + num_groups=32, + rescale_output_factor=1.0, + eps=1e-5, + ): + super().__init__() + self.channels = channels + if num_head_channels is None: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + + self.num_head_size = num_head_channels + self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True) + + # define q,k,v as linear layers + self.query = nn.Linear(channels, channels) + self.key = nn.Linear(channels, channels) + self.value = nn.Linear(channels, channels) + + self.rescale_output_factor = rescale_output_factor + self.proj_attn = zero_module(nn.Linear(channels, channels, 1)) + + def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor: + new_projection_shape = projection.size()[:-1] + (self.num_heads, self.num_head_size) + # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D) + new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3) + return new_projection + + def forward(self, hidden_states): + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = self.group_norm(hidden_states) + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = self.query(hidden_states) + key_proj = self.key(hidden_states) + value_proj = self.value(hidden_states) + + # transpose + query_states = self.transpose_for_scores(query_proj) + key_states = self.transpose_for_scores(key_proj) + value_states = self.transpose_for_scores(value_proj) + + # get scores + attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.channels // self.num_heads) + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # compute attention output + context_states = torch.matmul(attention_probs, value_states) + + context_states = context_states.permute(0, 2, 1, 3).contiguous() + new_context_states_shape = context_states.size()[:-2] + (self.channels,) + context_states = context_states.view(new_context_states_shape) + + # compute next hidden_states + hidden_states = self.proj_attn(context_states) + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states + + def set_weight(self, attn_layer): + self.group_norm.weight.data = attn_layer.norm.weight.data + self.group_norm.bias.data = attn_layer.norm.bias.data + + qkv_weight = attn_layer.qkv.weight.data.reshape( + self.num_heads, 3 * self.channels // self.num_heads, self.channels + ) + qkv_bias = attn_layer.qkv.bias.data.reshape(self.num_heads, 3 * self.channels // self.num_heads) + + q_w, k_w, v_w = qkv_weight.split(self.channels // self.num_heads, dim=1) + q_b, k_b, v_b = qkv_bias.split(self.channels // self.num_heads, dim=1) + + self.query.weight.data = q_w.reshape(-1, self.channels) + self.key.weight.data = k_w.reshape(-1, self.channels) + self.value.weight.data = v_w.reshape(-1, self.channels) + + self.query.bias.data = q_b.reshape(-1) + self.key.bias.data = k_b.reshape(-1) + self.value.bias.data = v_b.reshape(-1) + + self.proj_attn.weight.data = attn_layer.proj.weight.data[:, :, 0] + self.proj_attn.bias.data = attn_layer.proj.bias.data class SpatialTransformer(nn.Module): diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 540b1d94f6..d211c4bc43 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -81,8 +81,10 @@ class Downsample2D(nn.Module): self.conv = conv elif name == "Conv2d_0": self.Conv2d_0 = conv + self.conv = conv else: self.op = conv + self.conv = conv def forward(self, x): assert x.shape[1] == self.channels @@ -90,13 +92,16 @@ class Downsample2D(nn.Module): pad = (0, 1, 0, 1) x = F.pad(x, pad, mode="constant", value=0) + return self.conv(x) # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed - if self.name == "conv": - return self.conv(x) - elif self.name == "Conv2d_0": - return self.Conv2d_0(x) - else: - return self.op(x) + + +# if self.name == "conv": +# return self.conv(x) +# elif self.name == "Conv2d_0": +# return self.Conv2d_0(x) +# else: +# return self.op(x) class Upsample1D(nn.Module): @@ -656,9 +661,9 @@ class ResnetBlock(nn.Module): self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) if time_embedding_norm == "default" and temb_channels > 0: - self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels) elif time_embedding_norm == "scale_shift" and temb_channels > 0: - self.temb_proj = torch.nn.Linear(temb_channels, 2 * out_channels) + self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels) self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) self.dropout = torch.nn.Dropout(dropout) @@ -691,9 +696,9 @@ class ResnetBlock(nn.Module): self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut - self.nin_shortcut = None + self.conv_shortcut = None if self.use_nin_shortcut: - self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) def forward(self, x, temb): h = x @@ -715,7 +720,7 @@ class ResnetBlock(nn.Module): h = self.nonlinearity(h) if temb is not None: - temb = self.temb_proj(self.nonlinearity(temb))[:, :, None, None] + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] else: temb = 0 @@ -738,8 +743,8 @@ class ResnetBlock(nn.Module): h = self.norm2(h) h = self.nonlinearity(h) - if self.nin_shortcut is not None: - x = self.nin_shortcut(x) + if self.conv_shortcut is not None: + x = self.conv_shortcut(x) return (x + h) / self.output_scale_factor @@ -750,8 +755,8 @@ class ResnetBlock(nn.Module): self.conv1.weight.data = resnet.conv1.weight.data self.conv1.bias.data = resnet.conv1.bias.data - self.temb_proj.weight.data = resnet.temb_proj.weight.data - self.temb_proj.bias.data = resnet.temb_proj.bias.data + self.time_emb_proj.weight.data = resnet.temb_proj.weight.data + self.time_emb_proj.bias.data = resnet.temb_proj.bias.data self.norm2.weight.data = resnet.norm2.weight.data self.norm2.bias.data = resnet.norm2.bias.data @@ -760,8 +765,8 @@ class ResnetBlock(nn.Module): self.conv2.bias.data = resnet.conv2.bias.data if self.use_nin_shortcut: - self.nin_shortcut.weight.data = resnet.nin_shortcut.weight.data - self.nin_shortcut.bias.data = resnet.nin_shortcut.bias.data + self.conv_shortcut.weight.data = resnet.nin_shortcut.weight.data + self.conv_shortcut.bias.data = resnet.nin_shortcut.bias.data # TODO(Patrick) - just there to convert the weights; can delete afterward diff --git a/src/diffusers/models/unet.py b/src/diffusers/models/unet.py index 7c49769722..de2824af7a 100644 --- a/src/diffusers/models/unet.py +++ b/src/diffusers/models/unet.py @@ -177,7 +177,9 @@ class UNetModel(ModelMixin, ConfigMixin): hs.append(self.down[i_level].downsample(hs[-1])) # middle + print("hs", hs[-1].abs().sum()) h = self.mid_new(hs[-1], temb) + print("h", h.abs().sum()) # upsampling for i_level in reversed(range(self.num_resolutions)): diff --git a/src/diffusers/models/unet_new.py b/src/diffusers/models/unet_new.py index 5669c60cb5..83235d3d16 100644 --- a/src/diffusers/models/unet_new.py +++ b/src/diffusers/models/unet_new.py @@ -29,9 +29,10 @@ def get_down_block( resnet_eps, resnet_act_fn, attn_num_head_channels, + downsample_padding=None, ): if down_block_type == "UNetResDownBlock2D": - return UNetResAttnDownBlock2D( + return UNetResDownBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, @@ -39,6 +40,7 @@ def get_down_block( add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, + downsample_padding=downsample_padding, ) elif down_block_type == "UNetResAttnDownBlock2D": return UNetResAttnDownBlock2D( @@ -57,7 +59,8 @@ def get_up_block( up_block_type, num_layers, in_channels, - next_channels, + out_channels, + prev_output_channel, temb_channels, add_upsample, resnet_eps, @@ -68,7 +71,8 @@ def get_up_block( return UNetResUpBlock2D( num_layers=num_layers, in_channels=in_channels, - next_channels=next_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, @@ -78,7 +82,8 @@ def get_up_block( return UNetResAttnUpBlock2D( num_layers=num_layers, in_channels=in_channels, - next_channels=next_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, @@ -100,11 +105,14 @@ class UNetMidBlock2D(nn.Module): resnet_groups: int = 32, resnet_pre_norm: bool = True, attn_num_head_channels=1, + attention_type="default", output_scale_factor=1.0, **kwargs, ): super().__init__() + self.attention_type = attention_type + # there is always at least one resnet resnets = [ ResnetBlock( @@ -128,6 +136,7 @@ class UNetMidBlock2D(nn.Module): in_channels, num_head_channels=attn_num_head_channels, rescale_output_factor=output_scale_factor, + eps=resnet_eps, ) ) resnets.append( @@ -148,18 +157,15 @@ class UNetMidBlock2D(nn.Module): self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - def forward(self, hidden_states, temb=None, encoder_states=None, mask=None): - if mask is not None: - hidden_states = self.resnets[0](hidden_states, temb, mask=mask) - else: - hidden_states = self.resnets[0](hidden_states, temb) + def forward(self, hidden_states, temb=None, encoder_states=None): + hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): - hidden_states = attn(hidden_states, encoder_states) - if mask is not None: - hidden_states = resnet(hidden_states, temb, mask=mask) + if self.attention_type == "default": + hidden_states = attn(hidden_states) else: - hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_states) + hidden_states = resnet(hidden_states, temb) return hidden_states @@ -178,6 +184,7 @@ class UNetResAttnDownBlock2D(nn.Module): resnet_groups: int = 32, resnet_pre_norm: bool = True, attn_num_head_channels=1, + attention_type="default", output_scale_factor=1.0, add_downsample=True, ): @@ -185,6 +192,8 @@ class UNetResAttnDownBlock2D(nn.Module): resnets = [] attentions = [] + self.attention_type = attention_type + for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( @@ -206,6 +215,7 @@ class UNetResAttnDownBlock2D(nn.Module): out_channels, num_head_channels=attn_num_head_channels, rescale_output_factor=output_scale_factor, + eps=resnet_eps, ) ) @@ -251,6 +261,7 @@ class UNetResDownBlock2D(nn.Module): resnet_pre_norm: bool = True, output_scale_factor=1.0, add_downsample=True, + downsample_padding=1, ): super().__init__() resnets = [] @@ -276,7 +287,11 @@ class UNetResDownBlock2D(nn.Module): if add_downsample: self.downsamplers = nn.ModuleList( - [Downsample2D(in_channels, use_conv=True, out_channels=out_channels, padding=1, name="op")] + [ + Downsample2D( + in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] ) else: self.downsamplers = None @@ -301,7 +316,8 @@ class UNetResAttnUpBlock2D(nn.Module): def __init__( self, in_channels: int, - next_channels: int, + prev_output_channel: int, + out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, @@ -310,7 +326,7 @@ class UNetResAttnUpBlock2D(nn.Module): resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, - attention_layer_type: str = "self", + attention_type="default", attn_num_head_channels=1, output_scale_factor=1.0, add_upsample=True, @@ -319,12 +335,16 @@ class UNetResAttnUpBlock2D(nn.Module): resnets = [] attentions = [] + self.attention_type = attention_type + for i in range(num_layers): - resnet_channels = in_channels if i < num_layers - 1 else next_channels + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + resnets.append( ResnetBlock( - in_channels=in_channels + resnet_channels, - out_channels=in_channels, + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, @@ -337,9 +357,10 @@ class UNetResAttnUpBlock2D(nn.Module): ) attentions.append( AttentionBlockNew( - in_channels, + out_channels, num_head_channels=attn_num_head_channels, rescale_output_factor=output_scale_factor, + eps=resnet_eps, ) ) @@ -347,7 +368,7 @@ class UNetResAttnUpBlock2D(nn.Module): self.resnets = nn.ModuleList(resnets) if add_upsample: - self.upsamplers = nn.ModuleList([Upsample2D(in_channels, use_conv=True, out_channels=in_channels)]) + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None @@ -373,7 +394,8 @@ class UNetResUpBlock2D(nn.Module): def __init__( self, in_channels: int, - next_channels: int, + prev_output_channel: int, + out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, @@ -382,7 +404,6 @@ class UNetResUpBlock2D(nn.Module): resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, - attention_layer_type: str = "self", output_scale_factor=1.0, add_upsample=True, ): @@ -390,11 +411,13 @@ class UNetResUpBlock2D(nn.Module): resnets = [] for i in range(num_layers): - resnet_channels = in_channels if i < num_layers - 1 else next_channels + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + resnets.append( ResnetBlock( - in_channels=in_channels + resnet_channels, - out_channels=in_channels, + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, @@ -409,7 +432,7 @@ class UNetResUpBlock2D(nn.Module): self.resnets = nn.ModuleList(resnets) if add_upsample: - self.upsamplers = nn.ModuleList([Upsample2D(in_channels, use_conv=True, out_channels=in_channels)]) + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None diff --git a/src/diffusers/models/unet_unconditional.py b/src/diffusers/models/unet_unconditional.py index dba9ab8fcd..c912f29429 100644 --- a/src/diffusers/models/unet_unconditional.py +++ b/src/diffusers/models/unet_unconditional.py @@ -9,6 +9,30 @@ from .resnet import Downsample2D, ResnetBlock2D, Upsample2D from .unet_new import UNetMidBlock2D, get_down_block, get_up_block +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class TimestepEmbedding(nn.Module): + def __init__(self, channel, time_embed_dim): + super().__init__() + + self.linear_1 = nn.Linear(channel, time_embed_dim) + self.act = nn.SiLU() + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim) + + def forward(self, sample): + sample = self.linear_1(sample) + sample = self.act(sample) + sample = self.linear_2(sample) + return sample + + class UNetUnconditionalModel(ModelMixin, ConfigMixin): """ The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param @@ -35,35 +59,66 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): def __init__( self, - image_size, - in_channels, - out_channels, - num_res_blocks, + image_size=None, + in_channels=None, + out_channels=None, + num_res_blocks=None, dropout=0, - block_input_channels=(224, 224, 448, 672), - block_output_channels=(224, 448, 672, 896), + block_channels=(224, 448, 672, 896), down_blocks=( "UNetResDownBlock2D", "UNetResAttnDownBlock2D", "UNetResAttnDownBlock2D", "UNetResAttnDownBlock2D", ), + downsample_padding=1, up_blocks=("UNetResAttnUpBlock2D", "UNetResAttnUpBlock2D", "UNetResAttnUpBlock2D", "UNetResUpBlock2D"), resnet_act_fn="silu", resnet_eps=1e-5, conv_resample=True, num_head_channels=32, + flip_sin_to_cos=True, + downscale_freq_shift=0, # To delete once weights are converted + # LDM attention_resolutions=(8, 4, 2), + ldm=False, + # DDPM + out_ch=None, + resolution=None, + attn_resolutions=None, + resamp_with_conv=None, + ch_mult=None, + ch=None, + ddpm=False, ): super().__init__() + # DELETE if statements if not necessary anymore + # DDPM + if ddpm: + out_channels = out_ch + image_size = resolution + block_channels = [x * ch for x in ch_mult] + conv_resample = resamp_with_conv + flip_sin_to_cos = False + downscale_freq_shift = 1 + resnet_eps = 1e-6 + block_channels = (32, 64) + down_blocks = ( + "UNetResDownBlock2D", + "UNetResAttnDownBlock2D", + ) + up_blocks = ("UNetResUpBlock2D", "UNetResAttnUpBlock2D") + downsample_padding = 0 + num_head_channels = 64 + # register all __init__ params with self.register self.register_to_config( image_size=image_size, in_channels=in_channels, - block_input_channels=block_input_channels, - block_output_channels=block_output_channels, + block_channels=block_channels, + downsample_padding=downsample_padding, out_channels=out_channels, num_res_blocks=num_res_blocks, down_blocks=down_blocks, @@ -71,37 +126,34 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): dropout=dropout, conv_resample=conv_resample, num_head_channels=num_head_channels, + flip_sin_to_cos=flip_sin_to_cos, + downscale_freq_shift=downscale_freq_shift, # (TODO(PVP) - To delete once weights are converted attention_resolutions=attention_resolutions, + ldm=ldm, + ddpm=ddpm, ) # To delete - replace with config values self.image_size = image_size - self.in_channels = in_channels - self.out_channels = out_channels - self.num_res_blocks = num_res_blocks - self.dropout = dropout + time_embed_dim = block_channels[0] * 4 - time_embed_dim = block_input_channels[0] * 4 + # # input + self.conv_in = nn.Conv2d(in_channels, block_channels[0], kernel_size=3, padding=(1, 1)) - # ======================== Input =================== - self.conv_in = nn.Conv2d(in_channels, block_input_channels[0], kernel_size=3, padding=(1, 1)) - - # ======================== Time ==================== - self.time_embed = nn.Sequential( - nn.Linear(block_input_channels[0], time_embed_dim), - nn.SiLU(), - nn.Linear(time_embed_dim, time_embed_dim), - ) - - # ======================== Down ==================== - input_channels = list(block_input_channels) - output_channels = list(block_output_channels) + # # time + self.time_embedding = TimestepEmbedding(block_channels[0], time_embed_dim) self.downsample_blocks = nn.ModuleList([]) - for i, (input_channel, output_channel) in enumerate(zip(input_channels, output_channels)): - down_block_type = down_blocks[i] - is_final_block = i == len(input_channels) - 1 + self.mid = None + self.upsample_blocks = nn.ModuleList([]) + + # down + output_channel = block_channels[0] + for i, down_block_type in enumerate(down_blocks): + input_channel = output_channel + output_channel = block_channels[i] + is_final_block = i == len(block_channels) - 1 down_block = get_down_block( down_block_type, @@ -113,30 +165,48 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, attn_num_head_channels=num_head_channels, + downsample_padding=downsample_padding, ) self.downsample_blocks.append(down_block) - # ======================== Mid ==================== - self.mid = UNetMidBlock2D( - in_channels=output_channels[-1], - dropout=dropout, - temb_channels=time_embed_dim, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_time_scale_shift="default", - attn_num_head_channels=num_head_channels, - ) + # mid + if self.config.ddpm: + self.mid_new_2 = UNetMidBlock2D( + in_channels=block_channels[-1], + dropout=dropout, + temb_channels=time_embed_dim, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_time_scale_shift="default", + attn_num_head_channels=num_head_channels, + ) + else: + self.mid = UNetMidBlock2D( + in_channels=block_channels[-1], + dropout=dropout, + temb_channels=time_embed_dim, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_time_scale_shift="default", + attn_num_head_channels=num_head_channels, + ) - self.upsample_blocks = nn.ModuleList([]) - for i, (input_channel, output_channel) in enumerate(zip(reversed(input_channels), reversed(output_channels))): - up_block_type = up_blocks[i] - is_final_block = i == len(input_channels) - 1 + # up + reversed_block_channels = list(reversed(block_channels)) + output_channel = reversed_block_channels[0] + for i, up_block_type in enumerate(up_blocks): + prev_output_channel = output_channel + output_channel = reversed_block_channels[i] + input_channel = reversed_block_channels[min(i + 1, len(block_channels) - 1)] + + is_final_block = i == len(block_channels) - 1 up_block = get_up_block( up_block_type, num_layers=num_res_blocks + 1, - in_channels=output_channel, - next_channels=input_channel, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, temb_channels=time_embed_dim, add_upsample=not is_final_block, resnet_eps=resnet_eps, @@ -144,50 +214,72 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): attn_num_head_channels=num_head_channels, ) self.upsample_blocks.append(up_block) + prev_output_channel = output_channel + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_channels[0], num_groups=32, eps=1e-5) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv2d(block_channels[0], out_channels, 3, padding=1) # ======================== Out ==================== - self.out = nn.Sequential( - nn.GroupNorm(num_channels=output_channels[0], num_groups=32, eps=1e-5), - nn.SiLU(), - nn.Conv2d(block_input_channels[0], out_channels, 3, padding=1), - ) - # =========== TO DELETE AFTER CONVERSION ========== - transformer_depth = 1 - context_dim = None - legacy = True - num_heads = -1 - model_channels = block_input_channels[0] - channel_mult = tuple([x // model_channels for x in block_output_channels]) - self.init_for_ldm( - in_channels, - model_channels, - channel_mult, - num_res_blocks, - dropout, - time_embed_dim, - attention_resolutions, - num_head_channels, - num_heads, - legacy, - False, - transformer_depth, - context_dim, - conv_resample, - out_channels, - ) + self.is_overwritten = False + if ldm: + # =========== TO DELETE AFTER CONVERSION ========== + transformer_depth = 1 + context_dim = None + legacy = True + num_heads = -1 + model_channels = block_channels[0] + channel_mult = tuple([x // model_channels for x in block_channels]) + self.init_for_ldm( + in_channels, + model_channels, + channel_mult, + num_res_blocks, + dropout, + time_embed_dim, + attention_resolutions, + num_head_channels, + num_heads, + legacy, + False, + transformer_depth, + context_dim, + conv_resample, + out_channels, + ) + if ddpm: + self.init_for_ddpm( + ch_mult, + ch, + num_res_blocks, + resolution, + in_channels, + resamp_with_conv, + attn_resolutions, + out_ch, + dropout=0.1, + ) def forward(self, sample, timesteps=None): + # TODO(PVP) - to delete later + if not self.is_overwritten: + self.set_weights() + # 1. time step embeddings if not torch.is_tensor(timesteps): timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) + t_emb = get_timestep_embedding( - timesteps, self.config.block_input_channels[0], flip_sin_to_cos=True, downscale_freq_shift=0 + timesteps, + self.config.block_channels[0], + flip_sin_to_cos=self.config.flip_sin_to_cos, + downscale_freq_shift=self.config.downscale_freq_shift, ) - emb = self.time_embed(t_emb) + emb = self.time_embedding(t_emb) # 2. pre-process sample - # sample = sample.type(self.dtype_) sample = self.conv_in(sample) # 3. down blocks @@ -198,8 +290,13 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): # append to tuple down_block_res_samples += res_samples + print("sample", sample.abs().sum()) # 4. mid block - sample = self.mid(sample, emb) + if self.config.ddpm: + sample = self.mid_new_2(sample, emb) + else: + sample = self.mid(sample, emb) + print("sample", sample.abs().sum()) # 5. up blocks for upsample_block in self.upsample_blocks: @@ -211,10 +308,192 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): sample = upsample_block(sample, res_samples, emb) # 6. post-process sample - sample = self.out(sample) + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) return sample + def set_weights(self): + self.is_overwritten = True + if self.config.ldm: + + self.time_embedding.linear_1.weight.data = self.time_embed[0].weight.data + self.time_embedding.linear_1.bias.data = self.time_embed[0].bias.data + self.time_embedding.linear_2.weight.data = self.time_embed[2].weight.data + self.time_embedding.linear_2.bias.data = self.time_embed[2].bias.data + + # ================ SET WEIGHTS OF ALL WEIGHTS ================== + for i, input_layer in enumerate(self.input_blocks[1:]): + block_id = i // (self.config.num_res_blocks + 1) + layer_in_block_id = i % (self.config.num_res_blocks + 1) + + if layer_in_block_id == 2: + self.downsample_blocks[block_id].downsamplers[0].conv.weight.data = input_layer[0].op.weight.data + self.downsample_blocks[block_id].downsamplers[0].conv.bias.data = input_layer[0].op.bias.data + elif len(input_layer) > 1: + self.downsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0]) + self.downsample_blocks[block_id].attentions[layer_in_block_id].set_weight(input_layer[1]) + else: + self.downsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0]) + + self.mid.resnets[0].set_weight(self.middle_block[0]) + self.mid.resnets[1].set_weight(self.middle_block[2]) + self.mid.attentions[0].set_weight(self.middle_block[1]) + + for i, input_layer in enumerate(self.output_blocks): + block_id = i // (self.config.num_res_blocks + 1) + layer_in_block_id = i % (self.config.num_res_blocks + 1) + + if len(input_layer) > 2: + self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0]) + self.upsample_blocks[block_id].attentions[layer_in_block_id].set_weight(input_layer[1]) + self.upsample_blocks[block_id].upsamplers[0].conv.weight.data = input_layer[2].conv.weight.data + self.upsample_blocks[block_id].upsamplers[0].conv.bias.data = input_layer[2].conv.bias.data + elif len(input_layer) > 1 and "Upsample2D" in input_layer[1].__class__.__name__: + self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0]) + self.upsample_blocks[block_id].upsamplers[0].conv.weight.data = input_layer[1].conv.weight.data + self.upsample_blocks[block_id].upsamplers[0].conv.bias.data = input_layer[1].conv.bias.data + elif len(input_layer) > 1: + self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0]) + self.upsample_blocks[block_id].attentions[layer_in_block_id].set_weight(input_layer[1]) + else: + self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0]) + + self.conv_in.weight.data = self.input_blocks[0][0].weight.data + self.conv_in.bias.data = self.input_blocks[0][0].bias.data + + self.conv_norm_out.weight.data = self.out[0].weight.data + self.conv_norm_out.bias.data = self.out[0].bias.data + self.conv_out.weight.data = self.out[2].weight.data + self.conv_out.bias.data = self.out[2].bias.data + + self.remove_ldm() + + elif self.config.ddpm: + # =============== SET WEIGHTS =============== + # =============== TIME ====================== + self.time_embed[0] = self.temb.dense[0] + self.time_embed[2] = self.temb.dense[1] + + for i, block in enumerate(self.down): + if hasattr(block, "downsample"): + self.downsample_blocks[i].downsamplers[0].conv.weight.data = block.downsample.conv.weight.data + self.downsample_blocks[i].downsamplers[0].conv.bias.data = block.downsample.conv.bias.data + if hasattr(block, "block") and len(block.block) > 0: + for j in range(self.num_res_blocks): + self.downsample_blocks[i].resnets[j].set_weight(block.block[j]) + if hasattr(block, "attn") and len(block.attn) > 0: + for j in range(self.num_res_blocks): + self.downsample_blocks[i].attentions[j].set_weight(block.attn[j]) + + self.mid_new_2.resnets[0].set_weight(self.mid.block_1) + self.mid_new_2.resnets[1].set_weight(self.mid.block_2) + self.mid_new_2.attentions[0].set_weight(self.mid.attn_1) + + def init_for_ddpm( + self, + ch_mult, + ch, + num_res_blocks, + resolution, + in_channels, + resamp_with_conv, + attn_resolutions, + out_ch, + dropout=0.1, + ): + ch_mult = tuple(ch_mult) + self.ch = ch + self.temb_ch = self.ch * 4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList( + [ + torch.nn.Linear(self.ch, self.temb_ch), + torch.nn.Linear(self.temb_ch, self.temb_ch), + ] + ) + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) + + curr_res = resolution + in_ch_mult = (1,) + ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock2D( + in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttentionBlock(block_in, overwrite_qkv=True)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample2D(block_in, use_conv=resamp_with_conv, padding=0) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock2D( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) + self.mid.attn_1 = AttentionBlock(block_in, overwrite_qkv=True) + self.mid.block_2 = ResnetBlock2D( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) + self.mid_new = UNetMidBlock2D(in_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) + self.mid_new.resnets[0] = self.mid.block_1 + self.mid_new.attentions[0] = self.mid.attn_1 + self.mid_new.resnets[1] = self.mid.block_2 + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + skip_in = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + if i_block == self.num_res_blocks: + skip_in = ch * in_ch_mult[i_level] + block.append( + ResnetBlock2D( + in_channels=block_in + skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttentionBlock(block_in, overwrite_qkv=True)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample2D(block_in, use_conv=resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + def init_for_ldm( self, in_channels, @@ -234,7 +513,6 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): out_channels, ): # TODO(PVP) - delete after weight conversion - class TimestepEmbedSequential(nn.Sequential): """ A sequential module that passes timestep embeddings to the children that support it as an extra input. @@ -255,6 +533,12 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): return nn.Conv3d(*args, **kwargs) raise ValueError(f"unsupported dimensions: {dims}") + self.time_embed = nn.Sequential( + nn.Linear(model_channels, time_embed_dim), + nn.SiLU(), + nn.Linear(time_embed_dim, time_embed_dim), + ) + dims = 2 self.input_blocks = nn.ModuleList( [TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))] @@ -389,42 +673,15 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): self.output_blocks.append(TimestepEmbedSequential(*layers)) self._feature_size += ch - # ================ SET WEIGHTS OF ALL WEIGHTS ================== - for i, input_layer in enumerate(self.input_blocks[1:]): - block_id = i // (num_res_blocks + 1) - layer_in_block_id = i % (num_res_blocks + 1) + self.out = nn.Sequential( + nn.GroupNorm(num_channels=model_channels, num_groups=32, eps=1e-5), + nn.SiLU(), + nn.Conv2d(model_channels, out_channels, 3, padding=1), + ) - if layer_in_block_id == 2: - self.downsample_blocks[block_id].downsamplers[0].op.weight.data = input_layer[0].op.weight.data - self.downsample_blocks[block_id].downsamplers[0].op.bias.data = input_layer[0].op.bias.data - elif len(input_layer) > 1: - self.downsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0]) - self.downsample_blocks[block_id].attentions[layer_in_block_id].set_weight(input_layer[1]) - else: - self.downsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0]) - - self.mid.resnets[0].set_weight(self.middle_block[0]) - self.mid.resnets[1].set_weight(self.middle_block[2]) - self.mid.attentions[0].set_weight(self.middle_block[1]) - - for i, input_layer in enumerate(self.output_blocks): - block_id = i // (num_res_blocks + 1) - layer_in_block_id = i % (num_res_blocks + 1) - - if len(input_layer) > 2: - self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0]) - self.upsample_blocks[block_id].attentions[layer_in_block_id].set_weight(input_layer[1]) - self.upsample_blocks[block_id].upsamplers[0].conv.weight.data = input_layer[2].conv.weight.data - self.upsample_blocks[block_id].upsamplers[0].conv.bias.data = input_layer[2].conv.bias.data - elif len(input_layer) > 1 and "Upsample2D" in input_layer[1].__class__.__name__: - self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0]) - self.upsample_blocks[block_id].upsamplers[0].conv.weight.data = input_layer[1].conv.weight.data - self.upsample_blocks[block_id].upsamplers[0].conv.bias.data = input_layer[1].conv.bias.data - elif len(input_layer) > 1: - self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0]) - self.upsample_blocks[block_id].attentions[layer_in_block_id].set_weight(input_layer[1]) - else: - self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0]) - - self.conv_in.weight.data = self.input_blocks[0][0].weight.data - self.conv_in.bias.data = self.input_blocks[0][0].bias.data + def remove_ldm(self): + del self.time_embed + del self.input_blocks + del self.middle_block + del self.output_blocks + del self.out diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 1a1e522923..f3c486df64 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -271,6 +271,27 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase): # fmt: on self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2)) + print("Original success!!!") + + model = UNetUnconditionalModel.from_pretrained("fusing/ddpm_dummy", ddpm=True) + model.eval() + + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size) + time_step = torch.tensor([10]) + + with torch.no_grad(): + output = model(noise, time_step) + + output_slice = output[0, -1, -3:, -3:].flatten() + # fmt: off + expected_output_slice = torch.tensor([0.2891, -0.1899, 0.2595, -0.6214, 0.0968, -0.2622, 0.4688, 0.1311, 0.0053]) + # fmt: on + self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2)) + class GlideSuperResUNetTests(ModelTesterMixin, unittest.TestCase): model_class = GlideSuperResUNetModel @@ -486,18 +507,20 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): "out_channels": 4, "num_res_blocks": 2, "attention_resolutions": (16,), - "block_input_channels": [32, 32], - "block_output_channels": [32, 64], + "block_channels": (32, 64), "num_head_channels": 32, "conv_resample": True, "down_blocks": ("UNetResDownBlock2D", "UNetResDownBlock2D"), "up_blocks": ("UNetResUpBlock2D", "UNetResUpBlock2D"), + "ldm": True, } inputs_dict = self.dummy_input return init_dict, inputs_dict def test_from_pretrained_hub(self): - model, loading_info = UNetUnconditionalModel.from_pretrained("fusing/unet-ldm-dummy", output_loading_info=True) + model, loading_info = UNetUnconditionalModel.from_pretrained( + "fusing/unet-ldm-dummy", output_loading_info=True, ldm=True + ) self.assertIsNotNone(model) self.assertEqual(len(loading_info["missing_keys"]), 0) @@ -507,7 +530,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): assert image is not None, "Make sure output is not None" def test_output_pretrained(self): - model = UNetUnconditionalModel.from_pretrained("fusing/unet-ldm-dummy") + model = UNetUnconditionalModel.from_pretrained("fusing/unet-ldm-dummy", ldm=True) model.eval() torch.manual_seed(0)