mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
92 lines
2.4 KiB
Python
92 lines
2.4 KiB
Python
import argparse
|
|
import os
|
|
|
|
import torch
|
|
from PIL import Image, ImageFilter
|
|
from transformers import CLIPTextModel
|
|
|
|
from diffusers import DPMSolverMultistepScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel
|
|
|
|
|
|
parser = argparse.ArgumentParser(description="Inference")
|
|
parser.add_argument(
|
|
"--model_path",
|
|
type=str,
|
|
default=None,
|
|
required=True,
|
|
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
|
)
|
|
parser.add_argument(
|
|
"--validation_image",
|
|
type=str,
|
|
default=None,
|
|
required=True,
|
|
help="The directory of the validation image",
|
|
)
|
|
parser.add_argument(
|
|
"--validation_mask",
|
|
type=str,
|
|
default=None,
|
|
required=True,
|
|
help="The directory of the validation mask",
|
|
)
|
|
parser.add_argument(
|
|
"--output_dir",
|
|
type=str,
|
|
default="./test-infer/",
|
|
help="The output directory where predictions are saved",
|
|
)
|
|
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible inference.")
|
|
|
|
args = parser.parse_args()
|
|
|
|
if __name__ == "__main__":
|
|
os.makedirs(args.output_dir, exist_ok=True)
|
|
generator = None
|
|
|
|
# create & load model
|
|
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
|
"stabilityai/stable-diffusion-2-inpainting", torch_dtype=torch.float32, revision=None
|
|
)
|
|
|
|
pipe.unet = UNet2DConditionModel.from_pretrained(
|
|
args.model_path,
|
|
subfolder="unet",
|
|
revision=None,
|
|
)
|
|
pipe.text_encoder = CLIPTextModel.from_pretrained(
|
|
args.model_path,
|
|
subfolder="text_encoder",
|
|
revision=None,
|
|
)
|
|
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
|
pipe = pipe.to("cuda")
|
|
|
|
if args.seed is not None:
|
|
generator = torch.Generator(device="cuda").manual_seed(args.seed)
|
|
|
|
image = Image.open(args.validation_image)
|
|
mask_image = Image.open(args.validation_mask)
|
|
|
|
results = pipe(
|
|
["a photo of sks"] * 16,
|
|
image=image,
|
|
mask_image=mask_image,
|
|
num_inference_steps=25,
|
|
guidance_scale=5,
|
|
generator=generator,
|
|
).images
|
|
|
|
erode_kernel = ImageFilter.MaxFilter(3)
|
|
mask_image = mask_image.filter(erode_kernel)
|
|
|
|
blur_kernel = ImageFilter.BoxBlur(1)
|
|
mask_image = mask_image.filter(blur_kernel)
|
|
|
|
for idx, result in enumerate(results):
|
|
result = Image.composite(result, image, mask_image)
|
|
result.save(f"{args.output_dir}/{idx}.png")
|
|
|
|
del pipe
|
|
torch.cuda.empty_cache()
|