mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
remove the deprecated prepare_mask_and_masked_image function (#8512)
remove prepare mask fn Co-authored-by: yiyixuxu <yixu310@gmail,com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -118,129 +118,6 @@ def retrieve_latents(
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.prepare_mask_and_masked_image
|
||||
def prepare_mask_and_masked_image(image, mask, height, width, return_image=False):
|
||||
"""
|
||||
Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
|
||||
converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
|
||||
``image`` and ``1`` for the ``mask``.
|
||||
|
||||
The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
|
||||
binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
|
||||
|
||||
Args:
|
||||
image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
|
||||
It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
|
||||
``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
|
||||
mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
|
||||
It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
|
||||
``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
|
||||
should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
|
||||
TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
|
||||
(ot the other way around).
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
|
||||
dimensions: ``batch x channels x height x width``.
|
||||
"""
|
||||
deprecation_message = "The prepare_mask_and_masked_image method is deprecated and will be removed in a future version. Please use VaeImageProcessor.preprocess instead"
|
||||
deprecate(
|
||||
"prepare_mask_and_masked_image",
|
||||
"0.30.0",
|
||||
deprecation_message,
|
||||
)
|
||||
if image is None:
|
||||
raise ValueError("`image` input cannot be undefined.")
|
||||
|
||||
if mask is None:
|
||||
raise ValueError("`mask_image` input cannot be undefined.")
|
||||
|
||||
if isinstance(image, torch.Tensor):
|
||||
if not isinstance(mask, torch.Tensor):
|
||||
raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not")
|
||||
|
||||
# Batch single image
|
||||
if image.ndim == 3:
|
||||
assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)"
|
||||
image = image.unsqueeze(0)
|
||||
|
||||
# Batch and add channel dim for single mask
|
||||
if mask.ndim == 2:
|
||||
mask = mask.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
# Batch single mask or add channel dim
|
||||
if mask.ndim == 3:
|
||||
# Single batched mask, no channel dim or single mask not batched but channel dim
|
||||
if mask.shape[0] == 1:
|
||||
mask = mask.unsqueeze(0)
|
||||
|
||||
# Batched masks no channel dim
|
||||
else:
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
|
||||
assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
|
||||
assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
|
||||
|
||||
# Check image is in [-1, 1]
|
||||
if image.min() < -1 or image.max() > 1:
|
||||
raise ValueError("Image should be in [-1, 1] range")
|
||||
|
||||
# Check mask is in [0, 1]
|
||||
if mask.min() < 0 or mask.max() > 1:
|
||||
raise ValueError("Mask should be in [0, 1] range")
|
||||
|
||||
# Binarize mask
|
||||
mask[mask < 0.5] = 0
|
||||
mask[mask >= 0.5] = 1
|
||||
|
||||
# Image as float32
|
||||
image = image.to(dtype=torch.float32)
|
||||
elif isinstance(mask, torch.Tensor):
|
||||
raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
|
||||
else:
|
||||
# preprocess image
|
||||
if isinstance(image, (PIL.Image.Image, np.ndarray)):
|
||||
image = [image]
|
||||
if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
|
||||
# resize all images w.r.t passed height an width
|
||||
image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]
|
||||
image = [np.array(i.convert("RGB"))[None, :] for i in image]
|
||||
image = np.concatenate(image, axis=0)
|
||||
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
|
||||
image = np.concatenate([i[None, :] for i in image], axis=0)
|
||||
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
||||
|
||||
# preprocess mask
|
||||
if isinstance(mask, (PIL.Image.Image, np.ndarray)):
|
||||
mask = [mask]
|
||||
|
||||
if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
|
||||
mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]
|
||||
mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
|
||||
mask = mask.astype(np.float32) / 255.0
|
||||
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
|
||||
mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
|
||||
|
||||
mask[mask < 0.5] = 0
|
||||
mask[mask >= 0.5] = 1
|
||||
mask = torch.from_numpy(mask)
|
||||
|
||||
masked_image = image * (mask < 0.5)
|
||||
|
||||
# n.b. ensure backwards compatibility as old function does not return image
|
||||
if return_image:
|
||||
return mask, masked_image, image
|
||||
|
||||
return mask, masked_image
|
||||
|
||||
|
||||
class StableDiffusionControlNetInpaintPipeline(
|
||||
DiffusionPipeline,
|
||||
StableDiffusionMixin,
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
from packaging import version
|
||||
@@ -38,128 +37,6 @@ from .safety_checker import StableDiffusionSafetyChecker
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False):
|
||||
"""
|
||||
Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
|
||||
converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
|
||||
``image`` and ``1`` for the ``mask``.
|
||||
|
||||
The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
|
||||
binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
|
||||
|
||||
Args:
|
||||
image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
|
||||
It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
|
||||
``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
|
||||
mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
|
||||
It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
|
||||
``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
|
||||
should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
|
||||
TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
|
||||
(ot the other way around).
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
|
||||
dimensions: ``batch x channels x height x width``.
|
||||
"""
|
||||
deprecation_message = "The prepare_mask_and_masked_image method is deprecated and will be removed in a future version. Please use VaeImageProcessor.preprocess instead"
|
||||
deprecate(
|
||||
"prepare_mask_and_masked_image",
|
||||
"0.30.0",
|
||||
deprecation_message,
|
||||
)
|
||||
if image is None:
|
||||
raise ValueError("`image` input cannot be undefined.")
|
||||
|
||||
if mask is None:
|
||||
raise ValueError("`mask_image` input cannot be undefined.")
|
||||
|
||||
if isinstance(image, torch.Tensor):
|
||||
if not isinstance(mask, torch.Tensor):
|
||||
raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not")
|
||||
|
||||
# Batch single image
|
||||
if image.ndim == 3:
|
||||
assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)"
|
||||
image = image.unsqueeze(0)
|
||||
|
||||
# Batch and add channel dim for single mask
|
||||
if mask.ndim == 2:
|
||||
mask = mask.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
# Batch single mask or add channel dim
|
||||
if mask.ndim == 3:
|
||||
# Single batched mask, no channel dim or single mask not batched but channel dim
|
||||
if mask.shape[0] == 1:
|
||||
mask = mask.unsqueeze(0)
|
||||
|
||||
# Batched masks no channel dim
|
||||
else:
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
|
||||
assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
|
||||
assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
|
||||
|
||||
# Check image is in [-1, 1]
|
||||
if image.min() < -1 or image.max() > 1:
|
||||
raise ValueError("Image should be in [-1, 1] range")
|
||||
|
||||
# Check mask is in [0, 1]
|
||||
if mask.min() < 0 or mask.max() > 1:
|
||||
raise ValueError("Mask should be in [0, 1] range")
|
||||
|
||||
# Binarize mask
|
||||
mask[mask < 0.5] = 0
|
||||
mask[mask >= 0.5] = 1
|
||||
|
||||
# Image as float32
|
||||
image = image.to(dtype=torch.float32)
|
||||
elif isinstance(mask, torch.Tensor):
|
||||
raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
|
||||
else:
|
||||
# preprocess image
|
||||
if isinstance(image, (PIL.Image.Image, np.ndarray)):
|
||||
image = [image]
|
||||
if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
|
||||
# resize all images w.r.t passed height an width
|
||||
image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]
|
||||
image = [np.array(i.convert("RGB"))[None, :] for i in image]
|
||||
image = np.concatenate(image, axis=0)
|
||||
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
|
||||
image = np.concatenate([i[None, :] for i in image], axis=0)
|
||||
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
||||
|
||||
# preprocess mask
|
||||
if isinstance(mask, (PIL.Image.Image, np.ndarray)):
|
||||
mask = [mask]
|
||||
|
||||
if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
|
||||
mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]
|
||||
mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
|
||||
mask = mask.astype(np.float32) / 255.0
|
||||
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
|
||||
mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
|
||||
|
||||
mask[mask < 0.5] = 0
|
||||
mask[mask >= 0.5] = 1
|
||||
mask = torch.from_numpy(mask)
|
||||
|
||||
masked_image = image * (mask < 0.5)
|
||||
|
||||
# n.b. ensure backwards compatibility as old function does not return image
|
||||
if return_image:
|
||||
return mask, masked_image, image
|
||||
|
||||
return mask, masked_image
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
|
||||
@@ -132,124 +132,6 @@ def mask_pil_to_torch(mask, height, width):
|
||||
return mask
|
||||
|
||||
|
||||
def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False):
|
||||
"""
|
||||
Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
|
||||
converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
|
||||
``image`` and ``1`` for the ``mask``.
|
||||
|
||||
The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
|
||||
binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
|
||||
|
||||
Args:
|
||||
image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
|
||||
It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
|
||||
``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
|
||||
mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
|
||||
It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
|
||||
``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
|
||||
should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
|
||||
TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
|
||||
(ot the other way around).
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
|
||||
dimensions: ``batch x channels x height x width``.
|
||||
"""
|
||||
|
||||
# checkpoint. TOD(Yiyi) - need to clean this up later
|
||||
deprecation_message = "The prepare_mask_and_masked_image method is deprecated and will be removed in a future version. Please use VaeImageProcessor.preprocess instead"
|
||||
deprecate(
|
||||
"prepare_mask_and_masked_image",
|
||||
"0.30.0",
|
||||
deprecation_message,
|
||||
)
|
||||
if image is None:
|
||||
raise ValueError("`image` input cannot be undefined.")
|
||||
|
||||
if mask is None:
|
||||
raise ValueError("`mask_image` input cannot be undefined.")
|
||||
|
||||
if isinstance(image, torch.Tensor):
|
||||
if not isinstance(mask, torch.Tensor):
|
||||
mask = mask_pil_to_torch(mask, height, width)
|
||||
|
||||
if image.ndim == 3:
|
||||
image = image.unsqueeze(0)
|
||||
|
||||
# Batch and add channel dim for single mask
|
||||
if mask.ndim == 2:
|
||||
mask = mask.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
# Batch single mask or add channel dim
|
||||
if mask.ndim == 3:
|
||||
# Single batched mask, no channel dim or single mask not batched but channel dim
|
||||
if mask.shape[0] == 1:
|
||||
mask = mask.unsqueeze(0)
|
||||
|
||||
# Batched masks no channel dim
|
||||
else:
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
|
||||
# assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
|
||||
assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
|
||||
|
||||
# Check image is in [-1, 1]
|
||||
# if image.min() < -1 or image.max() > 1:
|
||||
# raise ValueError("Image should be in [-1, 1] range")
|
||||
|
||||
# Check mask is in [0, 1]
|
||||
if mask.min() < 0 or mask.max() > 1:
|
||||
raise ValueError("Mask should be in [0, 1] range")
|
||||
|
||||
# Binarize mask
|
||||
mask[mask < 0.5] = 0
|
||||
mask[mask >= 0.5] = 1
|
||||
|
||||
# Image as float32
|
||||
image = image.to(dtype=torch.float32)
|
||||
elif isinstance(mask, torch.Tensor):
|
||||
raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
|
||||
else:
|
||||
# preprocess image
|
||||
if isinstance(image, (PIL.Image.Image, np.ndarray)):
|
||||
image = [image]
|
||||
if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
|
||||
# resize all images w.r.t passed height an width
|
||||
image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]
|
||||
image = [np.array(i.convert("RGB"))[None, :] for i in image]
|
||||
image = np.concatenate(image, axis=0)
|
||||
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
|
||||
image = np.concatenate([i[None, :] for i in image], axis=0)
|
||||
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
||||
|
||||
mask = mask_pil_to_torch(mask, height, width)
|
||||
mask[mask < 0.5] = 0
|
||||
mask[mask >= 0.5] = 1
|
||||
|
||||
if image.shape[1] == 4:
|
||||
# images are in latent space and thus can't
|
||||
# be masked set masked_image to None
|
||||
# we assume that the checkpoint is not an inpainting
|
||||
# checkpoint. TOD(Yiyi) - need to clean this up later
|
||||
masked_image = None
|
||||
else:
|
||||
masked_image = image * (mask < 0.5)
|
||||
|
||||
# n.b. ensure backwards compatibility as old function does not return image
|
||||
if return_image:
|
||||
return mask, masked_image, image
|
||||
|
||||
return mask, masked_image
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
|
||||
@@ -36,7 +36,6 @@ from diffusers import (
|
||||
StableDiffusionInpaintPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import prepare_mask_and_masked_image
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
@@ -1105,530 +1104,3 @@ class StableDiffusionInpaintPipelineNightlyTests(unittest.TestCase):
|
||||
)
|
||||
max_diff = np.abs(expected_image - image).max()
|
||||
assert max_diff < 1e-3
|
||||
|
||||
|
||||
class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase):
|
||||
def test_pil_inputs(self):
|
||||
height, width = 32, 32
|
||||
im = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8)
|
||||
im = Image.fromarray(im)
|
||||
mask = np.random.randint(0, 255, (height, width), dtype=np.uint8) > 127.5
|
||||
mask = Image.fromarray((mask * 255).astype(np.uint8))
|
||||
|
||||
t_mask, t_masked, t_image = prepare_mask_and_masked_image(im, mask, height, width, return_image=True)
|
||||
|
||||
self.assertTrue(isinstance(t_mask, torch.Tensor))
|
||||
self.assertTrue(isinstance(t_masked, torch.Tensor))
|
||||
self.assertTrue(isinstance(t_image, torch.Tensor))
|
||||
|
||||
self.assertEqual(t_mask.ndim, 4)
|
||||
self.assertEqual(t_masked.ndim, 4)
|
||||
self.assertEqual(t_image.ndim, 4)
|
||||
|
||||
self.assertEqual(t_mask.shape, (1, 1, height, width))
|
||||
self.assertEqual(t_masked.shape, (1, 3, height, width))
|
||||
self.assertEqual(t_image.shape, (1, 3, height, width))
|
||||
|
||||
self.assertTrue(t_mask.dtype == torch.float32)
|
||||
self.assertTrue(t_masked.dtype == torch.float32)
|
||||
self.assertTrue(t_image.dtype == torch.float32)
|
||||
|
||||
self.assertTrue(t_mask.min() >= 0.0)
|
||||
self.assertTrue(t_mask.max() <= 1.0)
|
||||
self.assertTrue(t_masked.min() >= -1.0)
|
||||
self.assertTrue(t_masked.min() <= 1.0)
|
||||
self.assertTrue(t_image.min() >= -1.0)
|
||||
self.assertTrue(t_image.min() >= -1.0)
|
||||
|
||||
self.assertTrue(t_mask.sum() > 0.0)
|
||||
|
||||
def test_np_inputs(self):
|
||||
height, width = 32, 32
|
||||
|
||||
im_np = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8)
|
||||
im_pil = Image.fromarray(im_np)
|
||||
mask_np = (
|
||||
np.random.randint(
|
||||
0,
|
||||
255,
|
||||
(
|
||||
height,
|
||||
width,
|
||||
),
|
||||
dtype=np.uint8,
|
||||
)
|
||||
> 127.5
|
||||
)
|
||||
mask_pil = Image.fromarray((mask_np * 255).astype(np.uint8))
|
||||
|
||||
t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(
|
||||
im_np, mask_np, height, width, return_image=True
|
||||
)
|
||||
t_mask_pil, t_masked_pil, t_image_pil = prepare_mask_and_masked_image(
|
||||
im_pil, mask_pil, height, width, return_image=True
|
||||
)
|
||||
|
||||
self.assertTrue((t_mask_np == t_mask_pil).all())
|
||||
self.assertTrue((t_masked_np == t_masked_pil).all())
|
||||
self.assertTrue((t_image_np == t_image_pil).all())
|
||||
|
||||
def test_torch_3D_2D_inputs(self):
|
||||
height, width = 32, 32
|
||||
|
||||
im_tensor = torch.randint(
|
||||
0,
|
||||
255,
|
||||
(
|
||||
3,
|
||||
height,
|
||||
width,
|
||||
),
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
mask_tensor = (
|
||||
torch.randint(
|
||||
0,
|
||||
255,
|
||||
(
|
||||
height,
|
||||
width,
|
||||
),
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
> 127.5
|
||||
)
|
||||
im_np = im_tensor.numpy().transpose(1, 2, 0)
|
||||
mask_np = mask_tensor.numpy()
|
||||
|
||||
t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
|
||||
im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True
|
||||
)
|
||||
t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(
|
||||
im_np, mask_np, height, width, return_image=True
|
||||
)
|
||||
|
||||
self.assertTrue((t_mask_tensor == t_mask_np).all())
|
||||
self.assertTrue((t_masked_tensor == t_masked_np).all())
|
||||
self.assertTrue((t_image_tensor == t_image_np).all())
|
||||
|
||||
def test_torch_3D_3D_inputs(self):
|
||||
height, width = 32, 32
|
||||
|
||||
im_tensor = torch.randint(
|
||||
0,
|
||||
255,
|
||||
(
|
||||
3,
|
||||
height,
|
||||
width,
|
||||
),
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
mask_tensor = (
|
||||
torch.randint(
|
||||
0,
|
||||
255,
|
||||
(
|
||||
1,
|
||||
height,
|
||||
width,
|
||||
),
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
> 127.5
|
||||
)
|
||||
im_np = im_tensor.numpy().transpose(1, 2, 0)
|
||||
mask_np = mask_tensor.numpy()[0]
|
||||
|
||||
t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
|
||||
im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True
|
||||
)
|
||||
t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(
|
||||
im_np, mask_np, height, width, return_image=True
|
||||
)
|
||||
|
||||
self.assertTrue((t_mask_tensor == t_mask_np).all())
|
||||
self.assertTrue((t_masked_tensor == t_masked_np).all())
|
||||
self.assertTrue((t_image_tensor == t_image_np).all())
|
||||
|
||||
def test_torch_4D_2D_inputs(self):
|
||||
height, width = 32, 32
|
||||
|
||||
im_tensor = torch.randint(
|
||||
0,
|
||||
255,
|
||||
(
|
||||
1,
|
||||
3,
|
||||
height,
|
||||
width,
|
||||
),
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
mask_tensor = (
|
||||
torch.randint(
|
||||
0,
|
||||
255,
|
||||
(
|
||||
height,
|
||||
width,
|
||||
),
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
> 127.5
|
||||
)
|
||||
im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
|
||||
mask_np = mask_tensor.numpy()
|
||||
|
||||
t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
|
||||
im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True
|
||||
)
|
||||
t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(
|
||||
im_np, mask_np, height, width, return_image=True
|
||||
)
|
||||
|
||||
self.assertTrue((t_mask_tensor == t_mask_np).all())
|
||||
self.assertTrue((t_masked_tensor == t_masked_np).all())
|
||||
self.assertTrue((t_image_tensor == t_image_np).all())
|
||||
|
||||
def test_torch_4D_3D_inputs(self):
|
||||
height, width = 32, 32
|
||||
|
||||
im_tensor = torch.randint(
|
||||
0,
|
||||
255,
|
||||
(
|
||||
1,
|
||||
3,
|
||||
height,
|
||||
width,
|
||||
),
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
mask_tensor = (
|
||||
torch.randint(
|
||||
0,
|
||||
255,
|
||||
(
|
||||
1,
|
||||
height,
|
||||
width,
|
||||
),
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
> 127.5
|
||||
)
|
||||
im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
|
||||
mask_np = mask_tensor.numpy()[0]
|
||||
|
||||
t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
|
||||
im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True
|
||||
)
|
||||
t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(
|
||||
im_np, mask_np, height, width, return_image=True
|
||||
)
|
||||
|
||||
self.assertTrue((t_mask_tensor == t_mask_np).all())
|
||||
self.assertTrue((t_masked_tensor == t_masked_np).all())
|
||||
self.assertTrue((t_image_tensor == t_image_np).all())
|
||||
|
||||
def test_torch_4D_4D_inputs(self):
|
||||
height, width = 32, 32
|
||||
|
||||
im_tensor = torch.randint(
|
||||
0,
|
||||
255,
|
||||
(
|
||||
1,
|
||||
3,
|
||||
height,
|
||||
width,
|
||||
),
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
mask_tensor = (
|
||||
torch.randint(
|
||||
0,
|
||||
255,
|
||||
(
|
||||
1,
|
||||
1,
|
||||
height,
|
||||
width,
|
||||
),
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
> 127.5
|
||||
)
|
||||
im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
|
||||
mask_np = mask_tensor.numpy()[0][0]
|
||||
|
||||
t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
|
||||
im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True
|
||||
)
|
||||
t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(
|
||||
im_np, mask_np, height, width, return_image=True
|
||||
)
|
||||
|
||||
self.assertTrue((t_mask_tensor == t_mask_np).all())
|
||||
self.assertTrue((t_masked_tensor == t_masked_np).all())
|
||||
self.assertTrue((t_image_tensor == t_image_np).all())
|
||||
|
||||
def test_torch_batch_4D_3D(self):
|
||||
height, width = 32, 32
|
||||
|
||||
im_tensor = torch.randint(
|
||||
0,
|
||||
255,
|
||||
(
|
||||
2,
|
||||
3,
|
||||
height,
|
||||
width,
|
||||
),
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
mask_tensor = (
|
||||
torch.randint(
|
||||
0,
|
||||
255,
|
||||
(
|
||||
2,
|
||||
height,
|
||||
width,
|
||||
),
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
> 127.5
|
||||
)
|
||||
|
||||
im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor]
|
||||
mask_nps = [mask.numpy() for mask in mask_tensor]
|
||||
|
||||
t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
|
||||
im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True
|
||||
)
|
||||
nps = [prepare_mask_and_masked_image(i, m, height, width, return_image=True) for i, m in zip(im_nps, mask_nps)]
|
||||
t_mask_np = torch.cat([n[0] for n in nps])
|
||||
t_masked_np = torch.cat([n[1] for n in nps])
|
||||
t_image_np = torch.cat([n[2] for n in nps])
|
||||
|
||||
self.assertTrue((t_mask_tensor == t_mask_np).all())
|
||||
self.assertTrue((t_masked_tensor == t_masked_np).all())
|
||||
self.assertTrue((t_image_tensor == t_image_np).all())
|
||||
|
||||
def test_torch_batch_4D_4D(self):
|
||||
height, width = 32, 32
|
||||
|
||||
im_tensor = torch.randint(
|
||||
0,
|
||||
255,
|
||||
(
|
||||
2,
|
||||
3,
|
||||
height,
|
||||
width,
|
||||
),
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
mask_tensor = (
|
||||
torch.randint(
|
||||
0,
|
||||
255,
|
||||
(
|
||||
2,
|
||||
1,
|
||||
height,
|
||||
width,
|
||||
),
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
> 127.5
|
||||
)
|
||||
|
||||
im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor]
|
||||
mask_nps = [mask.numpy()[0] for mask in mask_tensor]
|
||||
|
||||
t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
|
||||
im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True
|
||||
)
|
||||
nps = [prepare_mask_and_masked_image(i, m, height, width, return_image=True) for i, m in zip(im_nps, mask_nps)]
|
||||
t_mask_np = torch.cat([n[0] for n in nps])
|
||||
t_masked_np = torch.cat([n[1] for n in nps])
|
||||
t_image_np = torch.cat([n[2] for n in nps])
|
||||
|
||||
self.assertTrue((t_mask_tensor == t_mask_np).all())
|
||||
self.assertTrue((t_masked_tensor == t_masked_np).all())
|
||||
self.assertTrue((t_image_tensor == t_image_np).all())
|
||||
|
||||
def test_shape_mismatch(self):
|
||||
height, width = 32, 32
|
||||
|
||||
# test height and width
|
||||
with self.assertRaises(AssertionError):
|
||||
prepare_mask_and_masked_image(
|
||||
torch.randn(
|
||||
3,
|
||||
height,
|
||||
width,
|
||||
),
|
||||
torch.randn(64, 64),
|
||||
height,
|
||||
width,
|
||||
return_image=True,
|
||||
)
|
||||
# test batch dim
|
||||
with self.assertRaises(AssertionError):
|
||||
prepare_mask_and_masked_image(
|
||||
torch.randn(
|
||||
2,
|
||||
3,
|
||||
height,
|
||||
width,
|
||||
),
|
||||
torch.randn(4, 64, 64),
|
||||
height,
|
||||
width,
|
||||
return_image=True,
|
||||
)
|
||||
# test batch dim
|
||||
with self.assertRaises(AssertionError):
|
||||
prepare_mask_and_masked_image(
|
||||
torch.randn(
|
||||
2,
|
||||
3,
|
||||
height,
|
||||
width,
|
||||
),
|
||||
torch.randn(4, 1, 64, 64),
|
||||
height,
|
||||
width,
|
||||
return_image=True,
|
||||
)
|
||||
|
||||
def test_type_mismatch(self):
|
||||
height, width = 32, 32
|
||||
|
||||
# test tensors-only
|
||||
with self.assertRaises(TypeError):
|
||||
prepare_mask_and_masked_image(
|
||||
torch.rand(
|
||||
3,
|
||||
height,
|
||||
width,
|
||||
),
|
||||
torch.rand(
|
||||
3,
|
||||
height,
|
||||
width,
|
||||
).numpy(),
|
||||
height,
|
||||
width,
|
||||
return_image=True,
|
||||
)
|
||||
# test tensors-only
|
||||
with self.assertRaises(TypeError):
|
||||
prepare_mask_and_masked_image(
|
||||
torch.rand(
|
||||
3,
|
||||
height,
|
||||
width,
|
||||
).numpy(),
|
||||
torch.rand(
|
||||
3,
|
||||
height,
|
||||
width,
|
||||
),
|
||||
height,
|
||||
width,
|
||||
return_image=True,
|
||||
)
|
||||
|
||||
def test_channels_first(self):
|
||||
height, width = 32, 32
|
||||
|
||||
# test channels first for 3D tensors
|
||||
with self.assertRaises(AssertionError):
|
||||
prepare_mask_and_masked_image(
|
||||
torch.rand(height, width, 3),
|
||||
torch.rand(
|
||||
3,
|
||||
height,
|
||||
width,
|
||||
),
|
||||
height,
|
||||
width,
|
||||
return_image=True,
|
||||
)
|
||||
|
||||
def test_tensor_range(self):
|
||||
height, width = 32, 32
|
||||
|
||||
# test im <= 1
|
||||
with self.assertRaises(ValueError):
|
||||
prepare_mask_and_masked_image(
|
||||
torch.ones(
|
||||
3,
|
||||
height,
|
||||
width,
|
||||
)
|
||||
* 2,
|
||||
torch.rand(
|
||||
height,
|
||||
width,
|
||||
),
|
||||
height,
|
||||
width,
|
||||
return_image=True,
|
||||
)
|
||||
# test im >= -1
|
||||
with self.assertRaises(ValueError):
|
||||
prepare_mask_and_masked_image(
|
||||
torch.ones(
|
||||
3,
|
||||
height,
|
||||
width,
|
||||
)
|
||||
* (-2),
|
||||
torch.rand(
|
||||
height,
|
||||
width,
|
||||
),
|
||||
height,
|
||||
width,
|
||||
return_image=True,
|
||||
)
|
||||
# test mask <= 1
|
||||
with self.assertRaises(ValueError):
|
||||
prepare_mask_and_masked_image(
|
||||
torch.rand(
|
||||
3,
|
||||
height,
|
||||
width,
|
||||
),
|
||||
torch.ones(
|
||||
height,
|
||||
width,
|
||||
)
|
||||
* 2,
|
||||
height,
|
||||
width,
|
||||
return_image=True,
|
||||
)
|
||||
# test mask >= 0
|
||||
with self.assertRaises(ValueError):
|
||||
prepare_mask_and_masked_image(
|
||||
torch.rand(
|
||||
3,
|
||||
height,
|
||||
width,
|
||||
),
|
||||
torch.ones(
|
||||
height,
|
||||
width,
|
||||
)
|
||||
* -1,
|
||||
height,
|
||||
width,
|
||||
return_image=True,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user