From 3fe026e06c895049fe6f072fc2b394b2c8e85551 Mon Sep 17 00:00:00 2001 From: anton-l Date: Mon, 13 Jun 2022 12:44:45 +0200 Subject: [PATCH] Glide tensor format --- src/diffusers/pipelines/conversion_glide.py | 2 +- src/diffusers/pipelines/pipeline_glide.py | 15 +++++++++++---- src/diffusers/schedulers/glide_ddim.py | 0 3 files changed, 12 insertions(+), 5 deletions(-) delete mode 100644 src/diffusers/schedulers/glide_ddim.py diff --git a/src/diffusers/pipelines/conversion_glide.py b/src/diffusers/pipelines/conversion_glide.py index 499c071204..1ae69ac48a 100644 --- a/src/diffusers/pipelines/conversion_glide.py +++ b/src/diffusers/pipelines/conversion_glide.py @@ -97,7 +97,7 @@ superres_model = GLIDESuperResUNetModel( superres_model.load_state_dict(ups_state_dict, strict=False) -upscale_scheduler = DDIMScheduler(timesteps=1000, beta_schedule="linear", beta_start=0.0001, beta_end=0.02) +upscale_scheduler = DDIMScheduler(timesteps=1000, beta_schedule="linear", beta_start=0.0001, beta_end=0.02, tensor_format="pt") glide = GLIDE( text_unet=text2im_model, diff --git a/src/diffusers/pipelines/pipeline_glide.py b/src/diffusers/pipelines/pipeline_glide.py index 9a45d492b9..6d2c3982fd 100644 --- a/src/diffusers/pipelines/pipeline_glide.py +++ b/src/diffusers/pipelines/pipeline_glide.py @@ -30,7 +30,6 @@ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPo from transformers.modeling_utils import PreTrainedModel from transformers.utils import ( ModelOutput, - add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings, @@ -860,6 +859,9 @@ class GLIDE(DiffusionPipeline): nonzero_mask = (t != 0).float().view(-1, *([1] * (len(image.shape) - 1))) # no noise when t == 0 image = mean + nonzero_mask * torch.exp(0.5 * log_variance) * noise + image = image[:1].permute(0, 2, 3, 1) + return image + # 4. Run the upscaling step batch_size = 1 image = image[:1] @@ -872,10 +874,10 @@ class GLIDE(DiffusionPipeline): # Sample gaussian noise to begin loop image = torch.randn( - (batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), + (batch_size, self.upscale_unet.in_channels // 2, self.upscale_unet.resolution, self.upscale_unet.resolution), generator=generator, ) - image = image.to(torch_device) + image = image.to(torch_device) * upsample_temp # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # Ideally, read DDIM paper in-detail understanding @@ -887,10 +889,15 @@ class GLIDE(DiffusionPipeline): # - eta -> η # - pred_image_direction -> "direction pointingc to x_t" # - pred_prev_image -> "x_t-1" + + num_trained_timesteps = self.upscale_noise_scheduler.timesteps + inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps_upscale) + self.upscale_noise_scheduler.rescale_betas(num_inference_steps_upscale) + for t in tqdm.tqdm(reversed(range(num_inference_steps_upscale)), total=num_inference_steps_upscale): # 1. predict noise residual with torch.no_grad(): - time_input = torch.tensor([t] * image.shape[0], device=torch_device) + time_input = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device) model_output = self.upscale_unet(image, time_input, low_res) noise_residual, pred_variance = torch.split(model_output, 3, dim=1) diff --git a/src/diffusers/schedulers/glide_ddim.py b/src/diffusers/schedulers/glide_ddim.py deleted file mode 100644 index e69de29bb2..0000000000