mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix inpainting script (#258)
* expand latents before the check, style * update readme
This commit is contained in:
@@ -11,7 +11,7 @@ from tqdm.auto import tqdm
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
|
||||
def preprocess(image):
|
||||
def preprocess_image(image):
|
||||
w, h = image.size
|
||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
|
||||
@@ -20,15 +20,16 @@ def preprocess(image):
|
||||
image = torch.from_numpy(image)
|
||||
return 2.0 * image - 1.0
|
||||
|
||||
|
||||
def preprocess_mask(mask):
|
||||
mask=mask.convert("L")
|
||||
mask = mask.convert("L")
|
||||
w, h = mask.size
|
||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
mask = mask.resize((w//8, h//8), resample=PIL.Image.NEAREST)
|
||||
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST)
|
||||
mask = np.array(mask).astype(np.float32) / 255.0
|
||||
mask = np.tile(mask,(4,1,1))
|
||||
mask = mask[None].transpose(0, 1, 2, 3)#what does this step do?
|
||||
mask = 1 - mask #repaint white, keep black
|
||||
mask = np.tile(mask, (4, 1, 1))
|
||||
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
|
||||
mask = 1 - mask # repaint white, keep black
|
||||
mask = torch.from_numpy(mask)
|
||||
return mask
|
||||
|
||||
@@ -90,25 +91,25 @@ class StableDiffusionInpaintingPipeline(DiffusionPipeline):
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
|
||||
|
||||
#preprocess image
|
||||
init_image = preprocess(init_image).to(self.device)
|
||||
# preprocess image
|
||||
init_image = preprocess_image(init_image).to(self.device)
|
||||
|
||||
# encode the init image into latents and scale the latents
|
||||
init_latents = self.vae.encode(init_image).sample()
|
||||
init_latents = 0.18215 * init_latents
|
||||
|
||||
# prepare init_latents noise to latents
|
||||
init_latents = torch.cat([init_latents] * batch_size)
|
||||
init_latents_orig = init_latents
|
||||
|
||||
# preprocess mask
|
||||
mask = preprocess_mask(mask_image).to(self.device)
|
||||
mask = torch.cat([mask] * batch_size)
|
||||
|
||||
#check sizes
|
||||
# check sizes
|
||||
if not mask.shape == init_latents.shape:
|
||||
raise ValueError(f"The mask and init_image should be the same size!")
|
||||
|
||||
# prepare init_latents noise to latents
|
||||
init_latents = torch.cat([init_latents] * batch_size)
|
||||
|
||||
# get the original timestep using init_timestep
|
||||
init_timestep = int(num_inference_steps * strength) + offset
|
||||
init_timestep = min(init_timestep, num_inference_steps)
|
||||
@@ -172,9 +173,9 @@ class StableDiffusionInpaintingPipeline(DiffusionPipeline):
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"]
|
||||
|
||||
#masking
|
||||
# masking
|
||||
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t)
|
||||
latents = ( init_latents_proper * mask ) + ( latents * (1-mask) )
|
||||
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / 0.18215 * latents
|
||||
|
||||
@@ -52,3 +52,46 @@ You can also run this example on colab [ shows how to do it step by step. You can also run it in Google Colab [](https://colab.research.google.com/github/pcuenca/diffusers-examples/blob/main/notebooks/stable-diffusion-seeds.ipynb).
|
||||
|
||||
|
||||
## In-painting using Stable Diffusion
|
||||
|
||||
The `inpainting.py` script implements `StableDiffusionInpaintingPipeline`. This script lets you edit specific parts of an image by providing a mask and text prompt.
|
||||
|
||||
### How to use it
|
||||
|
||||
```python
|
||||
from io import BytesIO
|
||||
|
||||
from torch import autocast
|
||||
import requests
|
||||
import PIL
|
||||
|
||||
from inpainting import StableDiffusionInpaintingPipeline
|
||||
|
||||
def download_image(url):
|
||||
response = requests.get(url)
|
||||
return PIL.Image.open(BytesIO(response.content)).convert("RGB")
|
||||
|
||||
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
|
||||
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
|
||||
|
||||
init_image = download_image(img_url).resize((512, 512))
|
||||
mask_image = download_image(mask_url).resize((512, 512))
|
||||
|
||||
device = "cuda"
|
||||
pipe = StableDiffusionInpaintingPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
revision="fp16",
|
||||
torch_dtype=torch.float16,
|
||||
use_auth_token=True
|
||||
).to(device)
|
||||
|
||||
prompt = "a cat sitting on a bench"
|
||||
with autocast("cuda"):
|
||||
images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75)["sample"]
|
||||
|
||||
images[0].save("cat_on_bench.png")
|
||||
```
|
||||
|
||||
You can also run this example on colab [](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/in_painting_with_stable_diffusion_using_diffusers.ipynb)
|
||||
|
||||
Reference in New Issue
Block a user