diff --git a/models/vision/ddpm/example.py b/models/vision/ddpm/example.py index 2ba753c385..0dca57aa38 100755 --- a/models/vision/ddpm/example.py +++ b/models/vision/ddpm/example.py @@ -1,20 +1,26 @@ #!/usr/bin/env python3 import tempfile import sys - +import os +import pathlib from modeling_ddpm import DDPM - -model_id = sys.argv[1] - -ddpm = DDPM.from_pretrained(model_id) -image = ddpm() - import PIL.Image import numpy as np -image_processed = image.cpu().permute(0, 2, 3, 1) -image_processed = (image_processed + 1.0) * 127.5 -image_processed = image_processed.numpy().astype(np.uint8) -image_pil = PIL.Image.fromarray(image_processed[0]) -image_pil.save("test.png") -import ipdb; ipdb.set_trace() +model_ids = ["ddpm-lsun-cat", "ddpm-lsun-cat-ema", "ddpm-lsun-church-ema", "ddpm-lsun-church", "ddpm-lsun-bedroom", "ddpm-lsun-bedroom-ema", "ddpm-cifar10-ema", "ddpm-lsun-cifar10", "ddpm-lsun-celeba-hq", "ddpm-lsun-celeba-hq-ema"] + +for model_id in model_ids: + + path = os.path.join("/home/patrick/images/hf", model_id) + pathlib.Path(path).mkdir(parents=True, exist_ok=True) + + ddpm = DDPM.from_pretrained("fusing/" + model_id) + image = ddpm(batch_size=4) + + image_processed = image.cpu().permute(0, 2, 3, 1) + image_processed = (image_processed + 1.0) * 127.5 + image_processed = image_processed.numpy().astype(np.uint8) + + for i in range(image_processed.shape[0]): + image_pil = PIL.Image.fromarray(image_processed[i]) + image_pil.save(os.path.join(path, f"image_{i}.png")) diff --git a/models/vision/ddpm/modeling_ddpm.py b/models/vision/ddpm/modeling_ddpm.py index 24b9fdecbf..a10feaba40 100644 --- a/models/vision/ddpm/modeling_ddpm.py +++ b/models/vision/ddpm/modeling_ddpm.py @@ -33,7 +33,7 @@ class DDPM(DiffusionPipeline): self.unet.to(torch_device) # 1. Sample gaussian noise - image = self.noise_scheduler.sample_noise((1, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator) + image = self.noise_scheduler.sample_noise((batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator) for t in tqdm.tqdm(reversed(range(len(self.noise_scheduler))), total=len(self.noise_scheduler)): # i) define coefficients for time step t clip_image_coeff = 1 / torch.sqrt(self.noise_scheduler.get_alpha_prod(t)) diff --git a/src/diffusers/schedulers/gaussian_ddpm.py b/src/diffusers/schedulers/gaussian_ddpm.py index 4fcdfdf2bd..2a25cbbfc9 100644 --- a/src/diffusers/schedulers/gaussian_ddpm.py +++ b/src/diffusers/schedulers/gaussian_ddpm.py @@ -108,7 +108,7 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin): def sample_variance(self, time_step, shape, device, generator=None): variance = self.log_variance[time_step] - nonzero_mask = torch.tensor([1 - (time_step == 0)], device=device).float()[None, :].repeat(shape[0], 1) + nonzero_mask = torch.tensor([1 - (time_step == 0)], device=device).float()[None, :] noise = self.sample_noise(shape, device=device, generator=generator) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 4655c96749..6dce91ae4b 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -76,7 +76,7 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None): class ModelTesterMixin(unittest.TestCase): @property def dummy_input(self): - batch_size = 1 + batch_size = 4 num_channels = 3 sizes = (32, 32)