1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

merge from master

This commit is contained in:
Patrick von Platen
2022-07-01 16:20:05 +00:00
6 changed files with 174 additions and 40 deletions

View File

@@ -0,0 +1,40 @@
import argparse
import torch
from diffusers.pipelines.bddm import DiffWave, BDDMPipeline
from diffusers import DDPMScheduler
def convert_bddm_orginal(checkpoint_path, noise_scheduler_checkpoint_path, output_path):
sd = torch.load(checkpoint_path, map_location="cpu")["model_state_dict"]
noise_scheduler_sd = torch.load(noise_scheduler_checkpoint_path, map_location="cpu")
model = DiffWave()
model.load_state_dict(sd, strict=False)
ts, _, betas, _ = noise_scheduler_sd
ts, betas = list(ts.numpy().tolist()), list(betas.numpy().tolist())
noise_scheduler = DDPMScheduler(
timesteps=12,
trained_betas=betas,
timestep_values=ts,
clip_sample=False,
tensor_format="np",
)
pipeline = BDDMPipeline(model, noise_scheduler)
pipeline.save_pretrained(output_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint_path", type=str, required=True)
parser.add_argument("--noise_scheduler_checkpoint_path", type=str, required=True)
parser.add_argument("--output_path", type=str, required=True)
args = parser.parse_args()
convert_bddm_orginal(args.checkpoint_path, args.noise_scheduler_checkpoint_path, args.output_path)

View File

@@ -0,0 +1,56 @@
import argparse
import OmegaConf
import torch
from diffusers import UNetLDMModel, VQModel, LatentDiffusionUncondPipeline, DDIMScheduler
def convert_ldm_original(checkpoint_path, config_path, output_path):
config = OmegaConf.load(config_path)
state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
keys = list(state_dict.keys())
# extract state_dict for VQVAE
first_stage_dict = {}
first_stage_key = "first_stage_model."
for key in keys:
if key.startswith(first_stage_key):
first_stage_dict[key.replace(first_stage_key, "")] = state_dict[key]
# extract state_dict for UNetLDM
unet_state_dict = {}
unet_key = "model.diffusion_model."
for key in keys:
if key.startswith(unet_key):
unet_state_dict[key.replace(unet_key, "")] = state_dict[key]
vqvae_init_args = config.model.params.first_stage_config.params
unet_init_args = config.model.params.unet_config.params
vqvae = VQModel(**vqvae_init_args).eval()
vqvae.load_state_dict(first_stage_dict)
unet = UNetLDMModel(**unet_init_args).eval()
unet.load_state_dict(unet_state_dict)
noise_scheduler = DDIMScheduler(
timesteps=config.model.params.timesteps,
beta_schedule="scaled_linear",
beta_start=config.model.params.linear_start,
beta_end=config.model.params.linear_end,
clip_sample=False,
)
pipeline = LatentDiffusionUncondPipeline(vqvae, unet, noise_scheduler)
pipeline.save_pretrained(output_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint_path", type=str, required=True)
parser.add_argument("--config_path", type=str, required=True)
parser.add_argument("--output_path", type=str, required=True)
args = parser.parse_args()
convert_ldm_original(args.checkpoint_path, args.config_path, args.output_path)

View File

@@ -1,4 +1,5 @@
from abc import abstractmethod
from functools import partial
import numpy as np
import torch
@@ -78,18 +79,25 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv=False, use_conv_transpose=False, dims=2, out_channels=None):
def __init__(self, channels, use_conv=False, use_conv_transpose=False, dims=2, out_channels=None, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
self.use_conv_transpose = use_conv_transpose
self.name = name
conv = None
if use_conv_transpose:
self.conv = conv_transpose_nd(dims, channels, self.out_channels, 4, 2, 1)
conv = conv_transpose_nd(dims, channels, self.out_channels, 4, 2, 1)
elif use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
if name == "conv":
self.conv = conv
else:
self.Conv2d_0 = conv
def forward(self, x):
assert x.shape[1] == self.channels
@@ -102,7 +110,10 @@ class Upsample(nn.Module):
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
if self.use_conv:
x = self.conv(x)
if self.name == "conv":
x = self.conv(x)
else:
x = self.Conv2d_0(x)
return x
@@ -134,6 +145,8 @@ class Downsample(nn.Module):
if name == "conv":
self.conv = conv
elif name == "Conv2d_0":
self.Conv2d_0 = conv
else:
self.op = conv
@@ -145,6 +158,8 @@ class Downsample(nn.Module):
if self.name == "conv":
return self.conv(x)
elif self.name == "Conv2d_0":
return self.Conv2d_0(x)
else:
return self.op(x)
@@ -469,6 +484,7 @@ class ResnetBlockBigGANpp(nn.Module):
up=False,
down=False,
dropout=0.1,
fir=False,
fir_kernel=(1, 3, 3, 1),
skip_rescale=True,
init_scale=0.0,
@@ -479,8 +495,20 @@ class ResnetBlockBigGANpp(nn.Module):
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
self.up = up
self.down = down
self.fir = fir
self.fir_kernel = fir_kernel
if self.up:
if self.fir:
self.upsample = partial(upsample_2d, k=self.fir_kernel, factor=2)
else:
self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
elif self.down:
if self.fir:
self.downsample = partial(downsample_2d, k=self.fir_kernel, factor=2)
else:
self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1)
if temb_dim is not None:
self.Dense_0 = nn.Linear(temb_dim, out_ch)
@@ -503,11 +531,11 @@ class ResnetBlockBigGANpp(nn.Module):
h = self.act(self.GroupNorm_0(x))
if self.up:
h = upsample_2d(h, self.fir_kernel, factor=2)
x = upsample_2d(x, self.fir_kernel, factor=2)
h = self.upsample(h)
x = self.upsample(x)
elif self.down:
h = downsample_2d(h, self.fir_kernel, factor=2)
x = downsample_2d(x, self.fir_kernel, factor=2)
h = self.downsample(h)
x = self.downsample(x)
h = self.Conv_0(h)
# Add bias to each feature map conditioned on the time embedding

View File

@@ -27,7 +27,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention import AttentionBlock
from .embeddings import GaussianFourierProjection, get_timestep_embedding
from .resnet import downsample_2d, upfirdn2d, upsample_2d
from .resnet import downsample_2d, upfirdn2d, upsample_2d, Downsample, Upsample
from .resnet import ResnetBlock
@@ -185,18 +185,19 @@ class Combine(nn.Module):
class FirUpsample(nn.Module):
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir_kernel=(1, 3, 3, 1)):
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
super().__init__()
out_ch = out_ch if out_ch else in_ch
if with_conv:
self.Conv2d_0 = Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1)
self.with_conv = with_conv
out_channels = out_channels if out_channels else channels
if use_conv:
self.Conv2d_0 = Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
self.use_conv = use_conv
self.fir_kernel = fir_kernel
self.out_ch = out_ch
self.out_channels = out_channels
def forward(self, x):
if self.with_conv:
if self.use_conv:
h = _upsample_conv_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
h = h + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else:
h = upsample_2d(x, self.fir_kernel, factor=2)
@@ -204,18 +205,19 @@ class FirUpsample(nn.Module):
class FirDownsample(nn.Module):
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir_kernel=(1, 3, 3, 1)):
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
super().__init__()
out_ch = out_ch if out_ch else in_ch
if with_conv:
self.Conv2d_0 = self.Conv2d_0 = Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1)
out_channels = out_channels if out_channels else channels
if use_conv:
self.Conv2d_0 = self.Conv2d_0 = Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
self.fir_kernel = fir_kernel
self.with_conv = with_conv
self.out_ch = out_ch
self.use_conv = use_conv
self.out_channels = out_channels
def forward(self, x):
if self.with_conv:
if self.use_conv:
x = _conv_downsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else:
x = downsample_2d(x, self.fir_kernel, factor=2)
@@ -229,13 +231,14 @@ class NCSNpp(ModelMixin, ConfigMixin):
self,
image_size=1024,
num_channels=3,
centered=False,
attn_resolutions=(16,),
ch_mult=(1, 2, 4, 8, 16, 32, 32, 32),
conditional=True,
conv_size=3,
dropout=0.0,
embedding_type="fourier",
fir=True, # TODO (patil-suraj) remove this option from here and pre-trained model configs
fir=True,
fir_kernel=(1, 3, 3, 1),
fourier_scale=16,
init_scale=0.0,
@@ -253,12 +256,14 @@ class NCSNpp(ModelMixin, ConfigMixin):
self.register_to_config(
image_size=image_size,
num_channels=num_channels,
centered=centered,
attn_resolutions=attn_resolutions,
ch_mult=ch_mult,
conditional=conditional,
conv_size=conv_size,
dropout=dropout,
embedding_type=embedding_type,
fir=fir,
fir_kernel=fir_kernel,
fourier_scale=fourier_scale,
init_scale=init_scale,
@@ -308,21 +313,26 @@ class NCSNpp(ModelMixin, ConfigMixin):
modules.append(Linear(nf * 4, nf * 4))
AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0))
Up_sample = functools.partial(FirUpsample, with_conv=resamp_with_conv, fir_kernel=fir_kernel)
if self.fir:
Up_sample = functools.partial(FirUpsample, fir_kernel=fir_kernel, use_conv=resamp_with_conv)
else:
Up_sample = functools.partial(Upsample, name="Conv2d_0")
if progressive == "output_skip":
self.pyramid_upsample = Up_sample(fir_kernel=fir_kernel, with_conv=False)
self.pyramid_upsample = Up_sample(channels=None, use_conv=False)
elif progressive == "residual":
pyramid_upsample = functools.partial(Up_sample, fir_kernel=fir_kernel, with_conv=True)
pyramid_upsample = functools.partial(Up_sample, use_conv=True)
Down_sample = functools.partial(FirDownsample, with_conv=resamp_with_conv, fir_kernel=fir_kernel)
if self.fir:
Down_sample = functools.partial(FirDownsample, fir_kernel=fir_kernel, use_conv=resamp_with_conv)
else:
Down_sample = functools.partial(Downsample, padding=0, name="Conv2d_0")
if progressive_input == "input_skip":
self.pyramid_downsample = Down_sample(fir_kernel=fir_kernel, with_conv=False)
self.pyramid_downsample = Down_sample(channels=None, use_conv=False)
elif progressive_input == "residual":
pyramid_downsample = functools.partial(Down_sample, fir_kernel=fir_kernel, with_conv=True)
# Downsampling block
pyramid_downsample = functools.partial(Down_sample, use_conv=True)
channels = num_channels
if progressive_input != "none":
@@ -376,7 +386,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
in_ch *= 2
elif progressive_input == "residual":
modules.append(pyramid_downsample(in_ch=input_pyramid_ch, out_ch=in_ch))
modules.append(pyramid_downsample(channels=input_pyramid_ch, out_channels=in_ch))
input_pyramid_ch = in_ch
hs_c.append(in_ch)
@@ -448,7 +458,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
)
pyramid_ch = channels
elif progressive == "residual":
modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch))
modules.append(pyramid_upsample(channels=pyramid_ch, out_channels=in_ch))
pyramid_ch = in_ch
else:
raise ValueError(f"{progressive} is not a valid name")
@@ -505,7 +515,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
temb = None
# If input data is in [0, 1]
x = 2 * x - 1.0
if not self.config.centered:
x = 2 * x - 1.0
# Downsampling block
input_pyramid = None

View File

@@ -63,9 +63,6 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
# 4. set current image to prev_image: x_t -> x_t-1
image = pred_prev_image + variance
# scale and decode image with vae
image = 1 / 0.18215 * image
# decode image with vae
image = self.vqvae.decode(image)
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
return image

View File

@@ -1159,7 +1159,9 @@ class PipelineTesterMixin(unittest.TestCase):
image_slice = image[0, -1, -3:, -3:].cpu()
assert image.shape == (1, 3, 256, 256)
expected_slice = torch.tensor([0.5025, 0.4121, 0.3851, 0.4806, 0.3996, 0.3745, 0.4839, 0.4559, 0.4293])
expected_slice = torch.tensor(
[-0.1202, -0.1005, -0.0635, -0.0520, -0.1282, -0.0838, -0.0981, -0.1318, -0.1106]
)
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
def test_module_from_pipeline(self):