1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[ResNet] Refactor resnet from VAE (#65)

This commit is contained in:
Patrick von Platen
2022-07-03 18:43:43 +02:00
committed by GitHub
parent a7b0047e0f
commit 44705a648b
2 changed files with 7 additions and 69 deletions

View File

@@ -222,9 +222,9 @@ class ResnetBlock(nn.Module):
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if time_embedding_norm == "default":
if time_embedding_norm == "default" and temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
elif time_embedding_norm == "scale_shift":
elif time_embedding_norm == "scale_shift" and temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
self.norm2 = Normalize(out_channels, num_groups=groups_out, eps=eps)
@@ -427,7 +427,10 @@ class ResnetBlock(nn.Module):
h = self.nonlinearity(h)
h = h * mask
temb = self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
if temb is not None:
temb = self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
else:
temb = 0
if self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)

View File

@@ -1,5 +1,3 @@
import math
import numpy as np
import torch
import torch.nn as nn
@@ -7,26 +5,7 @@ import torch.nn as nn
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention import AttentionBlock
from .resnet import Downsample, Upsample
def get_timestep_embedding(timesteps, embedding_dim):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: From Fairseq. Build sinusoidal
embeddings. This matches the implementation in tensor2tensor, but differs slightly from the description in Section
3.5 of "Attention Is All You Need".
"""
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
emb = emb.to(device=timesteps.device)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
from .resnet import Downsample, ResnetBlock, Upsample
def nonlinearity(x):
@@ -38,50 +17,6 @@ def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
class ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
else:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x, temb):
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
class Encoder(nn.Module):
def __init__(
self,