From ba264419f40b94fd2e8135096db4780e1c188aef Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 27 Jun 2022 00:07:57 +0000 Subject: [PATCH] finish vp --- .../models/unet_sde_score_estimation.py | 5 ++--- .../pipelines/pipeline_score_sde_ve.py | 1 + .../pipelines/pipeline_score_sde_vp.py | 15 +++++---------- src/diffusers/schedulers/__init__.py | 2 +- src/diffusers/schedulers/scheduling_sde_ve.py | 2 ++ src/diffusers/schedulers/scheduling_sde_vp.py | 19 ++++++++++++++----- tests/test_modeling_utils.py | 4 ++-- 7 files changed, 27 insertions(+), 21 deletions(-) diff --git a/src/diffusers/models/unet_sde_score_estimation.py b/src/diffusers/models/unet_sde_score_estimation.py index 784d528dd4..299f96c9cd 100644 --- a/src/diffusers/models/unet_sde_score_estimation.py +++ b/src/diffusers/models/unet_sde_score_estimation.py @@ -766,7 +766,6 @@ class NCSNpp(ModelMixin, ConfigMixin): continuous=continuous, ) self.act = act = get_act(nonlinearity) - self.register_buffer('sigmas', torch.tensor(np.linspace(np.log(50), np.log(0.01), 10))) self.nf = nf self.num_res_blocks = num_res_blocks @@ -939,7 +938,7 @@ class NCSNpp(ModelMixin, ConfigMixin): self.all_modules = nn.ModuleList(modules) - def forward(self, x, time_cond): + def forward(self, x, time_cond, sigmas=None): # timestep/noise_level embedding; only for continuous training modules = self.all_modules m_idx = 0 @@ -952,7 +951,7 @@ class NCSNpp(ModelMixin, ConfigMixin): elif self.embedding_type == "positional": # Sinusoidal positional embeddings. timesteps = time_cond - used_sigmas = self.sigmas[time_cond.long()] + used_sigmas = sigmas temb = get_timestep_embedding(timesteps, self.nf) else: diff --git a/src/diffusers/pipelines/pipeline_score_sde_ve.py b/src/diffusers/pipelines/pipeline_score_sde_ve.py index a1a4843af1..1dfd304d83 100644 --- a/src/diffusers/pipelines/pipeline_score_sde_ve.py +++ b/src/diffusers/pipelines/pipeline_score_sde_ve.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 import torch + from diffusers import DiffusionPipeline diff --git a/src/diffusers/pipelines/pipeline_score_sde_vp.py b/src/diffusers/pipelines/pipeline_score_sde_vp.py index 9eb886296b..29551d9a6e 100644 --- a/src/diffusers/pipelines/pipeline_score_sde_vp.py +++ b/src/diffusers/pipelines/pipeline_score_sde_vp.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 import torch + from diffusers import DiffusionPipeline @@ -16,27 +17,21 @@ class ScoreSdeVpPipeline(DiffusionPipeline): channels = self.model.config.num_channels shape = (1, channels, img_size, img_size) - beta_min, beta_max = 0.1, 20 - model = self.model.to(device) x = torch.randn(*shape).to(device) self.scheduler.set_timesteps(num_inference_steps) - for i, t in enumerate(self.scheduler.timesteps): + for t in self.scheduler.timesteps: t = t * torch.ones(shape[0], device=device) - sigma_t = t * (num_inference_steps - 1) + scaled_t = t * (num_inference_steps - 1) with torch.no_grad(): - result = model(x, sigma_t) - - log_mean_coeff = -0.25 * t ** 2 * (beta_max - beta_min) - 0.5 * t * beta_min - std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff)) - result = -result / std[:, None, None, None] + result = model(x, scaled_t) x, x_mean = self.scheduler.step_pred(result, x, t) - x_mean = (x_mean + 1.) / 2. + x_mean = (x_mean + 1.0) / 2.0 return x_mean diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 6a6d628661..ad66fe5991 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -20,6 +20,6 @@ from .scheduling_ddim import DDIMScheduler from .scheduling_ddpm import DDPMScheduler from .scheduling_grad_tts import GradTTSScheduler from .scheduling_pndm import PNDMScheduler -from .scheduling_utils import SchedulerMixin from .scheduling_sde_ve import ScoreSdeVeScheduler from .scheduling_sde_vp import ScoreSdeVpScheduler +from .scheduling_utils import SchedulerMixin diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py index 2456afad7d..79936105b9 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -52,6 +52,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ) def step_pred(self, result, x, t): + # TODO(Patrick) better comments + non-PyTorch t = t * torch.ones(x.shape[0], device=x.device) timestep = (t * (len(self.timesteps) - 1)).long() @@ -70,6 +71,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): return x, x_mean def step_correct(self, result, x): + # TODO(Patrick) better comments + non-PyTorch noise = torch.randn_like(x) grad_norm = torch.norm(result.reshape(result.shape[0], -1), dim=-1).mean() noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean() diff --git a/src/diffusers/schedulers/scheduling_sde_vp.py b/src/diffusers/schedulers/scheduling_sde_vp.py index c7b6497117..dda32a2742 100644 --- a/src/diffusers/schedulers/scheduling_sde_vp.py +++ b/src/diffusers/schedulers/scheduling_sde_vp.py @@ -40,16 +40,25 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps) def step_pred(self, result, x, t): - dt = -1. / len(self.timesteps) - z = torch.randn_like(x) + # TODO(Patrick) better comments + non-PyTorch + # postprocess model result + log_mean_coeff = ( + -0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min + ) + std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff)) + result = -result / std[:, None, None, None] - beta_t = self.beta_min + t * (self.beta_max - self.beta_min) + # compute + dt = -1.0 / len(self.timesteps) + + beta_t = self.config.beta_min + t * (self.config.beta_max - self.config.beta_min) drift = -0.5 * beta_t[:, None, None, None] * x diffusion = torch.sqrt(beta_t) - drift = drift - diffusion[:, None, None, None] ** 2 * result - x_mean = x + drift * dt + + # add noise + z = torch.randn_like(x) x = x_mean + diffusion[:, None, None, None] * np.sqrt(-dt) * z return x, x_mean diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 32bc3003c5..6c5c115f19 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -746,8 +746,8 @@ class PipelineTesterMixin(unittest.TestCase): @slow def test_score_sde_vp_pipeline(self): - model = NCSNpp.from_pretrained("/home/patrick/cifar10-ddpmpp-vp") - scheduler = ScoreSdeVpScheduler() + model = NCSNpp.from_pretrained("fusing/cifar10-ddpmpp-vp") + scheduler = ScoreSdeVpScheduler.from_config("fusing/cifar10-ddpmpp-vp") sde_vp = ScoreSdeVpPipeline(model=model, scheduler=scheduler)