From 383dc795c9a4b8a7f2d75f76206be3dc70d47168 Mon Sep 17 00:00:00 2001 From: anton-l Date: Wed, 8 Jun 2022 13:51:46 +0200 Subject: [PATCH] glide is alive! --- models/vision/glide/modeling_glide.py | 3 +++ models/vision/glide/run_glide.py | 8 +++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/models/vision/glide/modeling_glide.py b/models/vision/glide/modeling_glide.py index cc2880d85d..22dbcaac02 100644 --- a/models/vision/glide/modeling_glide.py +++ b/models/vision/glide/modeling_glide.py @@ -124,6 +124,7 @@ class GLIDE(DiffusionPipeline): - _extract_into_tensor(self.noise_scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps ) + @torch.no_grad() def __call__(self, prompt, generator=None, torch_device=None): torch_device = "cuda" if torch.cuda.is_available() else "cpu" @@ -164,4 +165,6 @@ class GLIDE(DiffusionPipeline): nonzero_mask = (t != 0).float().view(-1, *([1] * (len(image.shape) - 1))) # no noise when t == 0 image = mean + nonzero_mask * torch.exp(0.5 * log_variance) * noise + image = image[0].permute(1, 2, 0) + return image diff --git a/models/vision/glide/run_glide.py b/models/vision/glide/run_glide.py index 1bea36fc85..c7510166da 100644 --- a/models/vision/glide/run_glide.py +++ b/models/vision/glide/run_glide.py @@ -1,6 +1,9 @@ import torch from modeling_glide import GLIDE +import matplotlib +import matplotlib.pyplot as plt +matplotlib.rcParams['interactive'] = True generator = torch.Generator() @@ -10,5 +13,8 @@ generator = generator.manual_seed(0) pipeline = GLIDE.from_pretrained("fusing/glide-base") img = pipeline("an oil painting of a corgi", generator) +img = ((img + 1)*127.5).round().clamp(0, 255).to(torch.uint8).cpu().numpy() -print(img) +plt.figure(figsize=(8, 8)) +plt.imshow(img) +plt.show()