From fb38bb1621976e6695ee8f8b2e95ced877b69c3b Mon Sep 17 00:00:00 2001 From: Anton Lozhkov Date: Thu, 27 Oct 2022 22:44:35 +0200 Subject: [PATCH] Support grayscale images in `numpy_to_pil` (#1025) --- src/diffusers/pipeline_flax_utils.py | 6 +++++- src/diffusers/pipeline_utils.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipeline_flax_utils.py b/src/diffusers/pipeline_flax_utils.py index 3c23693b40..e96c0c7467 100644 --- a/src/diffusers/pipeline_flax_utils.py +++ b/src/diffusers/pipeline_flax_utils.py @@ -444,7 +444,11 @@ class FlaxDiffusionPipeline(ConfigMixin): if images.ndim == 3: images = images[None, ...] images = (images * 255).round().astype("uint8") - pil_images = [Image.fromarray(image) for image in images] + if images.shape[-1] == 1: + # special case for grayscale (single channel) images + pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] + else: + pil_images = [Image.fromarray(image) for image in images] return pil_images diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index c9c58a7488..c0a44363a2 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -625,7 +625,11 @@ class DiffusionPipeline(ConfigMixin): if images.ndim == 3: images = images[None, ...] images = (images * 255).round().astype("uint8") - pil_images = [Image.fromarray(image) for image in images] + if images.shape[-1] == 1: + # special case for grayscale (single channel) images + pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] + else: + pil_images = [Image.fromarray(image) for image in images] return pil_images