diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 67ce572ea4..518a9a3e97 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -120,7 +120,7 @@ def prepare_mask_and_masked_image(image, mask, height, width): image = [image] if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): # resize all images w.r.t passed height an width - image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image] + image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image] image = [np.array(i.convert("RGB"))[None, :] for i in image] image = np.concatenate(image, axis=0) elif isinstance(image, list) and isinstance(image[0], np.ndarray): @@ -134,7 +134,7 @@ def prepare_mask_and_masked_image(image, mask, height, width): mask = [mask] if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): - mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] + mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) mask = mask.astype(np.float32) / 255.0 elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index 497d9e5367..93c3f7ec20 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -304,23 +304,23 @@ class StableDiffusionInpaintPipelineSlowTests(unittest.TestCase): assert np.abs(expected_slice - image_slice).max() < 1e-3 def test_stable_diffusion_inpaint_pil_input_resolution_test(self): - pipe = StableDiffusionInpaintPipeline.from_pretrained( - "runwayml/stable-diffusion-inpainting", safety_checker=None - ) - pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config) - pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - pipe.enable_attention_slicing() + pipe = StableDiffusionInpaintPipeline.from_pretrained( + "runwayml/stable-diffusion-inpainting", safety_checker=None + ) + pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() - inputs = self.get_inputs(torch_device) - # change input image to a random size (one that would cause a tensor mismatch error) - inputs['image'] = inputs['image'].resize((127,127)) - inputs['mask_image'] = inputs['mask_image'].resize((127,127)) - inputs['height'] = 128 - inputs['width'] = 128 - image = pipe(**inputs).images - # verify that the returned image has the same height and width as the input height and width - assert image.shape == (1, inputs['height'], inputs['width'], 3) + inputs = self.get_inputs(torch_device) + # change input image to a random size (one that would cause a tensor mismatch error) + inputs["image"] = inputs["image"].resize((127, 127)) + inputs["mask_image"] = inputs["mask_image"].resize((127, 127)) + inputs["height"] = 128 + inputs["width"] = 128 + image = pipe(**inputs).images + # verify that the returned image has the same height and width as the input height and width + assert image.shape == (1, inputs["height"], inputs["width"], 3) @nightly @@ -451,7 +451,18 @@ class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase) im_np = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8) im_pil = Image.fromarray(im_np) - mask_np = np.random.randint(0, 255, (height, width,), dtype=np.uint8) > 127.5 + mask_np = ( + np.random.randint( + 0, + 255, + ( + height, + width, + ), + dtype=np.uint8, + ) + > 127.5 + ) mask_pil = Image.fromarray((mask_np * 255).astype(np.uint8)) t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width) @@ -463,12 +474,34 @@ class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase) def test_torch_3D_2D_inputs(self): height, width = 32, 32 - im_tensor = torch.randint(0, 255, (3, height, width,), dtype=torch.uint8) - mask_tensor = torch.randint(0, 255, (height, width,), dtype=torch.uint8) > 127.5 + im_tensor = torch.randint( + 0, + 255, + ( + 3, + height, + width, + ), + dtype=torch.uint8, + ) + mask_tensor = ( + torch.randint( + 0, + 255, + ( + height, + width, + ), + dtype=torch.uint8, + ) + > 127.5 + ) im_np = im_tensor.numpy().transpose(1, 2, 0) mask_np = mask_tensor.numpy() - t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width) + t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image( + im_tensor / 127.5 - 1, mask_tensor, height, width + ) t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width) self.assertTrue((t_mask_tensor == t_mask_np).all()) @@ -477,12 +510,35 @@ class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase) def test_torch_3D_3D_inputs(self): height, width = 32, 32 - im_tensor = torch.randint(0, 255, (3, height, width,), dtype=torch.uint8) - mask_tensor = torch.randint(0, 255, (1, height, width,), dtype=torch.uint8) > 127.5 + im_tensor = torch.randint( + 0, + 255, + ( + 3, + height, + width, + ), + dtype=torch.uint8, + ) + mask_tensor = ( + torch.randint( + 0, + 255, + ( + 1, + height, + width, + ), + dtype=torch.uint8, + ) + > 127.5 + ) im_np = im_tensor.numpy().transpose(1, 2, 0) mask_np = mask_tensor.numpy()[0] - t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width) + t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image( + im_tensor / 127.5 - 1, mask_tensor, height, width + ) t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width) self.assertTrue((t_mask_tensor == t_mask_np).all()) @@ -491,12 +547,35 @@ class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase) def test_torch_4D_2D_inputs(self): height, width = 32, 32 - im_tensor = torch.randint(0, 255, (1, 3, height, width,), dtype=torch.uint8) - mask_tensor = torch.randint(0, 255, (height, width,), dtype=torch.uint8) > 127.5 + im_tensor = torch.randint( + 0, + 255, + ( + 1, + 3, + height, + width, + ), + dtype=torch.uint8, + ) + mask_tensor = ( + torch.randint( + 0, + 255, + ( + height, + width, + ), + dtype=torch.uint8, + ) + > 127.5 + ) im_np = im_tensor.numpy()[0].transpose(1, 2, 0) mask_np = mask_tensor.numpy() - t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width) + t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image( + im_tensor / 127.5 - 1, mask_tensor, height, width + ) t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width) self.assertTrue((t_mask_tensor == t_mask_np).all()) @@ -505,12 +584,36 @@ class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase) def test_torch_4D_3D_inputs(self): height, width = 32, 32 - im_tensor = torch.randint(0, 255, (1, 3, height, width,), dtype=torch.uint8) - mask_tensor = torch.randint(0, 255, (1, height, width,), dtype=torch.uint8) > 127.5 + im_tensor = torch.randint( + 0, + 255, + ( + 1, + 3, + height, + width, + ), + dtype=torch.uint8, + ) + mask_tensor = ( + torch.randint( + 0, + 255, + ( + 1, + height, + width, + ), + dtype=torch.uint8, + ) + > 127.5 + ) im_np = im_tensor.numpy()[0].transpose(1, 2, 0) mask_np = mask_tensor.numpy()[0] - t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width) + t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image( + im_tensor / 127.5 - 1, mask_tensor, height, width + ) t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width) self.assertTrue((t_mask_tensor == t_mask_np).all()) @@ -519,12 +622,37 @@ class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase) def test_torch_4D_4D_inputs(self): height, width = 32, 32 - im_tensor = torch.randint(0, 255, (1, 3, height, width,), dtype=torch.uint8) - mask_tensor = torch.randint(0, 255, (1, 1, height, width,), dtype=torch.uint8) > 127.5 + im_tensor = torch.randint( + 0, + 255, + ( + 1, + 3, + height, + width, + ), + dtype=torch.uint8, + ) + mask_tensor = ( + torch.randint( + 0, + 255, + ( + 1, + 1, + height, + width, + ), + dtype=torch.uint8, + ) + > 127.5 + ) im_np = im_tensor.numpy()[0].transpose(1, 2, 0) mask_np = mask_tensor.numpy()[0][0] - t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width) + t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image( + im_tensor / 127.5 - 1, mask_tensor, height, width + ) t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width) self.assertTrue((t_mask_tensor == t_mask_np).all()) @@ -533,13 +661,37 @@ class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase) def test_torch_batch_4D_3D(self): height, width = 32, 32 - im_tensor = torch.randint(0, 255, (2, 3, height, width,), dtype=torch.uint8) - mask_tensor = torch.randint(0, 255, (2, height, width,), dtype=torch.uint8) > 127.5 + im_tensor = torch.randint( + 0, + 255, + ( + 2, + 3, + height, + width, + ), + dtype=torch.uint8, + ) + mask_tensor = ( + torch.randint( + 0, + 255, + ( + 2, + height, + width, + ), + dtype=torch.uint8, + ) + > 127.5 + ) im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor] mask_nps = [mask.numpy() for mask in mask_tensor] - t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width) + t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image( + im_tensor / 127.5 - 1, mask_tensor, height, width + ) nps = [prepare_mask_and_masked_image(i, m, height, width) for i, m in zip(im_nps, mask_nps)] t_mask_np = torch.cat([n[0] for n in nps]) t_masked_np = torch.cat([n[1] for n in nps]) @@ -550,13 +702,38 @@ class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase) def test_torch_batch_4D_4D(self): height, width = 32, 32 - im_tensor = torch.randint(0, 255, (2, 3, height, width,), dtype=torch.uint8) - mask_tensor = torch.randint(0, 255, (2, 1, height, width,), dtype=torch.uint8) > 127.5 + im_tensor = torch.randint( + 0, + 255, + ( + 2, + 3, + height, + width, + ), + dtype=torch.uint8, + ) + mask_tensor = ( + torch.randint( + 0, + 255, + ( + 2, + 1, + height, + width, + ), + dtype=torch.uint8, + ) + > 127.5 + ) im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor] mask_nps = [mask.numpy()[0] for mask in mask_tensor] - t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width) + t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image( + im_tensor / 127.5 - 1, mask_tensor, height, width + ) nps = [prepare_mask_and_masked_image(i, m, height, width) for i, m in zip(im_nps, mask_nps)] t_mask_np = torch.cat([n[0] for n in nps]) t_masked_np = torch.cat([n[1] for n in nps]) @@ -569,43 +746,159 @@ class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase) # test height and width with self.assertRaises(AssertionError): - prepare_mask_and_masked_image(torch.randn(3, height, width,), torch.randn(64, 64), height, width) + prepare_mask_and_masked_image( + torch.randn( + 3, + height, + width, + ), + torch.randn(64, 64), + height, + width, + ) # test batch dim with self.assertRaises(AssertionError): - prepare_mask_and_masked_image(torch.randn(2, 3, height, width,), torch.randn(4, 64, 64), height, width) + prepare_mask_and_masked_image( + torch.randn( + 2, + 3, + height, + width, + ), + torch.randn(4, 64, 64), + height, + width, + ) # test batch dim with self.assertRaises(AssertionError): - prepare_mask_and_masked_image(torch.randn(2, 3, height, width,), torch.randn(4, 1, 64, 64), height, width) + prepare_mask_and_masked_image( + torch.randn( + 2, + 3, + height, + width, + ), + torch.randn(4, 1, 64, 64), + height, + width, + ) def test_type_mismatch(self): height, width = 32, 32 # test tensors-only with self.assertRaises(TypeError): - prepare_mask_and_masked_image(torch.rand(3, height, width,), torch.rand(3, height, width,).numpy(), height, width) + prepare_mask_and_masked_image( + torch.rand( + 3, + height, + width, + ), + torch.rand( + 3, + height, + width, + ).numpy(), + height, + width, + ) # test tensors-only with self.assertRaises(TypeError): - prepare_mask_and_masked_image(torch.rand(3, height, width,).numpy(), torch.rand(3, height, width,), height, width) + prepare_mask_and_masked_image( + torch.rand( + 3, + height, + width, + ).numpy(), + torch.rand( + 3, + height, + width, + ), + height, + width, + ) def test_channels_first(self): height, width = 32, 32 # test channels first for 3D tensors with self.assertRaises(AssertionError): - prepare_mask_and_masked_image(torch.rand(height, width, 3), torch.rand(3, height, width,), height, width) + prepare_mask_and_masked_image( + torch.rand(height, width, 3), + torch.rand( + 3, + height, + width, + ), + height, + width, + ) def test_tensor_range(self): height, width = 32, 32 # test im <= 1 with self.assertRaises(ValueError): - prepare_mask_and_masked_image(torch.ones(3, height, width,) * 2, torch.rand(height, width,), height, width) + prepare_mask_and_masked_image( + torch.ones( + 3, + height, + width, + ) + * 2, + torch.rand( + height, + width, + ), + height, + width, + ) # test im >= -1 with self.assertRaises(ValueError): - prepare_mask_and_masked_image(torch.ones(3, height, width,) * (-2), torch.rand(height, width,), height, width) + prepare_mask_and_masked_image( + torch.ones( + 3, + height, + width, + ) + * (-2), + torch.rand( + height, + width, + ), + height, + width, + ) # test mask <= 1 with self.assertRaises(ValueError): - prepare_mask_and_masked_image(torch.rand(3, height, width,), torch.ones(height, width,) * 2, height, width) + prepare_mask_and_masked_image( + torch.rand( + 3, + height, + width, + ), + torch.ones( + height, + width, + ) + * 2, + height, + width, + ) # test mask >= 0 with self.assertRaises(ValueError): - prepare_mask_and_masked_image(torch.rand(3, height, width,), torch.ones(height, width,) * -1, height, width) + prepare_mask_and_masked_image( + torch.rand( + 3, + height, + width, + ), + torch.ones( + height, + width, + ) + * -1, + height, + width, + )