1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Breaking change] fix legacy inpaint noise and resize mask tensor (#2147)

* fix legacy inpaint noise and resize mask tensor

* updated legacy inpaint pipe test expected_slice
This commit is contained in:
1lint
2023-01-31 03:44:35 -08:00
committed by GitHub
parent 7d96b38b70
commit d1efefe15e
2 changed files with 37 additions and 17 deletions

View File

@@ -45,16 +45,34 @@ def preprocess_image(image):
def preprocess_mask(mask, scale_factor=8):
mask = mask.convert("L")
w, h = mask.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
mask = 1 - mask # repaint white, keep black
mask = torch.from_numpy(mask)
return mask
if not isinstance(mask, torch.FloatTensor):
mask = mask.convert("L")
w, h = mask.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
mask = 1 - mask # repaint white, keep black
mask = torch.from_numpy(mask)
return mask
else:
valid_mask_channel_sizes = [1, 3]
# if mask channel is fourth tensor dimension, permute dimensions to pytorch standard (B, C, H, W)
if mask.shape[3] in valid_mask_channel_sizes:
mask = mask.permute(0, 3, 1, 2)
elif mask.shape[1] not in valid_mask_channel_sizes:
raise ValueError(
f"Mask channel dimension of size in {valid_mask_channel_sizes} should be second or fourth dimension, but received mask of shape {tuple(mask.shape)}"
)
# (potentially) reduce mask channel dimension from 3 to 1 for broadcasting to latent shape
mask = mask.mean(dim=1, keepdim=True)
h, w = mask.shape[-2:]
h, w = map(lambda x: x - x % 32, (h, w)) # resize to integer multiple of 32
mask = torch.nn.functional.interpolate(mask, (h // scale_factor, w // scale_factor))
return mask
class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
@@ -497,8 +515,8 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
PIL image, it will be converted to a single channel (luminance) before use. If mask is a tensor, the
expected shape should be either `(B, H, W, C)` or `(B, C, H, W)`, where C is 1 or 3.
strength (`float`, *optional*, defaults to 0.8):
Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
is 1, the denoising process will be run on the masked area for the full number of iterations specified
@@ -585,8 +603,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
if not isinstance(image, torch.FloatTensor):
image = preprocess_image(image)
if not isinstance(mask_image, torch.FloatTensor):
mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
# 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
@@ -640,6 +657,9 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# use original latents corresponding to unmasked portions of the image
latents = (init_latents_orig * mask) + (latents * (1 - mask))
# 10. Post-processing
image = self.decode_latents(latents)

View File

@@ -212,8 +212,8 @@ class StableDiffusionInpaintLegacyPipelineFastTests(unittest.TestCase):
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.4731, 0.5346, 0.4531, 0.6251, 0.5446, 0.4057, 0.5527, 0.5896, 0.5153])
expected_slice = np.array([0.4941, 0.5396, 0.4689, 0.6338, 0.5392, 0.4094, 0.5477, 0.5904, 0.5165])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
@@ -260,7 +260,7 @@ class StableDiffusionInpaintLegacyPipelineFastTests(unittest.TestCase):
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.4765, 0.5339, 0.4541, 0.6240, 0.5439, 0.4055, 0.5503, 0.5891, 0.5150])
expected_slice = np.array([0.4941, 0.5396, 0.4689, 0.6338, 0.5392, 0.4094, 0.5477, 0.5904, 0.5165])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2