From 48269070d23ad8a4c6f31bc6847c358aac182ad1 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 22 Jun 2022 13:40:08 +0000 Subject: [PATCH 01/35] more fixes --- README.md | 8 ++-- src/diffusers/__init__.py | 6 +-- src/diffusers/pipeline_utils.py | 26 ++-------- src/diffusers/pipelines/README.md | 2 +- src/diffusers/pipelines/__init__.py | 14 +++--- src/diffusers/pipelines/pipeline_bddm.py | 2 +- src/diffusers/pipelines/pipeline_ddim.py | 2 +- src/diffusers/pipelines/pipeline_ddpm.py | 2 +- src/diffusers/pipelines/pipeline_glide.py | 2 +- src/diffusers/pipelines/pipeline_grad_tts.py | 13 +++-- .../pipelines/pipeline_latent_diffusion.py | 2 +- src/diffusers/pipelines/pipeline_pndm.py | 2 +- tests/test_modeling_utils.py | 47 ++++++++++--------- 13 files changed, 59 insertions(+), 69 deletions(-) diff --git a/README.md b/README.md index f6889baf92..32dc7c8229 100644 --- a/README.md +++ b/README.md @@ -249,24 +249,24 @@ image_pil = PIL.Image.fromarray(image_processed[0]) image_pil.save("test.png") ``` -#### **Text to speech with GradTTS and BDDM** +#### **Text to speech with GradTTS and BDDMPipeline** ```python import torch -from diffusers import BDDM, GradTTS +from diffusers import BDDMPipeline, GradTTS torch_device = "cuda" # load grad tts and bddm pipelines grad_tts = GradTTS.from_pretrained("fusing/grad-tts-libri-tts") -bddm = BDDM.from_pretrained("fusing/diffwave-vocoder-ljspeech") +bddm = BDDMPipeline.from_pretrained("fusing/diffwave-vocoder-ljspeech") text = "Hello world, I missed you so much." # generate mel spectograms using text mel_spec = grad_tts(text, torch_device=torch_device) -# generate the speech by passing mel spectograms to BDDM pipeline +# generate the speech by passing mel spectograms to BDDMPipeline pipeline generator = torch.manual_seed(42) audio = bddm(mel_spec, generator, torch_device=torch_device) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index efb89e8597..aaca3d347b 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -11,19 +11,19 @@ from .models.unet import UNetModel from .models.unet_ldm import UNetLDMModel from .models.unet_rl import TemporalUNet from .pipeline_utils import DiffusionPipeline -from .pipelines import BDDM, DDIM, DDPM, PNDM +from .pipelines import BDDMPipeline, DDIMPipeline, DDPMPipeline, PNDMPipeline from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin if is_transformers_available(): from .models.unet_glide import GlideSuperResUNetModel, GlideTextToImageUNetModel, GlideUNetModel from .models.unet_grad_tts import UNetGradTTSModel - from .pipelines import Glide, LatentDiffusion + from .pipelines import GlidePipeline, LatentDiffusionPipeline else: from .utils.dummy_transformers_objects import * if is_transformers_available() and is_inflect_available() and is_unidecode_available(): - from .pipelines import GradTTS + from .pipelines import GradTTSPipeline else: from .utils.dummy_transformers_and_inflect_and_unidecode_objects import * diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index d8a2644dc9..339ebb074a 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -21,7 +21,6 @@ from typing import Optional, Union from huggingface_hub import snapshot_download from .configuration_utils import ConfigMixin -from .dynamic_modules_utils import get_class_from_dynamic_module from .utils import DIFFUSERS_CACHE, logging @@ -81,9 +80,6 @@ class DiffusionPipeline(ConfigMixin): # set models setattr(self, name, module) - register_dict = {"_module": self.__module__.split(".")[-1]} - self.register_to_config(**register_dict) - def save_pretrained(self, save_directory: Union[str, os.PathLike]): self.save_config(save_directory) @@ -139,11 +135,7 @@ class DiffusionPipeline(ConfigMixin): config_dict = cls.get_config_dict(cached_folder) - # 2. Get class name and module candidates to load custom models - module_candidate_name = config_dict["_module"] - module_candidate = module_candidate_name + ".py" - - # 3. Load the pipeline class, if using custom module then load it from the hub + # 2. Load the pipeline class, if using custom module then load it from the hub # if we load from explicit class, let's use it if cls != DiffusionPipeline: pipeline_class = cls @@ -151,11 +143,6 @@ class DiffusionPipeline(ConfigMixin): diffusers_module = importlib.import_module(cls.__module__.split(".")[0]) pipeline_class = getattr(diffusers_module, config_dict["_class_name"]) - # (TODO - we should allow to load custom pipelines - # else we need to load the correct module from the Hub - # module = module_candidate - # pipeline_class = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder) - init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) init_kwargs = {} @@ -163,7 +150,7 @@ class DiffusionPipeline(ConfigMixin): # import it here to avoid circular import from diffusers import pipelines - # 4. Load each module in the pipeline + # 3. Load each module in the pipeline for name, (library_name, class_name) in init_dict.items(): is_pipeline_module = hasattr(pipelines, library_name) # if the model is in a pipeline module, then we load it from the pipeline @@ -171,14 +158,7 @@ class DiffusionPipeline(ConfigMixin): pipeline_module = getattr(pipelines, library_name) class_obj = getattr(pipeline_module, class_name) importable_classes = ALL_IMPORTABLE_CLASSES - class_candidates = {c: class_obj for c in ALL_IMPORTABLE_CLASSES.keys()} - elif library_name == module_candidate_name: - # if the model is not in diffusers or transformers, we need to load it from the hub - # assumes that it's a subclass of ModelMixin - class_obj = get_class_from_dynamic_module(cached_folder, module_candidate, class_name, cached_folder) - # since it's not from a library, we need to check class candidates for all importable classes - importable_classes = ALL_IMPORTABLE_CLASSES - class_candidates = {c: class_obj for c in ALL_IMPORTABLE_CLASSES.keys()} + class_candidates = {c: class_obj for c in importable_classes.keys()} else: # else we just import it from the library. library = importlib.import_module(library_name) diff --git a/src/diffusers/pipelines/README.md b/src/diffusers/pipelines/README.md index 61e653a80f..c0558d35b9 100644 --- a/src/diffusers/pipelines/README.md +++ b/src/diffusers/pipelines/README.md @@ -15,5 +15,5 @@ TODO(Patrick, Anton, Suraj) - PNDM for unconditional image generation in [pipeline_pndm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py). - Latent diffusion for text to image generation / conditional image generation in [pipeline_latent_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_latent_diffusion.py). - Glide for text to image generation / conditional image generation in [pipeline_glide](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_glide.py). -- BDDM for spectrogram-to-sound vocoding in [pipeline_bddm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_bddm.py). +- BDDMPipeline for spectrogram-to-sound vocoding in [pipeline_bddm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_bddm.py). - Grad-TTS for text to audio generation / conditional audio generation in [pipeline_grad_tts](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_grad_tts.py). diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 7ba126b03b..d26c5fc8a7 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -1,14 +1,14 @@ from ..utils import is_inflect_available, is_transformers_available, is_unidecode_available -from .pipeline_bddm import BDDM -from .pipeline_ddim import DDIM -from .pipeline_ddpm import DDPM -from .pipeline_pndm import PNDM +from .pipeline_bddm import BDDMPipeline +from .pipeline_ddim import DDIMPipeline +from .pipeline_ddpm import DDPMPipeline +from .pipeline_pndm import PNDMPipeline if is_transformers_available(): - from .pipeline_glide import Glide - from .pipeline_latent_diffusion import LatentDiffusion + from .pipeline_glide import GlidePipeline + from .pipeline_latent_diffusion import LatentDiffusionPipeline if is_transformers_available() and is_unidecode_available() and is_inflect_available(): - from .pipeline_grad_tts import GradTTS + from .pipeline_grad_tts import GradTTSPipeline diff --git a/src/diffusers/pipelines/pipeline_bddm.py b/src/diffusers/pipelines/pipeline_bddm.py index 3ca79c3dee..8b24cb9ceb 100644 --- a/src/diffusers/pipelines/pipeline_bddm.py +++ b/src/diffusers/pipelines/pipeline_bddm.py @@ -271,7 +271,7 @@ class DiffWave(ModelMixin, ConfigMixin): return self.final_conv(x) -class BDDM(DiffusionPipeline): +class BDDMPipeline(DiffusionPipeline): def __init__(self, diffwave, noise_scheduler): super().__init__() noise_scheduler = noise_scheduler.set_format("pt") diff --git a/src/diffusers/pipelines/pipeline_ddim.py b/src/diffusers/pipelines/pipeline_ddim.py index 272d3edb6b..8da24dbf8f 100644 --- a/src/diffusers/pipelines/pipeline_ddim.py +++ b/src/diffusers/pipelines/pipeline_ddim.py @@ -21,7 +21,7 @@ import tqdm from ..pipeline_utils import DiffusionPipeline -class DDIM(DiffusionPipeline): +class DDIMPipeline(DiffusionPipeline): def __init__(self, unet, noise_scheduler): super().__init__() noise_scheduler = noise_scheduler.set_format("pt") diff --git a/src/diffusers/pipelines/pipeline_ddpm.py b/src/diffusers/pipelines/pipeline_ddpm.py index ebcce77337..9cf83bfb75 100644 --- a/src/diffusers/pipelines/pipeline_ddpm.py +++ b/src/diffusers/pipelines/pipeline_ddpm.py @@ -21,7 +21,7 @@ import tqdm from ..pipeline_utils import DiffusionPipeline -class DDPM(DiffusionPipeline): +class DDPMPipeline(DiffusionPipeline): def __init__(self, unet, noise_scheduler): super().__init__() noise_scheduler = noise_scheduler.set_format("pt") diff --git a/src/diffusers/pipelines/pipeline_glide.py b/src/diffusers/pipelines/pipeline_glide.py index 0046055349..8680b7542a 100644 --- a/src/diffusers/pipelines/pipeline_glide.py +++ b/src/diffusers/pipelines/pipeline_glide.py @@ -711,7 +711,7 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape): return res + torch.zeros(broadcast_shape, device=timesteps.device) -class Glide(DiffusionPipeline): +class GlidePipeline(DiffusionPipeline): def __init__( self, text_unet: GlideTextToImageUNetModel, diff --git a/src/diffusers/pipelines/pipeline_grad_tts.py b/src/diffusers/pipelines/pipeline_grad_tts.py index 4201124923..51c861a262 100644 --- a/src/diffusers/pipelines/pipeline_grad_tts.py +++ b/src/diffusers/pipelines/pipeline_grad_tts.py @@ -420,7 +420,7 @@ class TextEncoder(ModelMixin, ConfigMixin): return mu, logw, x_mask -class GradTTS(DiffusionPipeline): +class GradTTSPipeline(DiffusionPipeline): def __init__(self, unet, text_encoder, noise_scheduler, tokenizer): super().__init__() noise_scheduler = noise_scheduler.set_format("pt") @@ -430,7 +430,14 @@ class GradTTS(DiffusionPipeline): @torch.no_grad() def __call__( - self, text, num_inference_steps=50, temperature=1.3, length_scale=0.91, speaker_id=15, torch_device=None + self, + text, + num_inference_steps=50, + temperature=1.3, + length_scale=0.91, + speaker_id=15, + torch_device=None, + generator=None, ): if torch_device is None: torch_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -464,7 +471,7 @@ class GradTTS(DiffusionPipeline): mu_y = mu_y.transpose(1, 2) # Sample latent representation from terminal distribution N(mu_y, I) - z = mu_y + torch.randn_like(mu_y, device=mu_y.device) / temperature + z = mu_y + torch.randn(mu_y.shape, device=mu_y.device, generator=generator) / temperature xt = z * y_mask h = 1.0 / num_inference_steps diff --git a/src/diffusers/pipelines/pipeline_latent_diffusion.py b/src/diffusers/pipelines/pipeline_latent_diffusion.py index cd7f653bf4..7d386765d4 100644 --- a/src/diffusers/pipelines/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/pipeline_latent_diffusion.py @@ -860,7 +860,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin): return dec, posterior -class LatentDiffusion(DiffusionPipeline): +class LatentDiffusionPipeline(DiffusionPipeline): def __init__(self, vqvae, bert, tokenizer, unet, noise_scheduler): super().__init__() noise_scheduler = noise_scheduler.set_format("pt") diff --git a/src/diffusers/pipelines/pipeline_pndm.py b/src/diffusers/pipelines/pipeline_pndm.py index a19f933ed1..5fd8a98483 100644 --- a/src/diffusers/pipelines/pipeline_pndm.py +++ b/src/diffusers/pipelines/pipeline_pndm.py @@ -21,7 +21,7 @@ import tqdm from ..pipeline_utils import DiffusionPipeline -class PNDM(DiffusionPipeline): +class PNDMPipeline(DiffusionPipeline): def __init__(self, unet, noise_scheduler): super().__init__() noise_scheduler = noise_scheduler.set_format("pt") diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 372435de9d..720e68741f 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -22,17 +22,17 @@ import numpy as np import torch from diffusers import ( - BDDM, - DDIM, - DDPM, - Glide, - PNDM, + BDDMPipeline, + DDIMPipeline, DDIMScheduler, + DDPMPipeline, DDPMScheduler, + GlidePipeline, GlideSuperResUNetModel, GlideTextToImageUNetModel, - GradTTS, - LatentDiffusion, + GradTTSPipeline, + LatentDiffusionPipeline, + PNDMPipeline, PNDMScheduler, UNetGradTTSModel, UNetLDMModel, @@ -583,11 +583,11 @@ class PipelineTesterMixin(unittest.TestCase): model = UNetModel(ch=32, ch_mult=(1, 2), num_res_blocks=2, attn_resolutions=(16,), resolution=32) schedular = DDPMScheduler(timesteps=10) - ddpm = DDPM(model, schedular) + ddpm = DDPMPipeline(model, schedular) with tempfile.TemporaryDirectory() as tmpdirname: ddpm.save_pretrained(tmpdirname) - new_ddpm = DDPM.from_pretrained(tmpdirname) + new_ddpm = DDPMPipeline.from_pretrained(tmpdirname) generator = torch.manual_seed(0) @@ -601,7 +601,7 @@ class PipelineTesterMixin(unittest.TestCase): def test_from_pretrained_hub(self): model_path = "fusing/ddpm-cifar10" - ddpm = DDPM.from_pretrained(model_path) + ddpm = DDPMPipeline.from_pretrained(model_path) ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path) ddpm.noise_scheduler.num_timesteps = 10 @@ -624,7 +624,7 @@ class PipelineTesterMixin(unittest.TestCase): noise_scheduler = DDPMScheduler.from_config(model_id) noise_scheduler = noise_scheduler.set_format("pt") - ddpm = DDPM(unet=unet, noise_scheduler=noise_scheduler) + ddpm = DDPMPipeline(unet=unet, noise_scheduler=noise_scheduler) image = ddpm(generator=generator) image_slice = image[0, -1, -3:, -3:].cpu() @@ -641,7 +641,7 @@ class PipelineTesterMixin(unittest.TestCase): unet = UNetModel.from_pretrained(model_id) noise_scheduler = DDIMScheduler(tensor_format="pt") - ddim = DDIM(unet=unet, noise_scheduler=noise_scheduler) + ddim = DDIMPipeline(unet=unet, noise_scheduler=noise_scheduler) image = ddim(generator=generator, eta=0.0) image_slice = image[0, -1, -3:, -3:].cpu() @@ -660,7 +660,7 @@ class PipelineTesterMixin(unittest.TestCase): unet = UNetModel.from_pretrained(model_id) noise_scheduler = PNDMScheduler(tensor_format="pt") - pndm = PNDM(unet=unet, noise_scheduler=noise_scheduler) + pndm = PNDMPipeline(unet=unet, noise_scheduler=noise_scheduler) image = pndm(generator=generator) image_slice = image[0, -1, -3:, -3:].cpu() @@ -674,7 +674,7 @@ class PipelineTesterMixin(unittest.TestCase): @slow def test_ldm_text2img(self): model_id = "fusing/latent-diffusion-text2im-large" - ldm = LatentDiffusion.from_pretrained(model_id) + ldm = LatentDiffusionPipeline.from_pretrained(model_id) prompt = "A painting of a squirrel eating a burger" generator = torch.manual_seed(0) @@ -689,7 +689,7 @@ class PipelineTesterMixin(unittest.TestCase): @slow def test_glide_text2img(self): model_id = "fusing/glide-base" - glide = Glide.from_pretrained(model_id) + glide = GlidePipeline.from_pretrained(model_id) prompt = "a pencil sketch of a corgi" generator = torch.manual_seed(0) @@ -704,22 +704,25 @@ class PipelineTesterMixin(unittest.TestCase): @slow def test_grad_tts(self): model_id = "fusing/grad-tts-libri-tts" - grad_tts = GradTTS.from_pretrained(model_id) + grad_tts = GradTTSPipeline.from_pretrained(model_id) text = "Hello world, I missed you so much." + generator = torch.manual_seed(0) # generate mel spectograms using text - mel_spec = grad_tts(text) + mel_spec = grad_tts(text, generator=generator) - assert mel_spec.shape == (1, 256, 256, 3) - expected_slice = torch.tensor([0.7119, 0.7073, 0.6460, 0.7780, 0.7423, 0.6926, 0.7378, 0.7189, 0.7784]) - assert (mel_spec.flatten() - expected_slice).abs().max() < 1e-2 + assert mel_spec.shape == (1, 80, 143) + expected_slice = torch.tensor( + [-6.6119, -6.5963, -6.2776, -6.7496, -6.7096, -6.5131, -6.4643, -6.4817, -6.7185] + ) + assert (mel_spec[0, :3, :3].flatten() - expected_slice).abs().max() < 1e-2 def test_module_from_pipeline(self): model = DiffWave(num_res_layers=4) noise_scheduler = DDPMScheduler(timesteps=12) - bddm = BDDM(model, noise_scheduler) + bddm = BDDMPipeline(model, noise_scheduler) # check if the library name for the diffwave moduel is set to pipeline module self.assertTrue(bddm.config["diffwave"][0] == "pipeline_bddm") @@ -727,6 +730,6 @@ class PipelineTesterMixin(unittest.TestCase): # check if we can save and load the pipeline with tempfile.TemporaryDirectory() as tmpdirname: bddm.save_pretrained(tmpdirname) - _ = BDDM.from_pretrained(tmpdirname) + _ = BDDMPipeline.from_pretrained(tmpdirname) # check if the same works using the DifusionPipeline class _ = DiffusionPipeline.from_pretrained(tmpdirname) From 40e28e8bf4165c1167148fb825affd57c53b00ea Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 22 Jun 2022 13:42:09 +0000 Subject: [PATCH 02/35] only remove module if necessary --- src/diffusers/pipeline_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 339ebb074a..d73b8d8fb3 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -86,7 +86,7 @@ class DiffusionPipeline(ConfigMixin): model_index_dict = dict(self.config) model_index_dict.pop("_class_name") model_index_dict.pop("_diffusers_version") - model_index_dict.pop("_module") + model_index_dict.pop("_module", None) for pipeline_component_name in model_index_dict.keys(): sub_model = getattr(self, pipeline_component_name) From 3a17775454b80d2b0bceb0de7ac6b444ff288c75 Mon Sep 17 00:00:00 2001 From: Anton Lozhkov Date: Wed, 22 Jun 2022 17:26:07 +0200 Subject: [PATCH 03/35] TODO: Add FID and KID metrics --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 32dc7c8229..7cb20b0e0e 100644 --- a/README.md +++ b/README.md @@ -288,3 +288,4 @@ wavwrite("generated_audio.wav", sampling_rate, audio.squeeze().cpu().numpy()) - [ ] Add more vision models - [ ] Add more speech models - [ ] Add RL model +- [ ] Add FID and KID metrics From 6e456b7a7afa72543cad6503c91d31c6cb793a3a Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 22 Jun 2022 18:38:32 +0200 Subject: [PATCH 04/35] Update README.md --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 7cb20b0e0e..6c2c9799c2 100644 --- a/README.md +++ b/README.md @@ -253,12 +253,12 @@ image_pil.save("test.png") ```python import torch -from diffusers import BDDMPipeline, GradTTS +from diffusers import BDDMPipeline, GradTTSPipeline torch_device = "cuda" # load grad tts and bddm pipelines -grad_tts = GradTTS.from_pretrained("fusing/grad-tts-libri-tts") +grad_tts = GradTTSPipeline.from_pretrained("fusing/grad-tts-libri-tts") bddm = BDDMPipeline.from_pretrained("fusing/diffwave-vocoder-ljspeech") text = "Hello world, I missed you so much." From 0244e2af4ca354b4fed5a7258cebdd3f28a606c0 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 22 Jun 2022 18:41:14 +0200 Subject: [PATCH 05/35] correct diffusion test --- src/diffusers/pipelines/pipeline_grad_tts.py | 2 +- tests/test_modeling_utils.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_grad_tts.py b/src/diffusers/pipelines/pipeline_grad_tts.py index 51c861a262..3ad6bc8146 100644 --- a/src/diffusers/pipelines/pipeline_grad_tts.py +++ b/src/diffusers/pipelines/pipeline_grad_tts.py @@ -471,7 +471,7 @@ class GradTTSPipeline(DiffusionPipeline): mu_y = mu_y.transpose(1, 2) # Sample latent representation from terminal distribution N(mu_y, I) - z = mu_y + torch.randn(mu_y.shape, device=mu_y.device, generator=generator) / temperature + z = mu_y + torch.randn(mu_y.shape, generator=generator).to(mu_y.device) xt = z * y_mask h = 1.0 / num_inference_steps diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 720e68741f..f75bce88a9 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -714,9 +714,9 @@ class PipelineTesterMixin(unittest.TestCase): assert mel_spec.shape == (1, 80, 143) expected_slice = torch.tensor( - [-6.6119, -6.5963, -6.2776, -6.7496, -6.7096, -6.5131, -6.4643, -6.4817, -6.7185] + [-6.7584, -6.8347, -6.3293, -6.6437, -6.7233, -6.4684, -6.1187, -6.3172, -6.6890] ) - assert (mel_spec[0, :3, :3].flatten() - expected_slice).abs().max() < 1e-2 + assert (mel_spec[0, :3, :3].cpu().flatten() - expected_slice).abs().max() < 1e-2 def test_module_from_pipeline(self): model = DiffWave(num_res_layers=4) From 7b4e049eb00154219c025d20e2273f766c3bfc5f Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Wed, 22 Jun 2022 14:16:53 -0400 Subject: [PATCH 06/35] adding properties, formatting --- src/diffusers/models/unet_rl.py | 54 ++++++++++++++++++--------------- 1 file changed, 29 insertions(+), 25 deletions(-) diff --git a/src/diffusers/models/unet_rl.py b/src/diffusers/models/unet_rl.py index 55654dc62e..4fdffd33a0 100644 --- a/src/diffusers/models/unet_rl.py +++ b/src/diffusers/models/unet_rl.py @@ -5,7 +5,6 @@ import math import torch import torch.nn as nn - try: import einops from einops.layers.torch import Rearrange @@ -13,7 +12,6 @@ except: print("Einops is not installed") pass - from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin @@ -106,15 +104,22 @@ class ResidualTemporalBlock(nn.Module): class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): def __init__( - self, - horizon, - transition_dim, - cond_dim, - dim=32, - dim_mults=(1, 2, 4, 8), + self, + training_horizon, + transition_dim, + cond_dim, + predict_epsilon=False, + clip_denoised=True, + dim=32, + dim_mults=(1, 2, 4, 8), ): super().__init__() + self.transition_dim = transition_dim + self.cond_dim = cond_dim + self.predict_epsilon = predict_epsilon + self.clip_denoised = clip_denoised + dims = [transition_dim, *map(lambda m: dim * m, dim_mults)] in_out = list(zip(dims[:-1], dims[1:])) # print(f'[ models/temporal ] Channel dimensions: {in_out}') @@ -138,19 +143,19 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): self.downs.append( nn.ModuleList( [ - ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=horizon), - ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=horizon), + ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=training_horizon), + ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=training_horizon), Downsample1d(dim_out) if not is_last else nn.Identity(), ] ) ) if not is_last: - horizon = horizon // 2 + training_horizon = training_horizon // 2 mid_dim = dims[-1] - self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=horizon) - self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=horizon) + self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=training_horizon) + self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=training_horizon) for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): is_last = ind >= (num_resolutions - 1) @@ -158,15 +163,15 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): self.ups.append( nn.ModuleList( [ - ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=horizon), - ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=horizon), + ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=training_horizon), + ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=training_horizon), Upsample1d(dim_in) if not is_last else nn.Identity(), ] ) ) if not is_last: - horizon = horizon * 2 + training_horizon = training_horizon * 2 self.final_conv = nn.Sequential( Conv1dBlock(dim, dim, kernel_size=5), @@ -206,14 +211,14 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): class TemporalValue(nn.Module): def __init__( - self, - horizon, - transition_dim, - cond_dim, - dim=32, - time_dim=None, - out_dim=1, - dim_mults=(1, 2, 4, 8), + self, + horizon, + transition_dim, + cond_dim, + dim=32, + time_dim=None, + out_dim=1, + dim_mults=(1, 2, 4, 8), ): super().__init__() @@ -232,7 +237,6 @@ class TemporalValue(nn.Module): print(in_out) for dim_in, dim_out in in_out: - self.blocks.append( nn.ModuleList( [ From f941fc9917804941456efbfa5fff1024346d329d Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 22 Jun 2022 23:15:57 +0200 Subject: [PATCH 07/35] refactor tts sampler a bit --- src/diffusers/pipelines/pipeline_glide.py | 1 + src/diffusers/pipelines/pipeline_grad_tts.py | 6 ++- .../schedulers/scheduling_grad_tts.py | 37 +++++++++++-------- tests/test_modeling_utils.py | 3 ++ 4 files changed, 30 insertions(+), 17 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_glide.py b/src/diffusers/pipelines/pipeline_glide.py index 51be34efde..8680b7542a 100644 --- a/src/diffusers/pipelines/pipeline_glide.py +++ b/src/diffusers/pipelines/pipeline_glide.py @@ -694,6 +694,7 @@ class CLIPTextModel(CLIPPreTrainedModel): # END OF THE CLIP MODEL COPY-PASTE ##################### + def _extract_into_tensor(arr, timesteps, broadcast_shape): """ Extract values from a 1-D numpy array for a batch of indices. diff --git a/src/diffusers/pipelines/pipeline_grad_tts.py b/src/diffusers/pipelines/pipeline_grad_tts.py index 3ad6bc8146..743104e658 100644 --- a/src/diffusers/pipelines/pipeline_grad_tts.py +++ b/src/diffusers/pipelines/pipeline_grad_tts.py @@ -475,13 +475,15 @@ class GradTTSPipeline(DiffusionPipeline): xt = z * y_mask h = 1.0 / num_inference_steps + # (Patrick: TODO) for t in tqdm.tqdm(range(num_inference_steps), total=num_inference_steps): + t_new = num_inference_steps - t - 1 t = (1.0 - (t + 0.5) * h) * torch.ones(z.shape[0], dtype=z.dtype, device=z.device) - time = t.unsqueeze(-1).unsqueeze(-1) residual = self.unet(xt, t, mu_y, y_mask, speaker_id) - xt = self.noise_scheduler.step(xt, residual, mu_y, h, time) + scheduler_residual = residual - mu_y + xt + xt = self.noise_scheduler.step(scheduler_residual, xt, t_new, num_inference_steps) xt = xt * y_mask return xt[:, :, :y_max_length] diff --git a/src/diffusers/schedulers/scheduling_grad_tts.py b/src/diffusers/schedulers/scheduling_grad_tts.py index 94b5f2ac55..4dc6638de3 100644 --- a/src/diffusers/schedulers/scheduling_grad_tts.py +++ b/src/diffusers/schedulers/scheduling_grad_tts.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np + from ..configuration_utils import ConfigMixin from .scheduling_utils import SchedulerMixin @@ -19,29 +21,34 @@ from .scheduling_utils import SchedulerMixin class GradTTSScheduler(SchedulerMixin, ConfigMixin): def __init__( self, - timesteps=1000, - beta_start=0.0001, - beta_end=0.02, + beta_start=0.05, + beta_end=20, tensor_format="np", ): super().__init__() self.register_to_config( - timesteps=timesteps, beta_start=beta_start, beta_end=beta_end, ) self.set_format(tensor_format=tensor_format) + self.betas = None - def sample_noise(self, timestep): - noise = self.beta_start + (self.beta_end - self.beta_start) * timestep - return noise + def get_timesteps(self, num_inference_steps): + return np.array([(t + 0.5) / num_inference_steps for t in range(num_inference_steps)]) - def step(self, xt, residual, mu, h, timestep): - noise_t = self.sample_noise(timestep) - dxt = 0.5 * (mu - xt - residual) - dxt = dxt * noise_t * h - xt = xt - dxt - return xt + def set_betas(self, num_inference_steps): + timesteps = self.get_timesteps(num_inference_steps) + self.betas = np.array([self.beta_start + (self.beta_end - self.beta_start) * t for t in timesteps]) - def __len__(self): - return len(self.config.timesteps) + def step(self, residual, sample, t, num_inference_steps): + # This is a VE scheduler from https://arxiv.org/pdf/2011.13456.pdf (see Algorithm 2 in Appendix) + if self.betas is None: + self.set_betas(num_inference_steps) + + beta_t = self.betas[t] + beta_t_deriv = beta_t / num_inference_steps + + sample_deriv = residual * beta_t_deriv / 2 + + sample = sample + sample_deriv + return sample diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index f75bce88a9..db4ed6eb02 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -31,6 +31,7 @@ from diffusers import ( GlideSuperResUNetModel, GlideTextToImageUNetModel, GradTTSPipeline, + GradTTSScheduler, LatentDiffusionPipeline, PNDMPipeline, PNDMScheduler, @@ -705,6 +706,8 @@ class PipelineTesterMixin(unittest.TestCase): def test_grad_tts(self): model_id = "fusing/grad-tts-libri-tts" grad_tts = GradTTSPipeline.from_pretrained(model_id) + noise_scheduler = GradTTSScheduler() + grad_tts.noise_scheduler = noise_scheduler text = "Hello world, I missed you so much." generator = torch.manual_seed(0) From c3c1bdf8e244de5424a2fe52ed5c3befcf5366b7 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 24 Jun 2022 10:44:52 +0200 Subject: [PATCH 08/35] fixed typo in comment --- src/diffusers/schedulers/scheduling_ddpm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 331fad0f1e..5dea0b22b3 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -92,9 +92,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one - # For t > 0, compute predicted variance βt (see formala (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) + # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) # and sample from it to get previous sample - # x_{t-1} ~ N(pred_prev_sample, variance) == add variane to pred_sample + # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t] if variance_type is None: From ac796924dff7241d9b516ea27faaa7b2f12434fd Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 24 Jun 2022 15:55:26 +0000 Subject: [PATCH 09/35] add score estimation model --- src/diffusers/__init__.py | 4 +- src/diffusers/models/__init__.py | 1 + src/diffusers/models/unet_rl.py | 33 +- .../models/unet_sde_score_estimation.py | 1051 +++++++++++++++++ 4 files changed, 1070 insertions(+), 19 deletions(-) create mode 100644 src/diffusers/models/unet_sde_score_estimation.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index aaca3d347b..ba6df51070 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -7,9 +7,7 @@ from .utils import is_inflect_available, is_transformers_available, is_unidecode __version__ = "0.0.4" from .modeling_utils import ModelMixin -from .models.unet import UNetModel -from .models.unet_ldm import UNetLDMModel -from .models.unet_rl import TemporalUNet +from .models import NCSNpp, TemporalUNet, UNetLDMModel, UNetModel from .pipeline_utils import DiffusionPipeline from .pipelines import BDDMPipeline, DDIMPipeline, DDPMPipeline, PNDMPipeline from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 3f0c78b3c6..71e321e111 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -21,3 +21,4 @@ from .unet_glide import GlideSuperResUNetModel, GlideTextToImageUNetModel, Glide from .unet_grad_tts import UNetGradTTSModel from .unet_ldm import UNetLDMModel from .unet_rl import TemporalUNet +from .unet_sde_score_estimation import NCSNpp diff --git a/src/diffusers/models/unet_rl.py b/src/diffusers/models/unet_rl.py index 4fdffd33a0..28fea5753c 100644 --- a/src/diffusers/models/unet_rl.py +++ b/src/diffusers/models/unet_rl.py @@ -5,6 +5,7 @@ import math import torch import torch.nn as nn + try: import einops from einops.layers.torch import Rearrange @@ -104,14 +105,14 @@ class ResidualTemporalBlock(nn.Module): class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): def __init__( - self, - training_horizon, - transition_dim, - cond_dim, - predict_epsilon=False, - clip_denoised=True, - dim=32, - dim_mults=(1, 2, 4, 8), + self, + training_horizon, + transition_dim, + cond_dim, + predict_epsilon=False, + clip_denoised=True, + dim=32, + dim_mults=(1, 2, 4, 8), ): super().__init__() @@ -211,14 +212,14 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): class TemporalValue(nn.Module): def __init__( - self, - horizon, - transition_dim, - cond_dim, - dim=32, - time_dim=None, - out_dim=1, - dim_mults=(1, 2, 4, 8), + self, + horizon, + transition_dim, + cond_dim, + dim=32, + time_dim=None, + out_dim=1, + dim_mults=(1, 2, 4, 8), ): super().__init__() diff --git a/src/diffusers/models/unet_sde_score_estimation.py b/src/diffusers/models/unet_sde_score_estimation.py new file mode 100644 index 0000000000..26b4419ea2 --- /dev/null +++ b/src/diffusers/models/unet_sde_score_estimation.py @@ -0,0 +1,1051 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +# limitations under the License. + +# helpers functions + + +import functools +import math +import string + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + return upfirdn2d_native( + input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] + ) + + +def upfirdn2d_native( + input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 +): + _, channel, in_h, in_w = input.shape + input = input.reshape(-1, in_h, in_w, 1) + + _, in_h, in_w, minor = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, in_h, 1, in_w, 1, minor) + out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + + out = F.pad( + out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] + ) + out = out[ + :, + max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), + max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), + :, + ] + + out = out.permute(0, 3, 1, 2) + out = out.reshape( + [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] + ) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + out = out[:, ::down_y, ::down_x, :] + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + + return out.view(-1, channel, out_h, out_w) + + +# Function ported from StyleGAN2 +def get_weight(module, shape, weight_var="weight", kernel_init=None): + """Get/create weight tensor for a convolution or fully-connected layer.""" + + return module.param(weight_var, kernel_init, shape) + + +class Conv2d(nn.Module): + """Conv2d layer with optimal upsampling and downsampling (StyleGAN2).""" + + def __init__( + self, + in_ch, + out_ch, + kernel, + up=False, + down=False, + resample_kernel=(1, 3, 3, 1), + use_bias=True, + kernel_init=None, + ): + super().__init__() + assert not (up and down) + assert kernel >= 1 and kernel % 2 == 1 + self.weight = nn.Parameter(torch.zeros(out_ch, in_ch, kernel, kernel)) + if kernel_init is not None: + self.weight.data = kernel_init(self.weight.data.shape) + if use_bias: + self.bias = nn.Parameter(torch.zeros(out_ch)) + + self.up = up + self.down = down + self.resample_kernel = resample_kernel + self.kernel = kernel + self.use_bias = use_bias + + def forward(self, x): + if self.up: + x = upsample_conv_2d(x, self.weight, k=self.resample_kernel) + elif self.down: + x = conv_downsample_2d(x, self.weight, k=self.resample_kernel) + else: + x = F.conv2d(x, self.weight, stride=1, padding=self.kernel // 2) + + if self.use_bias: + x = x + self.bias.reshape(1, -1, 1, 1) + + return x + + +def naive_upsample_2d(x, factor=2): + _N, C, H, W = x.shape + x = torch.reshape(x, (-1, C, H, 1, W, 1)) + x = x.repeat(1, 1, 1, factor, 1, factor) + return torch.reshape(x, (-1, C, H * factor, W * factor)) + + +def naive_downsample_2d(x, factor=2): + _N, C, H, W = x.shape + x = torch.reshape(x, (-1, C, H // factor, factor, W // factor, factor)) + return torch.mean(x, dim=(3, 5)) + + +def upsample_conv_2d(x, w, k=None, factor=2, gain=1): + """Fused `upsample_2d()` followed by `tf.nn.conv2d()`. + + Padding is performed only once at the beginning, not between the + operations. + The fused op is considerably more efficient than performing the same + calculation + using standard TensorFlow ops. It supports gradients of arbitrary order. + Args: + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, + C]`. + w: Weight tensor of the shape `[filterH, filterW, inChannels, + outChannels]`. Grouped convolution can be performed by `inChannels = + x.shape[0] // numGroups`. + k: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to + nearest-neighbor upsampling. + factor: Integer upsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + Tensor of the shape `[N, C, H * factor, W * factor]` or + `[N, H * factor, W * factor, C]`, and same datatype as `x`. + """ + + assert isinstance(factor, int) and factor >= 1 + + # Check weight shape. + assert len(w.shape) == 4 + convH = w.shape[2] + convW = w.shape[3] + inC = w.shape[1] + + assert convW == convH + + # Setup filter kernel. + if k is None: + k = [1] * factor + k = _setup_kernel(k) * (gain * (factor**2)) + p = (k.shape[0] - factor) - (convW - 1) + + stride = (factor, factor) + + # Determine data dimensions. + stride = [1, 1, factor, factor] + output_shape = ((_shape(x, 2) - 1) * factor + convH, (_shape(x, 3) - 1) * factor + convW) + output_padding = ( + output_shape[0] - (_shape(x, 2) - 1) * stride[0] - convH, + output_shape[1] - (_shape(x, 3) - 1) * stride[1] - convW, + ) + assert output_padding[0] >= 0 and output_padding[1] >= 0 + num_groups = _shape(x, 1) // inC + + # Transpose weights. + w = torch.reshape(w, (num_groups, -1, inC, convH, convW)) + w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4) + w = torch.reshape(w, (num_groups * inC, -1, convH, convW)) + + x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0) + # Original TF code. + # x = tf.nn.conv2d_transpose( + # x, + # w, + # output_shape=output_shape, + # strides=stride, + # padding='VALID', + # data_format=data_format) + # JAX equivalent + + return upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1)) + + +def conv_downsample_2d(x, w, k=None, factor=2, gain=1): + """Fused `tf.nn.conv2d()` followed by `downsample_2d()`. + + Padding is performed only once at the beginning, not between the operations. + The fused op is considerably more efficient than performing the same + calculation + using standard TensorFlow ops. It supports gradients of arbitrary order. + Args: + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, + C]`. + w: Weight tensor of the shape `[filterH, filterW, inChannels, + outChannels]`. Grouped convolution can be performed by `inChannels = + x.shape[0] // numGroups`. + k: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to + average pooling. + factor: Integer downsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + Tensor of the shape `[N, C, H // factor, W // factor]` or + `[N, H // factor, W // factor, C]`, and same datatype as `x`. + """ + + assert isinstance(factor, int) and factor >= 1 + _outC, _inC, convH, convW = w.shape + assert convW == convH + if k is None: + k = [1] * factor + k = _setup_kernel(k) * gain + p = (k.shape[0] - factor) + (convW - 1) + s = [factor, factor] + x = upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2, p // 2)) + return F.conv2d(x, w, stride=s, padding=0) + + +def _setup_kernel(k): + k = np.asarray(k, dtype=np.float32) + if k.ndim == 1: + k = np.outer(k, k) + k /= np.sum(k) + assert k.ndim == 2 + assert k.shape[0] == k.shape[1] + return k + + +def _shape(x, dim): + return x.shape[dim] + + +def upsample_2d(x, k=None, factor=2, gain=1): + r"""Upsample a batch of 2D images with the given filter. + + Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` + and upsamples each image with the given filter. The filter is normalized so + that + if the input pixels are constant, they will be scaled by the specified + `gain`. + Pixels outside the image are assumed to be zero, and the filter is padded + with + zeros so that its shape is a multiple of the upsampling factor. + Args: + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, + C]`. + k: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to + nearest-neighbor upsampling. + factor: Integer upsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + Tensor of the shape `[N, C, H * factor, W * factor]` + """ + assert isinstance(factor, int) and factor >= 1 + if k is None: + k = [1] * factor + k = _setup_kernel(k) * (gain * (factor**2)) + p = k.shape[0] - factor + return upfirdn2d(x, torch.tensor(k, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)) + + +def downsample_2d(x, k=None, factor=2, gain=1): + r"""Downsample a batch of 2D images with the given filter. + + Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` + and downsamples each image with the given filter. The filter is normalized + so that + if the input pixels are constant, they will be scaled by the specified + `gain`. + Pixels outside the image are assumed to be zero, and the filter is padded + with + zeros so that its shape is a multiple of the downsampling factor. + Args: + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, + C]`. + k: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to + average pooling. + factor: Integer downsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + Tensor of the shape `[N, C, H // factor, W // factor]` + """ + + assert isinstance(factor, int) and factor >= 1 + if k is None: + k = [1] * factor + k = _setup_kernel(k) * gain + p = k.shape[0] - factor + return upfirdn2d(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2)) + + +def ddpm_conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1.0, padding=0): + """1x1 convolution with DDPM initialization.""" + conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias) + conv.weight.data = default_init(init_scale)(conv.weight.data.shape) + nn.init.zeros_(conv.bias) + return conv + + +def ddpm_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1.0, padding=1): + """3x3 convolution with DDPM initialization.""" + conv = nn.Conv2d( + in_planes, out_planes, kernel_size=3, stride=stride, padding=padding, dilation=dilation, bias=bias + ) + conv.weight.data = default_init(init_scale)(conv.weight.data.shape) + nn.init.zeros_(conv.bias) + return conv + + +conv1x1 = ddpm_conv1x1 +conv3x3 = ddpm_conv3x3 + + +def _einsum(a, b, c, x, y): + einsum_str = '{},{}->{}'.format(''.join(a), ''.join(b), ''.join(c)) + return torch.einsum(einsum_str, x, y) + + +def contract_inner(x, y): + """tensordot(x, y, 1).""" + x_chars = list(string.ascii_lowercase[: len(x.shape)]) + y_chars = list(string.ascii_lowercase[len(x.shape) : len(y.shape) + len(x.shape)]) + y_chars[0] = x_chars[-1] # first axis of y and last of x get summed + out_chars = x_chars[:-1] + y_chars[1:] + return _einsum(x_chars, y_chars, out_chars, x, y) + + +class NIN(nn.Module): + def __init__(self, in_dim, num_units, init_scale=0.1): + super().__init__() + self.W = nn.Parameter(default_init(scale=init_scale)((in_dim, num_units)), requires_grad=True) + self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True) + + def forward(self, x): + x = x.permute(0, 2, 3, 1) + y = contract_inner(x, self.W) + self.b + return y.permute(0, 3, 1, 2) + + +def get_act(config): + """Get activation functions from the config file.""" + + if config.model.nonlinearity.lower() == "elu": + return nn.ELU() + elif config.model.nonlinearity.lower() == "relu": + return nn.ReLU() + elif config.model.nonlinearity.lower() == "lrelu": + return nn.LeakyReLU(negative_slope=0.2) + elif config.model.nonlinearity.lower() == "swish": + return nn.SiLU() + else: + raise NotImplementedError("activation function does not exist!") + + +def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000): + assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32 + half_dim = embedding_dim // 2 + # magic number 10000 is from transformers + emb = math.log(max_positions) / (half_dim - 1) + # emb = math.log(2.) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb) + # emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :] + # emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :] + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = F.pad(emb, (0, 1), mode="constant") + assert emb.shape == (timesteps.shape[0], embedding_dim) + return emb + + +def default_init(scale=1.0): + """The same initialization used in DDPM.""" + scale = 1e-10 if scale == 0 else scale + return variance_scaling(scale, "fan_avg", "uniform") + + +def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"): + """Ported from JAX.""" + + def _compute_fans(shape, in_axis=1, out_axis=0): + receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis] + fan_in = shape[in_axis] * receptive_field_size + fan_out = shape[out_axis] * receptive_field_size + return fan_in, fan_out + + def init(shape, dtype=dtype, device=device): + fan_in, fan_out = _compute_fans(shape, in_axis, out_axis) + if mode == "fan_in": + denominator = fan_in + elif mode == "fan_out": + denominator = fan_out + elif mode == "fan_avg": + denominator = (fan_in + fan_out) / 2 + else: + raise ValueError("invalid mode for variance scaling initializer: {}".format(mode)) + variance = scale / denominator + if distribution == "normal": + return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance) + elif distribution == "uniform": + return (torch.rand(*shape, dtype=dtype, device=device) * 2.0 - 1.0) * np.sqrt(3 * variance) + else: + raise ValueError("invalid distribution for variance scaling initializer") + + return init + + +class GaussianFourierProjection(nn.Module): + """Gaussian Fourier embeddings for noise levels.""" + + def __init__(self, embedding_size=256, scale=1.0): + super().__init__() + self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) + + def forward(self, x): + x_proj = x[:, None] * self.W[None, :] * 2 * np.pi + return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) + + +class Combine(nn.Module): + """Combine information from skip connections.""" + + def __init__(self, dim1, dim2, method="cat"): + super().__init__() + self.Conv_0 = conv1x1(dim1, dim2) + self.method = method + + def forward(self, x, y): + h = self.Conv_0(x) + if self.method == "cat": + return torch.cat([h, y], dim=1) + elif self.method == "sum": + return h + y + else: + raise ValueError(f"Method {self.method} not recognized.") + + +class AttnBlockpp(nn.Module): + """Channel-wise self-attention block. Modified from DDPM.""" + + def __init__(self, channels, skip_rescale=False, init_scale=0.0): + super().__init__() + self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels, eps=1e-6) + self.NIN_0 = NIN(channels, channels) + self.NIN_1 = NIN(channels, channels) + self.NIN_2 = NIN(channels, channels) + self.NIN_3 = NIN(channels, channels, init_scale=init_scale) + self.skip_rescale = skip_rescale + + def forward(self, x): + B, C, H, W = x.shape + h = self.GroupNorm_0(x) + q = self.NIN_0(h) + k = self.NIN_1(h) + v = self.NIN_2(h) + + w = torch.einsum("bchw,bcij->bhwij", q, k) * (int(C) ** (-0.5)) + w = torch.reshape(w, (B, H, W, H * W)) + w = F.softmax(w, dim=-1) + w = torch.reshape(w, (B, H, W, H, W)) + h = torch.einsum("bhwij,bcij->bchw", w, v) + h = self.NIN_3(h) + if not self.skip_rescale: + return x + h + else: + return (x + h) / np.sqrt(2.0) + + +class Upsample(nn.Module): + def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, fir_kernel=(1, 3, 3, 1)): + super().__init__() + out_ch = out_ch if out_ch else in_ch + if not fir: + if with_conv: + self.Conv_0 = conv3x3(in_ch, out_ch) + else: + if with_conv: + self.Conv2d_0 = Conv2d( + in_ch, + out_ch, + kernel=3, + up=True, + resample_kernel=fir_kernel, + use_bias=True, + kernel_init=default_init(), + ) + self.fir = fir + self.with_conv = with_conv + self.fir_kernel = fir_kernel + self.out_ch = out_ch + + def forward(self, x): + B, C, H, W = x.shape + if not self.fir: + h = F.interpolate(x, (H * 2, W * 2), "nearest") + if self.with_conv: + h = self.Conv_0(h) + else: + if not self.with_conv: + h = upsample_2d(x, self.fir_kernel, factor=2) + else: + h = self.Conv2d_0(x) + + return h + + +class Downsample(nn.Module): + def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, fir_kernel=(1, 3, 3, 1)): + super().__init__() + out_ch = out_ch if out_ch else in_ch + if not fir: + if with_conv: + self.Conv_0 = conv3x3(in_ch, out_ch, stride=2, padding=0) + else: + if with_conv: + self.Conv2d_0 = Conv2d( + in_ch, + out_ch, + kernel=3, + down=True, + resample_kernel=fir_kernel, + use_bias=True, + kernel_init=default_init(), + ) + self.fir = fir + self.fir_kernel = fir_kernel + self.with_conv = with_conv + self.out_ch = out_ch + + def forward(self, x): + B, C, H, W = x.shape + if not self.fir: + if self.with_conv: + x = F.pad(x, (0, 1, 0, 1)) + x = self.Conv_0(x) + else: + x = F.avg_pool2d(x, 2, stride=2) + else: + if not self.with_conv: + x = downsample_2d(x, self.fir_kernel, factor=2) + else: + x = self.Conv2d_0(x) + + return x + + +class ResnetBlockDDPMpp(nn.Module): + """ResBlock adapted from DDPM.""" + + def __init__( + self, + act, + in_ch, + out_ch=None, + temb_dim=None, + conv_shortcut=False, + dropout=0.1, + skip_rescale=False, + init_scale=0.0, + ): + super().__init__() + out_ch = out_ch if out_ch else in_ch + self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6) + self.Conv_0 = conv3x3(in_ch, out_ch) + if temb_dim is not None: + self.Dense_0 = nn.Linear(temb_dim, out_ch) + self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape) + nn.init.zeros_(self.Dense_0.bias) + self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6) + self.Dropout_0 = nn.Dropout(dropout) + self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale) + if in_ch != out_ch: + if conv_shortcut: + self.Conv_2 = conv3x3(in_ch, out_ch) + else: + self.NIN_0 = NIN(in_ch, out_ch) + + self.skip_rescale = skip_rescale + self.act = act + self.out_ch = out_ch + self.conv_shortcut = conv_shortcut + + def forward(self, x, temb=None): + h = self.act(self.GroupNorm_0(x)) + h = self.Conv_0(h) + if temb is not None: + h += self.Dense_0(self.act(temb))[:, :, None, None] + h = self.act(self.GroupNorm_1(h)) + h = self.Dropout_0(h) + h = self.Conv_1(h) + if x.shape[1] != self.out_ch: + if self.conv_shortcut: + x = self.Conv_2(x) + else: + x = self.NIN_0(x) + if not self.skip_rescale: + return x + h + else: + return (x + h) / np.sqrt(2.0) + + +class ResnetBlockBigGANpp(nn.Module): + def __init__( + self, + act, + in_ch, + out_ch=None, + temb_dim=None, + up=False, + down=False, + dropout=0.1, + fir=False, + fir_kernel=(1, 3, 3, 1), + skip_rescale=True, + init_scale=0.0, + ): + super().__init__() + + out_ch = out_ch if out_ch else in_ch + self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6) + self.up = up + self.down = down + self.fir = fir + self.fir_kernel = fir_kernel + + self.Conv_0 = conv3x3(in_ch, out_ch) + if temb_dim is not None: + self.Dense_0 = nn.Linear(temb_dim, out_ch) + self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape) + nn.init.zeros_(self.Dense_0.bias) + + self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6) + self.Dropout_0 = nn.Dropout(dropout) + self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale) + if in_ch != out_ch or up or down: + self.Conv_2 = conv1x1(in_ch, out_ch) + + self.skip_rescale = skip_rescale + self.act = act + self.in_ch = in_ch + self.out_ch = out_ch + + def forward(self, x, temb=None): + h = self.act(self.GroupNorm_0(x)) + + if self.up: + if self.fir: + h = upsample_2d(h, self.fir_kernel, factor=2) + x = upsample_2d(x, self.fir_kernel, factor=2) + else: + h = naive_upsample_2d(h, factor=2) + x = naive_upsample_2d(x, factor=2) + elif self.down: + if self.fir: + h = downsample_2d(h, self.fir_kernel, factor=2) + x = downsample_2d(x, self.fir_kernel, factor=2) + else: + h = naive_downsample_2d(h, factor=2) + x = naive_downsample_2d(x, factor=2) + + h = self.Conv_0(h) + # Add bias to each feature map conditioned on the time embedding + if temb is not None: + h += self.Dense_0(self.act(temb))[:, :, None, None] + h = self.act(self.GroupNorm_1(h)) + h = self.Dropout_0(h) + h = self.Conv_1(h) + + if self.in_ch != self.out_ch or self.up or self.down: + x = self.Conv_2(x) + + if not self.skip_rescale: + return x + h + else: + return (x + h) / np.sqrt(2.0) + + +class NCSNpp(nn.Module): + """NCSN++ model""" + + def __init__(self, config): + super().__init__() + self.config = config + self.act = act = get_act(config) + # self.register_buffer('sigmas', torch.tensor(utils.get_sigmas(config))) + + self.nf = nf = config.model.nf + ch_mult = config.model.ch_mult + self.num_res_blocks = num_res_blocks = config.model.num_res_blocks + self.attn_resolutions = attn_resolutions = config.model.attn_resolutions + dropout = config.model.dropout + resamp_with_conv = config.model.resamp_with_conv + self.num_resolutions = num_resolutions = len(ch_mult) + self.all_resolutions = all_resolutions = [config.data.image_size // (2**i) for i in range(num_resolutions)] + + self.conditional = conditional = config.model.conditional # noise-conditional + fir = config.model.fir + fir_kernel = config.model.fir_kernel + self.skip_rescale = skip_rescale = config.model.skip_rescale + self.resblock_type = resblock_type = config.model.resblock_type.lower() + self.progressive = progressive = config.model.progressive.lower() + self.progressive_input = progressive_input = config.model.progressive_input.lower() + self.embedding_type = embedding_type = config.model.embedding_type.lower() + init_scale = config.model.init_scale + assert progressive in ["none", "output_skip", "residual"] + assert progressive_input in ["none", "input_skip", "residual"] + assert embedding_type in ["fourier", "positional"] + combine_method = config.model.progressive_combine.lower() + combiner = functools.partial(Combine, method=combine_method) + + modules = [] + # timestep/noise_level embedding; only for continuous training + if embedding_type == "fourier": + # Gaussian Fourier features embeddings. + assert config.training.continuous, "Fourier features are only used for continuous training." + + modules.append(GaussianFourierProjection(embedding_size=nf, scale=config.model.fourier_scale)) + embed_dim = 2 * nf + + elif embedding_type == "positional": + embed_dim = nf + + else: + raise ValueError(f"embedding type {embedding_type} unknown.") + + if conditional: + modules.append(nn.Linear(embed_dim, nf * 4)) + modules[-1].weight.data = default_init()(modules[-1].weight.shape) + nn.init.zeros_(modules[-1].bias) + modules.append(nn.Linear(nf * 4, nf * 4)) + modules[-1].weight.data = default_init()(modules[-1].weight.shape) + nn.init.zeros_(modules[-1].bias) + + AttnBlock = functools.partial(AttnBlockpp, init_scale=init_scale, skip_rescale=skip_rescale) + + Up_sample = functools.partial(Upsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel) + + if progressive == "output_skip": + self.pyramid_upsample = Up_sample(fir=fir, fir_kernel=fir_kernel, with_conv=False) + elif progressive == "residual": + pyramid_upsample = functools.partial(Up_sample, fir=fir, fir_kernel=fir_kernel, with_conv=True) + + Down_sample = functools.partial(Downsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel) + + if progressive_input == "input_skip": + self.pyramid_downsample = Down_sample(fir=fir, fir_kernel=fir_kernel, with_conv=False) + elif progressive_input == "residual": + pyramid_downsample = functools.partial(Down_sample, fir=fir, fir_kernel=fir_kernel, with_conv=True) + + if resblock_type == "ddpm": + ResnetBlock = functools.partial( + ResnetBlockDDPMpp, + act=act, + dropout=dropout, + init_scale=init_scale, + skip_rescale=skip_rescale, + temb_dim=nf * 4, + ) + + elif resblock_type == "biggan": + ResnetBlock = functools.partial( + ResnetBlockBigGANpp, + act=act, + dropout=dropout, + fir=fir, + fir_kernel=fir_kernel, + init_scale=init_scale, + skip_rescale=skip_rescale, + temb_dim=nf * 4, + ) + + else: + raise ValueError(f"resblock type {resblock_type} unrecognized.") + + # Downsampling block + + channels = config.data.num_channels + if progressive_input != "none": + input_pyramid_ch = channels + + modules.append(conv3x3(channels, nf)) + hs_c = [nf] + + in_ch = nf + for i_level in range(num_resolutions): + # Residual blocks for this resolution + for i_block in range(num_res_blocks): + out_ch = nf * ch_mult[i_level] + modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch)) + in_ch = out_ch + + if all_resolutions[i_level] in attn_resolutions: + modules.append(AttnBlock(channels=in_ch)) + hs_c.append(in_ch) + + if i_level != num_resolutions - 1: + if resblock_type == "ddpm": + modules.append(Downsample(in_ch=in_ch)) + else: + modules.append(ResnetBlock(down=True, in_ch=in_ch)) + + if progressive_input == "input_skip": + modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch)) + if combine_method == "cat": + in_ch *= 2 + + elif progressive_input == "residual": + modules.append(pyramid_downsample(in_ch=input_pyramid_ch, out_ch=in_ch)) + input_pyramid_ch = in_ch + + hs_c.append(in_ch) + + in_ch = hs_c[-1] + modules.append(ResnetBlock(in_ch=in_ch)) + modules.append(AttnBlock(channels=in_ch)) + modules.append(ResnetBlock(in_ch=in_ch)) + + pyramid_ch = 0 + # Upsampling block + for i_level in reversed(range(num_resolutions)): + for i_block in range(num_res_blocks + 1): + out_ch = nf * ch_mult[i_level] + modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch)) + in_ch = out_ch + + if all_resolutions[i_level] in attn_resolutions: + modules.append(AttnBlock(channels=in_ch)) + + if progressive != "none": + if i_level == num_resolutions - 1: + if progressive == "output_skip": + modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) + modules.append(conv3x3(in_ch, channels, init_scale=init_scale)) + pyramid_ch = channels + elif progressive == "residual": + modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) + modules.append(conv3x3(in_ch, in_ch, bias=True)) + pyramid_ch = in_ch + else: + raise ValueError(f"{progressive} is not a valid name.") + else: + if progressive == "output_skip": + modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) + modules.append(conv3x3(in_ch, channels, bias=True, init_scale=init_scale)) + pyramid_ch = channels + elif progressive == "residual": + modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch)) + pyramid_ch = in_ch + else: + raise ValueError(f"{progressive} is not a valid name") + + if i_level != 0: + if resblock_type == "ddpm": + modules.append(Upsample(in_ch=in_ch)) + else: + modules.append(ResnetBlock(in_ch=in_ch, up=True)) + + assert not hs_c + + if progressive != "output_skip": + modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) + modules.append(conv3x3(in_ch, channels, init_scale=init_scale)) + + self.all_modules = nn.ModuleList(modules) + + def forward(self, x, time_cond): + # import ipdb; ipdb.set_trace() + # timestep/noise_level embedding; only for continuous training + modules = self.all_modules + m_idx = 0 + if self.embedding_type == "fourier": + # Gaussian Fourier features embeddings. + used_sigmas = time_cond + temb = modules[m_idx](torch.log(used_sigmas)) + m_idx += 1 + + elif self.embedding_type == "positional": + # Sinusoidal positional embeddings. + timesteps = time_cond + used_sigmas = self.sigmas[time_cond.long()] + temb = get_timestep_embedding(timesteps, self.nf) + + else: + raise ValueError(f"embedding type {self.embedding_type} unknown.") + + if self.conditional: + temb = modules[m_idx](temb) + m_idx += 1 + temb = modules[m_idx](self.act(temb)) + m_idx += 1 + else: + temb = None + + if not self.config.data.centered: + # If input data is in [0, 1] + x = 2 * x - 1.0 + + # Downsampling block + input_pyramid = None + if self.progressive_input != "none": + input_pyramid = x + + hs = [modules[m_idx](x)] + m_idx += 1 + for i_level in range(self.num_resolutions): + # Residual blocks for this resolution + for i_block in range(self.num_res_blocks): + h = modules[m_idx](hs[-1], temb) + m_idx += 1 + if h.shape[-1] in self.attn_resolutions: + h = modules[m_idx](h) + m_idx += 1 + + hs.append(h) + + if i_level != self.num_resolutions - 1: + if self.resblock_type == "ddpm": + h = modules[m_idx](hs[-1]) + m_idx += 1 + else: + h = modules[m_idx](hs[-1], temb) + m_idx += 1 + + if self.progressive_input == "input_skip": + input_pyramid = self.pyramid_downsample(input_pyramid) + h = modules[m_idx](input_pyramid, h) + m_idx += 1 + + elif self.progressive_input == "residual": + input_pyramid = modules[m_idx](input_pyramid) + m_idx += 1 + if self.skip_rescale: + input_pyramid = (input_pyramid + h) / np.sqrt(2.0) + else: + input_pyramid = input_pyramid + h + h = input_pyramid + + hs.append(h) + + h = hs[-1] + h = modules[m_idx](h, temb) + m_idx += 1 + h = modules[m_idx](h) + m_idx += 1 + h = modules[m_idx](h, temb) + m_idx += 1 + + pyramid = None + + # Upsampling block + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb) + m_idx += 1 + + if h.shape[-1] in self.attn_resolutions: + h = modules[m_idx](h) + m_idx += 1 + + if self.progressive != "none": + if i_level == self.num_resolutions - 1: + if self.progressive == "output_skip": + pyramid = self.act(modules[m_idx](h)) + m_idx += 1 + pyramid = modules[m_idx](pyramid) + m_idx += 1 + elif self.progressive == "residual": + pyramid = self.act(modules[m_idx](h)) + m_idx += 1 + pyramid = modules[m_idx](pyramid) + m_idx += 1 + else: + raise ValueError(f"{self.progressive} is not a valid name.") + else: + if self.progressive == "output_skip": + pyramid = self.pyramid_upsample(pyramid) + pyramid_h = self.act(modules[m_idx](h)) + m_idx += 1 + pyramid_h = modules[m_idx](pyramid_h) + m_idx += 1 + pyramid = pyramid + pyramid_h + elif self.progressive == "residual": + pyramid = modules[m_idx](pyramid) + m_idx += 1 + if self.skip_rescale: + pyramid = (pyramid + h) / np.sqrt(2.0) + else: + pyramid = pyramid + h + h = pyramid + else: + raise ValueError(f"{self.progressive} is not a valid name") + + if i_level != 0: + if self.resblock_type == "ddpm": + h = modules[m_idx](h) + m_idx += 1 + else: + h = modules[m_idx](h, temb) + m_idx += 1 + + assert not hs + + if self.progressive == "output_skip": + h = pyramid + else: + h = self.act(modules[m_idx](h)) + m_idx += 1 + h = modules[m_idx](h) + m_idx += 1 + + assert m_idx == len(modules) + if self.config.model.scale_by_sigma: + used_sigmas = used_sigmas.reshape((x.shape[0], *([1] * len(x.shape[1:])))) + h = h / used_sigmas + + return h From 7ca832cac9525c297df0cea43753c0b4e26e491a Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 24 Jun 2022 17:20:25 +0000 Subject: [PATCH 10/35] save intermediate state score_sde --- run.py | 288 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 288 insertions(+) create mode 100755 run.py diff --git a/run.py b/run.py new file mode 100755 index 0000000000..61e29603fb --- /dev/null +++ b/run.py @@ -0,0 +1,288 @@ +#!/usr/bin/env python3 +import numpy as np +import PIL +import functools + +import models +from models import utils as mutils +from models import ncsnv2 +from models import ncsnpp +from models import ddpm as ddpm_model +from models import layerspp +from models import layers +from models import normalization + +from utils import restore_checkpoint + +import sampling +from sde_lib import VESDE, VPSDE, subVPSDE +from sampling import (NoneCorrector, + ReverseDiffusionPredictor, + LangevinCorrector, + EulerMaruyamaPredictor, + AncestralSamplingPredictor, + NonePredictor, + AnnealedLangevinDynamics) +import datasets +import torch + + +torch.manual_seed(0) + + +#class NewVESDE(SDE): +# def __init__(self, sigma_min=0.01, sigma_max=50, N=1000): +# """Construct a Variance Exploding SDE. +# +# Args: +# sigma_min: smallest sigma. +# sigma_max: largest sigma. +# N: number of discretization steps +# """ +# super().__init__(N) +# self.sigma_min = sigma_min +# self.sigma_max = sigma_max +# self.discrete_sigmas = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N)) +# self.N = N +# +# @property +# def T(self): +# return 1 +# +# def sde(self, x, t): +# sigma = self.sigma_min * (self.sigma_max / self.sigma_min) ** t +# drift = torch.zeros_like(x) +# diffusion = sigma * torch.sqrt(torch.tensor(2 * (np.log(self.sigma_max) - np.log(self.sigma_min)), +# device=t.device)) +# return drift, diffusion +# +# def marginal_prob(self, x, t): +# std = self.sigma_min * (self.sigma_max / self.sigma_min) ** t +# mean = x +# return mean, std +# +# def prior_sampling(self, shape): +# return torch.randn(*shape) * self.sigma_max +# +# def prior_logp(self, z): +# shape = z.shape +# N = np.prod(shape[1:]) +# return -N / 2. * np.log(2 * np.pi * self.sigma_max ** 2) - torch.sum(z ** 2, dim=(1, 2, 3)) / (2 * self.sigma_max ** 2) +# +# def discretize(self, x, t): +# """SMLD(NCSN) discretization.""" +# timestep = (t * (self.N - 1) / self.T).long() +# sigma = self.discrete_sigmas.to(t.device)[timestep] +# adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t), +# self.discrete_sigmas[timestep - 1].to(t.device)) +# f = torch.zeros_like(x) +# G = torch.sqrt(sigma ** 2 - adjacent_sigma ** 2) +# return f, G + + +class NewReverseDiffusionPredictor: + + def __init__(self, sde, score_fn, probability_flow=False): + super().__init__() + self.sde = sde + self.probability_flow = probability_flow + self.score_fn = score_fn + + def discretize(self, x, t): + timestep = (t * (self.sde.N - 1) / self.sde.T).long() + sigma = self.sde.discrete_sigmas.to(t.device)[timestep] + adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t), + self.sde.discrete_sigmas[timestep - 1].to(t.device)) + f = torch.zeros_like(x) + G = torch.sqrt(sigma ** 2 - adjacent_sigma ** 2) + + labels = self.sde.marginal_prob(torch.zeros_like(x), t)[1] + result = self.score_fn(x, labels) + + rev_f = f - G[:, None, None, None] ** 2 * result * (0.5 if self.probability_flow else 1.) + rev_G = torch.zeros_like(G) if self.probability_flow else G + return rev_f, rev_G + + def update_fn(self, x, t): + f, G = self.discretize(x, t) + z = torch.randn_like(x) + x_mean = x - f + x = x_mean + G[:, None, None, None] * z + return x, x_mean + + +class NewLangevinCorrector: + + def __init__(self, sde, score_fn, snr, n_steps): + super().__init__() + self.sde = sde + self.score_fn = score_fn + self.snr = snr + self.n_steps = n_steps + + def update_fn(self, x, t): + sde = self.sde + score_fn = self.score_fn + n_steps = self.n_steps + target_snr = self.snr + if isinstance(sde, VPSDE) or isinstance(sde, subVPSDE): + timestep = (t * (sde.N - 1) / sde.T).long() + alpha = sde.alphas.to(t.device)[timestep] + else: + alpha = torch.ones_like(t) + + for i in range(n_steps): + labels = sde.marginal_prob(torch.zeros_like(x), t)[1] + grad = score_fn(x, labels) + noise = torch.randn_like(x) + grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean() + noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean() + step_size = (target_snr * noise_norm / grad_norm) ** 2 * 2 * alpha + x_mean = x + step_size[:, None, None, None] * grad + x = x_mean + torch.sqrt(step_size * 2)[:, None, None, None] * noise + + return x, x_mean + + + +def save_image(x): +# image_processed = x.cpu().permute(0, 2, 3, 1) +# image_processed = (image_processed + 1.0) * 127.5 +# image_processed = image_processed.numpy().astype(np.uint8) + image_processed = np.clip(x.permute(0, 2, 3, 1).cpu().numpy() * 255, 0, 255).astype(np.uint8) + image_pil = PIL.Image.fromarray(image_processed[0]) + + # 6. save image + image_pil.save("../images/hey.png") + + +#x = np.load("cifar10.npy") +# +#save_image(x) +# @title Load the score-based model +sde = 'VESDE' #@param ['VESDE', 'VPSDE', 'subVPSDE'] {"type": "string"} +if sde.lower() == 'vesde': + from configs.ve import cifar10_ncsnpp_continuous as configs + ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth" +# from configs.ve import ffhq_ncsnpp_continuous as configs +# ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth" + config = configs.get_config() + config.model.num_scales = 1000 + sde = VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales) + sampling_eps = 1e-5 +elif sde.lower() == 'vpsde': + from configs.vp import cifar10_ddpmpp_continuous as configs + ckpt_filename = "exp/vp/cifar10_ddpmpp_continuous/checkpoint_8.pth" + config = configs.get_config() + sde = VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales) + sampling_eps = 1e-3 +elif sde.lower() == 'subvpsde': + from configs.subvp import cifar10_ddpmpp_continuous as configs + ckpt_filename = "exp/subvp/cifar10_ddpmpp_continuous/checkpoint_26.pth" + config = configs.get_config() + sde = subVPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales) + sampling_eps = 1e-3 + +batch_size = 1 #@param {"type":"integer"} +config.training.batch_size = batch_size +config.eval.batch_size = batch_size + +random_seed = 0 #@param {"type": "integer"} + +score_model = mutils.create_model(config) + +loaded_state = torch.load(ckpt_filename) +score_model.load_state_dict(loaded_state["model"], strict=False) + +inverse_scaler = datasets.get_data_inverse_scaler(config) +predictor = ReverseDiffusionPredictor #@param ["EulerMaruyamaPredictor", "AncestralSamplingPredictor", "ReverseDiffusionPredictor", "None"] {"type": "raw"} +corrector = LangevinCorrector #@param ["LangevinCorrector", "AnnealedLangevinDynamics", "None"] {"type": "raw"} + +def image_grid(x): + size = config.data.image_size + channels = config.data.num_channels + img = x.reshape(-1, size, size, channels) + w = int(np.sqrt(img.shape[0])) + img = img.reshape((w, w, size, size, channels)).transpose((0, 2, 1, 3, 4)).reshape((w * size, w * size, channels)) + return img + +#@title PC sampling +img_size = config.data.image_size +channels = config.data.num_channels +shape = (batch_size, channels, img_size, img_size) +probability_flow = False +snr = 0.16 #@param {"type": "number"} +n_steps = 1#@param {"type": "integer"} + + +def shared_predictor_update_fn(x, t, sde, model, predictor, probability_flow, continuous): + """A wrapper that configures and returns the update function of predictors.""" + score_fn = mutils.get_score_fn(sde, model, train=False, continuous=continuous) + if predictor is None: + # Corrector-only sampler + predictor_obj = NonePredictor(sde, score_fn, probability_flow) + else: + predictor_obj = predictor(sde, score_fn, probability_flow) + return predictor_obj.update_fn(x, t) + + +def shared_corrector_update_fn(x, t, sde, model, corrector, continuous, snr, n_steps): + """A wrapper tha configures and returns the update function of correctors.""" + score_fn = mutils.get_score_fn(sde, model, train=False, continuous=continuous) + if corrector is None: + # Predictor-only sampler + corrector_obj = NoneCorrector(sde, score_fn, snr, n_steps) + else: + corrector_obj = corrector(sde, score_fn, snr, n_steps) + return corrector_obj.update_fn(x, t) + + +continuous = config.training.continuous + + +predictor_update_fn = functools.partial(shared_predictor_update_fn, + sde=sde, + predictor=predictor, + probability_flow=probability_flow, + continuous=continuous) + +corrector_update_fn = functools.partial(shared_corrector_update_fn, + sde=sde, + corrector=corrector, + continuous=continuous, + snr=snr, + n_steps=n_steps) + +device = "cuda" +model = score_model.to(device) +denoise = False + +new_corrector = NewLangevinCorrector(sde=sde, score_fn=model, snr=snr, n_steps=n_steps) +new_predictor = NewReverseDiffusionPredictor(sde=sde, score_fn=model) + + +with torch.no_grad(): + # Initial sample + x = sde.prior_sampling(shape).to(device) + timesteps = torch.linspace(sde.T, sampling_eps, sde.N, device=device) + + for i in range(sde.N): + t = timesteps[i] + vec_t = torch.ones(shape[0], device=t.device) * t +# x, x_mean = corrector_update_fn(x, vec_t, model=model) +# x, x_mean = predictor_update_fn(x, vec_t, model=model) + x, x_mean = new_corrector.update_fn(x, vec_t) + x, x_mean = new_predictor.update_fn(x, vec_t) + + x, n = inverse_scaler(x_mean if denoise else x), sde.N * (n_steps + 1) + + +# for 5 +#assert x.abs().sum().cpu().item() - 106114.90625 < 1e-2, "sum wrong" +#assert x.abs().mean().cpu().item() - 34.5426139831543 < 1e-4, "mean wrong" + +# for 1000 +assert x.abs().sum().cpu().item() - 436.5811 < 1e-2, "sum wrong" +assert x.abs().mean().cpu().item() - 0.1421 < 1e-4, "mean wrong" + +save_image(x) From fc67917a181a4cbd539c794470948ffeb89e5b1d Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 24 Jun 2022 17:35:19 +0000 Subject: [PATCH 11/35] up --- run.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/run.py b/run.py index 61e29603fb..b2ec6eea29 100755 --- a/run.py +++ b/run.py @@ -269,20 +269,21 @@ with torch.no_grad(): for i in range(sde.N): t = timesteps[i] vec_t = torch.ones(shape[0], device=t.device) * t -# x, x_mean = corrector_update_fn(x, vec_t, model=model) -# x, x_mean = predictor_update_fn(x, vec_t, model=model) - x, x_mean = new_corrector.update_fn(x, vec_t) - x, x_mean = new_predictor.update_fn(x, vec_t) + x, x_mean = corrector_update_fn(x, vec_t, model=model) + x, x_mean = predictor_update_fn(x, vec_t, model=model) +# x, x_mean = new_corrector.update_fn(x, vec_t) +# x, x_mean = new_predictor.update_fn(x, vec_t) x, n = inverse_scaler(x_mean if denoise else x), sde.N * (n_steps + 1) +save_image(x) + # for 5 -#assert x.abs().sum().cpu().item() - 106114.90625 < 1e-2, "sum wrong" -#assert x.abs().mean().cpu().item() - 34.5426139831543 < 1e-4, "mean wrong" +#assert (x.abs().sum() - 106114.90625).cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}" +#assert (x.abs().mean() - 34.5426139831543).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}" # for 1000 -assert x.abs().sum().cpu().item() - 436.5811 < 1e-2, "sum wrong" -assert x.abs().mean().cpu().item() - 0.1421 < 1e-4, "mean wrong" +assert (x.abs().sum() - 436.5811).abs().sum().cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}" +assert (x.abs().mean() - 0.1421).abs().mean().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}" -save_image(x) From 78e99a997bb29bbaa7b91fa0ff233e46bee95e9c Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 24 Jun 2022 18:48:26 +0000 Subject: [PATCH 12/35] adapt run.py --- run.py | 110 +++++++++++++++++++++++++++++++++++---------------------- 1 file changed, 67 insertions(+), 43 deletions(-) diff --git a/run.py b/run.py index b2ec6eea29..7a55acbab2 100755 --- a/run.py +++ b/run.py @@ -11,6 +11,8 @@ from models import ddpm as ddpm_model from models import layerspp from models import layers from models import normalization +from models.ema import ExponentialMovingAverage +from losses import get_optimizer from utils import restore_checkpoint @@ -27,6 +29,7 @@ import datasets import torch +torch.backends.cuda.matmul.allow_tf32 = False torch.manual_seed(0) @@ -81,7 +84,6 @@ torch.manual_seed(0) class NewReverseDiffusionPredictor: - def __init__(self, sde, score_fn, probability_flow=False): super().__init__() self.sde = sde @@ -112,7 +114,6 @@ class NewReverseDiffusionPredictor: class NewLangevinCorrector: - def __init__(self, sde, score_fn, snr, n_steps): super().__init__() self.sde = sde @@ -146,28 +147,19 @@ class NewLangevinCorrector: def save_image(x): -# image_processed = x.cpu().permute(0, 2, 3, 1) -# image_processed = (image_processed + 1.0) * 127.5 -# image_processed = image_processed.numpy().astype(np.uint8) image_processed = np.clip(x.permute(0, 2, 3, 1).cpu().numpy() * 255, 0, 255).astype(np.uint8) image_pil = PIL.Image.fromarray(image_processed[0]) - - # 6. save image image_pil.save("../images/hey.png") -#x = np.load("cifar10.npy") -# -#save_image(x) -# @title Load the score-based model sde = 'VESDE' #@param ['VESDE', 'VPSDE', 'subVPSDE'] {"type": "string"} if sde.lower() == 'vesde': - from configs.ve import cifar10_ncsnpp_continuous as configs - ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth" -# from configs.ve import ffhq_ncsnpp_continuous as configs -# ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth" +# from configs.ve import cifar10_ncsnpp_continuous as configs +# ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth" + from configs.ve import ffhq_ncsnpp_continuous as configs + ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth" config = configs.get_config() - config.model.num_scales = 1000 + config.model.num_scales = 2 sde = VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales) sampling_eps = 1e-5 elif sde.lower() == 'vpsde': @@ -189,32 +181,53 @@ config.eval.batch_size = batch_size random_seed = 0 #@param {"type": "integer"} -score_model = mutils.create_model(config) +#sigmas = mutils.get_sigmas(config) +#scaler = datasets.get_data_scaler(config) +#inverse_scaler = datasets.get_data_inverse_scaler(config) +#score_model = mutils.create_model(config) +# +#optimizer = get_optimizer(config, score_model.parameters()) +#ema = ExponentialMovingAverage(score_model.parameters(), +# decay=config.model.ema_rate) +#state = dict(step=0, optimizer=optimizer, +# model=score_model, ema=ema) +# +#state = restore_checkpoint(ckpt_filename, state, config.device) +#ema.copy_to(score_model.parameters()) -loaded_state = torch.load(ckpt_filename) -score_model.load_state_dict(loaded_state["model"], strict=False) +#score_model = mutils.create_model(config) + +from diffusers import NCSNpp +score_model = NCSNpp(config).to(config.device) +score_model = torch.nn.DataParallel(score_model) + +loaded_state = torch.load("./ffhq_1024_ncsnpp_continuous_ema.pt") +del loaded_state["module.sigmas"] +score_model.load_state_dict(loaded_state, strict=False) inverse_scaler = datasets.get_data_inverse_scaler(config) predictor = ReverseDiffusionPredictor #@param ["EulerMaruyamaPredictor", "AncestralSamplingPredictor", "ReverseDiffusionPredictor", "None"] {"type": "raw"} corrector = LangevinCorrector #@param ["LangevinCorrector", "AnnealedLangevinDynamics", "None"] {"type": "raw"} -def image_grid(x): - size = config.data.image_size - channels = config.data.num_channels - img = x.reshape(-1, size, size, channels) - w = int(np.sqrt(img.shape[0])) - img = img.reshape((w, w, size, size, channels)).transpose((0, 2, 1, 3, 4)).reshape((w * size, w * size, channels)) - return img - #@title PC sampling img_size = config.data.image_size channels = config.data.num_channels shape = (batch_size, channels, img_size, img_size) probability_flow = False -snr = 0.16 #@param {"type": "number"} +snr = 0.15 #@param {"type": "number"} n_steps = 1#@param {"type": "integer"} +#sampling_fn = sampling.get_pc_sampler(sde, shape, predictor, corrector, +# inverse_scaler, snr, n_steps=n_steps, +# probability_flow=probability_flow, +# continuous=config.training.continuous, +# eps=sampling_eps, device=config.device) +# +#x, n = sampling_fn(score_model) +#save_image(x) + + def shared_predictor_update_fn(x, t, sde, model, predictor, probability_flow, continuous): """A wrapper that configures and returns the update function of predictors.""" score_fn = mutils.get_score_fn(sde, model, train=False, continuous=continuous) @@ -253,14 +266,14 @@ corrector_update_fn = functools.partial(shared_corrector_update_fn, snr=snr, n_steps=n_steps) -device = "cuda" -model = score_model.to(device) -denoise = False +device = config.device +model = score_model +denoise = True new_corrector = NewLangevinCorrector(sde=sde, score_fn=model, snr=snr, n_steps=n_steps) new_predictor = NewReverseDiffusionPredictor(sde=sde, score_fn=model) - +# with torch.no_grad(): # Initial sample x = sde.prior_sampling(shape).to(device) @@ -269,21 +282,32 @@ with torch.no_grad(): for i in range(sde.N): t = timesteps[i] vec_t = torch.ones(shape[0], device=t.device) * t - x, x_mean = corrector_update_fn(x, vec_t, model=model) - x, x_mean = predictor_update_fn(x, vec_t, model=model) -# x, x_mean = new_corrector.update_fn(x, vec_t) -# x, x_mean = new_predictor.update_fn(x, vec_t) +# x, x_mean = corrector_update_fn(x, vec_t, model=model) +# x, x_mean = predictor_update_fn(x, vec_t, model=model) + x, x_mean = new_corrector.update_fn(x, vec_t) + x, x_mean = new_predictor.update_fn(x, vec_t) x, n = inverse_scaler(x_mean if denoise else x), sde.N * (n_steps + 1) -save_image(x) -# for 5 -#assert (x.abs().sum() - 106114.90625).cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}" -#assert (x.abs().mean() - 34.5426139831543).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}" +#save_image(x) -# for 1000 -assert (x.abs().sum() - 436.5811).abs().sum().cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}" -assert (x.abs().mean() - 0.1421).abs().mean().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}" +# for 5 cifar10 +x_sum = 106071.9922 +x_mean = 34.52864456176758 +# for 1000 cifar10 +x_sum = 461.9700 +x_mean = 0.1504 + +# for 2 for 1024 +x_sum = 3382810112.0 +x_mean = 1075.366455078125 + +def check_x_sum_x_mean(x, x_sum, x_mean): + assert (x.abs().sum() - x_sum).abs().cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}" + assert (x.abs().mean() - x_mean).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}" + + +check_x_sum_x_mean(x, x_sum, x_mean) From 49a81f9f1ac006a52496899a16c7a79993247a98 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 24 Jun 2022 19:44:17 +0000 Subject: [PATCH 13/35] port first 1024 model --- run.py | 350 +++++++++++++++++++++++++-------------------------------- 1 file changed, 150 insertions(+), 200 deletions(-) diff --git a/run.py b/run.py index 7a55acbab2..0180c3489b 100755 --- a/run.py +++ b/run.py @@ -1,104 +1,128 @@ #!/usr/bin/env python3 import numpy as np import PIL -import functools - -import models -from models import utils as mutils -from models import ncsnv2 -from models import ncsnpp -from models import ddpm as ddpm_model -from models import layerspp -from models import layers -from models import normalization -from models.ema import ExponentialMovingAverage -from losses import get_optimizer - -from utils import restore_checkpoint - -import sampling -from sde_lib import VESDE, VPSDE, subVPSDE -from sampling import (NoneCorrector, - ReverseDiffusionPredictor, - LangevinCorrector, - EulerMaruyamaPredictor, - AncestralSamplingPredictor, - NonePredictor, - AnnealedLangevinDynamics) -import datasets import torch +import ml_collections +#from configs.ve import ffhq_ncsnpp_continuous as configs +# from configs.ve import cifar10_ncsnpp_continuous as configs + + +# ffhq_ncsnpp_continuous config +def get_config(): + config = ml_collections.ConfigDict() + # training + config.training = training = ml_collections.ConfigDict() + training.batch_size = 8 + training.n_iters = 2400001 + training.snapshot_freq = 50000 + training.log_freq = 50 + training.eval_freq = 100 + training.snapshot_freq_for_preemption = 5000 + training.snapshot_sampling = True + training.sde = 'vesde' + training.continuous = True + training.likelihood_weighting = False + training.reduce_mean = True + + # sampling + config.sampling = sampling = ml_collections.ConfigDict() + sampling.method = 'pc' + sampling.predictor = 'reverse_diffusion' + sampling.corrector = 'langevin' + sampling.probability_flow = False + sampling.snr = 0.15 + sampling.n_steps_each = 1 + sampling.noise_removal = True + + # eval + config.eval = evaluate = ml_collections.ConfigDict() + evaluate.batch_size = 1024 + evaluate.num_samples = 50000 + evaluate.begin_ckpt = 1 + evaluate.end_ckpt = 96 + + # data + config.data = data = ml_collections.ConfigDict() + data.dataset = 'FFHQ' + data.image_size = 1024 + data.centered = False + data.random_flip = True + data.uniform_dequantization = False + data.num_channels = 3 + # Plug in your own path to the tfrecords file. + data.tfrecords_path = '/raid/song/ffhq-dataset/ffhq/ffhq-r10.tfrecords' + + # model + config.model = model = ml_collections.ConfigDict() + model.name = 'ncsnpp' + model.scale_by_sigma = True + model.sigma_max = 1348 + model.num_scales = 2000 + model.ema_rate = 0.9999 + model.sigma_min = 0.01 + model.normalization = 'GroupNorm' + model.nonlinearity = 'swish' + model.nf = 16 + model.ch_mult = (1, 2, 4, 8, 16, 32, 32, 32) + model.num_res_blocks = 1 + model.attn_resolutions = (16,) + model.dropout = 0. + model.resamp_with_conv = True + model.conditional = True + model.fir = True + model.fir_kernel = [1, 3, 3, 1] + model.skip_rescale = True + model.resblock_type = 'biggan' + model.progressive = 'output_skip' + model.progressive_input = 'input_skip' + model.progressive_combine = 'sum' + model.attention_type = 'ddpm' + model.init_scale = 0. + model.fourier_scale = 16 + model.conv_size = 3 + model.embedding_type = 'fourier' + + # optim + config.optim = optim = ml_collections.ConfigDict() + optim.weight_decay = 0 + optim.optimizer = 'Adam' + optim.lr = 2e-4 + optim.beta1 = 0.9 + optim.amsgrad = False + optim.eps = 1e-8 + optim.warmup = 5000 + optim.grad_clip = 1. + + config.seed = 42 + config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') + + return config torch.backends.cuda.matmul.allow_tf32 = False -torch.manual_seed(0) - - -#class NewVESDE(SDE): -# def __init__(self, sigma_min=0.01, sigma_max=50, N=1000): -# """Construct a Variance Exploding SDE. -# -# Args: -# sigma_min: smallest sigma. -# sigma_max: largest sigma. -# N: number of discretization steps -# """ -# super().__init__(N) -# self.sigma_min = sigma_min -# self.sigma_max = sigma_max -# self.discrete_sigmas = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N)) -# self.N = N -# -# @property -# def T(self): -# return 1 -# -# def sde(self, x, t): -# sigma = self.sigma_min * (self.sigma_max / self.sigma_min) ** t -# drift = torch.zeros_like(x) -# diffusion = sigma * torch.sqrt(torch.tensor(2 * (np.log(self.sigma_max) - np.log(self.sigma_min)), -# device=t.device)) -# return drift, diffusion -# -# def marginal_prob(self, x, t): -# std = self.sigma_min * (self.sigma_max / self.sigma_min) ** t -# mean = x -# return mean, std -# -# def prior_sampling(self, shape): -# return torch.randn(*shape) * self.sigma_max -# -# def prior_logp(self, z): -# shape = z.shape -# N = np.prod(shape[1:]) -# return -N / 2. * np.log(2 * np.pi * self.sigma_max ** 2) - torch.sum(z ** 2, dim=(1, 2, 3)) / (2 * self.sigma_max ** 2) -# -# def discretize(self, x, t): -# """SMLD(NCSN) discretization.""" -# timestep = (t * (self.N - 1) / self.T).long() -# sigma = self.discrete_sigmas.to(t.device)[timestep] -# adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t), -# self.discrete_sigmas[timestep - 1].to(t.device)) -# f = torch.zeros_like(x) -# G = torch.sqrt(sigma ** 2 - adjacent_sigma ** 2) -# return f, G +torch.manual_seed(3) class NewReverseDiffusionPredictor: - def __init__(self, sde, score_fn, probability_flow=False): + def __init__(self, score_fn, probability_flow=False, sigma_min=0.0, sigma_max=0.0, N=0): super().__init__() - self.sde = sde + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.N = N + self.discrete_sigmas = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N)) + self.probability_flow = probability_flow self.score_fn = score_fn def discretize(self, x, t): - timestep = (t * (self.sde.N - 1) / self.sde.T).long() - sigma = self.sde.discrete_sigmas.to(t.device)[timestep] + timestep = (t * (self.N - 1)).long() + sigma = self.discrete_sigmas.to(t.device)[timestep] adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t), - self.sde.discrete_sigmas[timestep - 1].to(t.device)) + self.discrete_sigmas[timestep - 1].to(t.device)) f = torch.zeros_like(x) G = torch.sqrt(sigma ** 2 - adjacent_sigma ** 2) - labels = self.sde.marginal_prob(torch.zeros_like(x), t)[1] + labels = self.sigma_min * (self.sigma_max / self.sigma_min) ** t result = self.score_fn(x, labels) rev_f = f - G[:, None, None, None] ** 2 * result * (0.5 if self.probability_flow else 1.) @@ -114,26 +138,27 @@ class NewReverseDiffusionPredictor: class NewLangevinCorrector: - def __init__(self, sde, score_fn, snr, n_steps): + def __init__(self, score_fn, snr, n_steps, sigma_min=0.0, sigma_max=0.0): super().__init__() - self.sde = sde self.score_fn = score_fn self.snr = snr self.n_steps = n_steps + self.sigma_min = sigma_min + self.sigma_max = sigma_max + def update_fn(self, x, t): - sde = self.sde score_fn = self.score_fn n_steps = self.n_steps target_snr = self.snr - if isinstance(sde, VPSDE) or isinstance(sde, subVPSDE): - timestep = (t * (sde.N - 1) / sde.T).long() - alpha = sde.alphas.to(t.device)[timestep] - else: - alpha = torch.ones_like(t) +# if isinstance(sde, VPSDE) or isinstance(sde, subVPSDE): +# timestep = (t * (sde.N - 1) / sde.T).long() +# alpha = sde.alphas.to(t.device)[timestep] +# else: + alpha = torch.ones_like(t) for i in range(n_steps): - labels = sde.marginal_prob(torch.zeros_like(x), t)[1] + labels = self.sigma_min * (self.sigma_max / self.sigma_min) ** t grad = score_fn(x, labels) noise = torch.randn_like(x) grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean() @@ -152,64 +177,42 @@ def save_image(x): image_pil.save("../images/hey.png") -sde = 'VESDE' #@param ['VESDE', 'VPSDE', 'subVPSDE'] {"type": "string"} -if sde.lower() == 'vesde': -# from configs.ve import cifar10_ncsnpp_continuous as configs # ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth" - from configs.ve import ffhq_ncsnpp_continuous as configs - ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth" - config = configs.get_config() - config.model.num_scales = 2 - sde = VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales) - sampling_eps = 1e-5 -elif sde.lower() == 'vpsde': - from configs.vp import cifar10_ddpmpp_continuous as configs - ckpt_filename = "exp/vp/cifar10_ddpmpp_continuous/checkpoint_8.pth" - config = configs.get_config() - sde = VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales) - sampling_eps = 1e-3 -elif sde.lower() == 'subvpsde': - from configs.subvp import cifar10_ddpmpp_continuous as configs - ckpt_filename = "exp/subvp/cifar10_ddpmpp_continuous/checkpoint_26.pth" - config = configs.get_config() - sde = subVPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales) - sampling_eps = 1e-3 +#ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth" +# Note usually we need to restore ema etc... +# ema restored checkpoint used from below + + + +config = get_config() + +sigma_min, sigma_max = config.model.sigma_min, config.model.sigma_max +N = config.model.num_scales + +sampling_eps = 1e-5 batch_size = 1 #@param {"type":"integer"} config.training.batch_size = batch_size config.eval.batch_size = batch_size -random_seed = 0 #@param {"type": "integer"} - -#sigmas = mutils.get_sigmas(config) -#scaler = datasets.get_data_scaler(config) -#inverse_scaler = datasets.get_data_inverse_scaler(config) -#score_model = mutils.create_model(config) -# -#optimizer = get_optimizer(config, score_model.parameters()) -#ema = ExponentialMovingAverage(score_model.parameters(), -# decay=config.model.ema_rate) -#state = dict(step=0, optimizer=optimizer, -# model=score_model, ema=ema) -# -#state = restore_checkpoint(ckpt_filename, state, config.device) -#ema.copy_to(score_model.parameters()) - -#score_model = mutils.create_model(config) - from diffusers import NCSNpp -score_model = NCSNpp(config).to(config.device) -score_model = torch.nn.DataParallel(score_model) +model = NCSNpp(config).to(config.device) +model = torch.nn.DataParallel(model) -loaded_state = torch.load("./ffhq_1024_ncsnpp_continuous_ema.pt") +loaded_state = torch.load("../score_sde_pytorch/ffhq_1024_ncsnpp_continuous_ema.pt") del loaded_state["module.sigmas"] -score_model.load_state_dict(loaded_state, strict=False) +model.load_state_dict(loaded_state, strict=False) -inverse_scaler = datasets.get_data_inverse_scaler(config) -predictor = ReverseDiffusionPredictor #@param ["EulerMaruyamaPredictor", "AncestralSamplingPredictor", "ReverseDiffusionPredictor", "None"] {"type": "raw"} -corrector = LangevinCorrector #@param ["LangevinCorrector", "AnnealedLangevinDynamics", "None"] {"type": "raw"} +def get_data_inverse_scaler(config): + """Inverse data normalizer.""" + if config.data.centered: + # Rescale [-1, 1] to [0, 1] + return lambda x: (x + 1.) / 2. + else: + return lambda x: x + +inverse_scaler = get_data_inverse_scaler(config) -#@title PC sampling img_size = config.data.image_size channels = config.data.num_channels shape = (batch_size, channels, img_size, img_size) @@ -218,80 +221,27 @@ snr = 0.15 #@param {"type": "number"} n_steps = 1#@param {"type": "integer"} -#sampling_fn = sampling.get_pc_sampler(sde, shape, predictor, corrector, -# inverse_scaler, snr, n_steps=n_steps, -# probability_flow=probability_flow, -# continuous=config.training.continuous, -# eps=sampling_eps, device=config.device) -# -#x, n = sampling_fn(score_model) -#save_image(x) - - -def shared_predictor_update_fn(x, t, sde, model, predictor, probability_flow, continuous): - """A wrapper that configures and returns the update function of predictors.""" - score_fn = mutils.get_score_fn(sde, model, train=False, continuous=continuous) - if predictor is None: - # Corrector-only sampler - predictor_obj = NonePredictor(sde, score_fn, probability_flow) - else: - predictor_obj = predictor(sde, score_fn, probability_flow) - return predictor_obj.update_fn(x, t) - - -def shared_corrector_update_fn(x, t, sde, model, corrector, continuous, snr, n_steps): - """A wrapper tha configures and returns the update function of correctors.""" - score_fn = mutils.get_score_fn(sde, model, train=False, continuous=continuous) - if corrector is None: - # Predictor-only sampler - corrector_obj = NoneCorrector(sde, score_fn, snr, n_steps) - else: - corrector_obj = corrector(sde, score_fn, snr, n_steps) - return corrector_obj.update_fn(x, t) - - -continuous = config.training.continuous - - -predictor_update_fn = functools.partial(shared_predictor_update_fn, - sde=sde, - predictor=predictor, - probability_flow=probability_flow, - continuous=continuous) - -corrector_update_fn = functools.partial(shared_corrector_update_fn, - sde=sde, - corrector=corrector, - continuous=continuous, - snr=snr, - n_steps=n_steps) - device = config.device -model = score_model -denoise = True -new_corrector = NewLangevinCorrector(sde=sde, score_fn=model, snr=snr, n_steps=n_steps) -new_predictor = NewReverseDiffusionPredictor(sde=sde, score_fn=model) +new_corrector = NewLangevinCorrector(score_fn=model, snr=snr, n_steps=n_steps, sigma_min=sigma_min, sigma_max=sigma_max) +new_predictor = NewReverseDiffusionPredictor(score_fn=model, sigma_min=sigma_min, sigma_max=sigma_max, N=N) -# with torch.no_grad(): # Initial sample - x = sde.prior_sampling(shape).to(device) - timesteps = torch.linspace(sde.T, sampling_eps, sde.N, device=device) + x = torch.randn(*shape) * sigma_max + x = x.to(device) + timesteps = torch.linspace(1, sampling_eps, N, device=device) - for i in range(sde.N): + for i in range(N): t = timesteps[i] vec_t = torch.ones(shape[0], device=t.device) * t -# x, x_mean = corrector_update_fn(x, vec_t, model=model) -# x, x_mean = predictor_update_fn(x, vec_t, model=model) x, x_mean = new_corrector.update_fn(x, vec_t) x, x_mean = new_predictor.update_fn(x, vec_t) - x, n = inverse_scaler(x_mean if denoise else x), sde.N * (n_steps + 1) + x = inverse_scaler(x_mean) - -#save_image(x) +save_image(x) # for 5 cifar10 x_sum = 106071.9922 @@ -310,4 +260,4 @@ def check_x_sum_x_mean(x, x_sum, x_mean): assert (x.abs().mean() - x_mean).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}" -check_x_sum_x_mean(x, x_sum, x_mean) +#check_x_sum_x_mean(x, x_sum, x_mean) From bc2d586dcbba9429f4b0d9600a559fff18f599b6 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 25 Jun 2022 00:53:55 +0000 Subject: [PATCH 14/35] remove more dependencies --- run.py | 146 +++--------------- .../models/unet_sde_score_estimation.py | 125 ++++++++++----- 2 files changed, 105 insertions(+), 166 deletions(-) diff --git a/run.py b/run.py index 0180c3489b..cae9713967 100755 --- a/run.py +++ b/run.py @@ -2,105 +2,14 @@ import numpy as np import PIL import torch -import ml_collections #from configs.ve import ffhq_ncsnpp_continuous as configs # from configs.ve import cifar10_ncsnpp_continuous as configs -# ffhq_ncsnpp_continuous config -def get_config(): - config = ml_collections.ConfigDict() - # training - config.training = training = ml_collections.ConfigDict() - training.batch_size = 8 - training.n_iters = 2400001 - training.snapshot_freq = 50000 - training.log_freq = 50 - training.eval_freq = 100 - training.snapshot_freq_for_preemption = 5000 - training.snapshot_sampling = True - training.sde = 'vesde' - training.continuous = True - training.likelihood_weighting = False - training.reduce_mean = True - - # sampling - config.sampling = sampling = ml_collections.ConfigDict() - sampling.method = 'pc' - sampling.predictor = 'reverse_diffusion' - sampling.corrector = 'langevin' - sampling.probability_flow = False - sampling.snr = 0.15 - sampling.n_steps_each = 1 - sampling.noise_removal = True - - # eval - config.eval = evaluate = ml_collections.ConfigDict() - evaluate.batch_size = 1024 - evaluate.num_samples = 50000 - evaluate.begin_ckpt = 1 - evaluate.end_ckpt = 96 - - # data - config.data = data = ml_collections.ConfigDict() - data.dataset = 'FFHQ' - data.image_size = 1024 - data.centered = False - data.random_flip = True - data.uniform_dequantization = False - data.num_channels = 3 - # Plug in your own path to the tfrecords file. - data.tfrecords_path = '/raid/song/ffhq-dataset/ffhq/ffhq-r10.tfrecords' - - # model - config.model = model = ml_collections.ConfigDict() - model.name = 'ncsnpp' - model.scale_by_sigma = True - model.sigma_max = 1348 - model.num_scales = 2000 - model.ema_rate = 0.9999 - model.sigma_min = 0.01 - model.normalization = 'GroupNorm' - model.nonlinearity = 'swish' - model.nf = 16 - model.ch_mult = (1, 2, 4, 8, 16, 32, 32, 32) - model.num_res_blocks = 1 - model.attn_resolutions = (16,) - model.dropout = 0. - model.resamp_with_conv = True - model.conditional = True - model.fir = True - model.fir_kernel = [1, 3, 3, 1] - model.skip_rescale = True - model.resblock_type = 'biggan' - model.progressive = 'output_skip' - model.progressive_input = 'input_skip' - model.progressive_combine = 'sum' - model.attention_type = 'ddpm' - model.init_scale = 0. - model.fourier_scale = 16 - model.conv_size = 3 - model.embedding_type = 'fourier' - - # optim - config.optim = optim = ml_collections.ConfigDict() - optim.weight_decay = 0 - optim.optimizer = 'Adam' - optim.lr = 2e-4 - optim.beta1 = 0.9 - optim.amsgrad = False - optim.eps = 1e-8 - optim.warmup = 5000 - optim.grad_clip = 1. - - config.seed = 42 - config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') - - return config - +device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') torch.backends.cuda.matmul.allow_tf32 = False -torch.manual_seed(3) +torch.manual_seed(0) class NewReverseDiffusionPredictor: @@ -182,47 +91,26 @@ def save_image(x): # Note usually we need to restore ema etc... # ema restored checkpoint used from below - - -config = get_config() - -sigma_min, sigma_max = config.model.sigma_min, config.model.sigma_max -N = config.model.num_scales - +N = 2 +sigma_min = 0.01 +sigma_max = 1348 sampling_eps = 1e-5 - -batch_size = 1 #@param {"type":"integer"} -config.training.batch_size = batch_size -config.eval.batch_size = batch_size +batch_size = 1 +centered = False from diffusers import NCSNpp -model = NCSNpp(config).to(config.device) + +model = NCSNpp.from_pretrained("/home/patrick/ffhq_ncsnpp").to(device) model = torch.nn.DataParallel(model) -loaded_state = torch.load("../score_sde_pytorch/ffhq_1024_ncsnpp_continuous_ema.pt") -del loaded_state["module.sigmas"] -model.load_state_dict(loaded_state, strict=False) - -def get_data_inverse_scaler(config): - """Inverse data normalizer.""" - if config.data.centered: - # Rescale [-1, 1] to [0, 1] - return lambda x: (x + 1.) / 2. - else: - return lambda x: x - -inverse_scaler = get_data_inverse_scaler(config) - -img_size = config.data.image_size -channels = config.data.num_channels +img_size = model.module.config.image_size +channels = model.module.config.num_channels shape = (batch_size, channels, img_size, img_size) probability_flow = False -snr = 0.15 #@param {"type": "number"} -n_steps = 1#@param {"type": "integer"} +snr = 0.15 +n_steps = 1 -device = config.device - new_corrector = NewLangevinCorrector(score_fn=model, snr=snr, n_steps=n_steps, sigma_min=sigma_min, sigma_max=sigma_max) new_predictor = NewReverseDiffusionPredictor(score_fn=model, sigma_min=sigma_min, sigma_max=sigma_max, N=N) @@ -238,10 +126,12 @@ with torch.no_grad(): x, x_mean = new_corrector.update_fn(x, vec_t) x, x_mean = new_predictor.update_fn(x, vec_t) - x = inverse_scaler(x_mean) + x = x_mean + if centered: + x = (x + 1.) / 2. -save_image(x) +# save_image(x) # for 5 cifar10 x_sum = 106071.9922 @@ -260,4 +150,4 @@ def check_x_sum_x_mean(x, x_sum, x_mean): assert (x.abs().mean() - x_mean).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}" -#check_x_sum_x_mean(x, x_sum, x_mean) +check_x_sum_x_mean(x, x_sum, x_mean) diff --git a/src/diffusers/models/unet_sde_score_estimation.py b/src/diffusers/models/unet_sde_score_estimation.py index 26b4419ea2..30671ef293 100644 --- a/src/diffusers/models/unet_sde_score_estimation.py +++ b/src/diffusers/models/unet_sde_score_estimation.py @@ -15,6 +15,9 @@ # helpers functions +from ..modeling_utils import ModelMixin +from ..configuration_utils import ConfigMixin + import functools import math @@ -372,16 +375,16 @@ class NIN(nn.Module): return y.permute(0, 3, 1, 2) -def get_act(config): +def get_act(nonlinearity): """Get activation functions from the config file.""" - if config.model.nonlinearity.lower() == "elu": + if nonlinearity.lower() == "elu": return nn.ELU() - elif config.model.nonlinearity.lower() == "relu": + elif nonlinearity.lower() == "relu": return nn.ReLU() - elif config.model.nonlinearity.lower() == "lrelu": + elif nonlinearity.lower() == "lrelu": return nn.LeakyReLU(negative_slope=0.2) - elif config.model.nonlinearity.lower() == "swish": + elif nonlinearity.lower() == "swish": return nn.SiLU() else: raise NotImplementedError("activation function does not exist!") @@ -710,46 +713,93 @@ class ResnetBlockBigGANpp(nn.Module): return (x + h) / np.sqrt(2.0) -class NCSNpp(nn.Module): +class NCSNpp(ModelMixin, ConfigMixin): """NCSN++ model""" - def __init__(self, config): + def __init__( + self, + centered=False, + image_size=1024, + num_channels=3, + attention_type="ddpm", + attn_resolutions=(16,), + ch_mult=(1, 2, 4, 8, 16, 32, 32, 32), + conditional=True, + conv_size=3, + dropout=0.0, + embedding_type="fourier", + fir=True, + fir_kernel=(1, 3, 3, 1), + fourier_scale=16, + init_scale=0.0, + nf=16, + nonlinearity="swish", + normalization="GroupNorm", + num_res_blocks=1, + progressive="output_skip", + progressive_combine="sum", + progressive_input="input_skip", + resamp_with_conv=True, + resblock_type="biggan", + scale_by_sigma=True, + skip_rescale=True, + continuous=True, + ): super().__init__() - self.config = config - self.act = act = get_act(config) + self.register_to_config( + centered=centered, + image_size=image_size, + num_channels=num_channels, + attention_type=attention_type, + attn_resolutions=attn_resolutions, + ch_mult=ch_mult, + conditional=conditional, + conv_size=conv_size, + dropout=dropout, + embedding_type=embedding_type, + fir=fir, + fir_kernel=fir_kernel, + fourier_scale=fourier_scale, + init_scale=init_scale, + nf=nf, + nonlinearity=nonlinearity, + normalization=normalization, + num_res_blocks=num_res_blocks, + progressive=progressive, + progressive_combine=progressive_combine, + progressive_input=progressive_input, + resamp_with_conv=resamp_with_conv, + resblock_type=resblock_type, + scale_by_sigma=scale_by_sigma, + skip_rescale=skip_rescale, + continuous=continuous, + ) + self.act = act = get_act(nonlinearity) # self.register_buffer('sigmas', torch.tensor(utils.get_sigmas(config))) - self.nf = nf = config.model.nf - ch_mult = config.model.ch_mult - self.num_res_blocks = num_res_blocks = config.model.num_res_blocks - self.attn_resolutions = attn_resolutions = config.model.attn_resolutions - dropout = config.model.dropout - resamp_with_conv = config.model.resamp_with_conv - self.num_resolutions = num_resolutions = len(ch_mult) - self.all_resolutions = all_resolutions = [config.data.image_size // (2**i) for i in range(num_resolutions)] + self.nf = nf + self.num_res_blocks = num_res_blocks + self.attn_resolutions = attn_resolutions + self.num_resolutions = len(ch_mult) + self.all_resolutions = all_resolutions = [image_size // (2**i) for i in range(self.num_resolutions)] - self.conditional = conditional = config.model.conditional # noise-conditional - fir = config.model.fir - fir_kernel = config.model.fir_kernel - self.skip_rescale = skip_rescale = config.model.skip_rescale - self.resblock_type = resblock_type = config.model.resblock_type.lower() - self.progressive = progressive = config.model.progressive.lower() - self.progressive_input = progressive_input = config.model.progressive_input.lower() - self.embedding_type = embedding_type = config.model.embedding_type.lower() - init_scale = config.model.init_scale + self.conditional = conditional + self.skip_rescale = skip_rescale + self.resblock_type = resblock_type + self.progressive = progressive + self.progressive_input = progressive_input + self.embedding_type = embedding_type assert progressive in ["none", "output_skip", "residual"] assert progressive_input in ["none", "input_skip", "residual"] assert embedding_type in ["fourier", "positional"] - combine_method = config.model.progressive_combine.lower() + combine_method = progressive_combine.lower() combiner = functools.partial(Combine, method=combine_method) modules = [] # timestep/noise_level embedding; only for continuous training if embedding_type == "fourier": # Gaussian Fourier features embeddings. - assert config.training.continuous, "Fourier features are only used for continuous training." - - modules.append(GaussianFourierProjection(embedding_size=nf, scale=config.model.fourier_scale)) + modules.append(GaussianFourierProjection(embedding_size=nf, scale=fourier_scale)) embed_dim = 2 * nf elif embedding_type == "positional": @@ -809,7 +859,7 @@ class NCSNpp(nn.Module): # Downsampling block - channels = config.data.num_channels + channels = num_channels if progressive_input != "none": input_pyramid_ch = channels @@ -817,7 +867,7 @@ class NCSNpp(nn.Module): hs_c = [nf] in_ch = nf - for i_level in range(num_resolutions): + for i_level in range(self.num_resolutions): # Residual blocks for this resolution for i_block in range(num_res_blocks): out_ch = nf * ch_mult[i_level] @@ -828,7 +878,7 @@ class NCSNpp(nn.Module): modules.append(AttnBlock(channels=in_ch)) hs_c.append(in_ch) - if i_level != num_resolutions - 1: + if i_level != self.num_resolutions - 1: if resblock_type == "ddpm": modules.append(Downsample(in_ch=in_ch)) else: @@ -852,7 +902,7 @@ class NCSNpp(nn.Module): pyramid_ch = 0 # Upsampling block - for i_level in reversed(range(num_resolutions)): + for i_level in reversed(range(self.num_resolutions)): for i_block in range(num_res_blocks + 1): out_ch = nf * ch_mult[i_level] modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch)) @@ -862,7 +912,7 @@ class NCSNpp(nn.Module): modules.append(AttnBlock(channels=in_ch)) if progressive != "none": - if i_level == num_resolutions - 1: + if i_level == self.num_resolutions - 1: if progressive == "output_skip": modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) modules.append(conv3x3(in_ch, channels, init_scale=init_scale)) @@ -899,7 +949,6 @@ class NCSNpp(nn.Module): self.all_modules = nn.ModuleList(modules) def forward(self, x, time_cond): - # import ipdb; ipdb.set_trace() # timestep/noise_level embedding; only for continuous training modules = self.all_modules m_idx = 0 @@ -926,7 +975,7 @@ class NCSNpp(nn.Module): else: temb = None - if not self.config.data.centered: + if not self.config.centered: # If input data is in [0, 1] x = 2 * x - 1.0 @@ -1044,7 +1093,7 @@ class NCSNpp(nn.Module): m_idx += 1 assert m_idx == len(modules) - if self.config.model.scale_by_sigma: + if self.config.scale_by_sigma: used_sigmas = used_sigmas.reshape((x.shape[0], *([1] * len(x.shape[1:])))) h = h / used_sigmas From de810814dad81b6be5534cd071455c5e14245bc8 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 25 Jun 2022 02:50:42 +0000 Subject: [PATCH 15/35] finish first version sde ve --- src/diffusers/__init__.py | 2 +- .../models/unet_sde_score_estimation.py | 25 ++--- src/diffusers/pipelines/__init__.py | 3 + src/diffusers/pipelines/pipeline_score_sde.py | 94 +++++++++++++++++++ src/diffusers/schedulers/__init__.py | 1 + src/diffusers/schedulers/scheduling_ve_sde.py | 73 ++++++++++++++ 6 files changed, 180 insertions(+), 18 deletions(-) create mode 100755 src/diffusers/pipelines/pipeline_score_sde.py create mode 100644 src/diffusers/schedulers/scheduling_ve_sde.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index ba6df51070..ac68a6c309 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -10,7 +10,7 @@ from .modeling_utils import ModelMixin from .models import NCSNpp, TemporalUNet, UNetLDMModel, UNetModel from .pipeline_utils import DiffusionPipeline from .pipelines import BDDMPipeline, DDIMPipeline, DDPMPipeline, PNDMPipeline -from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin +from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin, VeSdeScheduler if is_transformers_available(): diff --git a/src/diffusers/models/unet_sde_score_estimation.py b/src/diffusers/models/unet_sde_score_estimation.py index 30671ef293..d46782c7af 100644 --- a/src/diffusers/models/unet_sde_score_estimation.py +++ b/src/diffusers/models/unet_sde_score_estimation.py @@ -15,10 +15,6 @@ # helpers functions -from ..modeling_utils import ModelMixin -from ..configuration_utils import ConfigMixin - - import functools import math import string @@ -28,16 +24,15 @@ import torch import torch.nn as nn import torch.nn.functional as F +from ..configuration_utils import ConfigMixin +from ..modeling_utils import ModelMixin + def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): - return upfirdn2d_native( - input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] - ) + return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) -def upfirdn2d_native( - input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 -): +def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1): _, channel, in_h, in_w = input.shape input = input.reshape(-1, in_h, in_w, 1) @@ -48,9 +43,7 @@ def upfirdn2d_native( out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) out = out.view(-1, in_h * up_y, in_w * up_x, minor) - out = F.pad( - out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] - ) + out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) out = out[ :, max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), @@ -59,9 +52,7 @@ def upfirdn2d_native( ] out = out.permute(0, 3, 1, 2) - out = out.reshape( - [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] - ) + out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) out = F.conv2d(out, w) out = out.reshape( @@ -350,7 +341,7 @@ conv3x3 = ddpm_conv3x3 def _einsum(a, b, c, x, y): - einsum_str = '{},{}->{}'.format(''.join(a), ''.join(b), ''.join(c)) + einsum_str = "{},{}->{}".format("".join(a), "".join(b), "".join(c)) return torch.einsum(einsum_str, x, y) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index d26c5fc8a7..e724149acf 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -5,6 +5,9 @@ from .pipeline_ddpm import DDPMPipeline from .pipeline_pndm import PNDMPipeline +# from .pipeline_score_sde import NCSNppPipeline + + if is_transformers_available(): from .pipeline_glide import GlidePipeline from .pipeline_latent_diffusion import LatentDiffusionPipeline diff --git a/src/diffusers/pipelines/pipeline_score_sde.py b/src/diffusers/pipelines/pipeline_score_sde.py new file mode 100755 index 0000000000..5b3cb5bcea --- /dev/null +++ b/src/diffusers/pipelines/pipeline_score_sde.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python3 +import numpy as np +import torch + +import PIL +from diffusers import DiffusionPipeline + + +# from configs.ve import ffhq_ncsnpp_continuous as configs +# from configs.ve import cifar10_ncsnpp_continuous as configs + +# ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth" +# ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth" +# Note usually we need to restore ema etc... +# ema restored checkpoint used from below +torch.backends.cuda.matmul.allow_tf32 = False +torch.manual_seed(0) + + +class NCSNppPipeline(DiffusionPipeline): + def __init__(self, model, scheduler): + super().__init__() + self.register_modules(model=model, scheduler=scheduler) + + def __call__(self, generator=None): + N = self.scheduler.config.N + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + img_size = self.model.config.image_size + channels = self.model.config.num_channels + shape = (1, channels, img_size, img_size) + + model = torch.nn.DataParallel(self.model.to(device)) + + centered = False + n_steps = 1 + + # Initial sample + x = torch.randn(*shape) * self.scheduler.config.sigma_max + x = x.to(device) + + for i in range(N): + sigma_t = self.scheduler.get_sigma_t(i) * torch.ones(shape[0], device=device) + + for _ in range(n_steps): + with torch.no_grad(): + result = model(x, sigma_t) + x = self.scheduler.step_correct(result, x) + + with torch.no_grad(): + result = model(x, sigma_t) + + x, x_mean = self.scheduler.step_pred(result, x, i) + + x = x_mean + + if centered: + x = (x + 1.0) / 2.0 + + return x + + +pipeline = NCSNppPipeline.from_pretrained("/home/patrick/ffhq_ncsnpp") +x = pipeline() + + +# for 5 cifar10 +# x_sum = 106071.9922 +# x_mean = 34.52864456176758 + +# for 1000 cifar10 +# x_sum = 461.9700 +# x_mean = 0.1504 + +# for N=2 for 1024 +x_sum = 3382810112.0 +x_mean = 1075.366455078125 + + +def check_x_sum_x_mean(x, x_sum, x_mean): + assert (x.abs().sum() - x_sum).abs().cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}" + assert (x.abs().mean() - x_mean).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}" + + +check_x_sum_x_mean(x, x_sum, x_mean) + + +def save_image(x): + image_processed = np.clip(x.permute(0, 2, 3, 1).cpu().numpy() * 255, 0, 255).astype(np.uint8) + image_pil = PIL.Image.fromarray(image_processed[0]) + image_pil.save("../images/hey.png") + + +# save_image(x) diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index b2d533d380..ea30626670 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -21,3 +21,4 @@ from .scheduling_ddpm import DDPMScheduler from .scheduling_grad_tts import GradTTSScheduler from .scheduling_pndm import PNDMScheduler from .scheduling_utils import SchedulerMixin +from .scheduling_ve_sde import VeSdeScheduler diff --git a/src/diffusers/schedulers/scheduling_ve_sde.py b/src/diffusers/schedulers/scheduling_ve_sde.py new file mode 100644 index 0000000000..6f188272b2 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_ve_sde.py @@ -0,0 +1,73 @@ +# Copyright 2022 UC Berkely Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin +from .scheduling_utils import SchedulerMixin + + +class VeSdeScheduler(SchedulerMixin, ConfigMixin): + def __init__(self, snr=0.15, sigma_min=0.01, sigma_max=1348, N=2, sampling_eps=1e-5, tensor_format="np"): + super().__init__() + self.register_to_config( + snr=snr, + sigma_min=sigma_min, + sigma_max=sigma_max, + N=N, + sampling_eps=sampling_eps, + ) + # (PVP) - clean up with .config. + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.snr = snr + self.N = N + self.discrete_sigmas = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N)) + self.timesteps = torch.linspace(1, sampling_eps, N) + + def get_sigma_t(self, t): + return self.sigma_min * (self.sigma_max / self.sigma_min) ** self.timesteps[t] + + def step_pred(self, result, x, t): + t = self.timesteps[t] * torch.ones(x.shape[0], device=x.device) + + timestep = (t * (self.N - 1)).long() + sigma = self.discrete_sigmas.to(t.device)[timestep] + adjacent_sigma = torch.where( + timestep == 0, torch.zeros_like(t), self.discrete_sigmas[timestep - 1].to(t.device) + ) + f = torch.zeros_like(x) + G = torch.sqrt(sigma**2 - adjacent_sigma**2) + + f = f - G[:, None, None, None] ** 2 * result + + z = torch.randn_like(x) + x_mean = x - f + x = x_mean + G[:, None, None, None] * z + return x, x_mean + + def step_correct(self, result, x): + noise = torch.randn_like(x) + grad_norm = torch.norm(result.reshape(result.shape[0], -1), dim=-1).mean() + noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean() + step_size = (self.snr * noise_norm / grad_norm) ** 2 * 2 + step_size = step_size * torch.ones(x.shape[0], device=x.device) + x_mean = x + step_size[:, None, None, None] * result + + x = x_mean + torch.sqrt(step_size * 2)[:, None, None, None] * noise + + return x From 433cb3f801470feef0a6bab3c90c7b303926ec98 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 25 Jun 2022 18:25:43 +0000 Subject: [PATCH 16/35] clean up sde ve more --- README.md | 24 ++++++ src/diffusers/__init__.py | 11 ++- src/diffusers/pipelines/__init__.py | 3 +- ..._score_sde.py => pipeline_score_sde_ve.py} | 76 +++++++++---------- src/diffusers/schedulers/__init__.py | 2 +- ...eduling_ve_sde.py => scheduling_sde_ve.py} | 42 ++++++---- tests/test_modeling_utils.py | 20 +++++ 7 files changed, 120 insertions(+), 58 deletions(-) rename src/diffusers/pipelines/{pipeline_score_sde.py => pipeline_score_sde_ve.py} (53%) rename src/diffusers/schedulers/{scheduling_ve_sde.py => scheduling_sde_ve.py} (63%) diff --git a/README.md b/README.md index 6c2c9799c2..bee5d880f0 100644 --- a/README.md +++ b/README.md @@ -226,6 +226,30 @@ image_pil = PIL.Image.fromarray(image_processed[0]) image_pil.save("test.png") ``` +#### **Example 1024x1024 image generation with SDE VE** + +See [paper](https://arxiv.org/abs/2011.13456) for more information on SDE VE. + +```python +from diffusers import DiffusionPipeline +import torch +import PIL.Image + +torch.manual_seed(32) + +score_sde_sv = DiffusionPipeline.from_pretrained("fusing/ffhq_ncsnpp") + +# Note this might take up to 3 minutes on a GPU +image = score_sde_sv(num_inference_steps=2000) + +image = image.permute(0, 2, 3, 1).cpu().numpy() +image = np.clip(image * 255, 0, 255).astype(np.uint8) +image_pil = PIL.Image.fromarray(image[0]) + +# save image +image_pil.save("test.png") +``` + #### **Text to Image generation with Latent Diffusion** _Note: To use latent diffusion install transformers from [this branch](https://github.com/patil-suraj/transformers/tree/ldm-bert)._ diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index ac68a6c309..d851608376 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -9,8 +9,15 @@ __version__ = "0.0.4" from .modeling_utils import ModelMixin from .models import NCSNpp, TemporalUNet, UNetLDMModel, UNetModel from .pipeline_utils import DiffusionPipeline -from .pipelines import BDDMPipeline, DDIMPipeline, DDPMPipeline, PNDMPipeline -from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin, VeSdeScheduler +from .pipelines import BDDMPipeline, DDIMPipeline, DDPMPipeline, PNDMPipeline, ScoreSdeVePipeline +from .schedulers import ( + DDIMScheduler, + DDPMScheduler, + GradTTSScheduler, + PNDMScheduler, + SchedulerMixin, + ScoreSdeVeScheduler, +) if is_transformers_available(): diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index e724149acf..b579652e25 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -3,9 +3,10 @@ from .pipeline_bddm import BDDMPipeline from .pipeline_ddim import DDIMPipeline from .pipeline_ddpm import DDPMPipeline from .pipeline_pndm import PNDMPipeline +from .pipeline_score_sde_ve import ScoreSdeVePipeline -# from .pipeline_score_sde import NCSNppPipeline +# from .pipeline_score_sde import ScoreSdeVePipeline if is_transformers_available(): diff --git a/src/diffusers/pipelines/pipeline_score_sde.py b/src/diffusers/pipelines/pipeline_score_sde_ve.py similarity index 53% rename from src/diffusers/pipelines/pipeline_score_sde.py rename to src/diffusers/pipelines/pipeline_score_sde_ve.py index 5b3cb5bcea..ca7592492b 100755 --- a/src/diffusers/pipelines/pipeline_score_sde.py +++ b/src/diffusers/pipelines/pipeline_score_sde_ve.py @@ -6,51 +6,44 @@ import PIL from diffusers import DiffusionPipeline -# from configs.ve import ffhq_ncsnpp_continuous as configs -# from configs.ve import cifar10_ncsnpp_continuous as configs - -# ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth" -# ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth" -# Note usually we need to restore ema etc... -# ema restored checkpoint used from below -torch.backends.cuda.matmul.allow_tf32 = False -torch.manual_seed(0) +# TODO(Patrick, Anton, Suraj) - rename `x` to better variable names -class NCSNppPipeline(DiffusionPipeline): +class ScoreSdeVePipeline(DiffusionPipeline): def __init__(self, model, scheduler): super().__init__() self.register_modules(model=model, scheduler=scheduler) - def __call__(self, generator=None): - N = self.scheduler.config.N + def __call__(self, num_inference_steps=2000, generator=None): device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") img_size = self.model.config.image_size channels = self.model.config.num_channels shape = (1, channels, img_size, img_size) - model = torch.nn.DataParallel(self.model.to(device)) + model = self.model.to(device) centered = False n_steps = 1 - # Initial sample x = torch.randn(*shape) * self.scheduler.config.sigma_max x = x.to(device) - for i in range(N): - sigma_t = self.scheduler.get_sigma_t(i) * torch.ones(shape[0], device=device) + self.scheduler.set_timesteps(num_inference_steps) + self.scheduler.set_sigmas(num_inference_steps) + + for i, t in enumerate(self.scheduler.timesteps): + sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=device) for _ in range(n_steps): with torch.no_grad(): - result = model(x, sigma_t) + result = self.model(x, sigma_t) x = self.scheduler.step_correct(result, x) with torch.no_grad(): result = model(x, sigma_t) - x, x_mean = self.scheduler.step_pred(result, x, i) + x, x_mean = self.scheduler.step_pred(result, x, t) x = x_mean @@ -60,9 +53,16 @@ class NCSNppPipeline(DiffusionPipeline): return x -pipeline = NCSNppPipeline.from_pretrained("/home/patrick/ffhq_ncsnpp") -x = pipeline() +# from configs.ve import ffhq_ncsnpp_continuous as configs +# from configs.ve import cifar10_ncsnpp_continuous as configs +# ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth" +# ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth" +# Note usually we need to restore ema etc... +# ema restored checkpoint used from below + +# pipeline = ScoreSdeVePipeline.from_pretrained("/home/patrick/ffhq_ncsnpp") +# x = pipeline(num_inference_steps=2) # for 5 cifar10 # x_sum = 106071.9922 @@ -73,22 +73,22 @@ x = pipeline() # x_mean = 0.1504 # for N=2 for 1024 -x_sum = 3382810112.0 -x_mean = 1075.366455078125 - - -def check_x_sum_x_mean(x, x_sum, x_mean): - assert (x.abs().sum() - x_sum).abs().cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}" - assert (x.abs().mean() - x_mean).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}" - - -check_x_sum_x_mean(x, x_sum, x_mean) - - -def save_image(x): - image_processed = np.clip(x.permute(0, 2, 3, 1).cpu().numpy() * 255, 0, 255).astype(np.uint8) - image_pil = PIL.Image.fromarray(image_processed[0]) - image_pil.save("../images/hey.png") - - +# x_sum = 3382810112.0 +# x_mean = 1075.366455078125 +# +# +# def check_x_sum_x_mean(x, x_sum, x_mean): +# assert (x.abs().sum() - x_sum).abs().cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}" +# assert (x.abs().mean() - x_mean).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}" +# +# +# check_x_sum_x_mean(x, x_sum, x_mean) +# +# +# def save_image(x): +# image_processed = np.clip(x.permute(0, 2, 3, 1).cpu().numpy() * 255, 0, 255).astype(np.uint8) +# image_pil = PIL.Image.fromarray(image_processed[0]) +# image_pil.save("../images/hey.png") +# +# # save_image(x) diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index ea30626670..36bc441b50 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -21,4 +21,4 @@ from .scheduling_ddpm import DDPMScheduler from .scheduling_grad_tts import GradTTSScheduler from .scheduling_pndm import PNDMScheduler from .scheduling_utils import SchedulerMixin -from .scheduling_ve_sde import VeSdeScheduler +from .scheduling_sde_ve import ScoreSdeVeScheduler diff --git a/src/diffusers/schedulers/scheduling_ve_sde.py b/src/diffusers/schedulers/scheduling_sde_ve.py similarity index 63% rename from src/diffusers/schedulers/scheduling_ve_sde.py rename to src/diffusers/schedulers/scheduling_sde_ve.py index 6f188272b2..652314b9c9 100644 --- a/src/diffusers/schedulers/scheduling_ve_sde.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -1,4 +1,4 @@ -# Copyright 2022 UC Berkely Team and The HuggingFace Team. All rights reserved. +# Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim +# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch + +# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit import numpy as np import torch @@ -21,7 +23,7 @@ from ..configuration_utils import ConfigMixin from .scheduling_utils import SchedulerMixin -class VeSdeScheduler(SchedulerMixin, ConfigMixin): +class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): def __init__(self, snr=0.15, sigma_min=0.01, sigma_max=1348, N=2, sampling_eps=1e-5, tensor_format="np"): super().__init__() self.register_to_config( @@ -31,24 +33,32 @@ class VeSdeScheduler(SchedulerMixin, ConfigMixin): N=N, sampling_eps=sampling_eps, ) - # (PVP) - clean up with .config. - self.sigma_min = sigma_min - self.sigma_max = sigma_max - self.snr = snr - self.N = N - self.discrete_sigmas = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N)) - self.timesteps = torch.linspace(1, sampling_eps, N) - def get_sigma_t(self, t): - return self.sigma_min * (self.sigma_max / self.sigma_min) ** self.timesteps[t] + self.sigmas = None + self.discrete_sigmas = None + self.timesteps = None + + def set_timesteps(self, num_inference_steps): + self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps) + + def set_sigmas(self, num_inference_steps): + if self.timesteps is None: + self.set_timesteps(num_inference_steps) + + self.discrete_sigmas = torch.exp( + torch.linspace(np.log(self.config.sigma_min), np.log(self.config.sigma_max), num_inference_steps) + ) + self.sigmas = torch.tensor( + [self.config.sigma_min * (self.config.sigma_max / self.sigma_min) ** t for t in self.timesteps] + ) def step_pred(self, result, x, t): - t = self.timesteps[t] * torch.ones(x.shape[0], device=x.device) + t = t * torch.ones(x.shape[0], device=x.device) + timestep = (t * (2 - 1)).long() - timestep = (t * (self.N - 1)).long() sigma = self.discrete_sigmas.to(t.device)[timestep] adjacent_sigma = torch.where( - timestep == 0, torch.zeros_like(t), self.discrete_sigmas[timestep - 1].to(t.device) + timestep == 0, torch.zeros_like(t), self.discrete_sigmas[timestep - 1].to(timestep.device) ) f = torch.zeros_like(x) G = torch.sqrt(sigma**2 - adjacent_sigma**2) @@ -64,7 +74,7 @@ class VeSdeScheduler(SchedulerMixin, ConfigMixin): noise = torch.randn_like(x) grad_norm = torch.norm(result.reshape(result.shape[0], -1), dim=-1).mean() noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean() - step_size = (self.snr * noise_norm / grad_norm) ** 2 * 2 + step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2 step_size = step_size * torch.ones(x.shape[0], device=x.device) x_mean = x + step_size[:, None, None, None] * result diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index db4ed6eb02..15547afba6 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -33,8 +33,11 @@ from diffusers import ( GradTTSPipeline, GradTTSScheduler, LatentDiffusionPipeline, + NCSNpp, PNDMPipeline, PNDMScheduler, + ScoreSdeVePipeline, + ScoreSdeVeScheduler, UNetGradTTSModel, UNetLDMModel, UNetModel, @@ -721,6 +724,23 @@ class PipelineTesterMixin(unittest.TestCase): ) assert (mel_spec[0, :3, :3].cpu().flatten() - expected_slice).abs().max() < 1e-2 + @slow + def test_score_sde_ve_pipeline(self): + torch.manual_seed(0) + + model = NCSNpp.from_pretrained("fusing/ffhq_ncsnpp") + scheduler = ScoreSdeVeScheduler.from_config("fusing/ffhq_ncsnpp") + + sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler) + + image = sde_ve(num_inference_steps=2) + + expected_image_sum = 3382810112.0 + expected_image_mean = 1075.366455078125 + + assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2 + assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4 + def test_module_from_pipeline(self): model = DiffWave(num_res_layers=4) noise_scheduler = DDPMScheduler(timesteps=12) From 135acd83af86b02c1dfb3bdb5650d19ef10332b2 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sun, 26 Jun 2022 00:56:18 +0000 Subject: [PATCH 17/35] fix bug --- src/diffusers/schedulers/scheduling_sde_ve.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py index 652314b9c9..2456afad7d 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -24,13 +24,12 @@ from .scheduling_utils import SchedulerMixin class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): - def __init__(self, snr=0.15, sigma_min=0.01, sigma_max=1348, N=2, sampling_eps=1e-5, tensor_format="np"): + def __init__(self, snr=0.15, sigma_min=0.01, sigma_max=1348, sampling_eps=1e-5, tensor_format="np"): super().__init__() self.register_to_config( snr=snr, sigma_min=sigma_min, sigma_max=sigma_max, - N=N, sampling_eps=sampling_eps, ) @@ -54,7 +53,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): def step_pred(self, result, x, t): t = t * torch.ones(x.shape[0], device=x.device) - timestep = (t * (2 - 1)).long() + timestep = (t * (len(self.timesteps) - 1)).long() sigma = self.discrete_sigmas.to(t.device)[timestep] adjacent_sigma = torch.where( From d5c527a499cf284f6756e0a28b68e14e808dfcc9 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sun, 26 Jun 2022 11:02:57 +0000 Subject: [PATCH 18/35] clean up --- .../pipelines/pipeline_score_sde_ve.py | 55 +------------------ 1 file changed, 2 insertions(+), 53 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_score_sde_ve.py b/src/diffusers/pipelines/pipeline_score_sde_ve.py index ca7592492b..a1a4843af1 100755 --- a/src/diffusers/pipelines/pipeline_score_sde_ve.py +++ b/src/diffusers/pipelines/pipeline_score_sde_ve.py @@ -1,14 +1,9 @@ #!/usr/bin/env python3 -import numpy as np import torch - -import PIL from diffusers import DiffusionPipeline # TODO(Patrick, Anton, Suraj) - rename `x` to better variable names - - class ScoreSdeVePipeline(DiffusionPipeline): def __init__(self, model, scheduler): super().__init__() @@ -23,7 +18,7 @@ class ScoreSdeVePipeline(DiffusionPipeline): model = self.model.to(device) - centered = False + # TODO(Patrick) move to scheduler config n_steps = 1 x = torch.randn(*shape) * self.scheduler.config.sigma_max @@ -45,50 +40,4 @@ class ScoreSdeVePipeline(DiffusionPipeline): x, x_mean = self.scheduler.step_pred(result, x, t) - x = x_mean - - if centered: - x = (x + 1.0) / 2.0 - - return x - - -# from configs.ve import ffhq_ncsnpp_continuous as configs -# from configs.ve import cifar10_ncsnpp_continuous as configs - -# ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth" -# ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth" -# Note usually we need to restore ema etc... -# ema restored checkpoint used from below - -# pipeline = ScoreSdeVePipeline.from_pretrained("/home/patrick/ffhq_ncsnpp") -# x = pipeline(num_inference_steps=2) - -# for 5 cifar10 -# x_sum = 106071.9922 -# x_mean = 34.52864456176758 - -# for 1000 cifar10 -# x_sum = 461.9700 -# x_mean = 0.1504 - -# for N=2 for 1024 -# x_sum = 3382810112.0 -# x_mean = 1075.366455078125 -# -# -# def check_x_sum_x_mean(x, x_sum, x_mean): -# assert (x.abs().sum() - x_sum).abs().cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}" -# assert (x.abs().mean() - x_mean).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}" -# -# -# check_x_sum_x_mean(x, x_sum, x_mean) -# -# -# def save_image(x): -# image_processed = np.clip(x.permute(0, 2, 3, 1).cpu().numpy() * 255, 0, 255).astype(np.uint8) -# image_pil = PIL.Image.fromarray(image_processed[0]) -# image_pil.save("../images/hey.png") -# -# -# save_image(x) + return x_mean From dc6d028654c7a6f1ae22728bddf4509206127ac0 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sun, 26 Jun 2022 23:41:55 +0000 Subject: [PATCH 19/35] add vp sampler --- src/diffusers/__init__.py | 3 +- .../models/unet_sde_score_estimation.py | 2 +- src/diffusers/pipelines/__init__.py | 1 + .../pipelines/pipeline_score_sde_ve.py | 0 .../pipelines/pipeline_score_sde_vp.py | 42 ++++++++++++++ src/diffusers/schedulers/__init__.py | 1 + src/diffusers/schedulers/scheduling_sde_vp.py | 55 +++++++++++++++++++ tests/test_modeling_utils.py | 19 +++++++ 8 files changed, 121 insertions(+), 2 deletions(-) mode change 100755 => 100644 src/diffusers/pipelines/pipeline_score_sde_ve.py create mode 100644 src/diffusers/pipelines/pipeline_score_sde_vp.py create mode 100644 src/diffusers/schedulers/scheduling_sde_vp.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index d851608376..213b9a5bcc 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -9,7 +9,7 @@ __version__ = "0.0.4" from .modeling_utils import ModelMixin from .models import NCSNpp, TemporalUNet, UNetLDMModel, UNetModel from .pipeline_utils import DiffusionPipeline -from .pipelines import BDDMPipeline, DDIMPipeline, DDPMPipeline, PNDMPipeline, ScoreSdeVePipeline +from .pipelines import BDDMPipeline, DDIMPipeline, DDPMPipeline, PNDMPipeline, ScoreSdeVePipeline, ScoreSdeVpPipeline from .schedulers import ( DDIMScheduler, DDPMScheduler, @@ -17,6 +17,7 @@ from .schedulers import ( PNDMScheduler, SchedulerMixin, ScoreSdeVeScheduler, + ScoreSdeVpScheduler, ) diff --git a/src/diffusers/models/unet_sde_score_estimation.py b/src/diffusers/models/unet_sde_score_estimation.py index d46782c7af..784d528dd4 100644 --- a/src/diffusers/models/unet_sde_score_estimation.py +++ b/src/diffusers/models/unet_sde_score_estimation.py @@ -766,7 +766,7 @@ class NCSNpp(ModelMixin, ConfigMixin): continuous=continuous, ) self.act = act = get_act(nonlinearity) - # self.register_buffer('sigmas', torch.tensor(utils.get_sigmas(config))) + self.register_buffer('sigmas', torch.tensor(np.linspace(np.log(50), np.log(0.01), 10))) self.nf = nf self.num_res_blocks = num_res_blocks diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index b579652e25..5d7b1f14cf 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -4,6 +4,7 @@ from .pipeline_ddim import DDIMPipeline from .pipeline_ddpm import DDPMPipeline from .pipeline_pndm import PNDMPipeline from .pipeline_score_sde_ve import ScoreSdeVePipeline +from .pipeline_score_sde_vp import ScoreSdeVpPipeline # from .pipeline_score_sde import ScoreSdeVePipeline diff --git a/src/diffusers/pipelines/pipeline_score_sde_ve.py b/src/diffusers/pipelines/pipeline_score_sde_ve.py old mode 100755 new mode 100644 diff --git a/src/diffusers/pipelines/pipeline_score_sde_vp.py b/src/diffusers/pipelines/pipeline_score_sde_vp.py new file mode 100644 index 0000000000..9eb886296b --- /dev/null +++ b/src/diffusers/pipelines/pipeline_score_sde_vp.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +import torch +from diffusers import DiffusionPipeline + + +# TODO(Patrick, Anton, Suraj) - rename `x` to better variable names +class ScoreSdeVpPipeline(DiffusionPipeline): + def __init__(self, model, scheduler): + super().__init__() + self.register_modules(model=model, scheduler=scheduler) + + def __call__(self, num_inference_steps=1000, generator=None): + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + img_size = self.model.config.image_size + channels = self.model.config.num_channels + shape = (1, channels, img_size, img_size) + + beta_min, beta_max = 0.1, 20 + + model = self.model.to(device) + + x = torch.randn(*shape).to(device) + + self.scheduler.set_timesteps(num_inference_steps) + + for i, t in enumerate(self.scheduler.timesteps): + t = t * torch.ones(shape[0], device=device) + sigma_t = t * (num_inference_steps - 1) + + with torch.no_grad(): + result = model(x, sigma_t) + + log_mean_coeff = -0.25 * t ** 2 * (beta_max - beta_min) - 0.5 * t * beta_min + std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff)) + result = -result / std[:, None, None, None] + + x, x_mean = self.scheduler.step_pred(result, x, t) + + x_mean = (x_mean + 1.) / 2. + + return x_mean diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 36bc441b50..6a6d628661 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -22,3 +22,4 @@ from .scheduling_grad_tts import GradTTSScheduler from .scheduling_pndm import PNDMScheduler from .scheduling_utils import SchedulerMixin from .scheduling_sde_ve import ScoreSdeVeScheduler +from .scheduling_sde_vp import ScoreSdeVpScheduler diff --git a/src/diffusers/schedulers/scheduling_sde_vp.py b/src/diffusers/schedulers/scheduling_sde_vp.py new file mode 100644 index 0000000000..c7b6497117 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_sde_vp.py @@ -0,0 +1,55 @@ +# Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch + +# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin +from .scheduling_utils import SchedulerMixin + + +class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): + def __init__(self, beta_min=0.1, beta_max=20, sampling_eps=1e-3, tensor_format="np"): + super().__init__() + self.register_to_config( + beta_min=beta_min, + beta_max=beta_max, + sampling_eps=sampling_eps, + ) + + self.sigmas = None + self.discrete_sigmas = None + self.timesteps = None + + def set_timesteps(self, num_inference_steps): + self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps) + + def step_pred(self, result, x, t): + dt = -1. / len(self.timesteps) + z = torch.randn_like(x) + + beta_t = self.beta_min + t * (self.beta_max - self.beta_min) + drift = -0.5 * beta_t[:, None, None, None] * x + diffusion = torch.sqrt(beta_t) + + drift = drift - diffusion[:, None, None, None] ** 2 * result + + x_mean = x + drift * dt + x = x_mean + diffusion[:, None, None, None] * np.sqrt(-dt) * z + + return x, x_mean diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 15547afba6..32bc3003c5 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -38,6 +38,8 @@ from diffusers import ( PNDMScheduler, ScoreSdeVePipeline, ScoreSdeVeScheduler, + ScoreSdeVpPipeline, + ScoreSdeVpScheduler, UNetGradTTSModel, UNetLDMModel, UNetModel, @@ -741,6 +743,23 @@ class PipelineTesterMixin(unittest.TestCase): assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2 assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4 + @slow + def test_score_sde_vp_pipeline(self): + + model = NCSNpp.from_pretrained("/home/patrick/cifar10-ddpmpp-vp") + scheduler = ScoreSdeVpScheduler() + + sde_vp = ScoreSdeVpPipeline(model=model, scheduler=scheduler) + + torch.manual_seed(0) + image = sde_vp(num_inference_steps=10) + + expected_image_sum = 4183.2012 + expected_image_mean = 1.3617 + + assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2 + assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4 + def test_module_from_pipeline(self): model = DiffWave(num_res_layers=4) noise_scheduler = DDPMScheduler(timesteps=12) From ba264419f40b94fd2e8135096db4780e1c188aef Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 27 Jun 2022 00:07:57 +0000 Subject: [PATCH 20/35] finish vp --- .../models/unet_sde_score_estimation.py | 5 ++--- .../pipelines/pipeline_score_sde_ve.py | 1 + .../pipelines/pipeline_score_sde_vp.py | 15 +++++---------- src/diffusers/schedulers/__init__.py | 2 +- src/diffusers/schedulers/scheduling_sde_ve.py | 2 ++ src/diffusers/schedulers/scheduling_sde_vp.py | 19 ++++++++++++++----- tests/test_modeling_utils.py | 4 ++-- 7 files changed, 27 insertions(+), 21 deletions(-) diff --git a/src/diffusers/models/unet_sde_score_estimation.py b/src/diffusers/models/unet_sde_score_estimation.py index 784d528dd4..299f96c9cd 100644 --- a/src/diffusers/models/unet_sde_score_estimation.py +++ b/src/diffusers/models/unet_sde_score_estimation.py @@ -766,7 +766,6 @@ class NCSNpp(ModelMixin, ConfigMixin): continuous=continuous, ) self.act = act = get_act(nonlinearity) - self.register_buffer('sigmas', torch.tensor(np.linspace(np.log(50), np.log(0.01), 10))) self.nf = nf self.num_res_blocks = num_res_blocks @@ -939,7 +938,7 @@ class NCSNpp(ModelMixin, ConfigMixin): self.all_modules = nn.ModuleList(modules) - def forward(self, x, time_cond): + def forward(self, x, time_cond, sigmas=None): # timestep/noise_level embedding; only for continuous training modules = self.all_modules m_idx = 0 @@ -952,7 +951,7 @@ class NCSNpp(ModelMixin, ConfigMixin): elif self.embedding_type == "positional": # Sinusoidal positional embeddings. timesteps = time_cond - used_sigmas = self.sigmas[time_cond.long()] + used_sigmas = sigmas temb = get_timestep_embedding(timesteps, self.nf) else: diff --git a/src/diffusers/pipelines/pipeline_score_sde_ve.py b/src/diffusers/pipelines/pipeline_score_sde_ve.py index a1a4843af1..1dfd304d83 100644 --- a/src/diffusers/pipelines/pipeline_score_sde_ve.py +++ b/src/diffusers/pipelines/pipeline_score_sde_ve.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 import torch + from diffusers import DiffusionPipeline diff --git a/src/diffusers/pipelines/pipeline_score_sde_vp.py b/src/diffusers/pipelines/pipeline_score_sde_vp.py index 9eb886296b..29551d9a6e 100644 --- a/src/diffusers/pipelines/pipeline_score_sde_vp.py +++ b/src/diffusers/pipelines/pipeline_score_sde_vp.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 import torch + from diffusers import DiffusionPipeline @@ -16,27 +17,21 @@ class ScoreSdeVpPipeline(DiffusionPipeline): channels = self.model.config.num_channels shape = (1, channels, img_size, img_size) - beta_min, beta_max = 0.1, 20 - model = self.model.to(device) x = torch.randn(*shape).to(device) self.scheduler.set_timesteps(num_inference_steps) - for i, t in enumerate(self.scheduler.timesteps): + for t in self.scheduler.timesteps: t = t * torch.ones(shape[0], device=device) - sigma_t = t * (num_inference_steps - 1) + scaled_t = t * (num_inference_steps - 1) with torch.no_grad(): - result = model(x, sigma_t) - - log_mean_coeff = -0.25 * t ** 2 * (beta_max - beta_min) - 0.5 * t * beta_min - std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff)) - result = -result / std[:, None, None, None] + result = model(x, scaled_t) x, x_mean = self.scheduler.step_pred(result, x, t) - x_mean = (x_mean + 1.) / 2. + x_mean = (x_mean + 1.0) / 2.0 return x_mean diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 6a6d628661..ad66fe5991 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -20,6 +20,6 @@ from .scheduling_ddim import DDIMScheduler from .scheduling_ddpm import DDPMScheduler from .scheduling_grad_tts import GradTTSScheduler from .scheduling_pndm import PNDMScheduler -from .scheduling_utils import SchedulerMixin from .scheduling_sde_ve import ScoreSdeVeScheduler from .scheduling_sde_vp import ScoreSdeVpScheduler +from .scheduling_utils import SchedulerMixin diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py index 2456afad7d..79936105b9 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -52,6 +52,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ) def step_pred(self, result, x, t): + # TODO(Patrick) better comments + non-PyTorch t = t * torch.ones(x.shape[0], device=x.device) timestep = (t * (len(self.timesteps) - 1)).long() @@ -70,6 +71,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): return x, x_mean def step_correct(self, result, x): + # TODO(Patrick) better comments + non-PyTorch noise = torch.randn_like(x) grad_norm = torch.norm(result.reshape(result.shape[0], -1), dim=-1).mean() noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean() diff --git a/src/diffusers/schedulers/scheduling_sde_vp.py b/src/diffusers/schedulers/scheduling_sde_vp.py index c7b6497117..dda32a2742 100644 --- a/src/diffusers/schedulers/scheduling_sde_vp.py +++ b/src/diffusers/schedulers/scheduling_sde_vp.py @@ -40,16 +40,25 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps) def step_pred(self, result, x, t): - dt = -1. / len(self.timesteps) - z = torch.randn_like(x) + # TODO(Patrick) better comments + non-PyTorch + # postprocess model result + log_mean_coeff = ( + -0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min + ) + std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff)) + result = -result / std[:, None, None, None] - beta_t = self.beta_min + t * (self.beta_max - self.beta_min) + # compute + dt = -1.0 / len(self.timesteps) + + beta_t = self.config.beta_min + t * (self.config.beta_max - self.config.beta_min) drift = -0.5 * beta_t[:, None, None, None] * x diffusion = torch.sqrt(beta_t) - drift = drift - diffusion[:, None, None, None] ** 2 * result - x_mean = x + drift * dt + + # add noise + z = torch.randn_like(x) x = x_mean + diffusion[:, None, None, None] * np.sqrt(-dt) * z return x, x_mean diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 32bc3003c5..6c5c115f19 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -746,8 +746,8 @@ class PipelineTesterMixin(unittest.TestCase): @slow def test_score_sde_vp_pipeline(self): - model = NCSNpp.from_pretrained("/home/patrick/cifar10-ddpmpp-vp") - scheduler = ScoreSdeVpScheduler() + model = NCSNpp.from_pretrained("fusing/cifar10-ddpmpp-vp") + scheduler = ScoreSdeVpScheduler.from_config("fusing/cifar10-ddpmpp-vp") sde_vp = ScoreSdeVpPipeline(model=model, scheduler=scheduler) From 9a4d53a4762e6b4c8766f66fcb02f78b99f170b5 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 27 Jun 2022 02:09:49 +0200 Subject: [PATCH 21/35] Update README.md --- README.md | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/README.md b/README.md index bee5d880f0..7f2704e5d6 100644 --- a/README.md +++ b/README.md @@ -234,6 +234,7 @@ See [paper](https://arxiv.org/abs/2011.13456) for more information on SDE VE. from diffusers import DiffusionPipeline import torch import PIL.Image +import numpy as np torch.manual_seed(32) @@ -249,6 +250,31 @@ image_pil = PIL.Image.fromarray(image[0]) # save image image_pil.save("test.png") ``` +#### **Example 32x32 image generation with SDE VP** + +See [paper](https://arxiv.org/abs/2011.13456) for more information on SDE VE. + +```python +from diffusers import DiffusionPipeline +import torch +import PIL.Image +import numpy as np + +torch.manual_seed(32) + +score_sde_sv = DiffusionPipeline.from_pretrained("fusing/cifar10-ddpmpp-deep-vp") + +# Note this might take up to 3 minutes on a GPU +image = score_sde_sv(num_inference_steps=1000) + +image = image.permute(0, 2, 3, 1).cpu().numpy() +image = np.clip(image * 255, 0, 255).astype(np.uint8) +image_pil = PIL.Image.fromarray(image[0]) + +# save image +image_pil.save("test.png") +``` + #### **Text to Image generation with Latent Diffusion** From f6e8c8c09cb476b59c2d28c6fd7a969dee575164 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 27 Jun 2022 10:46:13 +0200 Subject: [PATCH 22/35] add layers --- src/diffusers/models/attention2d.py | 0 src/diffusers/models/resnet.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/diffusers/models/attention2d.py create mode 100644 src/diffusers/models/resnet.py diff --git a/src/diffusers/models/attention2d.py b/src/diffusers/models/attention2d.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py new file mode 100644 index 0000000000..e69de29bb2 From 45a09bebf38a201e92002739a63eaa4b1f608920 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 27 Jun 2022 10:46:39 +0200 Subject: [PATCH 23/35] add first files --- tests/test_layers_utils.py | 738 +++++++++++++++++++++++++++++++++++++ 1 file changed, 738 insertions(+) create mode 100755 tests/test_layers_utils.py diff --git a/tests/test_layers_utils.py b/tests/test_layers_utils.py new file mode 100755 index 0000000000..db4ed6eb02 --- /dev/null +++ b/tests/test_layers_utils.py @@ -0,0 +1,738 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect +import tempfile +import unittest + +import numpy as np +import torch + +from diffusers import ( + BDDMPipeline, + DDIMPipeline, + DDIMScheduler, + DDPMPipeline, + DDPMScheduler, + GlidePipeline, + GlideSuperResUNetModel, + GlideTextToImageUNetModel, + GradTTSPipeline, + GradTTSScheduler, + LatentDiffusionPipeline, + PNDMPipeline, + PNDMScheduler, + UNetGradTTSModel, + UNetLDMModel, + UNetModel, +) +from diffusers.configuration_utils import ConfigMixin +from diffusers.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.pipeline_bddm import DiffWave +from diffusers.testing_utils import floats_tensor, slow, torch_device + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class ConfigTester(unittest.TestCase): + def test_load_not_from_mixin(self): + with self.assertRaises(ValueError): + ConfigMixin.from_config("dummy_path") + + def test_save_load(self): + class SampleObject(ConfigMixin): + config_name = "config.json" + + def __init__( + self, + a=2, + b=5, + c=(2, 5), + d="for diffusion", + e=[1, 3], + ): + self.register_to_config(a=a, b=b, c=c, d=d, e=e) + + obj = SampleObject() + config = obj.config + + assert config["a"] == 2 + assert config["b"] == 5 + assert config["c"] == (2, 5) + assert config["d"] == "for diffusion" + assert config["e"] == [1, 3] + + with tempfile.TemporaryDirectory() as tmpdirname: + obj.save_config(tmpdirname) + new_obj = SampleObject.from_config(tmpdirname) + new_config = new_obj.config + + # unfreeze configs + config = dict(config) + new_config = dict(new_config) + + assert config.pop("c") == (2, 5) # instantiated as tuple + assert new_config.pop("c") == [2, 5] # saved & loaded as list because of json + assert config == new_config + + +class ModelTesterMixin: + def test_from_pretrained_save_pretrained(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + new_model = self.model_class.from_pretrained(tmpdirname) + new_model.to(torch_device) + + with torch.no_grad(): + image = model(**inputs_dict) + new_image = new_model(**inputs_dict) + + max_diff = (image - new_image).abs().sum().item() + self.assertLessEqual(max_diff, 1e-5, "Models give different forward passes") + + def test_determinism(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + with torch.no_grad(): + first = model(**inputs_dict) + second = model(**inputs_dict) + + out_1 = first.cpu().numpy() + out_2 = second.cpu().numpy() + out_1 = out_1[~np.isnan(out_1)] + out_2 = out_2[~np.isnan(out_2)] + max_diff = np.amax(np.abs(out_1 - out_2)) + self.assertLessEqual(max_diff, 1e-5) + + def test_output(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model(**inputs_dict) + + self.assertIsNotNone(output) + expected_shape = inputs_dict["x"].shape + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + + def test_forward_signature(self): + init_dict, _ = self.prepare_init_args_and_inputs_for_common() + + model = self.model_class(**init_dict) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["x", "timesteps"] + self.assertListEqual(arg_names[:2], expected_arg_names) + + def test_model_from_config(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + # test if the model can be loaded from the config + # and has all the expected shape + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_config(tmpdirname) + new_model = self.model_class.from_config(tmpdirname) + new_model.to(torch_device) + new_model.eval() + + # check if all paramters shape are the same + for param_name in model.state_dict().keys(): + param_1 = model.state_dict()[param_name] + param_2 = new_model.state_dict()[param_name] + self.assertEqual(param_1.shape, param_2.shape) + + with torch.no_grad(): + output_1 = model(**inputs_dict) + output_2 = new_model(**inputs_dict) + + self.assertEqual(output_1.shape, output_2.shape) + + def test_training(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + model = self.model_class(**init_dict) + model.to(torch_device) + model.train() + output = model(**inputs_dict) + noise = torch.randn((inputs_dict["x"].shape[0],) + self.get_output_shape).to(torch_device) + loss = torch.nn.functional.mse_loss(output, noise) + loss.backward() + + +class UnetModelTests(ModelTesterMixin, unittest.TestCase): + model_class = UNetModel + + @property + def dummy_input(self): + batch_size = 4 + num_channels = 3 + sizes = (32, 32) + + noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) + time_step = torch.tensor([10]).to(torch_device) + + return {"x": noise, "timesteps": time_step} + + @property + def get_input_shape(self): + return (3, 32, 32) + + @property + def get_output_shape(self): + return (3, 32, 32) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "ch": 32, + "ch_mult": (1, 2), + "num_res_blocks": 2, + "attn_resolutions": (16,), + "resolution": 32, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_from_pretrained_hub(self): + model, loading_info = UNetModel.from_pretrained("fusing/ddpm_dummy", output_loading_info=True) + self.assertIsNotNone(model) + self.assertEqual(len(loading_info["missing_keys"]), 0) + + model.to(torch_device) + image = model(**self.dummy_input) + + assert image is not None, "Make sure output is not None" + + def test_output_pretrained(self): + model = UNetModel.from_pretrained("fusing/ddpm_dummy") + model.eval() + + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + noise = torch.randn(1, model.config.in_channels, model.config.resolution, model.config.resolution) + time_step = torch.tensor([10]) + + with torch.no_grad(): + output = model(noise, time_step) + + output_slice = output[0, -1, -3:, -3:].flatten() + # fmt: off + expected_output_slice = torch.tensor([0.2891, -0.1899, 0.2595, -0.6214, 0.0968, -0.2622, 0.4688, 0.1311, 0.0053]) + # fmt: on + self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) + + +class GlideSuperResUNetTests(ModelTesterMixin, unittest.TestCase): + model_class = GlideSuperResUNetModel + + @property + def dummy_input(self): + batch_size = 4 + num_channels = 6 + sizes = (32, 32) + low_res_size = (4, 4) + + noise = torch.randn((batch_size, num_channels // 2) + sizes).to(torch_device) + low_res = torch.randn((batch_size, 3) + low_res_size).to(torch_device) + time_step = torch.tensor([10] * noise.shape[0], device=torch_device) + + return {"x": noise, "timesteps": time_step, "low_res": low_res} + + @property + def get_input_shape(self): + return (3, 32, 32) + + @property + def get_output_shape(self): + return (6, 32, 32) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "attention_resolutions": (2,), + "channel_mult": (1, 2), + "in_channels": 6, + "out_channels": 6, + "model_channels": 32, + "num_head_channels": 8, + "num_heads_upsample": 1, + "num_res_blocks": 2, + "resblock_updown": True, + "resolution": 32, + "use_scale_shift_norm": True, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_output(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model(**inputs_dict) + + output, _ = torch.split(output, 3, dim=1) + + self.assertIsNotNone(output) + expected_shape = inputs_dict["x"].shape + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + + def test_from_pretrained_hub(self): + model, loading_info = GlideSuperResUNetModel.from_pretrained( + "fusing/glide-super-res-dummy", output_loading_info=True + ) + self.assertIsNotNone(model) + self.assertEqual(len(loading_info["missing_keys"]), 0) + + model.to(torch_device) + image = model(**self.dummy_input) + + assert image is not None, "Make sure output is not None" + + def test_output_pretrained(self): + model = GlideSuperResUNetModel.from_pretrained("fusing/glide-super-res-dummy") + + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + noise = torch.randn(1, 3, 64, 64) + low_res = torch.randn(1, 3, 4, 4) + time_step = torch.tensor([42] * noise.shape[0]) + + with torch.no_grad(): + output = model(noise, time_step, low_res) + + output, _ = torch.split(output, 3, dim=1) + output_slice = output[0, -1, -3:, -3:].flatten() + # fmt: off + expected_output_slice = torch.tensor([-22.8782, -23.2652, -15.3966, -22.8034, -23.3159, -15.5640, -15.3970, -15.4614, - 10.4370]) + # fmt: on + self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) + + +class GlideTextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase): + model_class = GlideTextToImageUNetModel + + @property + def dummy_input(self): + batch_size = 4 + num_channels = 3 + sizes = (32, 32) + transformer_dim = 32 + seq_len = 16 + + noise = torch.randn((batch_size, num_channels) + sizes).to(torch_device) + emb = torch.randn((batch_size, seq_len, transformer_dim)).to(torch_device) + time_step = torch.tensor([10] * noise.shape[0], device=torch_device) + + return {"x": noise, "timesteps": time_step, "transformer_out": emb} + + @property + def get_input_shape(self): + return (3, 32, 32) + + @property + def get_output_shape(self): + return (6, 32, 32) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "attention_resolutions": (2,), + "channel_mult": (1, 2), + "in_channels": 3, + "out_channels": 6, + "model_channels": 32, + "num_head_channels": 8, + "num_heads_upsample": 1, + "num_res_blocks": 2, + "resblock_updown": True, + "resolution": 32, + "use_scale_shift_norm": True, + "transformer_dim": 32, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_output(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model(**inputs_dict) + + output, _ = torch.split(output, 3, dim=1) + + self.assertIsNotNone(output) + expected_shape = inputs_dict["x"].shape + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + + def test_from_pretrained_hub(self): + model, loading_info = GlideTextToImageUNetModel.from_pretrained( + "fusing/unet-glide-text2im-dummy", output_loading_info=True + ) + self.assertIsNotNone(model) + self.assertEqual(len(loading_info["missing_keys"]), 0) + + model.to(torch_device) + image = model(**self.dummy_input) + + assert image is not None, "Make sure output is not None" + + def test_output_pretrained(self): + model = GlideTextToImageUNetModel.from_pretrained("fusing/unet-glide-text2im-dummy") + + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + noise = torch.randn((1, model.config.in_channels, model.config.resolution, model.config.resolution)).to( + torch_device + ) + emb = torch.randn((1, 16, model.config.transformer_dim)).to(torch_device) + time_step = torch.tensor([10] * noise.shape[0], device=torch_device) + + with torch.no_grad(): + output = model(noise, time_step, emb) + + output, _ = torch.split(output, 3, dim=1) + output_slice = output[0, -1, -3:, -3:].flatten() + # fmt: off + expected_output_slice = torch.tensor([2.7766, -10.3558, -14.9149, -0.9376, -14.9175, -17.7679, -5.5565, -12.9521, -12.9845]) + # fmt: on + self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) + + +class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): + model_class = UNetLDMModel + + @property + def dummy_input(self): + batch_size = 4 + num_channels = 4 + sizes = (32, 32) + + noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) + time_step = torch.tensor([10]).to(torch_device) + + return {"x": noise, "timesteps": time_step} + + @property + def get_input_shape(self): + return (4, 32, 32) + + @property + def get_output_shape(self): + return (4, 32, 32) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "image_size": 32, + "in_channels": 4, + "out_channels": 4, + "model_channels": 32, + "num_res_blocks": 2, + "attention_resolutions": (16,), + "channel_mult": (1, 2), + "num_heads": 2, + "conv_resample": True, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_from_pretrained_hub(self): + model, loading_info = UNetLDMModel.from_pretrained("fusing/unet-ldm-dummy", output_loading_info=True) + self.assertIsNotNone(model) + self.assertEqual(len(loading_info["missing_keys"]), 0) + + model.to(torch_device) + image = model(**self.dummy_input) + + assert image is not None, "Make sure output is not None" + + def test_output_pretrained(self): + model = UNetLDMModel.from_pretrained("fusing/unet-ldm-dummy") + model.eval() + + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size) + time_step = torch.tensor([10] * noise.shape[0]) + + with torch.no_grad(): + output = model(noise, time_step) + + output_slice = output[0, -1, -3:, -3:].flatten() + # fmt: off + expected_output_slice = torch.tensor([-13.3258, -20.1100, -15.9873, -17.6617, -23.0596, -17.9419, -13.3675, -16.1889, -12.3800]) + # fmt: on + + self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) + + +class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase): + model_class = UNetGradTTSModel + + @property + def dummy_input(self): + batch_size = 4 + num_features = 32 + seq_len = 16 + + noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device) + condition = floats_tensor((batch_size, num_features, seq_len)).to(torch_device) + mask = floats_tensor((batch_size, 1, seq_len)).to(torch_device) + time_step = torch.tensor([10] * batch_size).to(torch_device) + + return {"x": noise, "timesteps": time_step, "mu": condition, "mask": mask} + + @property + def get_input_shape(self): + return (4, 32, 16) + + @property + def get_output_shape(self): + return (4, 32, 16) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "dim": 64, + "groups": 4, + "dim_mults": (1, 2), + "n_feats": 32, + "pe_scale": 1000, + "n_spks": 1, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_from_pretrained_hub(self): + model, loading_info = UNetGradTTSModel.from_pretrained("fusing/unet-grad-tts-dummy", output_loading_info=True) + self.assertIsNotNone(model) + self.assertEqual(len(loading_info["missing_keys"]), 0) + + model.to(torch_device) + image = model(**self.dummy_input) + + assert image is not None, "Make sure output is not None" + + def test_output_pretrained(self): + model = UNetGradTTSModel.from_pretrained("fusing/unet-grad-tts-dummy") + model.eval() + + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + num_features = model.config.n_feats + seq_len = 16 + noise = torch.randn((1, num_features, seq_len)) + condition = torch.randn((1, num_features, seq_len)) + mask = torch.randn((1, 1, seq_len)) + time_step = torch.tensor([10]) + + with torch.no_grad(): + output = model(noise, time_step, condition, mask) + + output_slice = output[0, -3:, -3:].flatten() + # fmt: off + expected_output_slice = torch.tensor([-0.0690, -0.0531, 0.0633, -0.0660, -0.0541, 0.0650, -0.0656, -0.0555, 0.0617]) + # fmt: on + + self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) + + +class PipelineTesterMixin(unittest.TestCase): + def test_from_pretrained_save_pretrained(self): + # 1. Load models + model = UNetModel(ch=32, ch_mult=(1, 2), num_res_blocks=2, attn_resolutions=(16,), resolution=32) + schedular = DDPMScheduler(timesteps=10) + + ddpm = DDPMPipeline(model, schedular) + + with tempfile.TemporaryDirectory() as tmpdirname: + ddpm.save_pretrained(tmpdirname) + new_ddpm = DDPMPipeline.from_pretrained(tmpdirname) + + generator = torch.manual_seed(0) + + image = ddpm(generator=generator) + generator = generator.manual_seed(0) + new_image = new_ddpm(generator=generator) + + assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass" + + @slow + def test_from_pretrained_hub(self): + model_path = "fusing/ddpm-cifar10" + + ddpm = DDPMPipeline.from_pretrained(model_path) + ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path) + + ddpm.noise_scheduler.num_timesteps = 10 + ddpm_from_hub.noise_scheduler.num_timesteps = 10 + + generator = torch.manual_seed(0) + + image = ddpm(generator=generator) + generator = generator.manual_seed(0) + new_image = ddpm_from_hub(generator=generator) + + assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass" + + @slow + def test_ddpm_cifar10(self): + generator = torch.manual_seed(0) + model_id = "fusing/ddpm-cifar10" + + unet = UNetModel.from_pretrained(model_id) + noise_scheduler = DDPMScheduler.from_config(model_id) + noise_scheduler = noise_scheduler.set_format("pt") + + ddpm = DDPMPipeline(unet=unet, noise_scheduler=noise_scheduler) + image = ddpm(generator=generator) + + image_slice = image[0, -1, -3:, -3:].cpu() + + assert image.shape == (1, 3, 32, 32) + expected_slice = torch.tensor([0.2250, 0.3375, 0.2360, 0.0930, 0.3440, 0.3156, 0.1937, 0.3585, 0.1761]) + assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 + + @slow + def test_ddim_cifar10(self): + generator = torch.manual_seed(0) + model_id = "fusing/ddpm-cifar10" + + unet = UNetModel.from_pretrained(model_id) + noise_scheduler = DDIMScheduler(tensor_format="pt") + + ddim = DDIMPipeline(unet=unet, noise_scheduler=noise_scheduler) + image = ddim(generator=generator, eta=0.0) + + image_slice = image[0, -1, -3:, -3:].cpu() + + assert image.shape == (1, 3, 32, 32) + expected_slice = torch.tensor( + [-0.7383, -0.7385, -0.7298, -0.7364, -0.7414, -0.7239, -0.6737, -0.6813, -0.7068] + ) + assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 + + @slow + def test_pndm_cifar10(self): + generator = torch.manual_seed(0) + model_id = "fusing/ddpm-cifar10" + + unet = UNetModel.from_pretrained(model_id) + noise_scheduler = PNDMScheduler(tensor_format="pt") + + pndm = PNDMPipeline(unet=unet, noise_scheduler=noise_scheduler) + image = pndm(generator=generator) + + image_slice = image[0, -1, -3:, -3:].cpu() + + assert image.shape == (1, 3, 32, 32) + expected_slice = torch.tensor( + [-0.7888, -0.7870, -0.7759, -0.7823, -0.8014, -0.7608, -0.6818, -0.7130, -0.7471] + ) + assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 + + @slow + def test_ldm_text2img(self): + model_id = "fusing/latent-diffusion-text2im-large" + ldm = LatentDiffusionPipeline.from_pretrained(model_id) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.manual_seed(0) + image = ldm([prompt], generator=generator, num_inference_steps=20) + + image_slice = image[0, -1, -3:, -3:].cpu() + + assert image.shape == (1, 3, 256, 256) + expected_slice = torch.tensor([0.7295, 0.7358, 0.7256, 0.7435, 0.7095, 0.6884, 0.7325, 0.6921, 0.6458]) + assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 + + @slow + def test_glide_text2img(self): + model_id = "fusing/glide-base" + glide = GlidePipeline.from_pretrained(model_id) + + prompt = "a pencil sketch of a corgi" + generator = torch.manual_seed(0) + image = glide(prompt, generator=generator, num_inference_steps_upscale=20) + + image_slice = image[0, :3, :3, -1].cpu() + + assert image.shape == (1, 256, 256, 3) + expected_slice = torch.tensor([0.7119, 0.7073, 0.6460, 0.7780, 0.7423, 0.6926, 0.7378, 0.7189, 0.7784]) + assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 + + @slow + def test_grad_tts(self): + model_id = "fusing/grad-tts-libri-tts" + grad_tts = GradTTSPipeline.from_pretrained(model_id) + noise_scheduler = GradTTSScheduler() + grad_tts.noise_scheduler = noise_scheduler + + text = "Hello world, I missed you so much." + generator = torch.manual_seed(0) + + # generate mel spectograms using text + mel_spec = grad_tts(text, generator=generator) + + assert mel_spec.shape == (1, 80, 143) + expected_slice = torch.tensor( + [-6.7584, -6.8347, -6.3293, -6.6437, -6.7233, -6.4684, -6.1187, -6.3172, -6.6890] + ) + assert (mel_spec[0, :3, :3].cpu().flatten() - expected_slice).abs().max() < 1e-2 + + def test_module_from_pipeline(self): + model = DiffWave(num_res_layers=4) + noise_scheduler = DDPMScheduler(timesteps=12) + + bddm = BDDMPipeline(model, noise_scheduler) + + # check if the library name for the diffwave moduel is set to pipeline module + self.assertTrue(bddm.config["diffwave"][0] == "pipeline_bddm") + + # check if we can save and load the pipeline + with tempfile.TemporaryDirectory() as tmpdirname: + bddm.save_pretrained(tmpdirname) + _ = BDDMPipeline.from_pretrained(tmpdirname) + # check if the same works using the DifusionPipeline class + _ = DiffusionPipeline.from_pretrained(tmpdirname) From fbb103deb6aa48959221b61ded849b47655a8684 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 27 Jun 2022 10:59:22 +0200 Subject: [PATCH 24/35] add the bert model in latent diffusion pipeline --- .../pipelines/pipeline_latent_diffusion.py | 542 +++++++++++++++++- 1 file changed, 541 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pipeline_latent_diffusion.py b/src/diffusers/pipelines/pipeline_latent_diffusion.py index 7d386765d4..219ad31d04 100644 --- a/src/diffusers/pipelines/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/pipeline_latent_diffusion.py @@ -1,17 +1,557 @@ -# pytorch_diffusion + derived encoder decoder import math +from typing import Optional, Tuple, Union import numpy as np import torch import torch.nn as nn +import torch.utils.checkpoint import tqdm + +try: + from transformers.activations import ACT2FN + from transformers.configuration_utils import PretrainedConfig + from transformers.modeling_outputs import BaseModelOutput + from transformers.modeling_utils import PreTrainedModel + from transformers.utils import logging +except ImportError: + raise ImportError("Please install the transformers.") + from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin from ..pipeline_utils import DiffusionPipeline +################################################################################ +# Code for the text transformer model +################################################################################ +""" PyTorch LDMBERT model.""" + + +logger = logging.get_logger(__name__) + +LDMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "ldm-bert", + # See all LDMBert models at https://huggingface.co/models?filter=ldmbert +] + + +LDMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "ldm-bert": "https://huggingface.co/ldm-bert/resolve/main/config.json", +} + + +""" LDMBERT model configuration""" + + +class LDMBertConfig(PretrainedConfig): + model_type = "ldmbert" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"} + + def __init__( + self, + vocab_size=30522, + max_position_embeddings=77, + encoder_layers=32, + encoder_ffn_dim=5120, + encoder_attention_heads=8, + head_dim=64, + encoder_layerdrop=0.0, + activation_function="gelu", + d_model=1280, + dropout=0.1, + attention_dropout=0.0, + activation_dropout=0.0, + init_std=0.02, + classifier_dropout=0.0, + scale_embedding=False, + use_cache=True, + pad_token_id=0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.d_model = d_model + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.head_dim = head_dim + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.activation_function = activation_function + self.init_std = init_std + self.encoder_layerdrop = encoder_layerdrop + self.classifier_dropout = classifier_dropout + self.use_cache = use_cache + self.num_hidden_layers = encoder_layers + self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True + + super().__init__(pad_token_id=pad_token_id, **kwargs) + + +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->LDMBert +class LDMBertAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + head_dim: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = False, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = head_dim + self.inner_dim = head_dim * num_heads + + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + + self.k_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias) + self.out_proj = nn.Linear(self.inner_dim, embed_dim) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned aross GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.inner_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +class LDMBertEncoderLayer(nn.Module): + def __init__(self, config: LDMBertConfig): + super().__init__() + self.embed_dim = config.d_model + self.self_attn = LDMBertAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + head_dim=config.head_dim, + dropout=config.attention_dropout, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: torch.FloatTensor, + layer_head_mask: torch.FloatTensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weights, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + ): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +# Copied from transformers.models.bart.modeling_bart.BartPretrainedModel with Bart->LDMBert +class LDMBertPreTrainedModel(PreTrainedModel): + config_class = LDMBertConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _keys_to_ignore_on_load_unexpected = [r"encoder\.version", r"decoder\.version"] + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (LDMBertEncoder,)): + module.gradient_checkpointing = value + + @property + def dummy_inputs(self): + pad_token = self.config.pad_token_id + input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) + dummy_inputs = { + "attention_mask": input_ids.ne(pad_token), + "input_ids": input_ids, + } + return dummy_inputs + + +class LDMBertEncoder(LDMBertPreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`LDMBertEncoderLayer`]. + + Args: + config: LDMBertConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, config: LDMBertConfig): + super().__init__(config) + + self.dropout = config.dropout + + embed_dim = config.d_model + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_position_embeddings + + self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim) + self.embed_positions = nn.Embedding(config.max_position_embeddings, embed_dim) + self.layers = nn.ModuleList([LDMBertEncoderLayer(config) for _ in range(config.encoder_layers)]) + self.layer_norm = nn.LayerNorm(embed_dim) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + seq_len = input_shape[1] + if position_ids is None: + position_ids = torch.arange(seq_len, dtype=torch.long, device=inputs_embeds.device).expand((1, -1)) + embed_pos = self.embed_positions(position_ids) + + hidden_states = inputs_embeds + embed_pos + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + if head_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(encoder_layer), + hidden_states, + attention_mask, + (head_mask[idx] if head_mask is not None else None), + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + hidden_states = self.layer_norm(hidden_states) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class LDMBertModel(LDMBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.model = LDMBertEncoder(config) + self.to_logits = nn.Linear(config.hidden_size, config.vocab_size) + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + return sequence_output + + def get_timestep_embedding(timesteps, embedding_dim): """ This matches the implementation in Denoising Diffusion Probabilistic Models: From 1a0331a78a8d55efabc32bdf2b3168505a77fd36 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 27 Jun 2022 09:07:57 +0000 Subject: [PATCH 25/35] fix some tests on gpu --- src/diffusers/models/embeddings.py | 56 +++++++++++++++++++++++++++ src/diffusers/models/unet_grad_tts.py | 1 + tests/test_modeling_utils.py | 5 ++- 3 files changed, 60 insertions(+), 2 deletions(-) create mode 100644 src/diffusers/models/embeddings.py diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py new file mode 100644 index 0000000000..333aeb85d5 --- /dev/null +++ b/src/diffusers/models/embeddings.py @@ -0,0 +1,56 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# unet.py +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + +# unet_glide.py +def timestep_embedding(timesteps, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + device=timesteps.device + ) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding diff --git a/src/diffusers/models/unet_grad_tts.py b/src/diffusers/models/unet_grad_tts.py index a2bdd951e4..81719f088b 100644 --- a/src/diffusers/models/unet_grad_tts.py +++ b/src/diffusers/models/unet_grad_tts.py @@ -198,6 +198,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): if not isinstance(spk, type(None)): s = self.spk_mlp(spk) + t = self.time_pos_emb(timesteps, scale=self.pe_scale) t = self.mlp(t) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 6c5c115f19..4556753006 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -113,7 +113,7 @@ class ModelTesterMixin: new_image = new_model(**inputs_dict) max_diff = (image - new_image).abs().sum().item() - self.assertLessEqual(max_diff, 1e-5, "Models give different forward passes") + self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes") def test_determinism(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -431,11 +431,12 @@ class GlideTextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase): emb = torch.randn((1, 16, model.config.transformer_dim)).to(torch_device) time_step = torch.tensor([10] * noise.shape[0], device=torch_device) + model.to(torch_device) with torch.no_grad(): output = model(noise, time_step, emb) output, _ = torch.split(output, 3, dim=1) - output_slice = output[0, -1, -3:, -3:].flatten() + output_slice = output[0, -1, -3:, -3:].cpu().flatten() # fmt: off expected_output_slice = torch.tensor([2.7766, -10.3558, -14.9149, -0.9376, -14.9175, -17.7679, -5.5565, -12.9521, -12.9845]) # fmt: on From 7c120874bedb11fc6fda619e99dc8ab6fb922f48 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 27 Jun 2022 11:09:21 +0200 Subject: [PATCH 26/35] fix LatentDiffusionPipeline --- src/diffusers/pipelines/pipeline_latent_diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pipeline_latent_diffusion.py b/src/diffusers/pipelines/pipeline_latent_diffusion.py index 219ad31d04..aa0ec3810e 100644 --- a/src/diffusers/pipelines/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/pipeline_latent_diffusion.py @@ -1435,7 +1435,7 @@ class LatentDiffusionPipeline(DiffusionPipeline): # get text embedding text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device) - text_embedding = self.bert(text_input.input_ids)[0] + text_embedding = self.bert(text_input.input_ids) num_trained_timesteps = self.noise_scheduler.config.timesteps inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps) From 43bf361a7a9e652ed2982198652145a3d67acb31 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 27 Jun 2022 11:10:10 +0200 Subject: [PATCH 27/35] fix more LatentDiffusionPipeline --- src/diffusers/pipelines/pipeline_latent_diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pipeline_latent_diffusion.py b/src/diffusers/pipelines/pipeline_latent_diffusion.py index aa0ec3810e..ffc8ae670c 100644 --- a/src/diffusers/pipelines/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/pipeline_latent_diffusion.py @@ -1431,7 +1431,7 @@ class LatentDiffusionPipeline(DiffusionPipeline): uncond_input = self.tokenizer([""], padding="max_length", max_length=77, return_tensors="pt").to( torch_device ) - uncond_embeddings = self.bert(uncond_input.input_ids)[0] + uncond_embeddings = self.bert(uncond_input.input_ids) # get text embedding text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device) From 168e5b7ffa4949fca82ed2fcd17d3451c5804401 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 27 Jun 2022 09:23:10 +0000 Subject: [PATCH 28/35] add embeddings --- src/diffusers/models/embeddings.py | 80 ++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 333aeb85d5..704f72f9b6 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -54,3 +54,83 @@ def timestep_embedding(timesteps, dim, max_period=10000): if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) return embedding + +# unet_grad_tts.py +class SinusoidalPosEmb(torch.nn.Module): + def __init__(self, dim): + super(SinusoidalPosEmb, self).__init__() + self.dim = dim + + def forward(self, x, scale=1000): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) + emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + +# unet_ldm.py +def timestep_embedding(timesteps, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + device=timesteps.device + ) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + +# unet_rl.py +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + +# unet_sde_score_estimation.py +def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000): + assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32 + half_dim = embedding_dim // 2 + # magic number 10000 is from transformers + emb = math.log(max_positions) / (half_dim - 1) + # emb = math.log(2.) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb) + # emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :] + # emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :] + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = F.pad(emb, (0, 1), mode="constant") + assert emb.shape == (timesteps.shape[0], embedding_dim) + return emb + +# unet_sde_score_estimation.py +class GaussianFourierProjection(nn.Module): + """Gaussian Fourier embeddings for noise levels.""" + + def __init__(self, embedding_size=256, scale=1.0): + super().__init__() + self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) + + def forward(self, x): + x_proj = x[:, None] * self.W[None, :] * 2 * np.pi + return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) From 17bf65e1868ef2821cae8769b1da7258e7f01f4c Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 27 Jun 2022 11:39:19 +0200 Subject: [PATCH 29/35] skip test_ldm_text2img for now --- tests/test_modeling_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 4556753006..191d28ee6d 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -679,6 +679,7 @@ class PipelineTesterMixin(unittest.TestCase): assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 @slow + @unittest.skip("Skipping for now as it takes too long") def test_ldm_text2img(self): model_id = "fusing/latent-diffusion-text2im-large" ldm = LatentDiffusionPipeline.from_pretrained(model_id) From 6921393ae27f7ab9f3f25f9e772ec42cfdf82f63 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 27 Jun 2022 11:42:52 +0200 Subject: [PATCH 30/35] add fast test for ldm --- src/diffusers/models/embeddings.py | 1 + src/diffusers/models/unet_grad_tts.py | 1 - tests/test_modeling_utils.py | 15 +++++++++++++++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 333aeb85d5..fbeb37a02a 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -34,6 +34,7 @@ def get_timestep_embedding(timesteps, embedding_dim): emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) return emb + # unet_glide.py def timestep_embedding(timesteps, dim, max_period=10000): """ diff --git a/src/diffusers/models/unet_grad_tts.py b/src/diffusers/models/unet_grad_tts.py index 81719f088b..a2bdd951e4 100644 --- a/src/diffusers/models/unet_grad_tts.py +++ b/src/diffusers/models/unet_grad_tts.py @@ -198,7 +198,6 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): if not isinstance(spk, type(None)): s = self.spk_mlp(spk) - t = self.time_pos_emb(timesteps, scale=self.pe_scale) t = self.mlp(t) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 191d28ee6d..2cb8196826 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -694,6 +694,21 @@ class PipelineTesterMixin(unittest.TestCase): expected_slice = torch.tensor([0.7295, 0.7358, 0.7256, 0.7435, 0.7095, 0.6884, 0.7325, 0.6921, 0.6458]) assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 + @slow + def test_ldm_text2img_fast(self): + model_id = "fusing/latent-diffusion-text2im-large" + ldm = LatentDiffusionPipeline.from_pretrained(model_id) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.manual_seed(0) + image = ldm([prompt], generator=generator, num_inference_steps=20) + + image_slice = image[0, -1, -3:, -3:].cpu() + + assert image.shape == (1, 3, 256, 256) + expected_slice = torch.rensor([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344]) + assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 + @slow def test_glide_text2img(self): model_id = "fusing/glide-base" From b7f0ce5b39acc2e0d879c55c56b374738be7c027 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 27 Jun 2022 11:44:05 +0200 Subject: [PATCH 31/35] fix test_ldm_text2img_fast --- tests/test_modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 2cb8196826..ea94f48608 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -706,7 +706,7 @@ class PipelineTesterMixin(unittest.TestCase): image_slice = image[0, -1, -3:, -3:].cpu() assert image.shape == (1, 3, 256, 256) - expected_slice = torch.rensor([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344]) + expected_slice = torch.tensor([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344]) assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 @slow From 9b9afc9726ef1656552bc7bfa2e5afac696b2070 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 27 Jun 2022 11:46:50 +0200 Subject: [PATCH 32/35] actually fix test_ldm_text2img_fast --- tests/test_modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index ea94f48608..697a377f8c 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -701,7 +701,7 @@ class PipelineTesterMixin(unittest.TestCase): prompt = "A painting of a squirrel eating a burger" generator = torch.manual_seed(0) - image = ldm([prompt], generator=generator, num_inference_steps=20) + image = ldm([prompt], generator=generator, num_inference_steps=1) image_slice = image[0, -1, -3:, -3:].cpu() From 02a76c2c81915846eb679ce9f24fbe9806e49c20 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 27 Jun 2022 10:14:54 +0000 Subject: [PATCH 33/35] consolidate timestep embeds --- src/diffusers/models/embeddings.py | 153 ++-- src/diffusers/models/unet.py | 39 +- src/diffusers/models/unet_glide.py | 45 +- src/diffusers/models/unet_ldm.py | 52 +- .../models/unet_sde_score_estimation.py | 23 +- tests/test_layers_utils.py | 714 +----------------- 6 files changed, 173 insertions(+), 853 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 704f72f9b6..a9143053a6 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -11,49 +11,104 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import torch +import math +import numpy as np + +from torch import nn +import torch.nn.functional as F -# unet.py -def get_timestep_embedding(timesteps, embedding_dim): +def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1, max_period=10000): """ This matches the implementation in Denoising Diffusion Probabilistic Models: - From Fairseq. - Build sinusoidal embeddings. - This matches the implementation in tensor2tensor, but differs slightly - from the description in Section 3.5 of "Attention Is All You Need". - """ - assert len(timesteps.shape) == 1 - half_dim = embedding_dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) - emb = emb.to(device=timesteps.device) - emb = timesteps.float()[:, None] * emb[None, :] - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) - return emb - -# unet_glide.py -def timestep_embedding(timesteps, dim, max_period=10000): - """ Create sinusoidal timestep embeddings. :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional. - :param dim: the dimension of the output. + :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an [N x dim] Tensor of positional embeddings. """ - half = dim // 2 - freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( - device=timesteps.device - ) - args = timesteps[:, None].float() * freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - return embedding + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = torch.exp(-math.log(max_period) * torch.arange(half_dim, dtype=torch.float32) / (embedding_dim // 2 - downscale_freq_shift)) + + emb = emb.to(device=timesteps.device) + emb = timesteps[:, None].float() * emb[None, :] + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +#def get_timestep_embedding(timesteps, embedding_dim): +# """ +# This matches the implementation in Denoising Diffusion Probabilistic Models: +# From Fairseq. +# Build sinusoidal embeddings. +# This matches the implementation in tensor2tensor, but differs slightly +# from the description in Section 3.5 of "Attention Is All You Need". +# """ +# assert len(timesteps.shape) == 1 +# +# half_dim = embedding_dim // 2 +# emb = math.log(10000) / (half_dim - 1) +# emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) +# emb = emb.to(device=timesteps.device) +# emb = timesteps.float()[:, None] * emb[None, :] +# emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) +# if embedding_dim % 2 == 1: # zero pad +# emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + + +#def timestep_embedding(timesteps, dim, max_period=10000): +# """ +# Create sinusoidal timestep embeddings. +# +# :param timesteps: a 1-D Tensor of N indices, one per batch element. +# These may be fractional. +# :param dim: the dimension of the output. +# :param max_period: controls the minimum frequency of the embeddings. +# :return: an [N x dim] Tensor of positional embeddings. +# """ +# half = dim // 2 +# freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( +# device=timesteps.device +# ) +# args = timesteps[:, None].float() * freqs[None, :] +# embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) +# if dim % 2: +# embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) +# return embedding + + +#def a_get_timestep_embedding(timesteps, embedding_dim, max_positions=10000): +# assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32 +# half_dim = embedding_dim // 2 + # magic number 10000 is from transformers +# emb = math.log(max_positions) / (half_dim - 1) + # emb = math.log(2.) / (half_dim - 1) +# emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb) + # emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :] + # emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :] +# emb = timesteps.float()[:, None] * emb[None, :] +# emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) +# if embedding_dim % 2 == 1: # zero pad +# emb = F.pad(emb, (0, 1), mode="constant") +# assert emb.shape == (timesteps.shape[0], embedding_dim) +# return emb + # unet_grad_tts.py class SinusoidalPosEmb(torch.nn.Module): @@ -70,26 +125,6 @@ class SinusoidalPosEmb(torch.nn.Module): emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb -# unet_ldm.py -def timestep_embedding(timesteps, dim, max_period=10000): - """ - Create sinusoidal timestep embeddings. - - :param timesteps: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param dim: the dimension of the output. - :param max_period: controls the minimum frequency of the embeddings. - :return: an [N x dim] Tensor of positional embeddings. - """ - half = dim // 2 - freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( - device=timesteps.device - ) - args = timesteps[:, None].float() * freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - return embedding # unet_rl.py class SinusoidalPosEmb(nn.Module): @@ -106,22 +141,6 @@ class SinusoidalPosEmb(nn.Module): emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb -# unet_sde_score_estimation.py -def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000): - assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32 - half_dim = embedding_dim // 2 - # magic number 10000 is from transformers - emb = math.log(max_positions) / (half_dim - 1) - # emb = math.log(2.) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb) - # emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :] - # emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :] - emb = timesteps.float()[:, None] * emb[None, :] - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = F.pad(emb, (0, 1), mode="constant") - assert emb.shape == (timesteps.shape[0], embedding_dim) - return emb # unet_sde_score_estimation.py class GaussianFourierProjection(nn.Module): diff --git a/src/diffusers/models/unet.py b/src/diffusers/models/unet.py index a4e1e22df8..7d5eebfd3d 100644 --- a/src/diffusers/models/unet.py +++ b/src/diffusers/models/unet.py @@ -30,27 +30,28 @@ from tqdm import tqdm from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin +from .embeddings import get_timestep_embedding -def get_timestep_embedding(timesteps, embedding_dim): - """ - This matches the implementation in Denoising Diffusion Probabilistic Models: - From Fairseq. - Build sinusoidal embeddings. - This matches the implementation in tensor2tensor, but differs slightly - from the description in Section 3.5 of "Attention Is All You Need". - """ - assert len(timesteps.shape) == 1 - - half_dim = embedding_dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) - emb = emb.to(device=timesteps.device) - emb = timesteps.float()[:, None] * emb[None, :] - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) - return emb +#def get_timestep_embedding(timesteps, embedding_dim): +# """ +# This matches the implementation in Denoising Diffusion Probabilistic Models: +# From Fairseq. +# Build sinusoidal embeddings. +# This matches the implementation in tensor2tensor, but differs slightly +# from the description in Section 3.5 of "Attention Is All You Need". +# """ +# assert len(timesteps.shape) == 1 +# +# half_dim = embedding_dim // 2 +# emb = math.log(10000) / (half_dim - 1) +# emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) +# emb = emb.to(device=timesteps.device) +# emb = timesteps.float()[:, None] * emb[None, :] +# emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) +# if embedding_dim % 2 == 1: # zero pad +# emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) +# return emb def nonlinearity(x): diff --git a/src/diffusers/models/unet_glide.py b/src/diffusers/models/unet_glide.py index 648ff9c34a..0e04537766 100644 --- a/src/diffusers/models/unet_glide.py +++ b/src/diffusers/models/unet_glide.py @@ -7,6 +7,7 @@ import torch.nn.functional as F from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin +from .embeddings import get_timestep_embedding def convert_module_to_f16(l): @@ -86,25 +87,25 @@ def normalization(channels, swish=0.0): return GroupNorm32(num_channels=channels, num_groups=32, swish=swish) -def timestep_embedding(timesteps, dim, max_period=10000): - """ - Create sinusoidal timestep embeddings. - - :param timesteps: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param dim: the dimension of the output. - :param max_period: controls the minimum frequency of the embeddings. - :return: an [N x dim] Tensor of positional embeddings. - """ - half = dim // 2 - freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( - device=timesteps.device - ) - args = timesteps[:, None].float() * freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - return embedding +# def timestep_embedding(timesteps, dim, max_period=10000): +# """ +# Create sinusoidal timestep embeddings. +# +# :param timesteps: a 1-D Tensor of N indices, one per batch element. +# These may be fractional. +# :param dim: the dimension of the output. +# :param max_period: controls the minimum frequency of the embeddings. +# :return: an [N x dim] Tensor of positional embeddings. +# """ +# half = dim // 2 +# freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( +# device=timesteps.device +# ) +# args = timesteps[:, None].float() * freqs[None] +# embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) +# if dim % 2: +# embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) +# return embedding def zero_module(module): @@ -627,7 +628,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin): """ hs = [] - emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + emb = self.time_embed(get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)) h = x.type(self.dtype) for module in self.input_blocks: @@ -714,7 +715,7 @@ class GlideTextToImageUNetModel(GlideUNetModel): def forward(self, x, timesteps, transformer_out=None): hs = [] - emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + emb = self.time_embed(get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)) # project the last token transformer_proj = self.transformer_proj(transformer_out[:, -1]) @@ -806,7 +807,7 @@ class GlideSuperResUNetModel(GlideUNetModel): x = torch.cat([x, upsampled], dim=1) hs = [] - emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + emb = self.time_embed(get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)) h = x for module in self.input_blocks: diff --git a/src/diffusers/models/unet_ldm.py b/src/diffusers/models/unet_ldm.py index cca3231341..cfc200bf6a 100644 --- a/src/diffusers/models/unet_ldm.py +++ b/src/diffusers/models/unet_ldm.py @@ -16,6 +16,7 @@ except: from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin +from .embeddings import get_timestep_embedding def exists(val): @@ -316,34 +317,25 @@ def normalization(channels, swish=0.0): return GroupNorm32(num_channels=channels, num_groups=32, swish=swish) -def timestep_embedding(timesteps, dim, max_period=10000): - """ - Create sinusoidal timestep embeddings. - - :param timesteps: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param dim: the dimension of the output. - :param max_period: controls the minimum frequency of the embeddings. - :return: an [N x dim] Tensor of positional embeddings. - """ - half = dim // 2 - freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( - device=timesteps.device - ) - args = timesteps[:, None].float() * freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - return embedding - - -def zero_module(module): - """ - Zero out the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().zero_() - return module +#def timestep_embedding(timesteps, dim, max_period=10000): +# """ +# Create sinusoidal timestep embeddings. +# +# :param timesteps: a 1-D Tensor of N indices, one per batch element. +# These may be fractional. +# :param dim: the dimension of the output. +# :param max_period: controls the minimum frequency of the embeddings. +# :return: an [N x dim] Tensor of positional embeddings. +# """ +# half = dim // 2 +# freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( +# device=timesteps.device +# ) +# args = timesteps[:, None].float() * freqs[None] +# embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) +# if dim % 2: +# embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) +# return embedding ## go @@ -1026,7 +1018,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin): hs = [] if not torch.is_tensor(timesteps): timesteps = torch.tensor([timesteps], dtype=torch.long, device=x.device) - t_emb = timestep_embedding(timesteps, self.model_channels) + t_emb = get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0) emb = self.time_embed(t_emb) if self.num_classes is not None: @@ -1240,7 +1232,7 @@ class EncoderUNetModel(nn.Module): :param timesteps: a 1-D batch of timesteps. :return: an [N x K] Tensor of outputs. """ - emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + emb = self.time_embed(get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)) results = [] h = x.type(self.dtype) diff --git a/src/diffusers/models/unet_sde_score_estimation.py b/src/diffusers/models/unet_sde_score_estimation.py index 299f96c9cd..7d00eb2174 100644 --- a/src/diffusers/models/unet_sde_score_estimation.py +++ b/src/diffusers/models/unet_sde_score_estimation.py @@ -26,6 +26,7 @@ import torch.nn.functional as F from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin +from .embeddings import get_timestep_embedding def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): @@ -381,21 +382,21 @@ def get_act(nonlinearity): raise NotImplementedError("activation function does not exist!") -def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000): - assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32 - half_dim = embedding_dim // 2 +#def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000): +# assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32 +# half_dim = embedding_dim // 2 # magic number 10000 is from transformers - emb = math.log(max_positions) / (half_dim - 1) +# emb = math.log(max_positions) / (half_dim - 1) # emb = math.log(2.) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb) +# emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb) # emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :] # emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :] - emb = timesteps.float()[:, None] * emb[None, :] - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = F.pad(emb, (0, 1), mode="constant") - assert emb.shape == (timesteps.shape[0], embedding_dim) - return emb +# emb = timesteps.float()[:, None] * emb[None, :] +# emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) +# if embedding_dim % 2 == 1: # zero pad +# emb = F.pad(emb, (0, 1), mode="constant") +# assert emb.shape == (timesteps.shape[0], embedding_dim) +# return emb def default_init(scale=1.0): diff --git a/tests/test_layers_utils.py b/tests/test_layers_utils.py index db4ed6eb02..0b50e7bc86 100755 --- a/tests/test_layers_utils.py +++ b/tests/test_layers_utils.py @@ -21,718 +21,24 @@ import unittest import numpy as np import torch -from diffusers import ( - BDDMPipeline, - DDIMPipeline, - DDIMScheduler, - DDPMPipeline, - DDPMScheduler, - GlidePipeline, - GlideSuperResUNetModel, - GlideTextToImageUNetModel, - GradTTSPipeline, - GradTTSScheduler, - LatentDiffusionPipeline, - PNDMPipeline, - PNDMScheduler, - UNetGradTTSModel, - UNetLDMModel, - UNetModel, -) -from diffusers.configuration_utils import ConfigMixin -from diffusers.pipeline_utils import DiffusionPipeline -from diffusers.pipelines.pipeline_bddm import DiffWave +#from diffusers.models.embeddings import get_timestep_embedding, timestep_embedding, a_get_timestep_embedding +from diffusers.models.embeddings import get_timestep_embedding, timestep_embedding from diffusers.testing_utils import floats_tensor, slow, torch_device torch.backends.cuda.matmul.allow_tf32 = False -class ConfigTester(unittest.TestCase): - def test_load_not_from_mixin(self): - with self.assertRaises(ValueError): - ConfigMixin.from_config("dummy_path") +class EmbeddingsTests(unittest.TestCase): - def test_save_load(self): - class SampleObject(ConfigMixin): - config_name = "config.json" + def test_timestep_embeddings(self): + embedding_dim = 16 + timesteps = torch.arange(10) - def __init__( - self, - a=2, - b=5, - c=(2, 5), - d="for diffusion", - e=[1, 3], - ): - self.register_to_config(a=a, b=b, c=c, d=d, e=e) + t1 = get_timestep_embedding(timesteps, embedding_dim) + t2 = timestep_embedding(timesteps, embedding_dim) + t3 = get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=True, downscale_freq_factor=8) - obj = SampleObject() - config = obj.config + import ipdb; ipdb.set_trace() - assert config["a"] == 2 - assert config["b"] == 5 - assert config["c"] == (2, 5) - assert config["d"] == "for diffusion" - assert config["e"] == [1, 3] - with tempfile.TemporaryDirectory() as tmpdirname: - obj.save_config(tmpdirname) - new_obj = SampleObject.from_config(tmpdirname) - new_config = new_obj.config - - # unfreeze configs - config = dict(config) - new_config = dict(new_config) - - assert config.pop("c") == (2, 5) # instantiated as tuple - assert new_config.pop("c") == [2, 5] # saved & loaded as list because of json - assert config == new_config - - -class ModelTesterMixin: - def test_from_pretrained_save_pretrained(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - new_model = self.model_class.from_pretrained(tmpdirname) - new_model.to(torch_device) - - with torch.no_grad(): - image = model(**inputs_dict) - new_image = new_model(**inputs_dict) - - max_diff = (image - new_image).abs().sum().item() - self.assertLessEqual(max_diff, 1e-5, "Models give different forward passes") - - def test_determinism(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() - with torch.no_grad(): - first = model(**inputs_dict) - second = model(**inputs_dict) - - out_1 = first.cpu().numpy() - out_2 = second.cpu().numpy() - out_1 = out_1[~np.isnan(out_1)] - out_2 = out_2[~np.isnan(out_2)] - max_diff = np.amax(np.abs(out_1 - out_2)) - self.assertLessEqual(max_diff, 1e-5) - - def test_output(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() - - with torch.no_grad(): - output = model(**inputs_dict) - - self.assertIsNotNone(output) - expected_shape = inputs_dict["x"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") - - def test_forward_signature(self): - init_dict, _ = self.prepare_init_args_and_inputs_for_common() - - model = self.model_class(**init_dict) - signature = inspect.signature(model.forward) - # signature.parameters is an OrderedDict => so arg_names order is deterministic - arg_names = [*signature.parameters.keys()] - - expected_arg_names = ["x", "timesteps"] - self.assertListEqual(arg_names[:2], expected_arg_names) - - def test_model_from_config(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() - - # test if the model can be loaded from the config - # and has all the expected shape - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_config(tmpdirname) - new_model = self.model_class.from_config(tmpdirname) - new_model.to(torch_device) - new_model.eval() - - # check if all paramters shape are the same - for param_name in model.state_dict().keys(): - param_1 = model.state_dict()[param_name] - param_2 = new_model.state_dict()[param_name] - self.assertEqual(param_1.shape, param_2.shape) - - with torch.no_grad(): - output_1 = model(**inputs_dict) - output_2 = new_model(**inputs_dict) - - self.assertEqual(output_1.shape, output_2.shape) - - def test_training(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - model = self.model_class(**init_dict) - model.to(torch_device) - model.train() - output = model(**inputs_dict) - noise = torch.randn((inputs_dict["x"].shape[0],) + self.get_output_shape).to(torch_device) - loss = torch.nn.functional.mse_loss(output, noise) - loss.backward() - - -class UnetModelTests(ModelTesterMixin, unittest.TestCase): - model_class = UNetModel - - @property - def dummy_input(self): - batch_size = 4 - num_channels = 3 - sizes = (32, 32) - - noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) - time_step = torch.tensor([10]).to(torch_device) - - return {"x": noise, "timesteps": time_step} - - @property - def get_input_shape(self): - return (3, 32, 32) - - @property - def get_output_shape(self): - return (3, 32, 32) - - def prepare_init_args_and_inputs_for_common(self): - init_dict = { - "ch": 32, - "ch_mult": (1, 2), - "num_res_blocks": 2, - "attn_resolutions": (16,), - "resolution": 32, - } - inputs_dict = self.dummy_input - return init_dict, inputs_dict - - def test_from_pretrained_hub(self): - model, loading_info = UNetModel.from_pretrained("fusing/ddpm_dummy", output_loading_info=True) - self.assertIsNotNone(model) - self.assertEqual(len(loading_info["missing_keys"]), 0) - - model.to(torch_device) - image = model(**self.dummy_input) - - assert image is not None, "Make sure output is not None" - - def test_output_pretrained(self): - model = UNetModel.from_pretrained("fusing/ddpm_dummy") - model.eval() - - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) - - noise = torch.randn(1, model.config.in_channels, model.config.resolution, model.config.resolution) - time_step = torch.tensor([10]) - - with torch.no_grad(): - output = model(noise, time_step) - - output_slice = output[0, -1, -3:, -3:].flatten() - # fmt: off - expected_output_slice = torch.tensor([0.2891, -0.1899, 0.2595, -0.6214, 0.0968, -0.2622, 0.4688, 0.1311, 0.0053]) - # fmt: on - self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) - - -class GlideSuperResUNetTests(ModelTesterMixin, unittest.TestCase): - model_class = GlideSuperResUNetModel - - @property - def dummy_input(self): - batch_size = 4 - num_channels = 6 - sizes = (32, 32) - low_res_size = (4, 4) - - noise = torch.randn((batch_size, num_channels // 2) + sizes).to(torch_device) - low_res = torch.randn((batch_size, 3) + low_res_size).to(torch_device) - time_step = torch.tensor([10] * noise.shape[0], device=torch_device) - - return {"x": noise, "timesteps": time_step, "low_res": low_res} - - @property - def get_input_shape(self): - return (3, 32, 32) - - @property - def get_output_shape(self): - return (6, 32, 32) - - def prepare_init_args_and_inputs_for_common(self): - init_dict = { - "attention_resolutions": (2,), - "channel_mult": (1, 2), - "in_channels": 6, - "out_channels": 6, - "model_channels": 32, - "num_head_channels": 8, - "num_heads_upsample": 1, - "num_res_blocks": 2, - "resblock_updown": True, - "resolution": 32, - "use_scale_shift_norm": True, - } - inputs_dict = self.dummy_input - return init_dict, inputs_dict - - def test_output(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() - - with torch.no_grad(): - output = model(**inputs_dict) - - output, _ = torch.split(output, 3, dim=1) - - self.assertIsNotNone(output) - expected_shape = inputs_dict["x"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") - - def test_from_pretrained_hub(self): - model, loading_info = GlideSuperResUNetModel.from_pretrained( - "fusing/glide-super-res-dummy", output_loading_info=True - ) - self.assertIsNotNone(model) - self.assertEqual(len(loading_info["missing_keys"]), 0) - - model.to(torch_device) - image = model(**self.dummy_input) - - assert image is not None, "Make sure output is not None" - - def test_output_pretrained(self): - model = GlideSuperResUNetModel.from_pretrained("fusing/glide-super-res-dummy") - - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) - - noise = torch.randn(1, 3, 64, 64) - low_res = torch.randn(1, 3, 4, 4) - time_step = torch.tensor([42] * noise.shape[0]) - - with torch.no_grad(): - output = model(noise, time_step, low_res) - - output, _ = torch.split(output, 3, dim=1) - output_slice = output[0, -1, -3:, -3:].flatten() - # fmt: off - expected_output_slice = torch.tensor([-22.8782, -23.2652, -15.3966, -22.8034, -23.3159, -15.5640, -15.3970, -15.4614, - 10.4370]) - # fmt: on - self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) - - -class GlideTextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase): - model_class = GlideTextToImageUNetModel - - @property - def dummy_input(self): - batch_size = 4 - num_channels = 3 - sizes = (32, 32) - transformer_dim = 32 - seq_len = 16 - - noise = torch.randn((batch_size, num_channels) + sizes).to(torch_device) - emb = torch.randn((batch_size, seq_len, transformer_dim)).to(torch_device) - time_step = torch.tensor([10] * noise.shape[0], device=torch_device) - - return {"x": noise, "timesteps": time_step, "transformer_out": emb} - - @property - def get_input_shape(self): - return (3, 32, 32) - - @property - def get_output_shape(self): - return (6, 32, 32) - - def prepare_init_args_and_inputs_for_common(self): - init_dict = { - "attention_resolutions": (2,), - "channel_mult": (1, 2), - "in_channels": 3, - "out_channels": 6, - "model_channels": 32, - "num_head_channels": 8, - "num_heads_upsample": 1, - "num_res_blocks": 2, - "resblock_updown": True, - "resolution": 32, - "use_scale_shift_norm": True, - "transformer_dim": 32, - } - inputs_dict = self.dummy_input - return init_dict, inputs_dict - - def test_output(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() - - with torch.no_grad(): - output = model(**inputs_dict) - - output, _ = torch.split(output, 3, dim=1) - - self.assertIsNotNone(output) - expected_shape = inputs_dict["x"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") - - def test_from_pretrained_hub(self): - model, loading_info = GlideTextToImageUNetModel.from_pretrained( - "fusing/unet-glide-text2im-dummy", output_loading_info=True - ) - self.assertIsNotNone(model) - self.assertEqual(len(loading_info["missing_keys"]), 0) - - model.to(torch_device) - image = model(**self.dummy_input) - - assert image is not None, "Make sure output is not None" - - def test_output_pretrained(self): - model = GlideTextToImageUNetModel.from_pretrained("fusing/unet-glide-text2im-dummy") - - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) - - noise = torch.randn((1, model.config.in_channels, model.config.resolution, model.config.resolution)).to( - torch_device - ) - emb = torch.randn((1, 16, model.config.transformer_dim)).to(torch_device) - time_step = torch.tensor([10] * noise.shape[0], device=torch_device) - - with torch.no_grad(): - output = model(noise, time_step, emb) - - output, _ = torch.split(output, 3, dim=1) - output_slice = output[0, -1, -3:, -3:].flatten() - # fmt: off - expected_output_slice = torch.tensor([2.7766, -10.3558, -14.9149, -0.9376, -14.9175, -17.7679, -5.5565, -12.9521, -12.9845]) - # fmt: on - self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) - - -class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): - model_class = UNetLDMModel - - @property - def dummy_input(self): - batch_size = 4 - num_channels = 4 - sizes = (32, 32) - - noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) - time_step = torch.tensor([10]).to(torch_device) - - return {"x": noise, "timesteps": time_step} - - @property - def get_input_shape(self): - return (4, 32, 32) - - @property - def get_output_shape(self): - return (4, 32, 32) - - def prepare_init_args_and_inputs_for_common(self): - init_dict = { - "image_size": 32, - "in_channels": 4, - "out_channels": 4, - "model_channels": 32, - "num_res_blocks": 2, - "attention_resolutions": (16,), - "channel_mult": (1, 2), - "num_heads": 2, - "conv_resample": True, - } - inputs_dict = self.dummy_input - return init_dict, inputs_dict - - def test_from_pretrained_hub(self): - model, loading_info = UNetLDMModel.from_pretrained("fusing/unet-ldm-dummy", output_loading_info=True) - self.assertIsNotNone(model) - self.assertEqual(len(loading_info["missing_keys"]), 0) - - model.to(torch_device) - image = model(**self.dummy_input) - - assert image is not None, "Make sure output is not None" - - def test_output_pretrained(self): - model = UNetLDMModel.from_pretrained("fusing/unet-ldm-dummy") - model.eval() - - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) - - noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size) - time_step = torch.tensor([10] * noise.shape[0]) - - with torch.no_grad(): - output = model(noise, time_step) - - output_slice = output[0, -1, -3:, -3:].flatten() - # fmt: off - expected_output_slice = torch.tensor([-13.3258, -20.1100, -15.9873, -17.6617, -23.0596, -17.9419, -13.3675, -16.1889, -12.3800]) - # fmt: on - - self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) - - -class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase): - model_class = UNetGradTTSModel - - @property - def dummy_input(self): - batch_size = 4 - num_features = 32 - seq_len = 16 - - noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device) - condition = floats_tensor((batch_size, num_features, seq_len)).to(torch_device) - mask = floats_tensor((batch_size, 1, seq_len)).to(torch_device) - time_step = torch.tensor([10] * batch_size).to(torch_device) - - return {"x": noise, "timesteps": time_step, "mu": condition, "mask": mask} - - @property - def get_input_shape(self): - return (4, 32, 16) - - @property - def get_output_shape(self): - return (4, 32, 16) - - def prepare_init_args_and_inputs_for_common(self): - init_dict = { - "dim": 64, - "groups": 4, - "dim_mults": (1, 2), - "n_feats": 32, - "pe_scale": 1000, - "n_spks": 1, - } - inputs_dict = self.dummy_input - return init_dict, inputs_dict - - def test_from_pretrained_hub(self): - model, loading_info = UNetGradTTSModel.from_pretrained("fusing/unet-grad-tts-dummy", output_loading_info=True) - self.assertIsNotNone(model) - self.assertEqual(len(loading_info["missing_keys"]), 0) - - model.to(torch_device) - image = model(**self.dummy_input) - - assert image is not None, "Make sure output is not None" - - def test_output_pretrained(self): - model = UNetGradTTSModel.from_pretrained("fusing/unet-grad-tts-dummy") - model.eval() - - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) - - num_features = model.config.n_feats - seq_len = 16 - noise = torch.randn((1, num_features, seq_len)) - condition = torch.randn((1, num_features, seq_len)) - mask = torch.randn((1, 1, seq_len)) - time_step = torch.tensor([10]) - - with torch.no_grad(): - output = model(noise, time_step, condition, mask) - - output_slice = output[0, -3:, -3:].flatten() - # fmt: off - expected_output_slice = torch.tensor([-0.0690, -0.0531, 0.0633, -0.0660, -0.0541, 0.0650, -0.0656, -0.0555, 0.0617]) - # fmt: on - - self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) - - -class PipelineTesterMixin(unittest.TestCase): - def test_from_pretrained_save_pretrained(self): - # 1. Load models - model = UNetModel(ch=32, ch_mult=(1, 2), num_res_blocks=2, attn_resolutions=(16,), resolution=32) - schedular = DDPMScheduler(timesteps=10) - - ddpm = DDPMPipeline(model, schedular) - - with tempfile.TemporaryDirectory() as tmpdirname: - ddpm.save_pretrained(tmpdirname) - new_ddpm = DDPMPipeline.from_pretrained(tmpdirname) - - generator = torch.manual_seed(0) - - image = ddpm(generator=generator) - generator = generator.manual_seed(0) - new_image = new_ddpm(generator=generator) - - assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass" - - @slow - def test_from_pretrained_hub(self): - model_path = "fusing/ddpm-cifar10" - - ddpm = DDPMPipeline.from_pretrained(model_path) - ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path) - - ddpm.noise_scheduler.num_timesteps = 10 - ddpm_from_hub.noise_scheduler.num_timesteps = 10 - - generator = torch.manual_seed(0) - - image = ddpm(generator=generator) - generator = generator.manual_seed(0) - new_image = ddpm_from_hub(generator=generator) - - assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass" - - @slow - def test_ddpm_cifar10(self): - generator = torch.manual_seed(0) - model_id = "fusing/ddpm-cifar10" - - unet = UNetModel.from_pretrained(model_id) - noise_scheduler = DDPMScheduler.from_config(model_id) - noise_scheduler = noise_scheduler.set_format("pt") - - ddpm = DDPMPipeline(unet=unet, noise_scheduler=noise_scheduler) - image = ddpm(generator=generator) - - image_slice = image[0, -1, -3:, -3:].cpu() - - assert image.shape == (1, 3, 32, 32) - expected_slice = torch.tensor([0.2250, 0.3375, 0.2360, 0.0930, 0.3440, 0.3156, 0.1937, 0.3585, 0.1761]) - assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 - - @slow - def test_ddim_cifar10(self): - generator = torch.manual_seed(0) - model_id = "fusing/ddpm-cifar10" - - unet = UNetModel.from_pretrained(model_id) - noise_scheduler = DDIMScheduler(tensor_format="pt") - - ddim = DDIMPipeline(unet=unet, noise_scheduler=noise_scheduler) - image = ddim(generator=generator, eta=0.0) - - image_slice = image[0, -1, -3:, -3:].cpu() - - assert image.shape == (1, 3, 32, 32) - expected_slice = torch.tensor( - [-0.7383, -0.7385, -0.7298, -0.7364, -0.7414, -0.7239, -0.6737, -0.6813, -0.7068] - ) - assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 - - @slow - def test_pndm_cifar10(self): - generator = torch.manual_seed(0) - model_id = "fusing/ddpm-cifar10" - - unet = UNetModel.from_pretrained(model_id) - noise_scheduler = PNDMScheduler(tensor_format="pt") - - pndm = PNDMPipeline(unet=unet, noise_scheduler=noise_scheduler) - image = pndm(generator=generator) - - image_slice = image[0, -1, -3:, -3:].cpu() - - assert image.shape == (1, 3, 32, 32) - expected_slice = torch.tensor( - [-0.7888, -0.7870, -0.7759, -0.7823, -0.8014, -0.7608, -0.6818, -0.7130, -0.7471] - ) - assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 - - @slow - def test_ldm_text2img(self): - model_id = "fusing/latent-diffusion-text2im-large" - ldm = LatentDiffusionPipeline.from_pretrained(model_id) - - prompt = "A painting of a squirrel eating a burger" - generator = torch.manual_seed(0) - image = ldm([prompt], generator=generator, num_inference_steps=20) - - image_slice = image[0, -1, -3:, -3:].cpu() - - assert image.shape == (1, 3, 256, 256) - expected_slice = torch.tensor([0.7295, 0.7358, 0.7256, 0.7435, 0.7095, 0.6884, 0.7325, 0.6921, 0.6458]) - assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 - - @slow - def test_glide_text2img(self): - model_id = "fusing/glide-base" - glide = GlidePipeline.from_pretrained(model_id) - - prompt = "a pencil sketch of a corgi" - generator = torch.manual_seed(0) - image = glide(prompt, generator=generator, num_inference_steps_upscale=20) - - image_slice = image[0, :3, :3, -1].cpu() - - assert image.shape == (1, 256, 256, 3) - expected_slice = torch.tensor([0.7119, 0.7073, 0.6460, 0.7780, 0.7423, 0.6926, 0.7378, 0.7189, 0.7784]) - assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 - - @slow - def test_grad_tts(self): - model_id = "fusing/grad-tts-libri-tts" - grad_tts = GradTTSPipeline.from_pretrained(model_id) - noise_scheduler = GradTTSScheduler() - grad_tts.noise_scheduler = noise_scheduler - - text = "Hello world, I missed you so much." - generator = torch.manual_seed(0) - - # generate mel spectograms using text - mel_spec = grad_tts(text, generator=generator) - - assert mel_spec.shape == (1, 80, 143) - expected_slice = torch.tensor( - [-6.7584, -6.8347, -6.3293, -6.6437, -6.7233, -6.4684, -6.1187, -6.3172, -6.6890] - ) - assert (mel_spec[0, :3, :3].cpu().flatten() - expected_slice).abs().max() < 1e-2 - - def test_module_from_pipeline(self): - model = DiffWave(num_res_layers=4) - noise_scheduler = DDPMScheduler(timesteps=12) - - bddm = BDDMPipeline(model, noise_scheduler) - - # check if the library name for the diffwave moduel is set to pipeline module - self.assertTrue(bddm.config["diffwave"][0] == "pipeline_bddm") - - # check if we can save and load the pipeline - with tempfile.TemporaryDirectory() as tmpdirname: - bddm.save_pretrained(tmpdirname) - _ = BDDMPipeline.from_pretrained(tmpdirname) - # check if the same works using the DifusionPipeline class - _ = DiffusionPipeline.from_pretrained(tmpdirname) From c7a39d38adb36059b31318b6841dda4e8f6ab172 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 27 Jun 2022 11:37:37 +0000 Subject: [PATCH 34/35] refactor all sinus embeddings --- src/diffusers/models/embeddings.py | 131 +++++------------- src/diffusers/models/unet.py | 21 --- src/diffusers/models/unet_glide.py | 33 ++--- src/diffusers/models/unet_grad_tts.py | 23 +-- src/diffusers/models/unet_ldm.py | 25 +--- .../models/unet_sde_score_estimation.py | 17 --- tests/test_layers_utils.py | 83 ++++++++++- 7 files changed, 123 insertions(+), 210 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index a9143053a6..2c26340f46 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -11,15 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import torch import math + import numpy as np - +import torch from torch import nn -import torch.nn.functional as F -def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1, max_period=10000): +def get_timestep_embedding( + timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1, scale=1, max_period=10000 +): """ This matches the implementation in Denoising Diffusion Probabilistic Models: @@ -31,18 +32,22 @@ def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, down :param max_period: controls the minimum frequency of the embeddings. :return: an [N x dim] Tensor of positional embeddings. """ - assert len(timesteps.shape) == 1 + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" half_dim = embedding_dim // 2 - emb = torch.exp(-math.log(max_period) * torch.arange(half_dim, dtype=torch.float32) / (embedding_dim // 2 - downscale_freq_shift)) - emb = emb.to(device=timesteps.device) + emb_coeff = -math.log(max_period) / (half_dim - downscale_freq_shift) + emb = torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) + emb = torch.exp(emb * emb_coeff) emb = timesteps[:, None].float() * emb[None, :] - # concat sine and cosine embeddings - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + # scale embeddings + emb = scale * emb - # flip sine and cosine embeddings + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings if flip_sin_to_cos: emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) @@ -52,96 +57,6 @@ def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, down return emb -#def get_timestep_embedding(timesteps, embedding_dim): -# """ -# This matches the implementation in Denoising Diffusion Probabilistic Models: -# From Fairseq. -# Build sinusoidal embeddings. -# This matches the implementation in tensor2tensor, but differs slightly -# from the description in Section 3.5 of "Attention Is All You Need". -# """ -# assert len(timesteps.shape) == 1 -# -# half_dim = embedding_dim // 2 -# emb = math.log(10000) / (half_dim - 1) -# emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) -# emb = emb.to(device=timesteps.device) -# emb = timesteps.float()[:, None] * emb[None, :] -# emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) -# if embedding_dim % 2 == 1: # zero pad -# emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) - - -#def timestep_embedding(timesteps, dim, max_period=10000): -# """ -# Create sinusoidal timestep embeddings. -# -# :param timesteps: a 1-D Tensor of N indices, one per batch element. -# These may be fractional. -# :param dim: the dimension of the output. -# :param max_period: controls the minimum frequency of the embeddings. -# :return: an [N x dim] Tensor of positional embeddings. -# """ -# half = dim // 2 -# freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( -# device=timesteps.device -# ) -# args = timesteps[:, None].float() * freqs[None, :] -# embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) -# if dim % 2: -# embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) -# return embedding - - -#def a_get_timestep_embedding(timesteps, embedding_dim, max_positions=10000): -# assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32 -# half_dim = embedding_dim // 2 - # magic number 10000 is from transformers -# emb = math.log(max_positions) / (half_dim - 1) - # emb = math.log(2.) / (half_dim - 1) -# emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb) - # emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :] - # emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :] -# emb = timesteps.float()[:, None] * emb[None, :] -# emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) -# if embedding_dim % 2 == 1: # zero pad -# emb = F.pad(emb, (0, 1), mode="constant") -# assert emb.shape == (timesteps.shape[0], embedding_dim) -# return emb - - -# unet_grad_tts.py -class SinusoidalPosEmb(torch.nn.Module): - def __init__(self, dim): - super(SinusoidalPosEmb, self).__init__() - self.dim = dim - - def forward(self, x, scale=1000): - device = x.device - half_dim = self.dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) - emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) - emb = torch.cat((emb.sin(), emb.cos()), dim=-1) - return emb - - -# unet_rl.py -class SinusoidalPosEmb(nn.Module): - def __init__(self, dim): - super().__init__() - self.dim = dim - - def forward(self, x): - device = x.device - half_dim = self.dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, device=device) * -emb) - emb = x[:, None] * emb[None, :] - emb = torch.cat((emb.sin(), emb.cos()), dim=-1) - return emb - - # unet_sde_score_estimation.py class GaussianFourierProjection(nn.Module): """Gaussian Fourier embeddings for noise levels.""" @@ -153,3 +68,19 @@ class GaussianFourierProjection(nn.Module): def forward(self, x): x_proj = x[:, None] * self.W[None, :] * 2 * np.pi return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) + + +# unet_rl.py - TODO(need test) +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb diff --git a/src/diffusers/models/unet.py b/src/diffusers/models/unet.py index 7d5eebfd3d..1749def9b1 100644 --- a/src/diffusers/models/unet.py +++ b/src/diffusers/models/unet.py @@ -33,27 +33,6 @@ from ..modeling_utils import ModelMixin from .embeddings import get_timestep_embedding -#def get_timestep_embedding(timesteps, embedding_dim): -# """ -# This matches the implementation in Denoising Diffusion Probabilistic Models: -# From Fairseq. -# Build sinusoidal embeddings. -# This matches the implementation in tensor2tensor, but differs slightly -# from the description in Section 3.5 of "Attention Is All You Need". -# """ -# assert len(timesteps.shape) == 1 -# -# half_dim = embedding_dim // 2 -# emb = math.log(10000) / (half_dim - 1) -# emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) -# emb = emb.to(device=timesteps.device) -# emb = timesteps.float()[:, None] * emb[None, :] -# emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) -# if embedding_dim % 2 == 1: # zero pad -# emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) -# return emb - - def nonlinearity(x): # swish return x * torch.sigmoid(x) diff --git a/src/diffusers/models/unet_glide.py b/src/diffusers/models/unet_glide.py index 0e04537766..c154db9210 100644 --- a/src/diffusers/models/unet_glide.py +++ b/src/diffusers/models/unet_glide.py @@ -87,27 +87,6 @@ def normalization(channels, swish=0.0): return GroupNorm32(num_channels=channels, num_groups=32, swish=swish) -# def timestep_embedding(timesteps, dim, max_period=10000): -# """ -# Create sinusoidal timestep embeddings. -# -# :param timesteps: a 1-D Tensor of N indices, one per batch element. -# These may be fractional. -# :param dim: the dimension of the output. -# :param max_period: controls the minimum frequency of the embeddings. -# :return: an [N x dim] Tensor of positional embeddings. -# """ -# half = dim // 2 -# freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( -# device=timesteps.device -# ) -# args = timesteps[:, None].float() * freqs[None] -# embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) -# if dim % 2: -# embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) -# return embedding - - def zero_module(module): """ Zero out the parameters of a module and return it. @@ -628,7 +607,9 @@ class GlideUNetModel(ModelMixin, ConfigMixin): """ hs = [] - emb = self.time_embed(get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)) + emb = self.time_embed( + get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0) + ) h = x.type(self.dtype) for module in self.input_blocks: @@ -715,7 +696,9 @@ class GlideTextToImageUNetModel(GlideUNetModel): def forward(self, x, timesteps, transformer_out=None): hs = [] - emb = self.time_embed(get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)) + emb = self.time_embed( + get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0) + ) # project the last token transformer_proj = self.transformer_proj(transformer_out[:, -1]) @@ -807,7 +790,9 @@ class GlideSuperResUNetModel(GlideUNetModel): x = torch.cat([x, upsampled], dim=1) hs = [] - emb = self.time_embed(get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)) + emb = self.time_embed( + get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0) + ) h = x for module in self.input_blocks: diff --git a/src/diffusers/models/unet_grad_tts.py b/src/diffusers/models/unet_grad_tts.py index 81719f088b..ccae3133fd 100644 --- a/src/diffusers/models/unet_grad_tts.py +++ b/src/diffusers/models/unet_grad_tts.py @@ -1,5 +1,3 @@ -import math - import torch @@ -11,6 +9,7 @@ except: from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin +from .embeddings import get_timestep_embedding class Mish(torch.nn.Module): @@ -107,21 +106,6 @@ class Residual(torch.nn.Module): return output -class SinusoidalPosEmb(torch.nn.Module): - def __init__(self, dim): - super(SinusoidalPosEmb, self).__init__() - self.dim = dim - - def forward(self, x, scale=1000): - device = x.device - half_dim = self.dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) - emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) - emb = torch.cat((emb.sin(), emb.cos()), dim=-1) - return emb - - class UNetGradTTSModel(ModelMixin, ConfigMixin): def __init__(self, dim, dim_mults=(1, 2, 4), groups=8, n_spks=None, spk_emb_dim=64, n_feats=80, pe_scale=1000): super(UNetGradTTSModel, self).__init__() @@ -149,7 +133,6 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(), torch.nn.Linear(spk_emb_dim * 4, n_feats) ) - self.time_pos_emb = SinusoidalPosEmb(dim) self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(), torch.nn.Linear(dim * 4, dim)) dims = [2 + (1 if n_spks > 1 else 0), *map(lambda m: dim * m, dim_mults)] @@ -198,8 +181,8 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): if not isinstance(spk, type(None)): s = self.spk_mlp(spk) - - t = self.time_pos_emb(timesteps, scale=self.pe_scale) + t = get_timestep_embedding(timesteps, self.dim, scale=self.pe_scale) + t = self.mlp(t) if self.n_spks < 2: diff --git a/src/diffusers/models/unet_ldm.py b/src/diffusers/models/unet_ldm.py index cfc200bf6a..da84391a36 100644 --- a/src/diffusers/models/unet_ldm.py +++ b/src/diffusers/models/unet_ldm.py @@ -317,27 +317,6 @@ def normalization(channels, swish=0.0): return GroupNorm32(num_channels=channels, num_groups=32, swish=swish) -#def timestep_embedding(timesteps, dim, max_period=10000): -# """ -# Create sinusoidal timestep embeddings. -# -# :param timesteps: a 1-D Tensor of N indices, one per batch element. -# These may be fractional. -# :param dim: the dimension of the output. -# :param max_period: controls the minimum frequency of the embeddings. -# :return: an [N x dim] Tensor of positional embeddings. -# """ -# half = dim // 2 -# freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( -# device=timesteps.device -# ) -# args = timesteps[:, None].float() * freqs[None] -# embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) -# if dim % 2: -# embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) -# return embedding - - ## go class AttentionPool2d(nn.Module): """ @@ -1232,7 +1211,9 @@ class EncoderUNetModel(nn.Module): :param timesteps: a 1-D batch of timesteps. :return: an [N x K] Tensor of outputs. """ - emb = self.time_embed(get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)) + emb = self.time_embed( + get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0) + ) results = [] h = x.type(self.dtype) diff --git a/src/diffusers/models/unet_sde_score_estimation.py b/src/diffusers/models/unet_sde_score_estimation.py index 7d00eb2174..0f0cc4b7e0 100644 --- a/src/diffusers/models/unet_sde_score_estimation.py +++ b/src/diffusers/models/unet_sde_score_estimation.py @@ -382,23 +382,6 @@ def get_act(nonlinearity): raise NotImplementedError("activation function does not exist!") -#def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000): -# assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32 -# half_dim = embedding_dim // 2 - # magic number 10000 is from transformers -# emb = math.log(max_positions) / (half_dim - 1) - # emb = math.log(2.) / (half_dim - 1) -# emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb) - # emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :] - # emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :] -# emb = timesteps.float()[:, None] * emb[None, :] -# emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) -# if embedding_dim % 2 == 1: # zero pad -# emb = F.pad(emb, (0, 1), mode="constant") -# assert emb.shape == (timesteps.shape[0], embedding_dim) -# return emb - - def default_init(scale=1.0): """The same initialization used in DDPM.""" scale = 1e-10 if scale == 0 else scale diff --git a/tests/test_layers_utils.py b/tests/test_layers_utils.py index 0b50e7bc86..42a4261081 100755 --- a/tests/test_layers_utils.py +++ b/tests/test_layers_utils.py @@ -21,8 +21,7 @@ import unittest import numpy as np import torch -#from diffusers.models.embeddings import get_timestep_embedding, timestep_embedding, a_get_timestep_embedding -from diffusers.models.embeddings import get_timestep_embedding, timestep_embedding +from diffusers.models.embeddings import get_timestep_embedding from diffusers.testing_utils import floats_tensor, slow, torch_device @@ -30,15 +29,87 @@ torch.backends.cuda.matmul.allow_tf32 = False class EmbeddingsTests(unittest.TestCase): - def test_timestep_embeddings(self): + embedding_dim = 256 + timesteps = torch.arange(16) + + t1 = get_timestep_embedding(timesteps, embedding_dim) + + # first vector should always be composed only of 0's and 1's + assert (t1[0, : embedding_dim // 2] - 0).abs().sum() < 1e-5 + assert (t1[0, embedding_dim // 2 :] - 1).abs().sum() < 1e-5 + + # last element of each vector should be one + assert (t1[:, -1] - 1).abs().sum() < 1e-5 + + # For large embeddings (e.g. 128) the frequency of every vector is higher + # than the previous one which means that the gradients of later vectors are + # ALWAYS higher than the previous ones + grad_mean = np.abs(np.gradient(t1, axis=-1)).mean(axis=1) + + prev_grad = 0.0 + for grad in grad_mean: + assert grad > prev_grad + prev_grad = grad + + def test_timestep_defaults(self): embedding_dim = 16 timesteps = torch.arange(10) t1 = get_timestep_embedding(timesteps, embedding_dim) - t2 = timestep_embedding(timesteps, embedding_dim) - t3 = get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=True, downscale_freq_factor=8) + t2 = get_timestep_embedding( + timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1, max_period=10_000 + ) - import ipdb; ipdb.set_trace() + assert torch.allclose(t1.cpu(), t2.cpu(), 1e-3) + def test_timestep_flip_sin_cos(self): + embedding_dim = 16 + timesteps = torch.arange(10) + t1 = get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=True) + t1 = torch.cat([t1[:, embedding_dim // 2 :], t1[:, : embedding_dim // 2]], dim=-1) + + t2 = get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False) + + assert torch.allclose(t1.cpu(), t2.cpu(), 1e-3) + + def test_timestep_downscale_freq_shift(self): + embedding_dim = 16 + timesteps = torch.arange(10) + + t1 = get_timestep_embedding(timesteps, embedding_dim, downscale_freq_shift=0) + t2 = get_timestep_embedding(timesteps, embedding_dim, downscale_freq_shift=1) + + # get cosine half (vectors that are wrapped into cosine) + cosine_half = (t1 - t2)[:, embedding_dim // 2 :] + + # cosine needs to be negative + assert (np.abs((cosine_half <= 0).numpy()) - 1).sum() < 1e-5 + + def test_sinoid_embeddings_hardcoded(self): + embedding_dim = 64 + timesteps = torch.arange(128) + + # standard unet, score_vde + t1 = get_timestep_embedding(timesteps, embedding_dim, downscale_freq_shift=1, flip_sin_to_cos=False) + # glide, ldm + t2 = get_timestep_embedding(timesteps, embedding_dim, downscale_freq_shift=0, flip_sin_to_cos=True) + # grad-tts + t3 = get_timestep_embedding(timesteps, embedding_dim, scale=1000) + + assert torch.allclose( + t1[23:26, 47:50].flatten().cpu(), + torch.tensor([0.9646, 0.9804, 0.9892, 0.9615, 0.9787, 0.9882, 0.9582, 0.9769, 0.9872]), + 1e-3, + ) + assert torch.allclose( + t2[23:26, 47:50].flatten().cpu(), + torch.tensor([0.3019, 0.2280, 0.1716, 0.3146, 0.2377, 0.1790, 0.3272, 0.2474, 0.1864]), + 1e-3, + ) + assert torch.allclose( + t3[23:26, 47:50].flatten().cpu(), + torch.tensor([-0.9801, -0.9464, -0.9349, -0.3952, 0.8887, -0.9709, 0.5299, -0.2853, -0.9927]), + 1e-3, + ) From 0027993e91a1caaa990b4569602d28d6dfdbd180 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 27 Jun 2022 14:48:20 +0200 Subject: [PATCH 35/35] add upsample and downsample blocks --- src/diffusers/models/resnet.py | 278 +++++++++++++++++++++++++++++++++ 1 file changed, 278 insertions(+) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index e69de29bb2..04e3735d60 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -0,0 +1,278 @@ + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + +def conv_transpose_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.ConvTranspose1d(*args, **kwargs) + elif dims == 2: + return nn.ConvTranspose2d(*args, **kwargs) + elif dims == 3: + return nn.ConvTranspose3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +def nonlinearity(x, swish=1.0): + # swish + if swish == 1.0: + return F.silu(x) + else: + return x * F.sigmoid(x * float(swish)) + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, use_conv_transpose=False, dims=2, out_channels=None): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + self.use_conv_transpose = use_conv_transpose + + if use_conv_transpose: + self.conv = conv_transpose_nd(dims, channels, out_channels, 4, 2, 1) + elif use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.use_conv_transpose: + return self.conv(x) + + if self.dims == 3: + x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest") + else: + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + + if self.use_conv: + x = self.conv(x) + + return x + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + self.padding = padding + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.down = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding) + else: + assert self.channels == self.out_channels + self.down = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.use_conv and self.padding == 0 and self.dims == 2: + pad = (0, 1, 0, 1) + x = F.pad(x, pad, mode="constant", value=0) + return self.down(x) + + +class UNetUpsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + +class GlideUpsample(nn.Module): + """ + An upsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest") + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class LDMUpsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest") + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class GradTTSUpsample(torch.nn.Module): + def __init__(self, dim): + super(Upsample, self).__init__() + self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1) + + def forward(self, x): + return self.conv(x) + + +class Upsample1d(nn.Module): + def __init__(self, dim): + super().__init__() + self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1) + + def forward(self, x): + return self.conv(x) + + +# class ResnetBlock(nn.Module): +# def __init__( +# self, +# *, +# in_channels, +# out_channels=None, +# conv_shortcut=False, +# dropout, +# temb_channels=512, +# use_scale_shift_norm=False, +# ): +# super().__init__() +# self.in_channels = in_channels +# out_channels = in_channels if out_channels is None else out_channels +# self.out_channels = out_channels +# self.use_conv_shortcut = conv_shortcut +# self.use_scale_shift_norm = use_scale_shift_norm + +# self.norm1 = Normalize(in_channels) +# self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + +# temp_out_channles = 2 * out_channels if use_scale_shift_norm else out_channels +# self.temb_proj = torch.nn.Linear(temb_channels, temp_out_channles) + +# self.norm2 = Normalize(out_channels) +# self.dropout = torch.nn.Dropout(dropout) +# self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) +# if self.in_channels != self.out_channels: +# if self.use_conv_shortcut: +# self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) +# else: +# self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + +# def forward(self, x, temb): +# h = x +# h = self.norm1(h) +# h = nonlinearity(h) +# h = self.conv1(h) + +# # TODO: check if this broadcasting works correctly for 1D and 3D +# temb = self.temb_proj(nonlinearity(temb))[:, :, None, None] + +# if self.use_scale_shift_norm: +# out_norm, out_rest = self.out_layers[0], self.out_layers[1:] +# scale, shift = torch.chunk(temb, 2, dim=1) +# h = self.norm2(h) * (1 + scale) + shift +# h = out_rest(h) +# else: +# h = h + temb +# h = self.norm2(h) +# h = nonlinearity(h) +# h = self.dropout(h) +# h = self.conv2(h) + +# if self.in_channels != self.out_channels: +# if self.use_conv_shortcut: +# x = self.conv_shortcut(x) +# else: +# x = self.nin_shortcut(x) + +# return x + h