From ea4cf2592865d3d5ea62f8ed5bb5ec04a54abfcf Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sun, 25 Jun 2023 21:39:15 +0000 Subject: [PATCH] clean up --- run_local_xl.py | 56 ------------------- .../pipeline_stable_diffusion_xl.py | 32 +++++++---- 2 files changed, 20 insertions(+), 68 deletions(-) delete mode 100755 run_local_xl.py diff --git a/run_local_xl.py b/run_local_xl.py deleted file mode 100755 index db41d2cdf9..0000000000 --- a/run_local_xl.py +++ /dev/null @@ -1,56 +0,0 @@ -#!/usr/bin/env python3 -from diffusers import DiffusionPipeline, EulerDiscreteScheduler, StableDiffusionPipeline, KDPM2DiscreteScheduler, StableDiffusionImg2ImgPipeline, HeunDiscreteScheduler, KDPM2AncestralDiscreteScheduler, DDIMScheduler -import time -import os -from huggingface_hub import HfApi -# from compel import Compel -import torch -import sys -from pathlib import Path -import requests -from PIL import Image -from io import BytesIO - -path = sys.argv[1] - -api = HfApi() -start_time = time.time() -pipe = DiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16) -pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) -# pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) -# pipe = StableDiffusionImg2ImgPipeline.from_pretrained(path, torch_dtype=torch.float16, safety_checker=None - -# compel = Compel(tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder) - - -pipe = pipe.to("cuda") - -prompt = "An astronaut riding a green horse on Mars" - -# rompts = ["a cat playing with a ball++ in the forest", "a cat playing with a ball++ in the forest", "a cat playing with a ball-- in the forest"] - -# prompt_embeds = torch.cat([compel.build_conditioning_tensor(prompt) for prompt in prompts]) - -# generator = [torch.Generator(device="cuda").manual_seed(0) for _ in range(prompt_embeds.shape[0])] -# -# url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" -# -# response = requests.get(url) -# image = Image.open(BytesIO(response.content)).convert("RGB") -# image.thumbnail((768, 768)) -# - -# pipe.unet.set_default_attn_processor() -image = pipe(prompt=prompt).images[0] - -file_name = f"aaa" -path = os.path.join(Path.home(), "images", f"{file_name}.png") -image.save(path) - -api.upload_file( - path_or_fileobj=path, - path_in_repo=path.split("/")[-1], - repo_id="patrickvonplaten/images", - repo_type="dataset", -) -print(f"https://huggingface.co/datasets/patrickvonplaten/images/blob/main/{file_name}.png") diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 14e674ca86..0610936b0a 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -14,8 +14,7 @@ import inspect import warnings -from typing import Any, Callable, Dict, List, Optional, Union -from pytorch_lightning import seed_everything +from typing import Any, Callable, Dict, List, Optional, Union, Tuple import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection @@ -511,10 +510,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline): ) if latents is None: - seed_everything(0) - # latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = randn_tensor(shape, generator=generator, device="cpu", dtype=torch.float32) - latents = latents.to(dtype=dtype, device=device) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) @@ -544,6 +540,9 @@ class StableDiffusionXLPipeline(DiffusionPipeline): callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guidance_rescale: float = 0.0, + original_size: Tuple[int, int] = (1024, 1024), + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Tuple[int, int] = (1024, 1024), ): r""" Function invoked when calling the pipeline for generation. @@ -609,6 +608,12 @@ class StableDiffusionXLPipeline(DiffusionPipeline): Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + TODO + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + TODO + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + TODO Examples: @@ -681,15 +686,18 @@ class StableDiffusionXLPipeline(DiffusionPipeline): # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + + add_text_embeds = pooled_prompt_embeds + add_time_ids = torch.tensor([list(original_size + crops_coords_top_left + target_size)], dtype=torch.long) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) - - # TODO - find better explanations where they come from - # original_size_as_tuple x crops_coords_top_left x target_size_as_tuple - add_time_ids = torch.tensor(2 * [[1024, 1024, 0, 0, 1024, 1024]], dtype=torch.long, device=add_text_embeds.device) + add_time_ids = add_time_ids.to(device) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps):