From 099d3eab4943dc50f36d9b75172f34bfa22df40c Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 1 Jul 2022 16:53:41 +0200 Subject: [PATCH] add conversion script for LatentDiffusionUncondPipeline --- scripts/conversion_ldm_uncond.py | 56 +++++++++++++++++++ .../convert_ldm_to_diffusers.py | 13 ----- 2 files changed, 56 insertions(+), 13 deletions(-) create mode 100644 scripts/conversion_ldm_uncond.py delete mode 100644 src/diffusers/pipelines/latent_diffusion_uncond/convert_ldm_to_diffusers.py diff --git a/scripts/conversion_ldm_uncond.py b/scripts/conversion_ldm_uncond.py new file mode 100644 index 0000000000..dd3fc7a9e0 --- /dev/null +++ b/scripts/conversion_ldm_uncond.py @@ -0,0 +1,56 @@ +import argparse + +import OmegaConf +import torch + +from diffusers import UNetLDMModel, VQModel, LatentDiffusionUncondPipeline, DDIMScheduler + +def convert_ldm_original(checkpoint_path, config_path, output_path): + config = OmegaConf.load(config_path) + state_dict = torch.load(checkpoint_path, map_location="cpu")["model"] + keys = list(state_dict.keys()) + + # extract state_dict for VQVAE + first_stage_dict = {} + first_stage_key = "first_stage_model." + for key in keys: + if key.startswith(first_stage_key): + first_stage_dict[key.replace(first_stage_key, "")] = state_dict[key] + + # extract state_dict for UNetLDM + unet_state_dict = {} + unet_key = "model.diffusion_model." + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = state_dict[key] + + vqvae_init_args = config.model.params.first_stage_config.params + unet_init_args = config.model.params.unet_config.params + + vqvae = VQModel(**vqvae_init_args).eval() + vqvae.load_state_dict(first_stage_dict) + + unet = UNetLDMModel(**unet_init_args).eval() + unet.load_state_dict(unet_state_dict) + + noise_scheduler = DDIMScheduler( + timesteps=config.model.params.timesteps, + beta_schedule="scaled_linear", + beta_start=config.model.params.linear_start, + beta_end=config.model.params.linear_end, + clip_sample=False, + ) + + pipeline = LatentDiffusionUncondPipeline(vqvae, unet, noise_scheduler) + pipeline.save_pretrained(output_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--checkpoint_path", type=str, required=True) + parser.add_argument("--config_path", type=str, required=True) + parser.add_argument("--output_path", type=str, required=True) + args = parser.parse_args() + + convert_ldm_original(args.checkpoint_path, args.config_path, args.output_path) + diff --git a/src/diffusers/pipelines/latent_diffusion_uncond/convert_ldm_to_diffusers.py b/src/diffusers/pipelines/latent_diffusion_uncond/convert_ldm_to_diffusers.py deleted file mode 100644 index 3c512fba9a..0000000000 --- a/src/diffusers/pipelines/latent_diffusion_uncond/convert_ldm_to_diffusers.py +++ /dev/null @@ -1,13 +0,0 @@ -import argparse - -import torch - -from diffusers import UNetLDMModel, VQModel - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--checkpoint_path", type=str, required=True) - parser.add_argument("--config_path", type=str, required=True) - parser.add_argument("--output_path", type=str, required=True) - args = parser.parse_args() -