diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index bccbaf3a16..3ee67fb365 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -222,9 +222,9 @@ class ResnetBlock(nn.Module): self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - if time_embedding_norm == "default": + if time_embedding_norm == "default" and temb_channels > 0: self.temb_proj = torch.nn.Linear(temb_channels, out_channels) - elif time_embedding_norm == "scale_shift": + elif time_embedding_norm == "scale_shift" and temb_channels > 0: self.temb_proj = torch.nn.Linear(temb_channels, 2 * out_channels) self.norm2 = Normalize(out_channels, num_groups=groups_out, eps=eps) @@ -427,7 +427,10 @@ class ResnetBlock(nn.Module): h = self.nonlinearity(h) h = h * mask - temb = self.temb_proj(self.nonlinearity(temb))[:, :, None, None] + if temb is not None: + temb = self.temb_proj(self.nonlinearity(temb))[:, :, None, None] + else: + temb = 0 if self.time_embedding_norm == "scale_shift": scale, shift = torch.chunk(temb, 2, dim=1) diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index 6bd9a07099..486cd3903d 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -1,5 +1,3 @@ -import math - import numpy as np import torch import torch.nn as nn @@ -7,26 +5,7 @@ import torch.nn as nn from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin from .attention import AttentionBlock -from .resnet import Downsample, Upsample - - -def get_timestep_embedding(timesteps, embedding_dim): - """ - This matches the implementation in Denoising Diffusion Probabilistic Models: From Fairseq. Build sinusoidal - embeddings. This matches the implementation in tensor2tensor, but differs slightly from the description in Section - 3.5 of "Attention Is All You Need". - """ - assert len(timesteps.shape) == 1 - - half_dim = embedding_dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) - emb = emb.to(device=timesteps.device) - emb = timesteps.float()[:, None] * emb[None, :] - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) - return emb +from .resnet import Downsample, ResnetBlock, Upsample def nonlinearity(x): @@ -38,50 +17,6 @@ def Normalize(in_channels): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) -class ResnetBlock(nn.Module): - def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512): - super().__init__() - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.use_conv_shortcut = conv_shortcut - - self.norm1 = Normalize(in_channels) - self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - if temb_channels > 0: - self.temb_proj = torch.nn.Linear(temb_channels, out_channels) - self.norm2 = Normalize(out_channels) - self.dropout = torch.nn.Dropout(dropout) - self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - else: - self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) - - def forward(self, x, temb): - h = x - h = self.norm1(h) - h = nonlinearity(h) - h = self.conv1(h) - - if temb is not None: - h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] - - h = self.norm2(h) - h = nonlinearity(h) - h = self.dropout(h) - h = self.conv2(h) - - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - x = self.conv_shortcut(x) - else: - x = self.nin_shortcut(x) - - return x + h - - class Encoder(nn.Module): def __init__( self,