From 32b93da875e95b8033fe3c493eaaa7bbc9a14048 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 29 Jun 2022 17:10:08 +0200 Subject: [PATCH 01/10] begin conversion script --- .../convert_ldm_to_diffusers.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 src/diffusers/pipelines/latent_diffusion_uncond/convert_ldm_to_diffusers.py diff --git a/src/diffusers/pipelines/latent_diffusion_uncond/convert_ldm_to_diffusers.py b/src/diffusers/pipelines/latent_diffusion_uncond/convert_ldm_to_diffusers.py new file mode 100644 index 0000000000..3c512fba9a --- /dev/null +++ b/src/diffusers/pipelines/latent_diffusion_uncond/convert_ldm_to_diffusers.py @@ -0,0 +1,13 @@ +import argparse + +import torch + +from diffusers import UNetLDMModel, VQModel + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--checkpoint_path", type=str, required=True) + parser.add_argument("--config_path", type=str, required=True) + parser.add_argument("--output_path", type=str, required=True) + args = parser.parse_args() + From 5018abff6ef8305c43d6520244f7e8ffb4a28bc3 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 1 Jul 2022 12:01:59 +0200 Subject: [PATCH 02/10] add fir=False back --- src/diffusers/models/resnet.py | 43 +++++++++++--- .../models/unet_sde_score_estimation.py | 59 +++++++++++-------- 2 files changed, 70 insertions(+), 32 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index f48a94039e..bad14f7e2a 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -1,4 +1,5 @@ from abc import abstractmethod +from functools import partial import numpy as np import torch @@ -78,18 +79,24 @@ class Upsample(nn.Module): upsampling occurs in the inner-two dimensions. """ - def __init__(self, channels, use_conv=False, use_conv_transpose=False, dims=2, out_channels=None): + def __init__(self, channels, use_conv=False, use_conv_transpose=False, dims=2, out_channels=None, name="conv"): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.dims = dims self.use_conv_transpose = use_conv_transpose + name = self.name if use_conv_transpose: - self.conv = conv_transpose_nd(dims, channels, self.out_channels, 4, 2, 1) + conv = conv_transpose_nd(dims, channels, self.out_channels, 4, 2, 1) elif use_conv: - self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) + conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) + + if name == "conv": + self.conv = conv + else: + self.Conv2d_0 = conv def forward(self, x): assert x.shape[1] == self.channels @@ -102,7 +109,10 @@ class Upsample(nn.Module): x = F.interpolate(x, scale_factor=2.0, mode="nearest") if self.use_conv: - x = self.conv(x) + if self.name == "conv": + x = self.conv(x) + else: + x = self.Conv2d_0(x) return x @@ -134,6 +144,8 @@ class Downsample(nn.Module): if name == "conv": self.conv = conv + elif name == "Conv2d_0": + self.Conv2d_0 = conv else: self.op = conv @@ -145,6 +157,8 @@ class Downsample(nn.Module): if self.name == "conv": return self.conv(x) + elif self.name == "Conv2d_0": + return self.Conv2d_0(x) else: return self.op(x) @@ -390,6 +404,7 @@ class ResnetBlockBigGANpp(nn.Module): up=False, down=False, dropout=0.1, + fir=False, fir_kernel=(1, 3, 3, 1), skip_rescale=True, init_scale=0.0, @@ -400,8 +415,20 @@ class ResnetBlockBigGANpp(nn.Module): self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6) self.up = up self.down = down + self.fir = fir self.fir_kernel = fir_kernel + if self.up: + if self.fir: + self.upsample = partial(upsample_2d, k=self.fir_kernel, factor=2) + else: + self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest") + elif self.down: + if self.fir: + self.downsample = partial(downsample_2d, k=self.fir_kernel, factor=2) + else: + self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2) + self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1) if temb_dim is not None: self.Dense_0 = nn.Linear(temb_dim, out_ch) @@ -424,11 +451,11 @@ class ResnetBlockBigGANpp(nn.Module): h = self.act(self.GroupNorm_0(x)) if self.up: - h = upsample_2d(h, self.fir_kernel, factor=2) - x = upsample_2d(x, self.fir_kernel, factor=2) + h = self.upsample(h) + x = self.upsample(x) elif self.down: - h = downsample_2d(h, self.fir_kernel, factor=2) - x = downsample_2d(x, self.fir_kernel, factor=2) + h = self.downsample(h) + x = self.downsample(x) h = self.Conv_0(h) # Add bias to each feature map conditioned on the time embedding diff --git a/src/diffusers/models/unet_sde_score_estimation.py b/src/diffusers/models/unet_sde_score_estimation.py index 9c82e53e70..d9a4732f0b 100644 --- a/src/diffusers/models/unet_sde_score_estimation.py +++ b/src/diffusers/models/unet_sde_score_estimation.py @@ -17,6 +17,7 @@ import functools import math +from unicodedata import name import numpy as np import torch @@ -27,7 +28,7 @@ from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin from .attention import AttentionBlock from .embeddings import GaussianFourierProjection, get_timestep_embedding -from .resnet import ResnetBlockBigGANpp, downsample_2d, upfirdn2d, upsample_2d +from .resnet import Downsample, ResnetBlockBigGANpp, Upsample, downsample_2d, upfirdn2d, upsample_2d def _setup_kernel(k): @@ -184,17 +185,17 @@ class Combine(nn.Module): class FirUpsample(nn.Module): - def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir_kernel=(1, 3, 3, 1)): + def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)): super().__init__() - out_ch = out_ch if out_ch else in_ch - if with_conv: - self.Conv2d_0 = Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1) - self.with_conv = with_conv + out_channels = out_channels if out_channels else channels + if use_conv: + self.Conv2d_0 = Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1) + self.use_conv = use_conv self.fir_kernel = fir_kernel - self.out_ch = out_ch + self.out_channels = out_channels def forward(self, x): - if self.with_conv: + if self.use_conv: h = _upsample_conv_2d(x, self.Conv2d_0.weight, k=self.fir_kernel) else: h = upsample_2d(x, self.fir_kernel, factor=2) @@ -203,17 +204,17 @@ class FirUpsample(nn.Module): class FirDownsample(nn.Module): - def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir_kernel=(1, 3, 3, 1)): + def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)): super().__init__() - out_ch = out_ch if out_ch else in_ch - if with_conv: - self.Conv2d_0 = self.Conv2d_0 = Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1) + out_channels = out_channels if out_channels else channels + if use_conv: + self.Conv2d_0 = self.Conv2d_0 = Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1) self.fir_kernel = fir_kernel - self.with_conv = with_conv - self.out_ch = out_ch + self.use_conv = use_conv + self.out_channels = out_channels def forward(self, x): - if self.with_conv: + if self.use_conv: x = _conv_downsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel) else: x = downsample_2d(x, self.fir_kernel, factor=2) @@ -234,7 +235,7 @@ class NCSNpp(ModelMixin, ConfigMixin): conv_size=3, dropout=0.0, embedding_type="fourier", - fir=True, # TODO (patil-suraj) remove this option from here and pre-trained model configs + fir=True, fir_kernel=(1, 3, 3, 1), fourier_scale=16, init_scale=0.0, @@ -258,6 +259,7 @@ class NCSNpp(ModelMixin, ConfigMixin): conv_size=conv_size, dropout=dropout, embedding_type=embedding_type, + fir=fir, fir_kernel=fir_kernel, fourier_scale=fourier_scale, init_scale=init_scale, @@ -307,24 +309,33 @@ class NCSNpp(ModelMixin, ConfigMixin): modules.append(Linear(nf * 4, nf * 4)) AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0)) - Up_sample = functools.partial(FirUpsample, with_conv=resamp_with_conv, fir_kernel=fir_kernel) + + if self.fir: + Up_sample = functools.partial(FirUpsample, fir_kernel=fir_kernel) + else: + Up_sample = functools.partial(Upsample, name="Conv2d_0") if progressive == "output_skip": - self.pyramid_upsample = Up_sample(fir_kernel=fir_kernel, with_conv=False) + self.pyramid_upsample = Up_sample(channels=None, use_conv=False) elif progressive == "residual": - pyramid_upsample = functools.partial(Up_sample, fir_kernel=fir_kernel, with_conv=True) + pyramid_upsample = functools.partial(Up_sample, use_conv=True) - Down_sample = functools.partial(FirDownsample, with_conv=resamp_with_conv, fir_kernel=fir_kernel) + if self.fir: + Down_sample = functools.partial(FirDownsample, fir_kernel=fir_kernel) + else: + print("fir false") + Down_sample = functools.partial(Downsample, padding=0, name="Conv2d_0") if progressive_input == "input_skip": - self.pyramid_downsample = Down_sample(fir_kernel=fir_kernel, with_conv=False) + self.pyramid_downsample = Down_sample(channels=None, use_conv=False) elif progressive_input == "residual": - pyramid_downsample = functools.partial(Down_sample, fir_kernel=fir_kernel, with_conv=True) + pyramid_downsample = functools.partial(Down_sample, use_conv=True) ResnetBlock = functools.partial( ResnetBlockBigGANpp, act=act, dropout=dropout, + fir=fir, fir_kernel=fir_kernel, init_scale=init_scale, skip_rescale=skip_rescale, @@ -361,7 +372,7 @@ class NCSNpp(ModelMixin, ConfigMixin): in_ch *= 2 elif progressive_input == "residual": - modules.append(pyramid_downsample(in_ch=input_pyramid_ch, out_ch=in_ch)) + modules.append(pyramid_downsample(channels=input_pyramid_ch, out_channels=in_ch)) input_pyramid_ch = in_ch hs_c.append(in_ch) @@ -402,7 +413,7 @@ class NCSNpp(ModelMixin, ConfigMixin): ) pyramid_ch = channels elif progressive == "residual": - modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch)) + modules.append(pyramid_upsample(channels=pyramid_ch, out_channels=in_ch)) pyramid_ch = in_ch else: raise ValueError(f"{progressive} is not a valid name") From 0dbc4779c8bf396d48170dda52befc83288e109f Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 1 Jul 2022 12:50:34 +0200 Subject: [PATCH 03/10] add centered back --- src/diffusers/models/unet_sde_score_estimation.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/unet_sde_score_estimation.py b/src/diffusers/models/unet_sde_score_estimation.py index d9a4732f0b..1c2a2d10ff 100644 --- a/src/diffusers/models/unet_sde_score_estimation.py +++ b/src/diffusers/models/unet_sde_score_estimation.py @@ -229,6 +229,7 @@ class NCSNpp(ModelMixin, ConfigMixin): self, image_size=1024, num_channels=3, + centered=False, attn_resolutions=(16,), ch_mult=(1, 2, 4, 8, 16, 32, 32, 32), conditional=True, @@ -253,6 +254,7 @@ class NCSNpp(ModelMixin, ConfigMixin): self.register_to_config( image_size=image_size, num_channels=num_channels, + centered=centered, attn_resolutions=attn_resolutions, ch_mult=ch_mult, conditional=conditional, @@ -457,7 +459,8 @@ class NCSNpp(ModelMixin, ConfigMixin): temb = None # If input data is in [0, 1] - x = 2 * x - 1.0 + if not self.config.centered: + x = 2 * x - 1.0 # Downsampling block input_pyramid = None From db5a05742e06d99665797036f34a0e71d0b6ec87 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 1 Jul 2022 12:54:47 +0200 Subject: [PATCH 04/10] fix typo --- src/diffusers/models/resnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index bad14f7e2a..ba0f9e819b 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -86,7 +86,7 @@ class Upsample(nn.Module): self.use_conv = use_conv self.dims = dims self.use_conv_transpose = use_conv_transpose - name = self.name + self.name = self.name if use_conv_transpose: conv = conv_transpose_nd(dims, channels, self.out_channels, 4, 2, 1) From 60a981343ef5b805c5860920bd306d303cdef7b7 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 1 Jul 2022 12:55:30 +0200 Subject: [PATCH 05/10] actually fix the typo --- src/diffusers/models/resnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index ba0f9e819b..4983016cf1 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -86,7 +86,7 @@ class Upsample(nn.Module): self.use_conv = use_conv self.dims = dims self.use_conv_transpose = use_conv_transpose - self.name = self.name + self.name = name if use_conv_transpose: conv = conv_transpose_nd(dims, channels, self.out_channels, 4, 2, 1) From 516cb9e7f88564fb150d454371a0750904e302f7 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 1 Jul 2022 12:58:50 +0200 Subject: [PATCH 06/10] fix Upsample --- src/diffusers/models/resnet.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 4983016cf1..d80ecd88b0 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -88,6 +88,7 @@ class Upsample(nn.Module): self.use_conv_transpose = use_conv_transpose self.name = name + conv = None if use_conv_transpose: conv = conv_transpose_nd(dims, channels, self.out_channels, 4, 2, 1) elif use_conv: From 4c293e0e1b77bc0665463b39056d2302e27768e8 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 1 Jul 2022 13:54:33 +0200 Subject: [PATCH 07/10] fix bias when using fir up/down sample --- src/diffusers/models/unet_sde_score_estimation.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/unet_sde_score_estimation.py b/src/diffusers/models/unet_sde_score_estimation.py index 1c2a2d10ff..7e368b8763 100644 --- a/src/diffusers/models/unet_sde_score_estimation.py +++ b/src/diffusers/models/unet_sde_score_estimation.py @@ -17,7 +17,6 @@ import functools import math -from unicodedata import name import numpy as np import torch @@ -197,6 +196,7 @@ class FirUpsample(nn.Module): def forward(self, x): if self.use_conv: h = _upsample_conv_2d(x, self.Conv2d_0.weight, k=self.fir_kernel) + h = h + self.Conv2d_0.bias.reshape(1, -1, 1, 1) else: h = upsample_2d(x, self.fir_kernel, factor=2) @@ -216,6 +216,7 @@ class FirDownsample(nn.Module): def forward(self, x): if self.use_conv: x = _conv_downsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel) + x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1) else: x = downsample_2d(x, self.fir_kernel, factor=2) @@ -313,7 +314,7 @@ class NCSNpp(ModelMixin, ConfigMixin): AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0)) if self.fir: - Up_sample = functools.partial(FirUpsample, fir_kernel=fir_kernel) + Up_sample = functools.partial(FirUpsample, fir_kernel=fir_kernel, use_conv=resamp_with_conv) else: Up_sample = functools.partial(Upsample, name="Conv2d_0") @@ -323,9 +324,8 @@ class NCSNpp(ModelMixin, ConfigMixin): pyramid_upsample = functools.partial(Up_sample, use_conv=True) if self.fir: - Down_sample = functools.partial(FirDownsample, fir_kernel=fir_kernel) + Down_sample = functools.partial(FirDownsample, fir_kernel=fir_kernel, use_conv=resamp_with_conv) else: - print("fir false") Down_sample = functools.partial(Downsample, padding=0, name="Conv2d_0") if progressive_input == "input_skip": From 099d3eab4943dc50f36d9b75172f34bfa22df40c Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 1 Jul 2022 16:53:41 +0200 Subject: [PATCH 08/10] add conversion script for LatentDiffusionUncondPipeline --- scripts/conversion_ldm_uncond.py | 56 +++++++++++++++++++ .../convert_ldm_to_diffusers.py | 13 ----- 2 files changed, 56 insertions(+), 13 deletions(-) create mode 100644 scripts/conversion_ldm_uncond.py delete mode 100644 src/diffusers/pipelines/latent_diffusion_uncond/convert_ldm_to_diffusers.py diff --git a/scripts/conversion_ldm_uncond.py b/scripts/conversion_ldm_uncond.py new file mode 100644 index 0000000000..dd3fc7a9e0 --- /dev/null +++ b/scripts/conversion_ldm_uncond.py @@ -0,0 +1,56 @@ +import argparse + +import OmegaConf +import torch + +from diffusers import UNetLDMModel, VQModel, LatentDiffusionUncondPipeline, DDIMScheduler + +def convert_ldm_original(checkpoint_path, config_path, output_path): + config = OmegaConf.load(config_path) + state_dict = torch.load(checkpoint_path, map_location="cpu")["model"] + keys = list(state_dict.keys()) + + # extract state_dict for VQVAE + first_stage_dict = {} + first_stage_key = "first_stage_model." + for key in keys: + if key.startswith(first_stage_key): + first_stage_dict[key.replace(first_stage_key, "")] = state_dict[key] + + # extract state_dict for UNetLDM + unet_state_dict = {} + unet_key = "model.diffusion_model." + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = state_dict[key] + + vqvae_init_args = config.model.params.first_stage_config.params + unet_init_args = config.model.params.unet_config.params + + vqvae = VQModel(**vqvae_init_args).eval() + vqvae.load_state_dict(first_stage_dict) + + unet = UNetLDMModel(**unet_init_args).eval() + unet.load_state_dict(unet_state_dict) + + noise_scheduler = DDIMScheduler( + timesteps=config.model.params.timesteps, + beta_schedule="scaled_linear", + beta_start=config.model.params.linear_start, + beta_end=config.model.params.linear_end, + clip_sample=False, + ) + + pipeline = LatentDiffusionUncondPipeline(vqvae, unet, noise_scheduler) + pipeline.save_pretrained(output_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--checkpoint_path", type=str, required=True) + parser.add_argument("--config_path", type=str, required=True) + parser.add_argument("--output_path", type=str, required=True) + args = parser.parse_args() + + convert_ldm_original(args.checkpoint_path, args.config_path, args.output_path) + diff --git a/src/diffusers/pipelines/latent_diffusion_uncond/convert_ldm_to_diffusers.py b/src/diffusers/pipelines/latent_diffusion_uncond/convert_ldm_to_diffusers.py deleted file mode 100644 index 3c512fba9a..0000000000 --- a/src/diffusers/pipelines/latent_diffusion_uncond/convert_ldm_to_diffusers.py +++ /dev/null @@ -1,13 +0,0 @@ -import argparse - -import torch - -from diffusers import UNetLDMModel, VQModel - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--checkpoint_path", type=str, required=True) - parser.add_argument("--config_path", type=str, required=True) - parser.add_argument("--output_path", type=str, required=True) - args = parser.parse_args() - From f26d3011c77e6df56ab244a41316bd8c8bc1cc30 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 1 Jul 2022 17:19:26 +0200 Subject: [PATCH 09/10] fix ldm uncond pipeline --- .../pipeline_latent_diffusion_uncond.py | 5 +---- tests/test_modeling_utils.py | 4 +++- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py index 873683b06d..f930e37709 100644 --- a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +++ b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py @@ -63,9 +63,6 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline): # 4. set current image to prev_image: x_t -> x_t-1 image = pred_prev_image + variance - # scale and decode image with vae - image = 1 / 0.18215 * image + # decode image with vae image = self.vqvae.decode(image) - image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) - return image diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 94f88a6a04..420aea2ac3 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -1159,7 +1159,9 @@ class PipelineTesterMixin(unittest.TestCase): image_slice = image[0, -1, -3:, -3:].cpu() assert image.shape == (1, 3, 256, 256) - expected_slice = torch.tensor([0.5025, 0.4121, 0.3851, 0.4806, 0.3996, 0.3745, 0.4839, 0.4559, 0.4293]) + expected_slice = torch.tensor( + [-0.1202, -0.1005, -0.0635, -0.0520, -0.1282, -0.0838, -0.0981, -0.1318, -0.1106] + ) assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 def test_module_from_pipeline(self): From ab946575b1050e67e2e6b4fdda237faa2dc342f5 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 1 Jul 2022 17:44:38 +0200 Subject: [PATCH 10/10] add conversion script for BDDMPipeline --- scripts/conversion_bddm.py | 40 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 scripts/conversion_bddm.py diff --git a/scripts/conversion_bddm.py b/scripts/conversion_bddm.py new file mode 100644 index 0000000000..205ff08e98 --- /dev/null +++ b/scripts/conversion_bddm.py @@ -0,0 +1,40 @@ + +import argparse +import torch + +from diffusers.pipelines.bddm import DiffWave, BDDMPipeline +from diffusers import DDPMScheduler + + +def convert_bddm_orginal(checkpoint_path, noise_scheduler_checkpoint_path, output_path): + sd = torch.load(checkpoint_path, map_location="cpu")["model_state_dict"] + noise_scheduler_sd = torch.load(noise_scheduler_checkpoint_path, map_location="cpu") + + model = DiffWave() + model.load_state_dict(sd, strict=False) + + ts, _, betas, _ = noise_scheduler_sd + ts, betas = list(ts.numpy().tolist()), list(betas.numpy().tolist()) + + noise_scheduler = DDPMScheduler( + timesteps=12, + trained_betas=betas, + timestep_values=ts, + clip_sample=False, + tensor_format="np", + ) + + pipeline = BDDMPipeline(model, noise_scheduler) + pipeline.save_pretrained(output_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--checkpoint_path", type=str, required=True) + parser.add_argument("--noise_scheduler_checkpoint_path", type=str, required=True) + parser.add_argument("--output_path", type=str, required=True) + args = parser.parse_args() + + convert_bddm_orginal(args.checkpoint_path, args.noise_scheduler_checkpoint_path, args.output_path) + +