mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
merge from master
This commit is contained in:
40
scripts/conversion_bddm.py
Normal file
40
scripts/conversion_bddm.py
Normal file
@@ -0,0 +1,40 @@
|
||||
|
||||
import argparse
|
||||
import torch
|
||||
|
||||
from diffusers.pipelines.bddm import DiffWave, BDDMPipeline
|
||||
from diffusers import DDPMScheduler
|
||||
|
||||
|
||||
def convert_bddm_orginal(checkpoint_path, noise_scheduler_checkpoint_path, output_path):
|
||||
sd = torch.load(checkpoint_path, map_location="cpu")["model_state_dict"]
|
||||
noise_scheduler_sd = torch.load(noise_scheduler_checkpoint_path, map_location="cpu")
|
||||
|
||||
model = DiffWave()
|
||||
model.load_state_dict(sd, strict=False)
|
||||
|
||||
ts, _, betas, _ = noise_scheduler_sd
|
||||
ts, betas = list(ts.numpy().tolist()), list(betas.numpy().tolist())
|
||||
|
||||
noise_scheduler = DDPMScheduler(
|
||||
timesteps=12,
|
||||
trained_betas=betas,
|
||||
timestep_values=ts,
|
||||
clip_sample=False,
|
||||
tensor_format="np",
|
||||
)
|
||||
|
||||
pipeline = BDDMPipeline(model, noise_scheduler)
|
||||
pipeline.save_pretrained(output_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--checkpoint_path", type=str, required=True)
|
||||
parser.add_argument("--noise_scheduler_checkpoint_path", type=str, required=True)
|
||||
parser.add_argument("--output_path", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_bddm_orginal(args.checkpoint_path, args.noise_scheduler_checkpoint_path, args.output_path)
|
||||
|
||||
|
||||
56
scripts/conversion_ldm_uncond.py
Normal file
56
scripts/conversion_ldm_uncond.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import argparse
|
||||
|
||||
import OmegaConf
|
||||
import torch
|
||||
|
||||
from diffusers import UNetLDMModel, VQModel, LatentDiffusionUncondPipeline, DDIMScheduler
|
||||
|
||||
def convert_ldm_original(checkpoint_path, config_path, output_path):
|
||||
config = OmegaConf.load(config_path)
|
||||
state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
|
||||
keys = list(state_dict.keys())
|
||||
|
||||
# extract state_dict for VQVAE
|
||||
first_stage_dict = {}
|
||||
first_stage_key = "first_stage_model."
|
||||
for key in keys:
|
||||
if key.startswith(first_stage_key):
|
||||
first_stage_dict[key.replace(first_stage_key, "")] = state_dict[key]
|
||||
|
||||
# extract state_dict for UNetLDM
|
||||
unet_state_dict = {}
|
||||
unet_key = "model.diffusion_model."
|
||||
for key in keys:
|
||||
if key.startswith(unet_key):
|
||||
unet_state_dict[key.replace(unet_key, "")] = state_dict[key]
|
||||
|
||||
vqvae_init_args = config.model.params.first_stage_config.params
|
||||
unet_init_args = config.model.params.unet_config.params
|
||||
|
||||
vqvae = VQModel(**vqvae_init_args).eval()
|
||||
vqvae.load_state_dict(first_stage_dict)
|
||||
|
||||
unet = UNetLDMModel(**unet_init_args).eval()
|
||||
unet.load_state_dict(unet_state_dict)
|
||||
|
||||
noise_scheduler = DDIMScheduler(
|
||||
timesteps=config.model.params.timesteps,
|
||||
beta_schedule="scaled_linear",
|
||||
beta_start=config.model.params.linear_start,
|
||||
beta_end=config.model.params.linear_end,
|
||||
clip_sample=False,
|
||||
)
|
||||
|
||||
pipeline = LatentDiffusionUncondPipeline(vqvae, unet, noise_scheduler)
|
||||
pipeline.save_pretrained(output_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--checkpoint_path", type=str, required=True)
|
||||
parser.add_argument("--config_path", type=str, required=True)
|
||||
parser.add_argument("--output_path", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_ldm_original(args.checkpoint_path, args.config_path, args.output_path)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from abc import abstractmethod
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -78,18 +79,25 @@ class Upsample(nn.Module):
|
||||
upsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv=False, use_conv_transpose=False, dims=2, out_channels=None):
|
||||
def __init__(self, channels, use_conv=False, use_conv_transpose=False, dims=2, out_channels=None, name="conv"):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
self.use_conv_transpose = use_conv_transpose
|
||||
self.name = name
|
||||
|
||||
conv = None
|
||||
if use_conv_transpose:
|
||||
self.conv = conv_transpose_nd(dims, channels, self.out_channels, 4, 2, 1)
|
||||
conv = conv_transpose_nd(dims, channels, self.out_channels, 4, 2, 1)
|
||||
elif use_conv:
|
||||
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
|
||||
conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
|
||||
|
||||
if name == "conv":
|
||||
self.conv = conv
|
||||
else:
|
||||
self.Conv2d_0 = conv
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
@@ -102,7 +110,10 @@ class Upsample(nn.Module):
|
||||
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
|
||||
if self.use_conv:
|
||||
x = self.conv(x)
|
||||
if self.name == "conv":
|
||||
x = self.conv(x)
|
||||
else:
|
||||
x = self.Conv2d_0(x)
|
||||
|
||||
return x
|
||||
|
||||
@@ -134,6 +145,8 @@ class Downsample(nn.Module):
|
||||
|
||||
if name == "conv":
|
||||
self.conv = conv
|
||||
elif name == "Conv2d_0":
|
||||
self.Conv2d_0 = conv
|
||||
else:
|
||||
self.op = conv
|
||||
|
||||
@@ -145,6 +158,8 @@ class Downsample(nn.Module):
|
||||
|
||||
if self.name == "conv":
|
||||
return self.conv(x)
|
||||
elif self.name == "Conv2d_0":
|
||||
return self.Conv2d_0(x)
|
||||
else:
|
||||
return self.op(x)
|
||||
|
||||
@@ -469,6 +484,7 @@ class ResnetBlockBigGANpp(nn.Module):
|
||||
up=False,
|
||||
down=False,
|
||||
dropout=0.1,
|
||||
fir=False,
|
||||
fir_kernel=(1, 3, 3, 1),
|
||||
skip_rescale=True,
|
||||
init_scale=0.0,
|
||||
@@ -479,8 +495,20 @@ class ResnetBlockBigGANpp(nn.Module):
|
||||
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
|
||||
self.up = up
|
||||
self.down = down
|
||||
self.fir = fir
|
||||
self.fir_kernel = fir_kernel
|
||||
|
||||
if self.up:
|
||||
if self.fir:
|
||||
self.upsample = partial(upsample_2d, k=self.fir_kernel, factor=2)
|
||||
else:
|
||||
self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
|
||||
elif self.down:
|
||||
if self.fir:
|
||||
self.downsample = partial(downsample_2d, k=self.fir_kernel, factor=2)
|
||||
else:
|
||||
self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
|
||||
|
||||
self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1)
|
||||
if temb_dim is not None:
|
||||
self.Dense_0 = nn.Linear(temb_dim, out_ch)
|
||||
@@ -503,11 +531,11 @@ class ResnetBlockBigGANpp(nn.Module):
|
||||
h = self.act(self.GroupNorm_0(x))
|
||||
|
||||
if self.up:
|
||||
h = upsample_2d(h, self.fir_kernel, factor=2)
|
||||
x = upsample_2d(x, self.fir_kernel, factor=2)
|
||||
h = self.upsample(h)
|
||||
x = self.upsample(x)
|
||||
elif self.down:
|
||||
h = downsample_2d(h, self.fir_kernel, factor=2)
|
||||
x = downsample_2d(x, self.fir_kernel, factor=2)
|
||||
h = self.downsample(h)
|
||||
x = self.downsample(x)
|
||||
|
||||
h = self.Conv_0(h)
|
||||
# Add bias to each feature map conditioned on the time embedding
|
||||
|
||||
@@ -27,7 +27,7 @@ from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .attention import AttentionBlock
|
||||
from .embeddings import GaussianFourierProjection, get_timestep_embedding
|
||||
from .resnet import downsample_2d, upfirdn2d, upsample_2d
|
||||
from .resnet import downsample_2d, upfirdn2d, upsample_2d, Downsample, Upsample
|
||||
from .resnet import ResnetBlock
|
||||
|
||||
|
||||
@@ -185,18 +185,19 @@ class Combine(nn.Module):
|
||||
|
||||
|
||||
class FirUpsample(nn.Module):
|
||||
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir_kernel=(1, 3, 3, 1)):
|
||||
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
|
||||
super().__init__()
|
||||
out_ch = out_ch if out_ch else in_ch
|
||||
if with_conv:
|
||||
self.Conv2d_0 = Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1)
|
||||
self.with_conv = with_conv
|
||||
out_channels = out_channels if out_channels else channels
|
||||
if use_conv:
|
||||
self.Conv2d_0 = Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.use_conv = use_conv
|
||||
self.fir_kernel = fir_kernel
|
||||
self.out_ch = out_ch
|
||||
self.out_channels = out_channels
|
||||
|
||||
def forward(self, x):
|
||||
if self.with_conv:
|
||||
if self.use_conv:
|
||||
h = _upsample_conv_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
|
||||
h = h + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
||||
else:
|
||||
h = upsample_2d(x, self.fir_kernel, factor=2)
|
||||
|
||||
@@ -204,18 +205,19 @@ class FirUpsample(nn.Module):
|
||||
|
||||
|
||||
class FirDownsample(nn.Module):
|
||||
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir_kernel=(1, 3, 3, 1)):
|
||||
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
|
||||
super().__init__()
|
||||
out_ch = out_ch if out_ch else in_ch
|
||||
if with_conv:
|
||||
self.Conv2d_0 = self.Conv2d_0 = Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1)
|
||||
out_channels = out_channels if out_channels else channels
|
||||
if use_conv:
|
||||
self.Conv2d_0 = self.Conv2d_0 = Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.fir_kernel = fir_kernel
|
||||
self.with_conv = with_conv
|
||||
self.out_ch = out_ch
|
||||
self.use_conv = use_conv
|
||||
self.out_channels = out_channels
|
||||
|
||||
def forward(self, x):
|
||||
if self.with_conv:
|
||||
if self.use_conv:
|
||||
x = _conv_downsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
|
||||
x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
||||
else:
|
||||
x = downsample_2d(x, self.fir_kernel, factor=2)
|
||||
|
||||
@@ -229,13 +231,14 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||
self,
|
||||
image_size=1024,
|
||||
num_channels=3,
|
||||
centered=False,
|
||||
attn_resolutions=(16,),
|
||||
ch_mult=(1, 2, 4, 8, 16, 32, 32, 32),
|
||||
conditional=True,
|
||||
conv_size=3,
|
||||
dropout=0.0,
|
||||
embedding_type="fourier",
|
||||
fir=True, # TODO (patil-suraj) remove this option from here and pre-trained model configs
|
||||
fir=True,
|
||||
fir_kernel=(1, 3, 3, 1),
|
||||
fourier_scale=16,
|
||||
init_scale=0.0,
|
||||
@@ -253,12 +256,14 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||
self.register_to_config(
|
||||
image_size=image_size,
|
||||
num_channels=num_channels,
|
||||
centered=centered,
|
||||
attn_resolutions=attn_resolutions,
|
||||
ch_mult=ch_mult,
|
||||
conditional=conditional,
|
||||
conv_size=conv_size,
|
||||
dropout=dropout,
|
||||
embedding_type=embedding_type,
|
||||
fir=fir,
|
||||
fir_kernel=fir_kernel,
|
||||
fourier_scale=fourier_scale,
|
||||
init_scale=init_scale,
|
||||
@@ -308,21 +313,26 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||
modules.append(Linear(nf * 4, nf * 4))
|
||||
|
||||
AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0))
|
||||
Up_sample = functools.partial(FirUpsample, with_conv=resamp_with_conv, fir_kernel=fir_kernel)
|
||||
|
||||
if self.fir:
|
||||
Up_sample = functools.partial(FirUpsample, fir_kernel=fir_kernel, use_conv=resamp_with_conv)
|
||||
else:
|
||||
Up_sample = functools.partial(Upsample, name="Conv2d_0")
|
||||
|
||||
if progressive == "output_skip":
|
||||
self.pyramid_upsample = Up_sample(fir_kernel=fir_kernel, with_conv=False)
|
||||
self.pyramid_upsample = Up_sample(channels=None, use_conv=False)
|
||||
elif progressive == "residual":
|
||||
pyramid_upsample = functools.partial(Up_sample, fir_kernel=fir_kernel, with_conv=True)
|
||||
pyramid_upsample = functools.partial(Up_sample, use_conv=True)
|
||||
|
||||
Down_sample = functools.partial(FirDownsample, with_conv=resamp_with_conv, fir_kernel=fir_kernel)
|
||||
if self.fir:
|
||||
Down_sample = functools.partial(FirDownsample, fir_kernel=fir_kernel, use_conv=resamp_with_conv)
|
||||
else:
|
||||
Down_sample = functools.partial(Downsample, padding=0, name="Conv2d_0")
|
||||
|
||||
if progressive_input == "input_skip":
|
||||
self.pyramid_downsample = Down_sample(fir_kernel=fir_kernel, with_conv=False)
|
||||
self.pyramid_downsample = Down_sample(channels=None, use_conv=False)
|
||||
elif progressive_input == "residual":
|
||||
pyramid_downsample = functools.partial(Down_sample, fir_kernel=fir_kernel, with_conv=True)
|
||||
|
||||
# Downsampling block
|
||||
pyramid_downsample = functools.partial(Down_sample, use_conv=True)
|
||||
|
||||
channels = num_channels
|
||||
if progressive_input != "none":
|
||||
@@ -376,7 +386,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||
in_ch *= 2
|
||||
|
||||
elif progressive_input == "residual":
|
||||
modules.append(pyramid_downsample(in_ch=input_pyramid_ch, out_ch=in_ch))
|
||||
modules.append(pyramid_downsample(channels=input_pyramid_ch, out_channels=in_ch))
|
||||
input_pyramid_ch = in_ch
|
||||
|
||||
hs_c.append(in_ch)
|
||||
@@ -448,7 +458,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||
)
|
||||
pyramid_ch = channels
|
||||
elif progressive == "residual":
|
||||
modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch))
|
||||
modules.append(pyramid_upsample(channels=pyramid_ch, out_channels=in_ch))
|
||||
pyramid_ch = in_ch
|
||||
else:
|
||||
raise ValueError(f"{progressive} is not a valid name")
|
||||
@@ -505,7 +515,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||
temb = None
|
||||
|
||||
# If input data is in [0, 1]
|
||||
x = 2 * x - 1.0
|
||||
if not self.config.centered:
|
||||
x = 2 * x - 1.0
|
||||
|
||||
# Downsampling block
|
||||
input_pyramid = None
|
||||
|
||||
@@ -63,9 +63,6 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
|
||||
# 4. set current image to prev_image: x_t -> x_t-1
|
||||
image = pred_prev_image + variance
|
||||
|
||||
# scale and decode image with vae
|
||||
image = 1 / 0.18215 * image
|
||||
# decode image with vae
|
||||
image = self.vqvae.decode(image)
|
||||
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
return image
|
||||
|
||||
@@ -1159,7 +1159,9 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
image_slice = image[0, -1, -3:, -3:].cpu()
|
||||
|
||||
assert image.shape == (1, 3, 256, 256)
|
||||
expected_slice = torch.tensor([0.5025, 0.4121, 0.3851, 0.4806, 0.3996, 0.3745, 0.4839, 0.4559, 0.4293])
|
||||
expected_slice = torch.tensor(
|
||||
[-0.1202, -0.1005, -0.0635, -0.0520, -0.1282, -0.0838, -0.0981, -0.1318, -0.1106]
|
||||
)
|
||||
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
||||
|
||||
def test_module_from_pipeline(self):
|
||||
|
||||
Reference in New Issue
Block a user