From 03cd62520ffff950e5ca62cf2e8d7d3a0cd2f5bb Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Sun, 3 Mar 2024 15:10:55 +0530 Subject: [PATCH] feat: add ip adapter benchmark (#6936) * feat: add ip adapter benchmark * sdxl support too. * Empty-Commit --- benchmarks/base_classes.py | 29 ++++++++++++++++++++++++++ benchmarks/benchmark_ip_adapters.py | 32 +++++++++++++++++++++++++++++ benchmarks/run_all.py | 2 +- 3 files changed, 62 insertions(+), 1 deletion(-) create mode 100644 benchmarks/benchmark_ip_adapters.py diff --git a/benchmarks/base_classes.py b/benchmarks/base_classes.py index 8f99b26f3a..dc1ca72388 100644 --- a/benchmarks/base_classes.py +++ b/benchmarks/base_classes.py @@ -236,6 +236,35 @@ class InpaintingBenchmark(ImageToImageBenchmark): ) +class IPAdapterTextToImageBenchmark(TextToImageBenchmark): + url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_neg_embed.png" + image = load_image(url) + + def __init__(self, args): + pipe = self.pipeline_class.from_pretrained(args.ckpt, torch_dtype=torch.float16).to("cuda") + pipe.load_ip_adapter( + args.ip_adapter_id[0], + subfolder="models" if "sdxl" not in args.ip_adapter_id[1] else "sdxl_models", + weight_name=args.ip_adapter_id[1], + ) + + if args.run_compile: + pipe.unet.to(memory_format=torch.channels_last) + print("Run torch compile") + pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) + + pipe.set_progress_bar_config(disable=True) + self.pipe = pipe + + def run_inference(self, pipe, args): + _ = pipe( + prompt=PROMPT, + ip_adapter_image=self.image, + num_inference_steps=args.num_inference_steps, + num_images_per_prompt=args.batch_size, + ) + + class ControlNetBenchmark(TextToImageBenchmark): pipeline_class = StableDiffusionControlNetPipeline aux_network_class = ControlNetModel diff --git a/benchmarks/benchmark_ip_adapters.py b/benchmarks/benchmark_ip_adapters.py new file mode 100644 index 0000000000..5c11ab3838 --- /dev/null +++ b/benchmarks/benchmark_ip_adapters.py @@ -0,0 +1,32 @@ +import argparse +import sys + + +sys.path.append(".") +from base_classes import IPAdapterTextToImageBenchmark # noqa: E402 + + +IP_ADAPTER_CKPTS = { + "runwayml/stable-diffusion-v1-5": ("h94/IP-Adapter", "ip-adapter_sd15.bin"), + "stabilityai/stable-diffusion-xl-base-1.0": ("h94/IP-Adapter", "ip-adapter_sdxl.bin"), +} + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--ckpt", + type=str, + default="runwayml/stable-diffusion-v1-5", + choices=list(IP_ADAPTER_CKPTS.keys()), + ) + parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument("--num_inference_steps", type=int, default=50) + parser.add_argument("--model_cpu_offload", action="store_true") + parser.add_argument("--run_compile", action="store_true") + args = parser.parse_args() + + args.ip_adapter_id = IP_ADAPTER_CKPTS[args.ckpt] + benchmark_pipe = IPAdapterTextToImageBenchmark(args) + args.ckpt = f"{args.ckpt} (IP-Adapter)" + benchmark_pipe.benchmark(args) diff --git a/benchmarks/run_all.py b/benchmarks/run_all.py index c70fb22273..8750e1333d 100644 --- a/benchmarks/run_all.py +++ b/benchmarks/run_all.py @@ -72,7 +72,7 @@ def main(): command += " --run_compile" run_command(command.split()) - elif file == "benchmark_sd_inpainting.py": + elif file in ["benchmark_sd_inpainting.py", "benchmark_ip_adapters.py"]: sdxl_ckpt = "stabilityai/stable-diffusion-xl-base-1.0" command = f"python {file} --ckpt {sdxl_ckpt}" run_command(command.split())