1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Fix inpainting script (#258)

* expand latents before the check, style

* update readme
This commit is contained in:
Suraj Patil
2022-08-26 21:16:43 +05:30
committed by GitHub
parent 11133dcca1
commit 5cbed8e0d1
2 changed files with 58 additions and 14 deletions

View File

@@ -11,7 +11,7 @@ from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
def preprocess(image):
def preprocess_image(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
@@ -20,15 +20,16 @@ def preprocess(image):
image = torch.from_numpy(image)
return 2.0 * image - 1.0
def preprocess_mask(mask):
mask=mask.convert("L")
mask = mask.convert("L")
w, h = mask.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
mask = mask.resize((w//8, h//8), resample=PIL.Image.NEAREST)
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST)
mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask,(4,1,1))
mask = mask[None].transpose(0, 1, 2, 3)#what does this step do?
mask = 1 - mask #repaint white, keep black
mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
mask = 1 - mask # repaint white, keep black
mask = torch.from_numpy(mask)
return mask
@@ -90,25 +91,25 @@ class StableDiffusionInpaintingPipeline(DiffusionPipeline):
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
#preprocess image
init_image = preprocess(init_image).to(self.device)
# preprocess image
init_image = preprocess_image(init_image).to(self.device)
# encode the init image into latents and scale the latents
init_latents = self.vae.encode(init_image).sample()
init_latents = 0.18215 * init_latents
# prepare init_latents noise to latents
init_latents = torch.cat([init_latents] * batch_size)
init_latents_orig = init_latents
# preprocess mask
mask = preprocess_mask(mask_image).to(self.device)
mask = torch.cat([mask] * batch_size)
#check sizes
# check sizes
if not mask.shape == init_latents.shape:
raise ValueError(f"The mask and init_image should be the same size!")
# prepare init_latents noise to latents
init_latents = torch.cat([init_latents] * batch_size)
# get the original timestep using init_timestep
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
@@ -172,9 +173,9 @@ class StableDiffusionInpaintingPipeline(DiffusionPipeline):
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"]
#masking
# masking
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t)
latents = ( init_latents_proper * mask ) + ( latents * (1-mask) )
latents = (init_latents_proper * mask) + (latents * (1 - mask))
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents

View File

@@ -52,3 +52,46 @@ You can also run this example on colab [![Open In Colab](https://colab.research.
## Tweak prompts reusing seeds and latents
You can generate your own latents to reproduce results, or tweak your prompt on a specific result you liked. [This notebook](stable-diffusion-seeds.ipynb) shows how to do it step by step. You can also run it in Google Colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pcuenca/diffusers-examples/blob/main/notebooks/stable-diffusion-seeds.ipynb).
## In-painting using Stable Diffusion
The `inpainting.py` script implements `StableDiffusionInpaintingPipeline`. This script lets you edit specific parts of an image by providing a mask and text prompt.
### How to use it
```python
from io import BytesIO
from torch import autocast
import requests
import PIL
from inpainting import StableDiffusionInpaintingPipeline
def download_image(url):
response = requests.get(url)
return PIL.Image.open(BytesIO(response.content)).convert("RGB")
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
init_image = download_image(img_url).resize((512, 512))
mask_image = download_image(mask_url).resize((512, 512))
device = "cuda"
pipe = StableDiffusionInpaintingPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
revision="fp16",
torch_dtype=torch.float16,
use_auth_token=True
).to(device)
prompt = "a cat sitting on a bench"
with autocast("cuda"):
images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75)["sample"]
images[0].save("cat_on_bench.png")
```
You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/in_painting_with_stable_diffusion_using_diffusers.ipynb)