mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
refactor image_processor.py file (#9608)
* refactor image_processor file * changes as requested * +1 edits * quality fix * indent issue --------- Co-authored-by: Aryan <aryan@huggingface.co> Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
@@ -38,16 +38,44 @@ PipelineImageInput = Union[
|
||||
PipelineDepthInput = PipelineImageInput
|
||||
|
||||
|
||||
def is_valid_image(image):
|
||||
def is_valid_image(image) -> bool:
|
||||
r"""
|
||||
Checks if the input is a valid image.
|
||||
|
||||
A valid image can be:
|
||||
- A `PIL.Image.Image`.
|
||||
- A 2D or 3D `np.ndarray` or `torch.Tensor` (grayscale or color image).
|
||||
|
||||
Args:
|
||||
image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`):
|
||||
The image to validate. It can be a PIL image, a NumPy array, or a torch tensor.
|
||||
|
||||
Returns:
|
||||
`bool`:
|
||||
`True` if the input is a valid image, `False` otherwise.
|
||||
"""
|
||||
return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3)
|
||||
|
||||
|
||||
def is_valid_image_imagelist(images):
|
||||
# check if the image input is one of the supported formats for image and image list:
|
||||
# it can be either one of below 3
|
||||
# (1) a 4d pytorch tensor or numpy array,
|
||||
# (2) a valid image: PIL.Image.Image, 2-d np.ndarray or torch.Tensor (grayscale image), 3-d np.ndarray or torch.Tensor
|
||||
# (3) a list of valid image
|
||||
r"""
|
||||
Checks if the input is a valid image or list of images.
|
||||
|
||||
The input can be one of the following formats:
|
||||
- A 4D tensor or numpy array (batch of images).
|
||||
- A valid single image: `PIL.Image.Image`, 2D `np.ndarray` or `torch.Tensor` (grayscale image), 3D `np.ndarray` or
|
||||
`torch.Tensor`.
|
||||
- A list of valid images.
|
||||
|
||||
Args:
|
||||
images (`Union[np.ndarray, torch.Tensor, PIL.Image.Image, List]`):
|
||||
The image(s) to check. Can be a batch of images (4D tensor/array), a single image, or a list of valid
|
||||
images.
|
||||
|
||||
Returns:
|
||||
`bool`:
|
||||
`True` if the input is valid, `False` otherwise.
|
||||
"""
|
||||
if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4:
|
||||
return True
|
||||
elif is_valid_image(images):
|
||||
@@ -103,8 +131,16 @@ class VaeImageProcessor(ConfigMixin):
|
||||
|
||||
@staticmethod
|
||||
def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
|
||||
"""
|
||||
r"""
|
||||
Convert a numpy image or a batch of images to a PIL image.
|
||||
|
||||
Args:
|
||||
images (`np.ndarray`):
|
||||
The image array to convert to PIL format.
|
||||
|
||||
Returns:
|
||||
`List[PIL.Image.Image]`:
|
||||
A list of PIL images.
|
||||
"""
|
||||
if images.ndim == 3:
|
||||
images = images[None, ...]
|
||||
@@ -119,8 +155,16 @@ class VaeImageProcessor(ConfigMixin):
|
||||
|
||||
@staticmethod
|
||||
def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
|
||||
"""
|
||||
r"""
|
||||
Convert a PIL image or a list of PIL images to NumPy arrays.
|
||||
|
||||
Args:
|
||||
images (`PIL.Image.Image` or `List[PIL.Image.Image]`):
|
||||
The PIL image or list of images to convert to NumPy format.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`:
|
||||
A NumPy array representation of the images.
|
||||
"""
|
||||
if not isinstance(images, list):
|
||||
images = [images]
|
||||
@@ -131,8 +175,16 @@ class VaeImageProcessor(ConfigMixin):
|
||||
|
||||
@staticmethod
|
||||
def numpy_to_pt(images: np.ndarray) -> torch.Tensor:
|
||||
"""
|
||||
r"""
|
||||
Convert a NumPy image to a PyTorch tensor.
|
||||
|
||||
Args:
|
||||
images (`np.ndarray`):
|
||||
The NumPy image array to convert to PyTorch format.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
A PyTorch tensor representation of the images.
|
||||
"""
|
||||
if images.ndim == 3:
|
||||
images = images[..., None]
|
||||
@@ -142,30 +194,62 @@ class VaeImageProcessor(ConfigMixin):
|
||||
|
||||
@staticmethod
|
||||
def pt_to_numpy(images: torch.Tensor) -> np.ndarray:
|
||||
"""
|
||||
r"""
|
||||
Convert a PyTorch tensor to a NumPy image.
|
||||
|
||||
Args:
|
||||
images (`torch.Tensor`):
|
||||
The PyTorch tensor to convert to NumPy format.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`:
|
||||
A NumPy array representation of the images.
|
||||
"""
|
||||
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
return images
|
||||
|
||||
@staticmethod
|
||||
def normalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
|
||||
"""
|
||||
r"""
|
||||
Normalize an image array to [-1,1].
|
||||
|
||||
Args:
|
||||
images (`np.ndarray` or `torch.Tensor`):
|
||||
The image array to normalize.
|
||||
|
||||
Returns:
|
||||
`np.ndarray` or `torch.Tensor`:
|
||||
The normalized image array.
|
||||
"""
|
||||
return 2.0 * images - 1.0
|
||||
|
||||
@staticmethod
|
||||
def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
|
||||
"""
|
||||
r"""
|
||||
Denormalize an image array to [0,1].
|
||||
|
||||
Args:
|
||||
images (`np.ndarray` or `torch.Tensor`):
|
||||
The image array to denormalize.
|
||||
|
||||
Returns:
|
||||
`np.ndarray` or `torch.Tensor`:
|
||||
The denormalized image array.
|
||||
"""
|
||||
return (images / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
@staticmethod
|
||||
def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
|
||||
"""
|
||||
r"""
|
||||
Converts a PIL image to RGB format.
|
||||
|
||||
Args:
|
||||
image (`PIL.Image.Image`):
|
||||
The PIL image to convert to RGB.
|
||||
|
||||
Returns:
|
||||
`PIL.Image.Image`:
|
||||
The RGB-converted PIL image.
|
||||
"""
|
||||
image = image.convert("RGB")
|
||||
|
||||
@@ -173,8 +257,16 @@ class VaeImageProcessor(ConfigMixin):
|
||||
|
||||
@staticmethod
|
||||
def convert_to_grayscale(image: PIL.Image.Image) -> PIL.Image.Image:
|
||||
"""
|
||||
Converts a PIL image to grayscale format.
|
||||
r"""
|
||||
Converts a given PIL image to grayscale.
|
||||
|
||||
Args:
|
||||
image (`PIL.Image.Image`):
|
||||
The input image to convert.
|
||||
|
||||
Returns:
|
||||
`PIL.Image.Image`:
|
||||
The image converted to grayscale.
|
||||
"""
|
||||
image = image.convert("L")
|
||||
|
||||
@@ -182,8 +274,16 @@ class VaeImageProcessor(ConfigMixin):
|
||||
|
||||
@staticmethod
|
||||
def blur(image: PIL.Image.Image, blur_factor: int = 4) -> PIL.Image.Image:
|
||||
"""
|
||||
r"""
|
||||
Applies Gaussian blur to an image.
|
||||
|
||||
Args:
|
||||
image (`PIL.Image.Image`):
|
||||
The PIL image to convert to grayscale.
|
||||
|
||||
Returns:
|
||||
`PIL.Image.Image`:
|
||||
The grayscale-converted PIL image.
|
||||
"""
|
||||
image = image.filter(ImageFilter.GaussianBlur(blur_factor))
|
||||
|
||||
@@ -191,7 +291,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
|
||||
@staticmethod
|
||||
def get_crop_region(mask_image: PIL.Image.Image, width: int, height: int, pad=0):
|
||||
"""
|
||||
r"""
|
||||
Finds a rectangular region that contains all masked ares in an image, and expands region to match the aspect
|
||||
ratio of the original image; for example, if user drew mask in a 128x32 region, and the dimensions for
|
||||
processing are 512x512, the region will be expanded to 128x128.
|
||||
@@ -285,14 +385,21 @@ class VaeImageProcessor(ConfigMixin):
|
||||
width: int,
|
||||
height: int,
|
||||
) -> PIL.Image.Image:
|
||||
"""
|
||||
r"""
|
||||
Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
|
||||
the image within the dimensions, filling empty with data from image.
|
||||
|
||||
Args:
|
||||
image: The image to resize.
|
||||
width: The width to resize the image to.
|
||||
height: The height to resize the image to.
|
||||
image (`PIL.Image.Image`):
|
||||
The image to resize and fill.
|
||||
width (`int`):
|
||||
The width to resize the image to.
|
||||
height (`int`):
|
||||
The height to resize the image to.
|
||||
|
||||
Returns:
|
||||
`PIL.Image.Image`:
|
||||
The resized and filled image.
|
||||
"""
|
||||
|
||||
ratio = width / height
|
||||
@@ -330,14 +437,21 @@ class VaeImageProcessor(ConfigMixin):
|
||||
width: int,
|
||||
height: int,
|
||||
) -> PIL.Image.Image:
|
||||
"""
|
||||
r"""
|
||||
Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
|
||||
the image within the dimensions, cropping the excess.
|
||||
|
||||
Args:
|
||||
image: The image to resize.
|
||||
width: The width to resize the image to.
|
||||
height: The height to resize the image to.
|
||||
image (`PIL.Image.Image`):
|
||||
The image to resize and crop.
|
||||
width (`int`):
|
||||
The width to resize the image to.
|
||||
height (`int`):
|
||||
The height to resize the image to.
|
||||
|
||||
Returns:
|
||||
`PIL.Image.Image`:
|
||||
The resized and cropped image.
|
||||
"""
|
||||
ratio = width / height
|
||||
src_ratio = image.width / image.height
|
||||
@@ -429,19 +543,23 @@ class VaeImageProcessor(ConfigMixin):
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
This function return the height and width that are downscaled to the next integer multiple of
|
||||
`vae_scale_factor`.
|
||||
r"""
|
||||
Returns the height and width of the image, downscaled to the next integer multiple of `vae_scale_factor`.
|
||||
|
||||
Args:
|
||||
image(`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
|
||||
The image input, can be a PIL image, numpy array or pytorch tensor. if it is a numpy array, should have
|
||||
shape `[batch, height, width]` or `[batch, height, width, channel]` if it is a pytorch tensor, should
|
||||
have shape `[batch, channel, height, width]`.
|
||||
height (`int`, *optional*, defaults to `None`):
|
||||
The height in preprocessed image. If `None`, will use the height of `image` input.
|
||||
width (`int`, *optional*`, defaults to `None`):
|
||||
The width in preprocessed. If `None`, will use the width of the `image` input.
|
||||
image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`):
|
||||
The image input, which can be a PIL image, NumPy array, or PyTorch tensor. If it is a NumPy array, it
|
||||
should have shape `[batch, height, width]` or `[batch, height, width, channels]`. If it is a PyTorch
|
||||
tensor, it should have shape `[batch, channels, height, width]`.
|
||||
height (`Optional[int]`, *optional*, defaults to `None`):
|
||||
The height of the preprocessed image. If `None`, the height of the `image` input will be used.
|
||||
width (`Optional[int]`, *optional*, defaults to `None`):
|
||||
The width of the preprocessed image. If `None`, the width of the `image` input will be used.
|
||||
|
||||
Returns:
|
||||
`Tuple[int, int]`:
|
||||
A tuple containing the height and width, both resized to the nearest integer multiple of
|
||||
`vae_scale_factor`.
|
||||
"""
|
||||
|
||||
if height is None:
|
||||
@@ -478,13 +596,13 @@ class VaeImageProcessor(ConfigMixin):
|
||||
Preprocess the image input.
|
||||
|
||||
Args:
|
||||
image (`pipeline_image_input`):
|
||||
image (`PipelineImageInput`):
|
||||
The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of
|
||||
supported formats.
|
||||
height (`int`, *optional*, defaults to `None`):
|
||||
height (`int`, *optional*):
|
||||
The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default
|
||||
height.
|
||||
width (`int`, *optional*`, defaults to `None`):
|
||||
width (`int`, *optional*):
|
||||
The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
|
||||
resize_mode (`str`, *optional*, defaults to `default`):
|
||||
The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within
|
||||
@@ -496,6 +614,10 @@ class VaeImageProcessor(ConfigMixin):
|
||||
supported for PIL image input.
|
||||
crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
|
||||
The crop coordinates for each image in the batch. If `None`, will not crop the image.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The preprocessed image.
|
||||
"""
|
||||
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
|
||||
|
||||
@@ -655,8 +777,22 @@ class VaeImageProcessor(ConfigMixin):
|
||||
image: PIL.Image.Image,
|
||||
crop_coords: Optional[Tuple[int, int, int, int]] = None,
|
||||
) -> PIL.Image.Image:
|
||||
"""
|
||||
overlay the inpaint output to the original image
|
||||
r"""
|
||||
Applies an overlay of the mask and the inpainted image on the original image.
|
||||
|
||||
Args:
|
||||
mask (`PIL.Image.Image`):
|
||||
The mask image that highlights regions to overlay.
|
||||
init_image (`PIL.Image.Image`):
|
||||
The original image to which the overlay is applied.
|
||||
image (`PIL.Image.Image`):
|
||||
The image to overlay onto the original.
|
||||
crop_coords (`Tuple[int, int, int, int]`, *optional*):
|
||||
Coordinates to crop the image. If provided, the image will be cropped accordingly.
|
||||
|
||||
Returns:
|
||||
`PIL.Image.Image`:
|
||||
The final image with the overlay applied.
|
||||
"""
|
||||
|
||||
width, height = image.width, image.height
|
||||
@@ -713,8 +849,16 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
|
||||
|
||||
@staticmethod
|
||||
def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
|
||||
"""
|
||||
Convert a NumPy image or a batch of images to a PIL image.
|
||||
r"""
|
||||
Convert a NumPy image or a batch of images to a list of PIL images.
|
||||
|
||||
Args:
|
||||
images (`np.ndarray`):
|
||||
The input NumPy array of images, which can be a single image or a batch.
|
||||
|
||||
Returns:
|
||||
`List[PIL.Image.Image]`:
|
||||
A list of PIL images converted from the input NumPy array.
|
||||
"""
|
||||
if images.ndim == 3:
|
||||
images = images[None, ...]
|
||||
@@ -729,8 +873,16 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
|
||||
|
||||
@staticmethod
|
||||
def depth_pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
|
||||
"""
|
||||
r"""
|
||||
Convert a PIL image or a list of PIL images to NumPy arrays.
|
||||
|
||||
Args:
|
||||
images (`Union[List[PIL.Image.Image], PIL.Image.Image]`):
|
||||
The input image or list of images to be converted.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`:
|
||||
A NumPy array of the converted images.
|
||||
"""
|
||||
if not isinstance(images, list):
|
||||
images = [images]
|
||||
@@ -741,18 +893,30 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
|
||||
|
||||
@staticmethod
|
||||
def rgblike_to_depthmap(image: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
|
||||
"""
|
||||
r"""
|
||||
Convert an RGB-like depth image to a depth map.
|
||||
|
||||
Args:
|
||||
image: RGB-like depth image
|
||||
|
||||
Returns: depth map
|
||||
image (`Union[np.ndarray, torch.Tensor]`):
|
||||
The RGB-like depth image to convert.
|
||||
|
||||
Returns:
|
||||
`Union[np.ndarray, torch.Tensor]`:
|
||||
The corresponding depth map.
|
||||
"""
|
||||
return image[:, :, 1] * 2**8 + image[:, :, 2]
|
||||
|
||||
def numpy_to_depth(self, images: np.ndarray) -> List[PIL.Image.Image]:
|
||||
"""
|
||||
Convert a NumPy depth image or a batch of images to a PIL image.
|
||||
r"""
|
||||
Convert a NumPy depth image or a batch of images to a list of PIL images.
|
||||
|
||||
Args:
|
||||
images (`np.ndarray`):
|
||||
The input NumPy array of depth images, which can be a single image or a batch.
|
||||
|
||||
Returns:
|
||||
`List[PIL.Image.Image]`:
|
||||
A list of PIL images converted from the input NumPy depth images.
|
||||
"""
|
||||
if images.ndim == 3:
|
||||
images = images[None, ...]
|
||||
@@ -833,8 +997,24 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
|
||||
width: Optional[int] = None,
|
||||
target_res: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors.
|
||||
r"""
|
||||
Preprocess the image input. Accepted formats are PIL images, NumPy arrays, or PyTorch tensors.
|
||||
|
||||
Args:
|
||||
rgb (`Union[torch.Tensor, PIL.Image.Image, np.ndarray]`):
|
||||
The RGB input image, which can be a single image or a batch.
|
||||
depth (`Union[torch.Tensor, PIL.Image.Image, np.ndarray]`):
|
||||
The depth input image, which can be a single image or a batch.
|
||||
height (`Optional[int]`, *optional*, defaults to `None`):
|
||||
The desired height of the processed image. If `None`, defaults to the height of the input image.
|
||||
width (`Optional[int]`, *optional*, defaults to `None`):
|
||||
The desired width of the processed image. If `None`, defaults to the width of the input image.
|
||||
target_res (`Optional[int]`, *optional*, defaults to `None`):
|
||||
Target resolution for resizing the images. If specified, overrides height and width.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, torch.Tensor]`:
|
||||
A tuple containing the processed RGB and depth images as PyTorch tensors.
|
||||
"""
|
||||
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
|
||||
|
||||
@@ -1072,7 +1252,17 @@ class PixArtImageProcessor(VaeImageProcessor):
|
||||
|
||||
@staticmethod
|
||||
def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]:
|
||||
"""Returns binned height and width."""
|
||||
r"""
|
||||
Returns the binned height and width based on the aspect ratio.
|
||||
|
||||
Args:
|
||||
height (`int`): The height of the image.
|
||||
width (`int`): The width of the image.
|
||||
ratios (`dict`): A dictionary where keys are aspect ratios and values are tuples of (height, width).
|
||||
|
||||
Returns:
|
||||
`Tuple[int, int]`: The closest binned height and width.
|
||||
"""
|
||||
ar = float(height / width)
|
||||
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
|
||||
default_hw = ratios[closest_ratio]
|
||||
@@ -1080,6 +1270,19 @@ class PixArtImageProcessor(VaeImageProcessor):
|
||||
|
||||
@staticmethod
|
||||
def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: int) -> torch.Tensor:
|
||||
r"""
|
||||
Resizes and crops a tensor of images to the specified dimensions.
|
||||
|
||||
Args:
|
||||
samples (`torch.Tensor`):
|
||||
A tensor of shape (N, C, H, W) where N is the batch size, C is the number of channels, H is the height,
|
||||
and W is the width.
|
||||
new_width (`int`): The desired width of the output images.
|
||||
new_height (`int`): The desired height of the output images.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: A tensor containing the resized and cropped images.
|
||||
"""
|
||||
orig_height, orig_width = samples.shape[2], samples.shape[3]
|
||||
|
||||
# Check if resizing is needed
|
||||
|
||||
Reference in New Issue
Block a user