From c6b04589b667ec09341cb195b7182e7d88bdf8a8 Mon Sep 17 00:00:00 2001 From: Yassine El Boudouri Date: Fri, 12 Jan 2024 02:50:24 +0100 Subject: [PATCH] Remove conversion to RGB (#6479) * Remove conversion to RGB * Add a Conversion Function * Add type hint for convert_method * Update src/diffusers/utils/loading_utils.py Update docstring Co-authored-by: Patrick von Platen * Update docstring * Optimize imports * Optimize imports (2) * Reformat code --------- Co-authored-by: Patrick von Platen Co-authored-by: Sayak Paul --- src/diffusers/utils/loading_utils.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/src/diffusers/utils/loading_utils.py b/src/diffusers/utils/loading_utils.py index 279aa6fe73..e129d5f3e3 100644 --- a/src/diffusers/utils/loading_utils.py +++ b/src/diffusers/utils/loading_utils.py @@ -1,18 +1,24 @@ import os -from typing import Union +from typing import Callable, Union import PIL.Image import PIL.ImageOps import requests -def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image: +def load_image( + image: Union[str, PIL.Image.Image], convert_method: Callable[[PIL.Image.Image], PIL.Image.Image] = None +) -> PIL.Image.Image: """ Loads `image` to a PIL Image. Args: image (`str` or `PIL.Image.Image`): The image to convert to the PIL Image format. + convert_method (Callable[[PIL.Image.Image], PIL.Image.Image], optional): + A conversion method to apply to the image after loading it. + When set to `None` the image will be converted "RGB". + Returns: `PIL.Image.Image`: A PIL Image. @@ -24,14 +30,18 @@ def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image: image = PIL.Image.open(image) else: raise ValueError( - f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path" + f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {image} is not a valid path." ) - elif isinstance(image, PIL.Image.Image): - image = image else: raise ValueError( - "Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image." + "Incorrect format used for the image. Should be a URL linking to an image, a local path, or a PIL image." ) + image = PIL.ImageOps.exif_transpose(image) - image = image.convert("RGB") + + if convert_method is not None: + image = convert_method(image) + else: + image = image.convert("RGB") + return image