mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
PIL-ify the pipeline outputs (#111)
This commit is contained in:
@@ -19,6 +19,7 @@ import os
|
||||
from typing import Optional, Union
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from PIL import Image
|
||||
|
||||
from .configuration_utils import ConfigMixin
|
||||
from .utils import DIFFUSERS_CACHE, logging
|
||||
@@ -189,3 +190,15 @@ class DiffusionPipeline(ConfigMixin):
|
||||
# 5. Instantiate the pipeline
|
||||
model = pipeline_class(**init_kwargs)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def numpy_to_pil(images):
|
||||
"""
|
||||
Convert a numpy image or a batch of images to a PIL image.
|
||||
"""
|
||||
if images.ndim == 3:
|
||||
images = images[None, ...]
|
||||
images = (images * 255).round().astype("uint8")
|
||||
pil_images = [Image.fromarray(image) for image in images]
|
||||
|
||||
return pil_images
|
||||
|
||||
@@ -28,7 +28,7 @@ class DDPMPipeline(DiffusionPipeline):
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, batch_size=1, generator=None, torch_device=None):
|
||||
def __call__(self, batch_size=1, generator=None, torch_device=None, output_type="numpy"):
|
||||
if torch_device is None:
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
@@ -56,5 +56,7 @@ class DDPMPipeline(DiffusionPipeline):
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
return {"sample": image}
|
||||
|
||||
@@ -30,6 +30,7 @@ class LatentDiffusionPipeline(DiffusionPipeline):
|
||||
eta=0.0,
|
||||
guidance_scale=1.0,
|
||||
num_inference_steps=50,
|
||||
output_type="numpy",
|
||||
):
|
||||
# eta corresponds to η in paper and should be between [0, 1]
|
||||
|
||||
@@ -86,6 +87,8 @@ class LatentDiffusionPipeline(DiffusionPipeline):
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
return {"sample": image}
|
||||
|
||||
|
||||
@@ -13,12 +13,7 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
batch_size=1,
|
||||
generator=None,
|
||||
torch_device=None,
|
||||
eta=0.0,
|
||||
num_inference_steps=50,
|
||||
self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50, output_type="numpy"
|
||||
):
|
||||
# eta corresponds to η in paper and should be between [0, 1]
|
||||
|
||||
@@ -47,5 +42,7 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
return {"sample": image}
|
||||
|
||||
@@ -28,7 +28,7 @@ class PNDMPipeline(DiffusionPipeline):
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, batch_size=1, generator=None, torch_device=None, num_inference_steps=50):
|
||||
def __call__(self, batch_size=1, generator=None, torch_device=None, num_inference_steps=50, output_type="numpy"):
|
||||
# For more information on the sampling method you can take a look at Algorithm 2 of
|
||||
# the official paper: https://arxiv.org/pdf/2202.09778.pdf
|
||||
if torch_device is None:
|
||||
@@ -59,5 +59,7 @@ class PNDMPipeline(DiffusionPipeline):
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
return {"sample": image}
|
||||
|
||||
@@ -11,7 +11,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
|
||||
self.register_modules(model=model, scheduler=scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, num_inference_steps=2000, generator=None):
|
||||
def __call__(self, num_inference_steps=2000, generator=None, output_type="numpy"):
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
|
||||
img_size = self.model.config.image_size
|
||||
@@ -47,5 +47,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
|
||||
|
||||
sample = sample.clamp(0, 1)
|
||||
sample = sample.cpu().permute(0, 2, 3, 1).numpy()
|
||||
if output_type == "pil":
|
||||
sample = self.numpy_to_pil(sample)
|
||||
|
||||
return {"sample": sample}
|
||||
|
||||
@@ -18,11 +18,11 @@ import inspect
|
||||
import math
|
||||
import tempfile
|
||||
import unittest
|
||||
from atexit import register
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import PIL
|
||||
from diffusers import UNetConditionalModel # noqa: F401 TODO(Patrick) - need to write tests with it
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
@@ -728,6 +728,26 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
|
||||
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
|
||||
|
||||
@slow
|
||||
def test_output_format(self):
|
||||
model_path = "google/ddpm-cifar10-32"
|
||||
|
||||
pipe = DDIMPipeline.from_pretrained(model_path)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
images = pipe(generator=generator)["sample"]
|
||||
assert images.shape == (1, 32, 32, 3)
|
||||
assert isinstance(images, np.ndarray)
|
||||
|
||||
images = pipe(generator=generator, output_type="numpy")["sample"]
|
||||
assert images.shape == (1, 32, 32, 3)
|
||||
assert isinstance(images, np.ndarray)
|
||||
|
||||
images = pipe(generator=generator, output_type="pil")["sample"]
|
||||
assert isinstance(images, list)
|
||||
assert len(images) == 1
|
||||
assert isinstance(images[0], PIL.Image.Image)
|
||||
|
||||
@slow
|
||||
def test_ddpm_cifar10(self):
|
||||
model_id = "google/ddpm-cifar10-32"
|
||||
|
||||
Reference in New Issue
Block a user