1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

ONNX supervised inpainting (#906)

* ONNX supervised inpainting

* sync with the torch pipeline

* fix concat

* update ref values

* back to 8 steps

* type fix

* make fix-copies
This commit is contained in:
Anton Lozhkov
2022-10-19 17:03:31 +02:00
committed by GitHub
parent 46557121e6
commit 89d124945a
4 changed files with 111 additions and 90 deletions

View File

@@ -99,7 +99,12 @@ def convert_models(model_path: str, output_path: str, opset: int):
unet_path = output_path / "unet" / "model.onnx"
onnx_export(
pipeline.unet,
model_args=(torch.randn(2, 4, 64, 64), torch.LongTensor([0, 1]), torch.randn(2, 77, 768), False),
model_args=(
torch.randn(2, pipeline.unet.in_channels, 64, 64),
torch.LongTensor([0, 1]),
torch.randn(2, 77, 768),
False,
),
output_path=unet_path,
ordered_input_names=["sample", "timestep", "encoder_hidden_states", "return_dict"],
output_names=["out_sample"], # has to be different from "sample" for correct tracing

View File

@@ -5,7 +5,6 @@ import numpy as np
import torch
import PIL
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTokenizer
from ...configuration_utils import FrozenDict
@@ -16,28 +15,29 @@ from ...utils import deprecate, logging
from . import StableDiffusionPipelineOutput
logger = logging.get_logger(__name__)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def preprocess_image(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
NUM_UNET_INPUT_CHANNELS = 9
NUM_LATENT_CHANNELS = 4
def prepare_mask_and_masked_image(image, mask, latents_shape):
image = np.array(image.convert("RGB"))
image = image[None].transpose(0, 3, 1, 2)
return 2.0 * image - 1.0
image = image.astype(np.float32) / 127.5 - 1.0
image_mask = np.array(mask.convert("L"))
masked_image = image * (image_mask < 127.5)
def preprocess_mask(mask):
mask = mask.convert("L")
w, h = mask.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST)
mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
mask = 1 - mask # repaint white, keep black
return mask
mask = mask.resize((latents_shape[1], latents_shape[0]), PIL.Image.NEAREST)
mask = np.array(mask.convert("L"))
mask = mask.astype(np.float32) / 255.0
mask = mask[None, None]
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
return mask, masked_image
class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
@@ -129,14 +129,16 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
def __call__(
self,
prompt: Union[str, List[str]],
init_image: Union[np.ndarray, PIL.Image.Image],
mask_image: Union[np.ndarray, PIL.Image.Image],
strength: float = 0.8,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
image: PIL.Image.Image,
mask_image: PIL.Image.Image,
height: int = 512,
width: int = 512,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0,
eta: float = 0.0,
latents: Optional[np.ndarray] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
@@ -149,22 +151,21 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
init_image (`np.ndarray` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, that will be used as the starting point for the
process. This is the image whose masked region will be inpainted.
mask_image (`np.ndarray` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
strength (`float`, *optional*, defaults to 0.8):
Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
is 1, the denoising process will be run on the masked area for the full number of iterations specified
in `num_inference_steps`. `init_image` will be used as a reference for the masked area, adding more
noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
image (`PIL.Image.Image`):
`Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
be masked out with `mask_image` and repainted according to `prompt`.
mask_image (`PIL.Image.Image`):
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
instead of 3, so the expected shape would be `(B, H, W, 1)`.
height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 512):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -179,6 +180,10 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
latents (`np.ndarray`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -206,8 +211,8 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
@@ -285,41 +290,46 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
# to avoid doing two forward passes
text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
# preprocess image
if not isinstance(init_image, torch.FloatTensor):
init_image = preprocess_image(init_image)
num_channels_latents = NUM_LATENT_CHANNELS
latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8)
latents_dtype = text_embeddings.dtype
if latents is None:
latents = np.random.randn(*latents_shape).astype(latents_dtype)
else:
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
# encode the init image into latents and scale the latents
init_latents = self.vae_encoder(sample=init_image)[0]
init_latents = 0.18215 * init_latents
# prepare mask and masked_image
mask, masked_image = prepare_mask_and_masked_image(image, mask_image, latents_shape[-2:])
mask = mask.astype(latents.dtype)
masked_image = masked_image.astype(latents.dtype)
# Expand init_latents for batch_size and num_images_per_prompt
init_latents = np.concatenate([init_latents] * batch_size * num_images_per_prompt, axis=0)
init_latents_orig = init_latents
masked_image_latents = self.vae_encoder(sample=masked_image)[0]
masked_image_latents = 0.18215 * masked_image_latents
# preprocess mask
if not isinstance(mask_image, np.ndarray):
mask_image = preprocess_mask(mask_image)
mask = np.concatenate([mask_image] * batch_size * num_images_per_prompt)
# check sizes
if not mask.shape == init_latents.shape:
raise ValueError("The mask and init_image should be the same size!")
# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
timesteps = self.scheduler.timesteps.numpy()[-init_timestep]
timesteps = np.array([timesteps] * batch_size * num_images_per_prompt)
# add noise to latents using the timesteps
noise = np.random.randn(*init_latents.shape).astype(np.float32)
init_latents = self.scheduler.add_noise(
torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timesteps)
mask = np.concatenate([mask] * 2) if do_classifier_free_guidance else mask
masked_image_latents = (
np.concatenate([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
)
init_latents = init_latents.numpy()
num_channels_mask = mask.shape[1]
num_channels_masked_image = masked_image_latents.shape[1]
unet_input_channels = NUM_UNET_INPUT_CHANNELS
if num_channels_latents + num_channels_mask + num_channels_masked_image != unet_input_channels:
raise ValueError(
"Incorrect configuration settings! The config of `pipeline.unet` expects"
f" {unet_input_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
# set timesteps
self.scheduler.set_timesteps(num_inference_steps)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
@@ -330,15 +340,13 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
if accepts_eta:
extra_step_kwargs["eta"] = eta
latents = init_latents
t_start = max(num_inference_steps - init_timestep + offset, 0)
timesteps = self.scheduler.timesteps[t_start:].numpy()
for i, t in tqdm(enumerate(timesteps)):
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# concat latents, mask, masked_image_latnets in the channel dimension
latent_model_input = np.concatenate([latent_model_input, mask, masked_image_latents], axis=1)
latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
latent_model_input = latent_model_input.numpy()
# predict the noise residual
noise_pred = self.unet(
@@ -353,12 +361,6 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = latents.numpy()
# masking
init_latents_proper = self.scheduler.add_noise(
torch.from_numpy(init_latents_orig), torch.from_numpy(noise), torch.tensor([t])
)
latents = (init_latents_proper * mask) + (latents * (1 - mask))
# call the callback, if provided
if callback is not None and i % callback_steps == 0:

View File

@@ -49,6 +49,21 @@ class StableDiffusionInpaintPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusionInpaintPipelineLegacy(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -2271,7 +2271,7 @@ class PipelineTesterMixin(unittest.TestCase):
)
pipe = OnnxStableDiffusionInpaintPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="onnx", provider="CPUExecutionProvider"
"runwayml/stable-diffusion-inpainting", revision="onnx", provider="CPUExecutionProvider"
)
pipe.set_progress_bar_config(disable=None)
@@ -2280,9 +2280,8 @@ class PipelineTesterMixin(unittest.TestCase):
np.random.seed(0)
output = pipe(
prompt=prompt,
init_image=init_image,
image=init_image,
mask_image=mask_image,
strength=0.75,
guidance_scale=7.5,
num_inference_steps=8,
output_type="np",
@@ -2291,7 +2290,7 @@ class PipelineTesterMixin(unittest.TestCase):
image_slice = images[0, 255:258, 255:258, -1]
assert images.shape == (1, 512, 512, 3)
expected_slice = np.array([0.3524, 0.3289, 0.3464, 0.3872, 0.4129, 0.3566, 0.3709, 0.4128, 0.3734])
expected_slice = np.array([0.2951, 0.2955, 0.2922, 0.2036, 0.1977, 0.2279, 0.1716, 0.1641, 0.1799])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
@slow