mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Adding 'strength' parameter to StableDiffusionInpaintingPipeline (#3424)
* Added explanation of 'strength' parameter * Added get_timesteps function which relies on new strength parameter * Added `strength` parameter which defaults to 1. * Swapped ordering so `noise_timestep` can be calculated before masking the image this is required when you aren't applying 100% noise to the masked region, e.g. strength < 1. * Added strength to check_inputs, throws error if out of range * Changed `prepare_latents` to initialise latents w.r.t strength inspired from the stable diffusion img2img pipeline, init latents are initialised by converting the init image into a VAE latent and adding noise (based upon the strength parameter passed in), e.g. random when strength = 1, or the init image at strength = 0. * WIP: Added a unit test for the new strength parameter in the StableDiffusionInpaintingPipeline still need to add correct regression values * Created a is_strength_max to initialise from pure random noise * Updated unit tests w.r.t new strength parameter + fixed new strength unit test * renamed parameter to avoid confusion with variable of same name * Updated regression values for new strength test - now passes * removed 'copied from' comment as this method is now different and divergent from the cpy * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Ensure backwards compatibility for prepare_mask_and_masked_image created a return_image boolean and initialised to false * Ensure backwards compatibility for prepare_latents * Fixed copy check typo * Fixes w.r.t backward compibility changes * make style * keep function argument ordering same for backwards compatibility in callees with copied from statements * make fix-copies --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: William Berman <WLBberman@gmail.com>
This commit is contained in:
@@ -99,7 +99,7 @@ EXAMPLE_DOC_STRING = """
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.prepare_mask_and_masked_image
|
||||
def prepare_mask_and_masked_image(image, mask, height, width):
|
||||
def prepare_mask_and_masked_image(image, mask, height, width, return_image=False):
|
||||
"""
|
||||
Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
|
||||
converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
|
||||
@@ -209,6 +209,10 @@ def prepare_mask_and_masked_image(image, mask, height, width):
|
||||
|
||||
masked_image = image * (mask < 0.5)
|
||||
|
||||
# n.b. ensure backwards compatibility as old function does not return image
|
||||
if return_image:
|
||||
return mask, masked_image, image
|
||||
|
||||
return mask, masked_image
|
||||
|
||||
|
||||
@@ -795,7 +799,20 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi
|
||||
return image
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_latents
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
image=None,
|
||||
timestep=None,
|
||||
is_strength_max=True,
|
||||
):
|
||||
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
@@ -803,13 +820,37 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if (image is None or timestep is None) and not is_strength_max:
|
||||
raise ValueError(
|
||||
"Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
|
||||
"However, either the image or the noise timestep has not been provided."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
if is_strength_max:
|
||||
# if strength is 100% then simply initialise the latents to noise
|
||||
latents = noise
|
||||
else:
|
||||
# otherwise initialise latents as init image + noise
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
if isinstance(generator, list):
|
||||
image_latents = [
|
||||
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
|
||||
for i in range(batch_size)
|
||||
]
|
||||
else:
|
||||
image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)
|
||||
|
||||
image_latents = self.vae.config.scaling_factor * image_latents
|
||||
|
||||
latents = self.scheduler.add_noise(image_latents, noise, timestep)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
|
||||
return latents
|
||||
|
||||
def _default_height_width(self, height, width, image):
|
||||
|
||||
@@ -36,7 +36,7 @@ from .safety_checker import StableDiffusionSafetyChecker
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def prepare_mask_and_masked_image(image, mask, height, width):
|
||||
def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False):
|
||||
"""
|
||||
Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
|
||||
converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
|
||||
@@ -146,6 +146,10 @@ def prepare_mask_and_masked_image(image, mask, height, width):
|
||||
|
||||
masked_image = image * (mask < 0.5)
|
||||
|
||||
# n.b. ensure backwards compatibility as old function does not return image
|
||||
if return_image:
|
||||
return mask, masked_image, image
|
||||
|
||||
return mask, masked_image
|
||||
|
||||
|
||||
@@ -552,17 +556,20 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMi
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
return image
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
strength,
|
||||
callback_steps,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
):
|
||||
if strength < 0 or strength > 1:
|
||||
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
||||
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
@@ -600,8 +607,20 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMi
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
image=None,
|
||||
timestep=None,
|
||||
is_strength_max=True,
|
||||
):
|
||||
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
@@ -609,13 +628,37 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMi
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if (image is None or timestep is None) and not is_strength_max:
|
||||
raise ValueError(
|
||||
"Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
|
||||
"However, either the image or the noise timestep has not been provided."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
if is_strength_max:
|
||||
# if strength is 100% then simply initialise the latents to noise
|
||||
latents = noise
|
||||
else:
|
||||
# otherwise initialise latents as init image + noise
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
if isinstance(generator, list):
|
||||
image_latents = [
|
||||
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
|
||||
for i in range(batch_size)
|
||||
]
|
||||
else:
|
||||
image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)
|
||||
|
||||
image_latents = self.vae.config.scaling_factor * image_latents
|
||||
|
||||
latents = self.scheduler.add_noise(image_latents, noise, timestep)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
|
||||
return latents
|
||||
|
||||
def prepare_mask_latents(
|
||||
@@ -669,6 +712,16 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMi
|
||||
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
|
||||
return mask, masked_image_latents
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
|
||||
def get_timesteps(self, num_inference_steps, strength, device):
|
||||
# get the original timestep using init_timestep
|
||||
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
@@ -677,6 +730,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMi
|
||||
mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
strength: float = 1.0,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
@@ -710,6 +764,13 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMi
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image.
|
||||
strength (`float`, *optional*, defaults to 1.):
|
||||
Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
|
||||
between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
|
||||
`strength`. The number of denoising steps depends on the amount of noise initially added. When
|
||||
`strength` is 1, added noise will be maximum and the denoising process will run for the full number of
|
||||
iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked
|
||||
portion of the reference `image`.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
@@ -802,6 +863,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMi
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
strength,
|
||||
callback_steps,
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
@@ -833,12 +895,20 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMi
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
)
|
||||
|
||||
# 4. Preprocess mask and image - resizes image and mask w.r.t height and width
|
||||
mask, masked_image = prepare_mask_and_masked_image(image, mask_image, height, width)
|
||||
|
||||
# 5. set timesteps
|
||||
# 4. set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
timesteps, num_inference_steps = self.get_timesteps(
|
||||
num_inference_steps=num_inference_steps, strength=strength, device=device
|
||||
)
|
||||
# at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
||||
# create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
|
||||
is_strength_max = strength == 1.0
|
||||
|
||||
# 5. Preprocess mask and image
|
||||
mask, masked_image, init_image = prepare_mask_and_masked_image(
|
||||
image, mask_image, height, width, return_image=True
|
||||
)
|
||||
|
||||
# 6. Prepare latent variables
|
||||
num_channels_latents = self.vae.config.latent_channels
|
||||
@@ -851,6 +921,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMi
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
image=init_image,
|
||||
timestep=latent_timestep,
|
||||
is_strength_max=is_strength_max,
|
||||
)
|
||||
|
||||
# 7. Prepare mask latent variables
|
||||
|
||||
@@ -324,6 +324,26 @@ class StableDiffusionInpaintPipelineSlowTests(unittest.TestCase):
|
||||
# verify that the returned image has the same height and width as the input height and width
|
||||
assert image.shape == (1, inputs["height"], inputs["width"], 3)
|
||||
|
||||
def test_stable_diffusion_inpaint_strength_test(self):
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting", safety_checker=None
|
||||
)
|
||||
pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
# change input strength
|
||||
inputs["strength"] = 0.75
|
||||
image = pipe(**inputs).images
|
||||
# verify that the returned image has the same height and width as the input height and width
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
|
||||
image_slice = image[0, 253:256, 253:256, -1].flatten()
|
||||
expected_slice = np.array([0.0021, 0.2350, 0.3712, 0.0575, 0.2485, 0.3451, 0.1857, 0.3156, 0.3943])
|
||||
assert np.abs(expected_slice - image_slice).max() < 3e-3
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_gpu
|
||||
@@ -427,24 +447,30 @@ class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase)
|
||||
mask = np.random.randint(0, 255, (height, width), dtype=np.uint8) > 127.5
|
||||
mask = Image.fromarray((mask * 255).astype(np.uint8))
|
||||
|
||||
t_mask, t_masked = prepare_mask_and_masked_image(im, mask, height, width)
|
||||
t_mask, t_masked, t_image = prepare_mask_and_masked_image(im, mask, height, width, return_image=True)
|
||||
|
||||
self.assertTrue(isinstance(t_mask, torch.Tensor))
|
||||
self.assertTrue(isinstance(t_masked, torch.Tensor))
|
||||
self.assertTrue(isinstance(t_image, torch.Tensor))
|
||||
|
||||
self.assertEqual(t_mask.ndim, 4)
|
||||
self.assertEqual(t_masked.ndim, 4)
|
||||
self.assertEqual(t_image.ndim, 4)
|
||||
|
||||
self.assertEqual(t_mask.shape, (1, 1, height, width))
|
||||
self.assertEqual(t_masked.shape, (1, 3, height, width))
|
||||
self.assertEqual(t_image.shape, (1, 3, height, width))
|
||||
|
||||
self.assertTrue(t_mask.dtype == torch.float32)
|
||||
self.assertTrue(t_masked.dtype == torch.float32)
|
||||
self.assertTrue(t_image.dtype == torch.float32)
|
||||
|
||||
self.assertTrue(t_mask.min() >= 0.0)
|
||||
self.assertTrue(t_mask.max() <= 1.0)
|
||||
self.assertTrue(t_masked.min() >= -1.0)
|
||||
self.assertTrue(t_masked.min() <= 1.0)
|
||||
self.assertTrue(t_image.min() >= -1.0)
|
||||
self.assertTrue(t_image.min() >= -1.0)
|
||||
|
||||
self.assertTrue(t_mask.sum() > 0.0)
|
||||
|
||||
@@ -467,11 +493,16 @@ class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase)
|
||||
)
|
||||
mask_pil = Image.fromarray((mask_np * 255).astype(np.uint8))
|
||||
|
||||
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width)
|
||||
t_mask_pil, t_masked_pil = prepare_mask_and_masked_image(im_pil, mask_pil, height, width)
|
||||
t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(
|
||||
im_np, mask_np, height, width, return_image=True
|
||||
)
|
||||
t_mask_pil, t_masked_pil, t_image_pil = prepare_mask_and_masked_image(
|
||||
im_pil, mask_pil, height, width, return_image=True
|
||||
)
|
||||
|
||||
self.assertTrue((t_mask_np == t_mask_pil).all())
|
||||
self.assertTrue((t_masked_np == t_masked_pil).all())
|
||||
self.assertTrue((t_image_np == t_image_pil).all())
|
||||
|
||||
def test_torch_3D_2D_inputs(self):
|
||||
height, width = 32, 32
|
||||
@@ -501,13 +532,16 @@ class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase)
|
||||
im_np = im_tensor.numpy().transpose(1, 2, 0)
|
||||
mask_np = mask_tensor.numpy()
|
||||
|
||||
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(
|
||||
im_tensor / 127.5 - 1, mask_tensor, height, width
|
||||
t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
|
||||
im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True
|
||||
)
|
||||
t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(
|
||||
im_np, mask_np, height, width, return_image=True
|
||||
)
|
||||
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width)
|
||||
|
||||
self.assertTrue((t_mask_tensor == t_mask_np).all())
|
||||
self.assertTrue((t_masked_tensor == t_masked_np).all())
|
||||
self.assertTrue((t_image_tensor == t_image_np).all())
|
||||
|
||||
def test_torch_3D_3D_inputs(self):
|
||||
height, width = 32, 32
|
||||
@@ -538,13 +572,16 @@ class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase)
|
||||
im_np = im_tensor.numpy().transpose(1, 2, 0)
|
||||
mask_np = mask_tensor.numpy()[0]
|
||||
|
||||
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(
|
||||
im_tensor / 127.5 - 1, mask_tensor, height, width
|
||||
t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
|
||||
im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True
|
||||
)
|
||||
t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(
|
||||
im_np, mask_np, height, width, return_image=True
|
||||
)
|
||||
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width)
|
||||
|
||||
self.assertTrue((t_mask_tensor == t_mask_np).all())
|
||||
self.assertTrue((t_masked_tensor == t_masked_np).all())
|
||||
self.assertTrue((t_image_tensor == t_image_np).all())
|
||||
|
||||
def test_torch_4D_2D_inputs(self):
|
||||
height, width = 32, 32
|
||||
@@ -575,13 +612,16 @@ class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase)
|
||||
im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
|
||||
mask_np = mask_tensor.numpy()
|
||||
|
||||
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(
|
||||
im_tensor / 127.5 - 1, mask_tensor, height, width
|
||||
t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
|
||||
im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True
|
||||
)
|
||||
t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(
|
||||
im_np, mask_np, height, width, return_image=True
|
||||
)
|
||||
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width)
|
||||
|
||||
self.assertTrue((t_mask_tensor == t_mask_np).all())
|
||||
self.assertTrue((t_masked_tensor == t_masked_np).all())
|
||||
self.assertTrue((t_image_tensor == t_image_np).all())
|
||||
|
||||
def test_torch_4D_3D_inputs(self):
|
||||
height, width = 32, 32
|
||||
@@ -613,13 +653,16 @@ class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase)
|
||||
im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
|
||||
mask_np = mask_tensor.numpy()[0]
|
||||
|
||||
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(
|
||||
im_tensor / 127.5 - 1, mask_tensor, height, width
|
||||
t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
|
||||
im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True
|
||||
)
|
||||
t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(
|
||||
im_np, mask_np, height, width, return_image=True
|
||||
)
|
||||
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width)
|
||||
|
||||
self.assertTrue((t_mask_tensor == t_mask_np).all())
|
||||
self.assertTrue((t_masked_tensor == t_masked_np).all())
|
||||
self.assertTrue((t_image_tensor == t_image_np).all())
|
||||
|
||||
def test_torch_4D_4D_inputs(self):
|
||||
height, width = 32, 32
|
||||
@@ -652,13 +695,16 @@ class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase)
|
||||
im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
|
||||
mask_np = mask_tensor.numpy()[0][0]
|
||||
|
||||
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(
|
||||
im_tensor / 127.5 - 1, mask_tensor, height, width
|
||||
t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
|
||||
im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True
|
||||
)
|
||||
t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(
|
||||
im_np, mask_np, height, width, return_image=True
|
||||
)
|
||||
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width)
|
||||
|
||||
self.assertTrue((t_mask_tensor == t_mask_np).all())
|
||||
self.assertTrue((t_masked_tensor == t_masked_np).all())
|
||||
self.assertTrue((t_image_tensor == t_image_np).all())
|
||||
|
||||
def test_torch_batch_4D_3D(self):
|
||||
height, width = 32, 32
|
||||
@@ -691,15 +737,17 @@ class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase)
|
||||
im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor]
|
||||
mask_nps = [mask.numpy() for mask in mask_tensor]
|
||||
|
||||
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(
|
||||
im_tensor / 127.5 - 1, mask_tensor, height, width
|
||||
t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
|
||||
im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True
|
||||
)
|
||||
nps = [prepare_mask_and_masked_image(i, m, height, width) for i, m in zip(im_nps, mask_nps)]
|
||||
nps = [prepare_mask_and_masked_image(i, m, height, width, return_image=True) for i, m in zip(im_nps, mask_nps)]
|
||||
t_mask_np = torch.cat([n[0] for n in nps])
|
||||
t_masked_np = torch.cat([n[1] for n in nps])
|
||||
t_image_np = torch.cat([n[2] for n in nps])
|
||||
|
||||
self.assertTrue((t_mask_tensor == t_mask_np).all())
|
||||
self.assertTrue((t_masked_tensor == t_masked_np).all())
|
||||
self.assertTrue((t_image_tensor == t_image_np).all())
|
||||
|
||||
def test_torch_batch_4D_4D(self):
|
||||
height, width = 32, 32
|
||||
@@ -733,15 +781,17 @@ class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase)
|
||||
im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor]
|
||||
mask_nps = [mask.numpy()[0] for mask in mask_tensor]
|
||||
|
||||
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(
|
||||
im_tensor / 127.5 - 1, mask_tensor, height, width
|
||||
t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
|
||||
im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True
|
||||
)
|
||||
nps = [prepare_mask_and_masked_image(i, m, height, width) for i, m in zip(im_nps, mask_nps)]
|
||||
nps = [prepare_mask_and_masked_image(i, m, height, width, return_image=True) for i, m in zip(im_nps, mask_nps)]
|
||||
t_mask_np = torch.cat([n[0] for n in nps])
|
||||
t_masked_np = torch.cat([n[1] for n in nps])
|
||||
t_image_np = torch.cat([n[2] for n in nps])
|
||||
|
||||
self.assertTrue((t_mask_tensor == t_mask_np).all())
|
||||
self.assertTrue((t_masked_tensor == t_masked_np).all())
|
||||
self.assertTrue((t_image_tensor == t_image_np).all())
|
||||
|
||||
def test_shape_mismatch(self):
|
||||
height, width = 32, 32
|
||||
@@ -757,6 +807,7 @@ class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase)
|
||||
torch.randn(64, 64),
|
||||
height,
|
||||
width,
|
||||
return_image=True,
|
||||
)
|
||||
# test batch dim
|
||||
with self.assertRaises(AssertionError):
|
||||
@@ -770,6 +821,7 @@ class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase)
|
||||
torch.randn(4, 64, 64),
|
||||
height,
|
||||
width,
|
||||
return_image=True,
|
||||
)
|
||||
# test batch dim
|
||||
with self.assertRaises(AssertionError):
|
||||
@@ -783,6 +835,7 @@ class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase)
|
||||
torch.randn(4, 1, 64, 64),
|
||||
height,
|
||||
width,
|
||||
return_image=True,
|
||||
)
|
||||
|
||||
def test_type_mismatch(self):
|
||||
@@ -803,6 +856,7 @@ class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase)
|
||||
).numpy(),
|
||||
height,
|
||||
width,
|
||||
return_image=True,
|
||||
)
|
||||
# test tensors-only
|
||||
with self.assertRaises(TypeError):
|
||||
@@ -819,6 +873,7 @@ class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase)
|
||||
),
|
||||
height,
|
||||
width,
|
||||
return_image=True,
|
||||
)
|
||||
|
||||
def test_channels_first(self):
|
||||
@@ -835,6 +890,7 @@ class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase)
|
||||
),
|
||||
height,
|
||||
width,
|
||||
return_image=True,
|
||||
)
|
||||
|
||||
def test_tensor_range(self):
|
||||
@@ -855,6 +911,7 @@ class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase)
|
||||
),
|
||||
height,
|
||||
width,
|
||||
return_image=True,
|
||||
)
|
||||
# test im >= -1
|
||||
with self.assertRaises(ValueError):
|
||||
@@ -871,6 +928,7 @@ class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase)
|
||||
),
|
||||
height,
|
||||
width,
|
||||
return_image=True,
|
||||
)
|
||||
# test mask <= 1
|
||||
with self.assertRaises(ValueError):
|
||||
@@ -887,6 +945,7 @@ class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase)
|
||||
* 2,
|
||||
height,
|
||||
width,
|
||||
return_image=True,
|
||||
)
|
||||
# test mask >= 0
|
||||
with self.assertRaises(ValueError):
|
||||
@@ -903,4 +962,5 @@ class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase)
|
||||
* -1,
|
||||
height,
|
||||
width,
|
||||
return_image=True,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user