diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index cbf2252c18..0fa6852bd1 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -19,6 +19,7 @@ import os from typing import Optional, Union from huggingface_hub import snapshot_download +from PIL import Image from .configuration_utils import ConfigMixin from .utils import DIFFUSERS_CACHE, logging @@ -189,3 +190,15 @@ class DiffusionPipeline(ConfigMixin): # 5. Instantiate the pipeline model = pipeline_class(**init_kwargs) return model + + @staticmethod + def numpy_to_pil(images): + """ + Convert a numpy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + pil_images = [Image.fromarray(image) for image in images] + + return pil_images diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py index a49a60dc84..aba1cde2ec 100644 --- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py +++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -28,7 +28,7 @@ class DDPMPipeline(DiffusionPipeline): self.register_modules(unet=unet, scheduler=scheduler) @torch.no_grad() - def __call__(self, batch_size=1, generator=None, torch_device=None): + def __call__(self, batch_size=1, generator=None, torch_device=None, output_type="numpy"): if torch_device is None: torch_device = "cuda" if torch.cuda.is_available() else "cpu" @@ -56,5 +56,7 @@ class DDPMPipeline(DiffusionPipeline): image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) return {"sample": image} diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py index f41ab01e15..f9da17fdee 100644 --- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py @@ -30,6 +30,7 @@ class LatentDiffusionPipeline(DiffusionPipeline): eta=0.0, guidance_scale=1.0, num_inference_steps=50, + output_type="numpy", ): # eta corresponds to η in paper and should be between [0, 1] @@ -86,6 +87,8 @@ class LatentDiffusionPipeline(DiffusionPipeline): image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) return {"sample": image} diff --git a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py index 38b4eb0517..9ddfa68b90 100644 --- a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +++ b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py @@ -13,12 +13,7 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline): @torch.no_grad() def __call__( - self, - batch_size=1, - generator=None, - torch_device=None, - eta=0.0, - num_inference_steps=50, + self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50, output_type="numpy" ): # eta corresponds to η in paper and should be between [0, 1] @@ -47,5 +42,7 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline): image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) return {"sample": image} diff --git a/src/diffusers/pipelines/pndm/pipeline_pndm.py b/src/diffusers/pipelines/pndm/pipeline_pndm.py index 69d6db6619..90cf761185 100644 --- a/src/diffusers/pipelines/pndm/pipeline_pndm.py +++ b/src/diffusers/pipelines/pndm/pipeline_pndm.py @@ -28,7 +28,7 @@ class PNDMPipeline(DiffusionPipeline): self.register_modules(unet=unet, scheduler=scheduler) @torch.no_grad() - def __call__(self, batch_size=1, generator=None, torch_device=None, num_inference_steps=50): + def __call__(self, batch_size=1, generator=None, torch_device=None, num_inference_steps=50, output_type="numpy"): # For more information on the sampling method you can take a look at Algorithm 2 of # the official paper: https://arxiv.org/pdf/2202.09778.pdf if torch_device is None: @@ -59,5 +59,7 @@ class PNDMPipeline(DiffusionPipeline): image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) return {"sample": image} diff --git a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py index 342847e2d1..bf1dbd38b9 100644 --- a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py +++ b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py @@ -11,7 +11,7 @@ class ScoreSdeVePipeline(DiffusionPipeline): self.register_modules(model=model, scheduler=scheduler) @torch.no_grad() - def __call__(self, num_inference_steps=2000, generator=None): + def __call__(self, num_inference_steps=2000, generator=None, output_type="numpy"): device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") img_size = self.model.config.image_size @@ -47,5 +47,7 @@ class ScoreSdeVePipeline(DiffusionPipeline): sample = sample.clamp(0, 1) sample = sample.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + sample = self.numpy_to_pil(sample) return {"sample": sample} diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 43b086290e..130ea65554 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -18,11 +18,11 @@ import inspect import math import tempfile import unittest -from atexit import register import numpy as np import torch +import PIL from diffusers import UNetConditionalModel # noqa: F401 TODO(Patrick) - need to write tests with it from diffusers import ( AutoencoderKL, @@ -728,6 +728,26 @@ class PipelineTesterMixin(unittest.TestCase): assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass" + @slow + def test_output_format(self): + model_path = "google/ddpm-cifar10-32" + + pipe = DDIMPipeline.from_pretrained(model_path) + + generator = torch.manual_seed(0) + images = pipe(generator=generator)["sample"] + assert images.shape == (1, 32, 32, 3) + assert isinstance(images, np.ndarray) + + images = pipe(generator=generator, output_type="numpy")["sample"] + assert images.shape == (1, 32, 32, 3) + assert isinstance(images, np.ndarray) + + images = pipe(generator=generator, output_type="pil")["sample"] + assert isinstance(images, list) + assert len(images) == 1 + assert isinstance(images[0], PIL.Image.Image) + @slow def test_ddpm_cifar10(self): model_id = "google/ddpm-cifar10-32"