1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

refactor image_processor.py file (#9608)

* refactor image_processor file

* changes as requested

* +1 edits

* quality fix

* indent issue

---------

Co-authored-by: Aryan <aryan@huggingface.co>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
Charchit Sharma
2024-10-15 17:20:33 +05:30
committed by GitHub
parent dccf39f01e
commit 92d2baf643

View File

@@ -38,16 +38,44 @@ PipelineImageInput = Union[
PipelineDepthInput = PipelineImageInput
def is_valid_image(image):
def is_valid_image(image) -> bool:
r"""
Checks if the input is a valid image.
A valid image can be:
- A `PIL.Image.Image`.
- A 2D or 3D `np.ndarray` or `torch.Tensor` (grayscale or color image).
Args:
image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`):
The image to validate. It can be a PIL image, a NumPy array, or a torch tensor.
Returns:
`bool`:
`True` if the input is a valid image, `False` otherwise.
"""
return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3)
def is_valid_image_imagelist(images):
# check if the image input is one of the supported formats for image and image list:
# it can be either one of below 3
# (1) a 4d pytorch tensor or numpy array,
# (2) a valid image: PIL.Image.Image, 2-d np.ndarray or torch.Tensor (grayscale image), 3-d np.ndarray or torch.Tensor
# (3) a list of valid image
r"""
Checks if the input is a valid image or list of images.
The input can be one of the following formats:
- A 4D tensor or numpy array (batch of images).
- A valid single image: `PIL.Image.Image`, 2D `np.ndarray` or `torch.Tensor` (grayscale image), 3D `np.ndarray` or
`torch.Tensor`.
- A list of valid images.
Args:
images (`Union[np.ndarray, torch.Tensor, PIL.Image.Image, List]`):
The image(s) to check. Can be a batch of images (4D tensor/array), a single image, or a list of valid
images.
Returns:
`bool`:
`True` if the input is valid, `False` otherwise.
"""
if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4:
return True
elif is_valid_image(images):
@@ -103,8 +131,16 @@ class VaeImageProcessor(ConfigMixin):
@staticmethod
def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
"""
r"""
Convert a numpy image or a batch of images to a PIL image.
Args:
images (`np.ndarray`):
The image array to convert to PIL format.
Returns:
`List[PIL.Image.Image]`:
A list of PIL images.
"""
if images.ndim == 3:
images = images[None, ...]
@@ -119,8 +155,16 @@ class VaeImageProcessor(ConfigMixin):
@staticmethod
def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
"""
r"""
Convert a PIL image or a list of PIL images to NumPy arrays.
Args:
images (`PIL.Image.Image` or `List[PIL.Image.Image]`):
The PIL image or list of images to convert to NumPy format.
Returns:
`np.ndarray`:
A NumPy array representation of the images.
"""
if not isinstance(images, list):
images = [images]
@@ -131,8 +175,16 @@ class VaeImageProcessor(ConfigMixin):
@staticmethod
def numpy_to_pt(images: np.ndarray) -> torch.Tensor:
"""
r"""
Convert a NumPy image to a PyTorch tensor.
Args:
images (`np.ndarray`):
The NumPy image array to convert to PyTorch format.
Returns:
`torch.Tensor`:
A PyTorch tensor representation of the images.
"""
if images.ndim == 3:
images = images[..., None]
@@ -142,30 +194,62 @@ class VaeImageProcessor(ConfigMixin):
@staticmethod
def pt_to_numpy(images: torch.Tensor) -> np.ndarray:
"""
r"""
Convert a PyTorch tensor to a NumPy image.
Args:
images (`torch.Tensor`):
The PyTorch tensor to convert to NumPy format.
Returns:
`np.ndarray`:
A NumPy array representation of the images.
"""
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
return images
@staticmethod
def normalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
"""
r"""
Normalize an image array to [-1,1].
Args:
images (`np.ndarray` or `torch.Tensor`):
The image array to normalize.
Returns:
`np.ndarray` or `torch.Tensor`:
The normalized image array.
"""
return 2.0 * images - 1.0
@staticmethod
def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
"""
r"""
Denormalize an image array to [0,1].
Args:
images (`np.ndarray` or `torch.Tensor`):
The image array to denormalize.
Returns:
`np.ndarray` or `torch.Tensor`:
The denormalized image array.
"""
return (images / 2 + 0.5).clamp(0, 1)
@staticmethod
def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
"""
r"""
Converts a PIL image to RGB format.
Args:
image (`PIL.Image.Image`):
The PIL image to convert to RGB.
Returns:
`PIL.Image.Image`:
The RGB-converted PIL image.
"""
image = image.convert("RGB")
@@ -173,8 +257,16 @@ class VaeImageProcessor(ConfigMixin):
@staticmethod
def convert_to_grayscale(image: PIL.Image.Image) -> PIL.Image.Image:
"""
Converts a PIL image to grayscale format.
r"""
Converts a given PIL image to grayscale.
Args:
image (`PIL.Image.Image`):
The input image to convert.
Returns:
`PIL.Image.Image`:
The image converted to grayscale.
"""
image = image.convert("L")
@@ -182,8 +274,16 @@ class VaeImageProcessor(ConfigMixin):
@staticmethod
def blur(image: PIL.Image.Image, blur_factor: int = 4) -> PIL.Image.Image:
"""
r"""
Applies Gaussian blur to an image.
Args:
image (`PIL.Image.Image`):
The PIL image to convert to grayscale.
Returns:
`PIL.Image.Image`:
The grayscale-converted PIL image.
"""
image = image.filter(ImageFilter.GaussianBlur(blur_factor))
@@ -191,7 +291,7 @@ class VaeImageProcessor(ConfigMixin):
@staticmethod
def get_crop_region(mask_image: PIL.Image.Image, width: int, height: int, pad=0):
"""
r"""
Finds a rectangular region that contains all masked ares in an image, and expands region to match the aspect
ratio of the original image; for example, if user drew mask in a 128x32 region, and the dimensions for
processing are 512x512, the region will be expanded to 128x128.
@@ -285,14 +385,21 @@ class VaeImageProcessor(ConfigMixin):
width: int,
height: int,
) -> PIL.Image.Image:
"""
r"""
Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
the image within the dimensions, filling empty with data from image.
Args:
image: The image to resize.
width: The width to resize the image to.
height: The height to resize the image to.
image (`PIL.Image.Image`):
The image to resize and fill.
width (`int`):
The width to resize the image to.
height (`int`):
The height to resize the image to.
Returns:
`PIL.Image.Image`:
The resized and filled image.
"""
ratio = width / height
@@ -330,14 +437,21 @@ class VaeImageProcessor(ConfigMixin):
width: int,
height: int,
) -> PIL.Image.Image:
"""
r"""
Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
the image within the dimensions, cropping the excess.
Args:
image: The image to resize.
width: The width to resize the image to.
height: The height to resize the image to.
image (`PIL.Image.Image`):
The image to resize and crop.
width (`int`):
The width to resize the image to.
height (`int`):
The height to resize the image to.
Returns:
`PIL.Image.Image`:
The resized and cropped image.
"""
ratio = width / height
src_ratio = image.width / image.height
@@ -429,19 +543,23 @@ class VaeImageProcessor(ConfigMixin):
height: Optional[int] = None,
width: Optional[int] = None,
) -> Tuple[int, int]:
"""
This function return the height and width that are downscaled to the next integer multiple of
`vae_scale_factor`.
r"""
Returns the height and width of the image, downscaled to the next integer multiple of `vae_scale_factor`.
Args:
image(`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
The image input, can be a PIL image, numpy array or pytorch tensor. if it is a numpy array, should have
shape `[batch, height, width]` or `[batch, height, width, channel]` if it is a pytorch tensor, should
have shape `[batch, channel, height, width]`.
height (`int`, *optional*, defaults to `None`):
The height in preprocessed image. If `None`, will use the height of `image` input.
width (`int`, *optional*`, defaults to `None`):
The width in preprocessed. If `None`, will use the width of the `image` input.
image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`):
The image input, which can be a PIL image, NumPy array, or PyTorch tensor. If it is a NumPy array, it
should have shape `[batch, height, width]` or `[batch, height, width, channels]`. If it is a PyTorch
tensor, it should have shape `[batch, channels, height, width]`.
height (`Optional[int]`, *optional*, defaults to `None`):
The height of the preprocessed image. If `None`, the height of the `image` input will be used.
width (`Optional[int]`, *optional*, defaults to `None`):
The width of the preprocessed image. If `None`, the width of the `image` input will be used.
Returns:
`Tuple[int, int]`:
A tuple containing the height and width, both resized to the nearest integer multiple of
`vae_scale_factor`.
"""
if height is None:
@@ -478,13 +596,13 @@ class VaeImageProcessor(ConfigMixin):
Preprocess the image input.
Args:
image (`pipeline_image_input`):
image (`PipelineImageInput`):
The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of
supported formats.
height (`int`, *optional*, defaults to `None`):
height (`int`, *optional*):
The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default
height.
width (`int`, *optional*`, defaults to `None`):
width (`int`, *optional*):
The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
resize_mode (`str`, *optional*, defaults to `default`):
The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within
@@ -496,6 +614,10 @@ class VaeImageProcessor(ConfigMixin):
supported for PIL image input.
crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
The crop coordinates for each image in the batch. If `None`, will not crop the image.
Returns:
`torch.Tensor`:
The preprocessed image.
"""
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
@@ -655,8 +777,22 @@ class VaeImageProcessor(ConfigMixin):
image: PIL.Image.Image,
crop_coords: Optional[Tuple[int, int, int, int]] = None,
) -> PIL.Image.Image:
"""
overlay the inpaint output to the original image
r"""
Applies an overlay of the mask and the inpainted image on the original image.
Args:
mask (`PIL.Image.Image`):
The mask image that highlights regions to overlay.
init_image (`PIL.Image.Image`):
The original image to which the overlay is applied.
image (`PIL.Image.Image`):
The image to overlay onto the original.
crop_coords (`Tuple[int, int, int, int]`, *optional*):
Coordinates to crop the image. If provided, the image will be cropped accordingly.
Returns:
`PIL.Image.Image`:
The final image with the overlay applied.
"""
width, height = image.width, image.height
@@ -713,8 +849,16 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
@staticmethod
def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
"""
Convert a NumPy image or a batch of images to a PIL image.
r"""
Convert a NumPy image or a batch of images to a list of PIL images.
Args:
images (`np.ndarray`):
The input NumPy array of images, which can be a single image or a batch.
Returns:
`List[PIL.Image.Image]`:
A list of PIL images converted from the input NumPy array.
"""
if images.ndim == 3:
images = images[None, ...]
@@ -729,8 +873,16 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
@staticmethod
def depth_pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
"""
r"""
Convert a PIL image or a list of PIL images to NumPy arrays.
Args:
images (`Union[List[PIL.Image.Image], PIL.Image.Image]`):
The input image or list of images to be converted.
Returns:
`np.ndarray`:
A NumPy array of the converted images.
"""
if not isinstance(images, list):
images = [images]
@@ -741,18 +893,30 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
@staticmethod
def rgblike_to_depthmap(image: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
"""
r"""
Convert an RGB-like depth image to a depth map.
Args:
image: RGB-like depth image
Returns: depth map
image (`Union[np.ndarray, torch.Tensor]`):
The RGB-like depth image to convert.
Returns:
`Union[np.ndarray, torch.Tensor]`:
The corresponding depth map.
"""
return image[:, :, 1] * 2**8 + image[:, :, 2]
def numpy_to_depth(self, images: np.ndarray) -> List[PIL.Image.Image]:
"""
Convert a NumPy depth image or a batch of images to a PIL image.
r"""
Convert a NumPy depth image or a batch of images to a list of PIL images.
Args:
images (`np.ndarray`):
The input NumPy array of depth images, which can be a single image or a batch.
Returns:
`List[PIL.Image.Image]`:
A list of PIL images converted from the input NumPy depth images.
"""
if images.ndim == 3:
images = images[None, ...]
@@ -833,8 +997,24 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
width: Optional[int] = None,
target_res: Optional[int] = None,
) -> torch.Tensor:
"""
Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors.
r"""
Preprocess the image input. Accepted formats are PIL images, NumPy arrays, or PyTorch tensors.
Args:
rgb (`Union[torch.Tensor, PIL.Image.Image, np.ndarray]`):
The RGB input image, which can be a single image or a batch.
depth (`Union[torch.Tensor, PIL.Image.Image, np.ndarray]`):
The depth input image, which can be a single image or a batch.
height (`Optional[int]`, *optional*, defaults to `None`):
The desired height of the processed image. If `None`, defaults to the height of the input image.
width (`Optional[int]`, *optional*, defaults to `None`):
The desired width of the processed image. If `None`, defaults to the width of the input image.
target_res (`Optional[int]`, *optional*, defaults to `None`):
Target resolution for resizing the images. If specified, overrides height and width.
Returns:
`Tuple[torch.Tensor, torch.Tensor]`:
A tuple containing the processed RGB and depth images as PyTorch tensors.
"""
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
@@ -1072,7 +1252,17 @@ class PixArtImageProcessor(VaeImageProcessor):
@staticmethod
def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]:
"""Returns binned height and width."""
r"""
Returns the binned height and width based on the aspect ratio.
Args:
height (`int`): The height of the image.
width (`int`): The width of the image.
ratios (`dict`): A dictionary where keys are aspect ratios and values are tuples of (height, width).
Returns:
`Tuple[int, int]`: The closest binned height and width.
"""
ar = float(height / width)
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
default_hw = ratios[closest_ratio]
@@ -1080,6 +1270,19 @@ class PixArtImageProcessor(VaeImageProcessor):
@staticmethod
def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: int) -> torch.Tensor:
r"""
Resizes and crops a tensor of images to the specified dimensions.
Args:
samples (`torch.Tensor`):
A tensor of shape (N, C, H, W) where N is the batch size, C is the number of channels, H is the height,
and W is the width.
new_width (`int`): The desired width of the output images.
new_height (`int`): The desired height of the output images.
Returns:
`torch.Tensor`: A tensor containing the resized and cropped images.
"""
orig_height, orig_width = samples.shape[2], samples.shape[3]
# Check if resizing is needed