From 57b8406ef04d264a70e07633853953e16861922a Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 23 Jun 2023 10:37:57 +0200 Subject: [PATCH 01/14] Add new text encoder --- src/diffusers/__init__.py | 1 + src/diffusers/pipelines/__init__.py | 1 + .../pipelines/stable_diffusion_xl/__init__.py | 30 + .../pipeline_stable_diffusion_xl.py | 749 ++++++++++++++++++ 4 files changed, 781 insertions(+) create mode 100644 src/diffusers/pipelines/stable_diffusion_xl/__init__.py create mode 100644 src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 0290707534..0d588a85f5 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -160,6 +160,7 @@ else: StableDiffusionPix2PixZeroPipeline, StableDiffusionSAGPipeline, StableDiffusionUpscalePipeline, + StableDiffusionXLPipeline, StableUnCLIPImg2ImgPipeline, StableUnCLIPPipeline, TextToVideoSDPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index b165024084..4ac734f416 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -89,6 +89,7 @@ else: StableUnCLIPPipeline, ) from .stable_diffusion_safe import StableDiffusionPipelineSafe + from .stable_diffusion_xl import StableDiffusionXLPipeline from .text_to_video_synthesis import TextToVideoSDPipeline, TextToVideoZeroPipeline from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline from .unidiffuser import ImageTextPipelineOutput, UniDiffuserModel, UniDiffuserPipeline, UniDiffuserTextDecoder diff --git a/src/diffusers/pipelines/stable_diffusion_xl/__init__.py b/src/diffusers/pipelines/stable_diffusion_xl/__init__.py new file mode 100644 index 0000000000..807673c3f5 --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion_xl/__init__.py @@ -0,0 +1,30 @@ +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy as np +import PIL + +from ...utils import BaseOutput, is_torch_available, is_transformers_available + + +@dataclass +# Copied from diffusers.pipelines.stable_diffusion.__init__.StableDiffusionPipelineOutput with StableDiffusion->StableDiffusionXL +class StableDiffusionXLPipelineOutput(BaseOutput): + """ + Output class for Alt Diffusion pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + nsfw_content_detected (`List[bool]`) + List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, or `None` if safety checking could not be performed. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + nsfw_content_detected: Optional[List[bool]] + + +if is_transformers_available() and is_torch_available(): + from .pipeline_stable_diffusion_xl import StableDiffusionXLPipeline diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py new file mode 100644 index 0000000000..c8d590db47 --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -0,0 +1,749 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import warnings +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor +from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + deprecate, + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline +from . import StableDiffusionXLPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionXLPipeline + + >>> pipe = StableDiffusionXLPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt).images[0] + ``` +""" + + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class StableDiffusionXLPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] + - *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`] + + as well as the following saving methods: + - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + # vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModel, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + # unet: UNet2DConditionModel, + # scheduler: KarrasDiffusionSchedulers, + # safety_checker: StableDiffusionSafetyChecker, + # feature_extractor: CLIPImageProcessor, + # requires_safety_checker: bool = True, + ): + super().__init__() + + # if safety_checker is None and requires_safety_checker: + # logger.warning( + # f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + # " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + # " results in services or applications open to the public. Both the diffusers team and Hugging Face" + # " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + # " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + # " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + # ) + + # if safety_checker is not None and feature_extractor is None: + # raise ValueError( + # "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + # " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + # ) + + self.register_modules( + # vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + # unet=unet, + # scheduler=scheduler, + # safety_checker=safety_checker, + # feature_extractor=feature_extractor, + ) + # self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + # self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + # self.register_to_config(requires_safety_checker=requires_safety_checker) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): + from accelerate import cpu_offload + else: + raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + return "cpu" + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] + text_encoders = [self.text_encoder, self.text_encoder_2] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode( + untruncated_ids[:, tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder( + text_input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds.pooler_output + + prompt_embeds = prompt_embeds.hidden_states[-2] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None and negative_prompt is None: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + negative_prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds.pooler_output + + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + # height = height or self.unet.config.sample_size * self.vae_scale_factor + # width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + # self.check_inputs( + # prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + # ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # device = self._execution_device + device = "cpu" + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + import ipdb; ipdb.set_trace() + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionXLPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) From 39b0b97aacfff9221504c90eb2a3987dac606559 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 23 Jun 2023 12:16:51 +0200 Subject: [PATCH 02/14] add transformers depth --- src/diffusers/models/unet_2d_blocks.py | 13 ++++++++--- src/diffusers/models/unet_2d_condition.py | 10 +++++++++ .../stable_diffusion/convert_from_ckpt.py | 22 ++++++++++++++----- 3 files changed, 36 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index eee7e6023e..a57a469caa 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -38,6 +38,7 @@ def get_down_block( add_downsample, resnet_eps, resnet_act_fn, + num_transformer_blocks=1, num_attention_heads=None, resnet_groups=None, cross_attention_dim=None, @@ -106,6 +107,7 @@ def get_down_block( raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D") return CrossAttnDownBlock2D( num_layers=num_layers, + num_transformer_blocks=num_transformer_blocks, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -227,6 +229,7 @@ def get_up_block( add_upsample, resnet_eps, resnet_act_fn, + num_transformer_blocks=1, num_attention_heads=None, resnet_groups=None, cross_attention_dim=None, @@ -281,6 +284,7 @@ def get_up_block( raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D") return CrossAttnUpBlock2D( num_layers=num_layers, + num_transformer_blocks=num_transformer_blocks, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, @@ -506,6 +510,7 @@ class UNetMidBlock2DCrossAttn(nn.Module): temb_channels: int, dropout: float = 0.0, num_layers: int = 1, + num_transformer_blocks: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", @@ -548,7 +553,7 @@ class UNetMidBlock2DCrossAttn(nn.Module): num_attention_heads, in_channels // num_attention_heads, in_channels=in_channels, - num_layers=1, + num_layers=num_transformer_blocks, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, @@ -829,6 +834,7 @@ class CrossAttnDownBlock2D(nn.Module): temb_channels: int, dropout: float = 0.0, num_layers: int = 1, + num_transformer_blocks: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", @@ -873,7 +879,7 @@ class CrossAttnDownBlock2D(nn.Module): num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, - num_layers=1, + num_layers=num_transformer_blocks, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, @@ -1939,6 +1945,7 @@ class CrossAttnUpBlock2D(nn.Module): temb_channels: int, dropout: float = 0.0, num_layers: int = 1, + num_transformer_blocks: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", @@ -1984,7 +1991,7 @@ class CrossAttnUpBlock2D(nn.Module): num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, - num_layers=1, + num_layers=num_transformer_blocks, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 7bca5c336c..0cc9618d91 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -96,6 +96,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): The dimension of the cross attention features. + num_transformer_blocks (`int` or `Tuple[int]`, *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. encoder_hid_dim (`int`, *optional*, defaults to None): If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` dimension to `cross_attention_dim`. @@ -168,6 +170,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: Union[int, Tuple[int]] = 1280, + num_transformer_blocks: Union[int, Tuple[int]] = 1, encoder_hid_dim: Optional[int] = None, encoder_hid_dim_type: Optional[str] = None, attention_head_dim: Union[int, Tuple[int]] = 8, @@ -381,6 +384,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) if isinstance(layers_per_block, int): layers_per_block = [layers_per_block] * len(down_block_types) + if isinstance(num_transformer_blocks, int): + num_transformer_blocks = [num_transformer_blocks] * len(down_block_types) + if class_embeddings_concat: # The time embeddings are concatenated with the class embeddings. The dimension of the # time embeddings passed to the down, middle, and up blocks is twice the dimension of the @@ -399,6 +405,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) down_block = get_down_block( down_block_type, num_layers=layers_per_block[i], + num_transformer_blocks=num_transformer_blocks[i], in_channels=input_channel, out_channels=output_channel, temb_channels=blocks_time_embed_dim, @@ -424,6 +431,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) # mid if mid_block_type == "UNetMidBlock2DCrossAttn": self.mid_block = UNetMidBlock2DCrossAttn( + num_transformer_blocks=num_transformer_blocks[-1], in_channels=block_out_channels[-1], temb_channels=blocks_time_embed_dim, resnet_eps=norm_eps, @@ -465,6 +473,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) reversed_num_attention_heads = list(reversed(num_attention_heads)) reversed_layers_per_block = list(reversed(layers_per_block)) reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_num_transformer_blocks = list(reversed(num_transformer_blocks)) only_cross_attention = list(reversed(only_cross_attention)) output_channel = reversed_block_out_channels[0] @@ -485,6 +494,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) up_block = get_up_block( up_block_type, num_layers=reversed_layers_per_block[i] + 1, + num_transformer_blocks=reversed_num_transformer_blocks[i], in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index 3b3724f0d0..94cc1b36d2 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -233,7 +233,10 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa if controlnet: unet_params = original_config.model.params.control_stage_config.params else: - unet_params = original_config.model.params.unet_config.params + if original_config.model.params.unet_config is not None: + unet_params = original_config.model.params.unet_config.params + else: + unet_params = original_config.model.params.network_config.params vae_params = original_config.model.params.first_stage_config.params.ddconfig @@ -253,6 +256,11 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa up_block_types.append(block_type) resolution //= 2 + if unet_params.transformer_depth is not None: + num_transformer_blocks = unet_params.transformer_depth if isinstance(unet_params.transformer_depth, int) else list(unet_params.transformer_depth) + else: + num_transformer_blocks = 1 + vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1) head_dim = unet_params.num_heads if "num_heads" in unet_params else None @@ -262,7 +270,7 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa if use_linear_projection: # stable diffusion 2-base-512 and 2-768 if head_dim is None: - head_dim = [5, 10, 20, 20] + head_dim = [5 * c for c in list(unet_params.channel_mult)] class_embed_type = None projection_class_embeddings_input_dim = None @@ -286,6 +294,7 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa "use_linear_projection": use_linear_projection, "class_embed_type": class_embed_type, "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, + "num_transformer_blocks": num_transformer_blocks, } if controlnet: @@ -1172,9 +1181,9 @@ def download_from_original_stable_diffusion_ckpt( checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema ) - num_train_timesteps = original_config.model.params.timesteps - beta_start = original_config.model.params.linear_start - beta_end = original_config.model.params.linear_end + num_train_timesteps = original_config.model.params.timesteps or 1000 + beta_start = original_config.model.params.linear_start or 0.02 + beta_end = original_config.model.params.linear_end or 0.085 scheduler = DDIMScheduler( beta_end=beta_end, @@ -1216,8 +1225,9 @@ def download_from_original_stable_diffusion_ckpt( converted_unet_checkpoint = convert_ldm_unet_checkpoint( checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema ) - unet.load_state_dict(converted_unet_checkpoint) + # Works! + import ipdb; ipdb.set_trace() # Convert the VAE model. vae_config = create_vae_diffusers_config(original_config, image_size=image_size) From 50df26c1411c7f5c3c7973815b367ea5f31071a7 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 23 Jun 2023 17:58:21 +0200 Subject: [PATCH 03/14] More --- src/diffusers/models/unet_2d_condition.py | 25 +++++++++++++-- .../pipeline_stable_diffusion_xl.py | 32 ++++++++++++------- 2 files changed, 44 insertions(+), 13 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 0cc9618d91..5f53b55286 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -179,6 +179,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) use_linear_projection: bool = False, class_embed_type: Optional[str] = None, addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", @@ -352,6 +353,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) self.add_embedding = TextImageTimeEmbedding( text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif addition_embed_type is not None: raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") @@ -789,7 +794,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) if self.config.addition_embed_type == "text": aug_emb = self.add_embedding(encoder_hidden_states) - emb = emb + aug_emb elif self.config.addition_embed_type == "text_image": # Kadinsky 2.1 - style if "image_embeds" not in added_cond_kwargs: @@ -801,7 +805,24 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) aug_emb = self.add_embedding(text_embs, image_embs) - emb = emb + aug_emb + elif self.config.addition_embed_type == "text_time": + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + aug_emb = self.add_embedding(add_embeds) + + emb = emb + aug_emb if self.time_embed_act is not None: emb = self.time_embed_act(emb) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index c8d590db47..189d362663 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -15,9 +15,10 @@ import inspect import warnings from typing import Any, Callable, Dict, List, Optional, Union +from pytorch_lightning import seed_everything import torch -from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection from ...configuration_utils import FrozenDict from ...image_processor import VaeImageProcessor @@ -109,11 +110,11 @@ class StableDiffusionXLPipeline(DiffusionPipeline): self, # vae: AutoencoderKL, text_encoder: CLIPTextModel, - text_encoder_2: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, tokenizer_2: CLIPTokenizer, - # unet: UNet2DConditionModel, - # scheduler: KarrasDiffusionSchedulers, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, # safety_checker: StableDiffusionSafetyChecker, # feature_extractor: CLIPImageProcessor, # requires_safety_checker: bool = True, @@ -142,12 +143,13 @@ class StableDiffusionXLPipeline(DiffusionPipeline): text_encoder_2=text_encoder_2, tokenizer=tokenizer, tokenizer_2=tokenizer_2, - # unet=unet, - # scheduler=scheduler, + unet=unet, + scheduler=scheduler, # safety_checker=safety_checker, # feature_extractor=feature_extractor, ) # self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 8 # self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) # self.register_to_config(requires_safety_checker=requires_safety_checker) @@ -341,7 +343,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline): output_hidden_states=True, ) # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds.pooler_output + pooled_prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.hidden_states[-2] @@ -398,7 +400,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline): output_hidden_states=True, ) # We are only ALWAYS interested in the pooled output of the final text encoder - negative_pooled_prompt_embeds = negative_prompt_embeds.pooler_output + negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] @@ -630,8 +632,10 @@ class StableDiffusionXLPipeline(DiffusionPipeline): (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet - # height = height or self.unet.config.sample_size * self.vae_scale_factor - # width = width or self.unet.config.sample_size * self.vae_scale_factor + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + seed_everything(0) # 1. Check inputs. Raise error if not correct # self.check_inputs( @@ -668,7 +672,6 @@ class StableDiffusionXLPipeline(DiffusionPipeline): negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) - import ipdb; ipdb.set_trace() # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) @@ -692,6 +695,10 @@ class StableDiffusionXLPipeline(DiffusionPipeline): # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + add_time_ids = torch.tensor(2 * [[128, 128, 0, 0, 1024, 1024]], dtype=torch.long) + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance @@ -699,13 +706,16 @@ class StableDiffusionXLPipeline(DiffusionPipeline): latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] + # TODO(Patrick) - forward path matches # perform guidance if do_classifier_free_guidance: From 4309a2c63bc3d04bcd4f40fc69b011406933c3a8 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 23 Jun 2023 22:27:53 +0200 Subject: [PATCH 04/14] Correct conversion script --- ..._original_stable_diffusion_to_diffusers.py | 8 +++ .../stable_diffusion/convert_from_ckpt.py | 53 ++++++++++++++++--- .../pipeline_stable_diffusion_xl.py | 8 +-- 3 files changed, 59 insertions(+), 10 deletions(-) diff --git a/scripts/convert_original_stable_diffusion_to_diffusers.py b/scripts/convert_original_stable_diffusion_to_diffusers.py index de64095523..73ec20dc67 100644 --- a/scripts/convert_original_stable_diffusion_to_diffusers.py +++ b/scripts/convert_original_stable_diffusion_to_diffusers.py @@ -126,6 +126,13 @@ if __name__ == "__main__": "--controlnet", action="store_true", default=None, help="Set flag if this is a controlnet checkpoint." ) parser.add_argument("--half", action="store_true", help="Save weights in half precision.") + parser.add_argument( + "--vae_path", + type=str, + default=None, + required=False, + help="Set to a path, hub id to an already converted vae to not convert it again." + ) args = parser.parse_args() pipe = download_from_original_stable_diffusion_ckpt( @@ -144,6 +151,7 @@ if __name__ == "__main__": stable_unclip_prior=args.stable_unclip_prior, clip_stats_path=args.clip_stats_path, controlnet=args.controlnet, + vae_path=args.vae_path, ) if args.half: diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index 94cc1b36d2..a16a32a8b8 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -277,7 +277,15 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa if "num_classes" in unet_params: if unet_params.num_classes == "sequential": - class_embed_type = "projection" + if unet_params.context_dim == 2048: + # SDXL + class_embed_type = None + addition_embed_type = "text_time" + addition_time_embed_dim = 256 + else: + class_embed_type = "projection" + addition_embed_type = None + addition_time_embed_dim = None assert "adm_in_channels" in unet_params projection_class_embeddings_input_dim = unet_params.adm_in_channels else: @@ -293,6 +301,8 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa "attention_head_dim": head_dim, "use_linear_projection": use_linear_projection, "class_embed_type": class_embed_type, + "addition_embed_type": addition_embed_type, + "addition_time_embed_dim": addition_time_embed_dim, "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, "num_transformer_blocks": num_transformer_blocks, } @@ -409,6 +419,12 @@ def convert_ldm_unet_checkpoint( else: raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") + if config["addition_embed_type"] == "text_time": + new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] + new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] + new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] + new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] @@ -1034,6 +1050,7 @@ def download_from_original_stable_diffusion_ckpt( load_safety_checker: bool = True, pipeline_class: DiffusionPipeline = None, local_files_only=False, + vae_path=None, ) -> DiffusionPipeline: """ Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml` @@ -1092,6 +1109,7 @@ def download_from_original_stable_diffusion_ckpt( StableDiffusionPipeline, StableUnCLIPImg2ImgPipeline, StableUnCLIPPipeline, + StableDiffusionXLPipeline, ) if pipeline_class is None: @@ -1230,16 +1248,22 @@ def download_from_original_stable_diffusion_ckpt( import ipdb; ipdb.set_trace() # Convert the VAE model. - vae_config = create_vae_diffusers_config(original_config, image_size=image_size) - converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) + if vae_path is None: + vae_config = create_vae_diffusers_config(original_config, image_size=image_size) + converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) - vae = AutoencoderKL(**vae_config) - vae.load_state_dict(converted_vae_checkpoint) + vae = AutoencoderKL(**vae_config) + vae.load_state_dict(converted_vae_checkpoint) + else: + vae = AutoencoderKL.from_pretrained(vae_path) # Convert the text model. - if model_type is None: + if model_type is None and original_config.model.params.cond_stage_config is not None: model_type = original_config.model.params.cond_stage_config.target.split(".")[-1] logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}") + elif model_type is None and original_config.model.params.network_config is not None: + if original_config.model.params.network_config.params.context_dim == 2048: + model_type = "SDXL" if model_type == "FrozenOpenCLIPEmbedder": text_model = convert_open_clip_checkpoint(checkpoint) @@ -1368,6 +1392,23 @@ def download_from_original_stable_diffusion_ckpt( safety_checker=safety_checker, feature_extractor=feature_extractor, ) + elif model_type == "SDXL": + tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") + tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") + text_encoder_2 = CLIPTextModelWithProjection.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") + pipe = StableDiffusionXLPipeline( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + unet=unet, + scheduler=scheduler, + # safety_checker=None, + # feature_extractor=None, + # requires_safety_checker=False, + ) else: text_config = create_ldm_bert_config(original_config) text_model = convert_ldm_bert_checkpoint(checkpoint, text_config) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 189d362663..ef759adb8d 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -108,7 +108,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline): def __init__( self, - # vae: AutoencoderKL, + vae: AutoencoderKL, text_encoder: CLIPTextModel, text_encoder_2: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, @@ -138,7 +138,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline): # ) self.register_modules( - # vae=vae, + vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, tokenizer=tokenizer, @@ -148,9 +148,9 @@ class StableDiffusionXLPipeline(DiffusionPipeline): # safety_checker=safety_checker, # feature_extractor=feature_extractor, ) - # self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 8 - # self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) # self.register_to_config(requires_safety_checker=requires_safety_checker) def enable_vae_slicing(self): From 51ab97a2f7e273f4f142416e58c12c90600a7d28 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 23 Jun 2023 20:48:50 +0000 Subject: [PATCH 05/14] Fix more --- .../pipelines/stable_diffusion/convert_from_ckpt.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index a16a32a8b8..d761793ac6 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -273,19 +273,18 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa head_dim = [5 * c for c in list(unet_params.channel_mult)] class_embed_type = None + addition_embed_type = None + addition_time_embed_dim = None projection_class_embeddings_input_dim = None if "num_classes" in unet_params: if unet_params.num_classes == "sequential": if unet_params.context_dim == 2048: # SDXL - class_embed_type = None addition_embed_type = "text_time" addition_time_embed_dim = 256 else: class_embed_type = "projection" - addition_embed_type = None - addition_time_embed_dim = None assert "adm_in_channels" in unet_params projection_class_embeddings_input_dim = unet_params.adm_in_channels else: @@ -1244,8 +1243,6 @@ def download_from_original_stable_diffusion_ckpt( checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema ) unet.load_state_dict(converted_unet_checkpoint) - # Works! - import ipdb; ipdb.set_trace() # Convert the VAE model. if vae_path is None: @@ -1396,7 +1393,7 @@ def download_from_original_stable_diffusion_ckpt( tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") - text_encoder_2 = CLIPTextModelWithProjection.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") + text_encoder_2 = CLIPTextModelWithProjection.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", projection_dim=1280) pipe = StableDiffusionXLPipeline( vae=vae, text_encoder=text_encoder, From dd48802fa58a0d0da88d8dd18202f3ad32c563c3 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 23 Jun 2023 21:16:04 +0000 Subject: [PATCH 06/14] Fix more --- src/diffusers/models/unet_2d_condition.py | 1 + .../pipeline_stable_diffusion_xl.py | 26 ++++++++++--------- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 5f53b55286..192f9663f2 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -820,6 +820,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) aug_emb = self.add_embedding(add_embeds) emb = emb + aug_emb diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index ef759adb8d..0a46aafae1 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -15,7 +15,7 @@ import inspect import warnings from typing import Any, Callable, Dict, List, Optional, Union -from pytorch_lightning import seed_everything +# from pytorch_lightning import seed_everything import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection @@ -117,7 +117,6 @@ class StableDiffusionXLPipeline(DiffusionPipeline): scheduler: KarrasDiffusionSchedulers, # safety_checker: StableDiffusionSafetyChecker, # feature_extractor: CLIPImageProcessor, - # requires_safety_checker: bool = True, ): super().__init__() @@ -151,7 +150,6 @@ class StableDiffusionXLPipeline(DiffusionPipeline): self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - # self.register_to_config(requires_safety_checker=requires_safety_checker) def enable_vae_slicing(self): r""" @@ -245,7 +243,6 @@ class StableDiffusionXLPipeline(DiffusionPipeline): `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ - return "cpu" if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): @@ -635,12 +632,12 @@ class StableDiffusionXLPipeline(DiffusionPipeline): height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor - seed_everything(0) + # seed_everything(0) # 1. Check inputs. Raise error if not correct - # self.check_inputs( - # prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds - # ) + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): @@ -650,8 +647,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline): else: batch_size = prompt_embeds.shape[0] - # device = self._execution_device - device = "cpu" + device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -697,7 +693,9 @@ class StableDiffusionXLPipeline(DiffusionPipeline): num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) - add_time_ids = torch.tensor(2 * [[128, 128, 0, 0, 1024, 1024]], dtype=torch.long) + + # TODO - find better explanations where they come from + add_time_ids = torch.tensor(2 * [[128, 128, 0, 0, 1024, 1024]], dtype=torch.long, device=add_text_embeds.device) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -736,8 +734,12 @@ class StableDiffusionXLPipeline(DiffusionPipeline): callback(i, t, latents) if not output_type == "latent": + # CHECK there is problem here (PVP) + # self.vae = self.vae.to(dtype=torch.float32) + # latents = latents.float() image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + #image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + has_nsfw_concept = None else: image = latents has_nsfw_concept = None From 7b767803f9b3b86b2f12375475e0611bddc436f9 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sun, 25 Jun 2023 18:40:18 +0000 Subject: [PATCH 07/14] Correct more --- .../stable_diffusion/convert_from_ckpt.py | 48 ++++++++++++------- .../pipeline_stable_diffusion_xl.py | 32 ++++++++----- 2 files changed, 51 insertions(+), 29 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index d761793ac6..147e1af1d5 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -765,9 +765,12 @@ def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False): text_model_dict = {} + remove_prefixes = ["cond_stage_model.transformer", "conditioner.embedders.0.transformer"] + for key in keys: - if key.startswith("cond_stage_model.transformer"): - text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] + for prefix in remove_prefixes: + if key.startswith(prefix): + text_model_dict[key[len(prefix + "."):]] = checkpoint[key] text_model.load_state_dict(text_model_dict) @@ -775,10 +778,11 @@ def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False): textenc_conversion_lst = [ - ("cond_stage_model.model.positional_embedding", "text_model.embeddings.position_embedding.weight"), - ("cond_stage_model.model.token_embedding.weight", "text_model.embeddings.token_embedding.weight"), - ("cond_stage_model.model.ln_final.weight", "text_model.final_layer_norm.weight"), - ("cond_stage_model.model.ln_final.bias", "text_model.final_layer_norm.bias"), + ("positional_embedding", "text_model.embeddings.position_embedding.weight"), + ("token_embedding.weight", "text_model.embeddings.token_embedding.weight"), + ("ln_final.weight", "text_model.final_layer_norm.weight"), + ("ln_final.bias", "text_model.final_layer_norm.bias"), + ("text_projection", "text_projection.weight") ] textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst} @@ -865,27 +869,34 @@ def convert_paint_by_example_checkpoint(checkpoint): return model -def convert_open_clip_checkpoint(checkpoint): - text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder") +def convert_open_clip_checkpoint(checkpoint, prefix="cond_stage_model.model."): + # text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder") + text_model = CLIPTextModelWithProjection.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", projection_dim=1280) keys = list(checkpoint.keys()) text_model_dict = {} - if "cond_stage_model.model.text_projection" in checkpoint: - d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0]) + if prefix + "text_projection" in checkpoint: + d_model = int(checkpoint[prefix + "text_projection"].shape[0]) else: d_model = 1024 text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids") for key in keys: - if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer - continue - if key in textenc_conversion_map: - text_model_dict[textenc_conversion_map[key]] = checkpoint[key] - if key.startswith("cond_stage_model.model.transformer."): - new_key = key[len("cond_stage_model.model.transformer.") :] + # if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer + # continue + if key[len(prefix):] in textenc_conversion_map: + if key.endswith("text_projection"): + value = checkpoint[key].T + else: + value = checkpoint[key] + + text_model_dict[textenc_conversion_map[key[len(prefix):]]] = value + + if key.startswith(prefix + "transformer."): + new_key = key[len(prefix + "transformer.") :] if new_key.endswith(".in_proj_weight"): new_key = new_key[: -len(".in_proj_weight")] new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) @@ -1391,9 +1402,10 @@ def download_from_original_stable_diffusion_ckpt( ) elif model_type == "SDXL": tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") - text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") - tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") + text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only) + tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!") text_encoder_2 = CLIPTextModelWithProjection.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", projection_dim=1280) + text_encoder_2 = convert_open_clip_checkpoint(checkpoint, prefix="conditioner.embedders.1.model.") pipe = StableDiffusionXLPipeline( vae=vae, text_encoder=text_encoder, diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 0a46aafae1..9a75321b70 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -15,7 +15,7 @@ import inspect import warnings from typing import Any, Callable, Dict, List, Optional, Union -# from pytorch_lightning import seed_everything +from pytorch_lightning import seed_everything import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection @@ -305,7 +305,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline): # Define tokenizers and text encoders tokenizers = [self.tokenizer, self.tokenizer_2] - text_encoders = [self.text_encoder, self.text_encoder_2] + text_encoders = [self.text_encoder.to(device), self.text_encoder_2.to(device)] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary @@ -334,7 +334,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline): "The following part of your input was truncated because CLIP can only handle sequences up to" f" {tokenizer.model_max_length} tokens: {removed_text}" ) - + prompt_embeds = text_encoder( text_input_ids.to(device), output_hidden_states=True, @@ -523,7 +523,10 @@ class StableDiffusionXLPipeline(DiffusionPipeline): ) if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + seed_everything(0) + # latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = randn_tensor(shape, generator=generator, device="cpu", dtype=torch.float32) + latents = latents.to(dtype=dtype, device=device) else: latents = latents.to(device) @@ -632,8 +635,6 @@ class StableDiffusionXLPipeline(DiffusionPipeline): height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor - # seed_everything(0) - # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds @@ -660,7 +661,8 @@ class StableDiffusionXLPipeline(DiffusionPipeline): ) prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = self.encode_prompt( prompt, - device, + "cpu", + # device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, @@ -671,6 +673,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline): # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps # 5. Prepare latent variables @@ -694,14 +697,20 @@ class StableDiffusionXLPipeline(DiffusionPipeline): prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + # TODO - find better explanations where they come from - add_time_ids = torch.tensor(2 * [[128, 128, 0, 0, 1024, 1024]], dtype=torch.long, device=add_text_embeds.device) + # original_size_as_tuple x crops_coords_top_left x target_size_as_tuple + add_time_ids = torch.tensor(2 * [[1024, 1024, 0, 0, 1024, 1024]], dtype=torch.long, device=add_text_embeds.device) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = latent_model_input * 0.07601528 # predict the noise residual added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} @@ -714,11 +723,12 @@ class StableDiffusionXLPipeline(DiffusionPipeline): return_dict=False, )[0] # TODO(Patrick) - forward path matches + import ipdb; ipdb.set_trace() # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + # if do_classifier_free_guidance: + # noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + # noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) if do_classifier_free_guidance and guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf From e0a0e363764a7a2ebd414acb04e65463a17d4e0b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sun, 25 Jun 2023 18:42:20 +0000 Subject: [PATCH 08/14] correct text encoder --- .../schedulers/scheduling_euler_discrete.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 7237128cbf..48a72cfaec 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -107,6 +107,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -123,6 +126,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): prediction_type: str = "epsilon", interpolation_type: str = "linear", use_karras_sigmas: Optional[bool] = False, + timestep_spacing: str = "linspace", ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -191,7 +195,16 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): """ self.num_inference_steps = num_inference_steps - timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + # "linspace" and "leading" corresponds to annotation of Table 1. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) From 277bc9d6235c7b3c9de7be8aaf8a9fec4c1ed646 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sun, 25 Jun 2023 21:08:31 +0000 Subject: [PATCH 09/14] Finish all --- .../pipeline_stable_diffusion_xl.py | 29 ++++--------------- .../schedulers/scheduling_euler_discrete.py | 13 ++++++--- 2 files changed, 15 insertions(+), 27 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 9a75321b70..85a8959f69 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -437,19 +437,6 @@ class StableDiffusionXLPipeline(DiffusionPipeline): ) return image, has_nsfw_concept - def decode_latents(self, latents): - warnings.warn( - "The decode_latents method is deprecated and will be removed in a future version. Please" - " use VaeImageProcessor instead", - FutureWarning, - ) - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents, return_dict=False)[0] - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - return image - def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. @@ -542,7 +529,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline): height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, - guidance_scale: float = 7.5, + guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, @@ -710,7 +697,6 @@ class StableDiffusionXLPipeline(DiffusionPipeline): latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - latent_model_input = latent_model_input * 0.07601528 # predict the noise residual added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} @@ -722,13 +708,11 @@ class StableDiffusionXLPipeline(DiffusionPipeline): added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] - # TODO(Patrick) - forward path matches - import ipdb; ipdb.set_trace() # perform guidance - # if do_classifier_free_guidance: - # noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - # noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) if do_classifier_free_guidance and guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf @@ -745,9 +729,8 @@ class StableDiffusionXLPipeline(DiffusionPipeline): if not output_type == "latent": # CHECK there is problem here (PVP) - # self.vae = self.vae.to(dtype=torch.float32) - # latents = latents.float() - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + with torch.autocast("cuda", enabled=False): + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] #image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) has_nsfw_concept = None else: diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 48a72cfaec..5576fd5078 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -150,9 +150,6 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) - # standard deviation of the initial noise distribution - self.init_noise_sigma = self.sigmas.max() - # setable values self.num_inference_steps = None timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() @@ -160,6 +157,14 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): self.is_scale_input_called = False self.use_karras_sigmas = use_karras_sigmas + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + if self.config.timestep_spacing == "linspace": + return self.sigmas.max() + + return (self.sigmas.max() ** 2 + 1) ** 0.5 + def scale_model_input( self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] ) -> torch.FloatTensor: @@ -178,7 +183,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): step_index = (self.timesteps == timestep).nonzero().item() sigma = self.sigmas[step_index] - sample = sample / ((sigma**2 + 1) ** 0.5) + sample = sample / ((sigma **2 + 1) ** 0.5) self.is_scale_input_called = True return sample From 62a151d8f46cf9cb66f7932b9d9fb31bef70d90a Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sun, 25 Jun 2023 21:26:31 +0000 Subject: [PATCH 10/14] proof that in works in run local xl --- run_local_xl.py | 56 +++++++++++++++++++ .../pipeline_stable_diffusion_upscale.py | 3 +- .../pipeline_stable_diffusion_xl.py | 26 +++++++-- 3 files changed, 80 insertions(+), 5 deletions(-) create mode 100755 run_local_xl.py diff --git a/run_local_xl.py b/run_local_xl.py new file mode 100755 index 0000000000..db41d2cdf9 --- /dev/null +++ b/run_local_xl.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +from diffusers import DiffusionPipeline, EulerDiscreteScheduler, StableDiffusionPipeline, KDPM2DiscreteScheduler, StableDiffusionImg2ImgPipeline, HeunDiscreteScheduler, KDPM2AncestralDiscreteScheduler, DDIMScheduler +import time +import os +from huggingface_hub import HfApi +# from compel import Compel +import torch +import sys +from pathlib import Path +import requests +from PIL import Image +from io import BytesIO + +path = sys.argv[1] + +api = HfApi() +start_time = time.time() +pipe = DiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16) +pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) +# pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) +# pipe = StableDiffusionImg2ImgPipeline.from_pretrained(path, torch_dtype=torch.float16, safety_checker=None + +# compel = Compel(tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder) + + +pipe = pipe.to("cuda") + +prompt = "An astronaut riding a green horse on Mars" + +# rompts = ["a cat playing with a ball++ in the forest", "a cat playing with a ball++ in the forest", "a cat playing with a ball-- in the forest"] + +# prompt_embeds = torch.cat([compel.build_conditioning_tensor(prompt) for prompt in prompts]) + +# generator = [torch.Generator(device="cuda").manual_seed(0) for _ in range(prompt_embeds.shape[0])] +# +# url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" +# +# response = requests.get(url) +# image = Image.open(BytesIO(response.content)).convert("RGB") +# image.thumbnail((768, 768)) +# + +# pipe.unet.set_default_attn_processor() +image = pipe(prompt=prompt).images[0] + +file_name = f"aaa" +path = os.path.join(Path.home(), "images", f"{file_name}.png") +image.save(path) + +api.upload_file( + path_or_fileobj=path, + path_in_repo=path.split("/")[-1], + repo_id="patrickvonplaten/images", + repo_type="dataset", +) +print(f"https://huggingface.co/datasets/patrickvonplaten/images/blob/main/{file_name}.png") diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index 0fda05ea5e..06b6628bd3 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -24,7 +24,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel -from ...models.attention_processor import AttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor +from ...models.attention_processor import AttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor, LoRAAttnProcessor2_0 from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline @@ -747,6 +747,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi AttnProcessor2_0, XFormersAttnProcessor, LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, ] # if xformers or torch_2_0 is used attention block does not need # to be in float32 which can save lots of memory diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 85a8959f69..14e674ca86 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -25,6 +25,7 @@ from ...image_processor import VaeImageProcessor from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers +from ...models.attention_processor import AttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor, LoRAAttnProcessor2_0 from ...utils import ( deprecate, is_accelerate_available, @@ -648,8 +649,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline): ) prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = self.encode_prompt( prompt, - "cpu", - # device, + device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, @@ -727,10 +727,28 @@ class StableDiffusionXLPipeline(DiffusionPipeline): if callback is not None and i % callback_steps == 0: callback(i, t, latents) + # make sure the VAE is in float32 mode, as it overflows in float16 + self.vae.to(dtype=torch.float32) + + use_torch_2_0_or_xformers = self.vae.decoder.mid_block.attentions[0].processor in [ + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + ] + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if not use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(latents.dtype) + self.vae.decoder.conv_in.to(latents.dtype) + self.vae.decoder.mid_block.to(latents.dtype) + else: + latents = latents.float() + + if not output_type == "latent": # CHECK there is problem here (PVP) - with torch.autocast("cuda", enabled=False): - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] #image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) has_nsfw_concept = None else: From ea4cf2592865d3d5ea62f8ed5bb5ec04a54abfcf Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sun, 25 Jun 2023 21:39:15 +0000 Subject: [PATCH 11/14] clean up --- run_local_xl.py | 56 ------------------- .../pipeline_stable_diffusion_xl.py | 32 +++++++---- 2 files changed, 20 insertions(+), 68 deletions(-) delete mode 100755 run_local_xl.py diff --git a/run_local_xl.py b/run_local_xl.py deleted file mode 100755 index db41d2cdf9..0000000000 --- a/run_local_xl.py +++ /dev/null @@ -1,56 +0,0 @@ -#!/usr/bin/env python3 -from diffusers import DiffusionPipeline, EulerDiscreteScheduler, StableDiffusionPipeline, KDPM2DiscreteScheduler, StableDiffusionImg2ImgPipeline, HeunDiscreteScheduler, KDPM2AncestralDiscreteScheduler, DDIMScheduler -import time -import os -from huggingface_hub import HfApi -# from compel import Compel -import torch -import sys -from pathlib import Path -import requests -from PIL import Image -from io import BytesIO - -path = sys.argv[1] - -api = HfApi() -start_time = time.time() -pipe = DiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16) -pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) -# pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) -# pipe = StableDiffusionImg2ImgPipeline.from_pretrained(path, torch_dtype=torch.float16, safety_checker=None - -# compel = Compel(tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder) - - -pipe = pipe.to("cuda") - -prompt = "An astronaut riding a green horse on Mars" - -# rompts = ["a cat playing with a ball++ in the forest", "a cat playing with a ball++ in the forest", "a cat playing with a ball-- in the forest"] - -# prompt_embeds = torch.cat([compel.build_conditioning_tensor(prompt) for prompt in prompts]) - -# generator = [torch.Generator(device="cuda").manual_seed(0) for _ in range(prompt_embeds.shape[0])] -# -# url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" -# -# response = requests.get(url) -# image = Image.open(BytesIO(response.content)).convert("RGB") -# image.thumbnail((768, 768)) -# - -# pipe.unet.set_default_attn_processor() -image = pipe(prompt=prompt).images[0] - -file_name = f"aaa" -path = os.path.join(Path.home(), "images", f"{file_name}.png") -image.save(path) - -api.upload_file( - path_or_fileobj=path, - path_in_repo=path.split("/")[-1], - repo_id="patrickvonplaten/images", - repo_type="dataset", -) -print(f"https://huggingface.co/datasets/patrickvonplaten/images/blob/main/{file_name}.png") diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 14e674ca86..0610936b0a 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -14,8 +14,7 @@ import inspect import warnings -from typing import Any, Callable, Dict, List, Optional, Union -from pytorch_lightning import seed_everything +from typing import Any, Callable, Dict, List, Optional, Union, Tuple import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection @@ -511,10 +510,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline): ) if latents is None: - seed_everything(0) - # latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = randn_tensor(shape, generator=generator, device="cpu", dtype=torch.float32) - latents = latents.to(dtype=dtype, device=device) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) @@ -544,6 +540,9 @@ class StableDiffusionXLPipeline(DiffusionPipeline): callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guidance_rescale: float = 0.0, + original_size: Tuple[int, int] = (1024, 1024), + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Tuple[int, int] = (1024, 1024), ): r""" Function invoked when calling the pipeline for generation. @@ -609,6 +608,12 @@ class StableDiffusionXLPipeline(DiffusionPipeline): Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + TODO + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + TODO + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + TODO Examples: @@ -681,15 +686,18 @@ class StableDiffusionXLPipeline(DiffusionPipeline): # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + + add_text_embeds = pooled_prompt_embeds + add_time_ids = torch.tensor([list(original_size + crops_coords_top_left + target_size)], dtype=torch.long) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) - - # TODO - find better explanations where they come from - # original_size_as_tuple x crops_coords_top_left x target_size_as_tuple - add_time_ids = torch.tensor(2 * [[1024, 1024, 0, 0, 1024, 1024]], dtype=torch.long, device=add_text_embeds.device) + add_time_ids = add_time_ids.to(device) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): From 48d203eeea8568365b224581d0b197332cbcd4d1 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sun, 25 Jun 2023 23:36:28 +0000 Subject: [PATCH 12/14] Get refiner to work --- .../stable_diffusion/convert_from_ckpt.py | 31 +++++++++++++------ .../pipeline_stable_diffusion_xl.py | 6 ++-- 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index 147e1af1d5..e1a83b93db 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -270,16 +270,21 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa if use_linear_projection: # stable diffusion 2-base-512 and 2-768 if head_dim is None: - head_dim = [5 * c for c in list(unet_params.channel_mult)] + head_dim_mult = unet_params.model_channels // unet_params.num_head_channels + head_dim = [head_dim_mult * c for c in list(unet_params.channel_mult)] class_embed_type = None addition_embed_type = None addition_time_embed_dim = None projection_class_embeddings_input_dim = None + context_dim = None + + if unet_params.context_dim is not None: + context_dim = unet_params.context_dim if isinstance(unet_params.context_dim, int) else unet_params.context_dim[0] if "num_classes" in unet_params: if unet_params.num_classes == "sequential": - if unet_params.context_dim == 2048: + if context_dim in [2048, 1280]: # SDXL addition_embed_type = "text_time" addition_time_embed_dim = 256 @@ -296,7 +301,7 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa "down_block_types": tuple(down_block_types), "block_out_channels": tuple(block_out_channels), "layers_per_block": unet_params.num_res_blocks, - "cross_attention_dim": unet_params.context_dim, + "cross_attention_dim": context_dim, "attention_head_dim": head_dim, "use_linear_projection": use_linear_projection, "class_embed_type": class_embed_type, @@ -1272,6 +1277,8 @@ def download_from_original_stable_diffusion_ckpt( elif model_type is None and original_config.model.params.network_config is not None: if original_config.model.params.network_config.params.context_dim == 2048: model_type = "SDXL" + else: + model_type = "SDXL-Refiner" if model_type == "FrozenOpenCLIPEmbedder": text_model = convert_open_clip_checkpoint(checkpoint) @@ -1400,12 +1407,18 @@ def download_from_original_stable_diffusion_ckpt( safety_checker=safety_checker, feature_extractor=feature_extractor, ) - elif model_type == "SDXL": - tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") - text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only) - tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!") - text_encoder_2 = CLIPTextModelWithProjection.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", projection_dim=1280) - text_encoder_2 = convert_open_clip_checkpoint(checkpoint, prefix="conditioner.embedders.1.model.") + elif model_type in ["SDXL", "SDXL-Refiner"]: + if model_type == "SDXL": + tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only) + tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!") + text_encoder_2 = convert_open_clip_checkpoint(checkpoint, prefix="conditioner.embedders.1.model.") + else: + tokenizer = None + text_encoder = None + tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!") + text_encoder_2 = convert_open_clip_checkpoint(checkpoint, prefix="conditioner.embedders.0.model.") + pipe = StableDiffusionXLPipeline( vae=vae, text_encoder=text_encoder, diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 0610936b0a..c99f054ef3 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -104,7 +104,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline): feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ - _optional_components = ["safety_checker", "feature_extractor"] + _optional_components = ["safety_checker", "feature_extractor", "tokenizer", "text_encoder"] def __init__( self, @@ -304,8 +304,8 @@ class StableDiffusionXLPipeline(DiffusionPipeline): batch_size = prompt_embeds.shape[0] # Define tokenizers and text encoders - tokenizers = [self.tokenizer, self.tokenizer_2] - text_encoders = [self.text_encoder.to(device), self.text_encoder_2.to(device)] + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary From 42168261fec78ad397167bc3b354276826000d25 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 26 Jun 2023 13:23:20 +0000 Subject: [PATCH 13/14] Add red castle --- src/diffusers/__init__.py | 1 + src/diffusers/pipelines/__init__.py | 2 +- .../stable_diffusion/convert_from_ckpt.py | 38 +- .../pipelines/stable_diffusion_xl/__init__.py | 1 + .../pipeline_stable_diffusion_xl.py | 4 +- .../pipeline_stable_diffusion_xl_img2img.py | 832 ++++++++++++++++++ 6 files changed, 863 insertions(+), 15 deletions(-) create mode 100644 src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 0d588a85f5..5a59250158 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -161,6 +161,7 @@ else: StableDiffusionSAGPipeline, StableDiffusionUpscalePipeline, StableDiffusionXLPipeline, + StableDiffusionXLImg2ImgPipeline, StableUnCLIPImg2ImgPipeline, StableUnCLIPPipeline, TextToVideoSDPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 4ac734f416..ec5b6ad4c7 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -89,7 +89,7 @@ else: StableUnCLIPPipeline, ) from .stable_diffusion_safe import StableDiffusionPipelineSafe - from .stable_diffusion_xl import StableDiffusionXLPipeline + from .stable_diffusion_xl import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline from .text_to_video_synthesis import TextToVideoSDPipeline, TextToVideoZeroPipeline from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline from .unidiffuser import ImageTextPipelineOutput, UniDiffuserModel, UniDiffuserPipeline, UniDiffuserTextDecoder diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index e1a83b93db..3c87297381 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -1125,6 +1125,7 @@ def download_from_original_stable_diffusion_ckpt( StableUnCLIPImg2ImgPipeline, StableUnCLIPPipeline, StableDiffusionXLPipeline, + StableDiffusionXLImg2ImgPipeline, ) if pipeline_class is None: @@ -1413,24 +1414,37 @@ def download_from_original_stable_diffusion_ckpt( text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only) tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!") text_encoder_2 = convert_open_clip_checkpoint(checkpoint, prefix="conditioner.embedders.1.model.") + + pipe = StableDiffusionXLPipeline( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + unet=unet, + scheduler=scheduler, + # safety_checker=None, + # feature_extractor=None, + # requires_safety_checker=False, + ) else: tokenizer = None text_encoder = None tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!") text_encoder_2 = convert_open_clip_checkpoint(checkpoint, prefix="conditioner.embedders.0.model.") - pipe = StableDiffusionXLPipeline( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - text_encoder_2=text_encoder_2, - tokenizer_2=tokenizer_2, - unet=unet, - scheduler=scheduler, - # safety_checker=None, - # feature_extractor=None, - # requires_safety_checker=False, - ) + pipe = StableDiffusionXLImg2ImgPipeline( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + unet=unet, + scheduler=scheduler, + # safety_checker=None, + # feature_extractor=None, + # requires_safety_checker=False, + ) else: text_config = create_ldm_bert_config(original_config) text_model = convert_ldm_bert_checkpoint(checkpoint, text_config) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/__init__.py b/src/diffusers/pipelines/stable_diffusion_xl/__init__.py index 807673c3f5..58a08aab70 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/__init__.py @@ -28,3 +28,4 @@ class StableDiffusionXLPipelineOutput(BaseOutput): if is_transformers_available() and is_torch_available(): from .pipeline_stable_diffusion_xl import StableDiffusionXLPipeline + from .pipeline_stable_diffusion_xl_img2img import StableDiffusionXLImg2ImgPipeline diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index c99f054ef3..8602d489f1 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -344,7 +344,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline): prompt_embeds = prompt_embeds.hidden_states[-2] - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method @@ -405,7 +405,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline): # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py new file mode 100644 index 0000000000..54754a1ca3 --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -0,0 +1,832 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import warnings +from typing import Any, Callable, Dict, List, Optional, Union, Tuple +import numpy as np +import PIL.Image + +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection + +from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor +from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...models.attention_processor import AttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor, LoRAAttnProcessor2_0 +from ...utils import ( + deprecate, + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline +from . import StableDiffusionXLPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionXLPipeline + + >>> pipe = StableDiffusionXLPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt).images[0] + ``` +""" + + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] + - *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`] + + as well as the following saving methods: + - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor", "tokenizer", "text_encoder"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + # safety_checker: StableDiffusionSafetyChecker, + # feature_extractor: CLIPImageProcessor, + ): + super().__init__() + + # if safety_checker is None and requires_safety_checker: + # logger.warning( + # f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + # " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + # " results in services or applications open to the public. Both the diffusers team and Hugging Face" + # " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + # " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + # " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + # ) + + # if safety_checker is not None and feature_extractor is None: + # raise ValueError( + # "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + # " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + # ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + scheduler=scheduler, + # safety_checker=safety_checker, + # feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): + from accelerate import cpu_offload + else: + raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode( + untruncated_ids[:, tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder( + text_input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.hidden_states[-2] + + prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None and negative_prompt is None: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + negative_prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + + # make sure the VAE is in float32 mode, as it overflows in float16 + image = image.float() + self.vae.to(dtype=torch.float32) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + init_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = self.vae.encode(image).latent_dist.sample(generator) + + self.vae.to(dtype) + init_latents = init_latents.to(dtype) + + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + + return latents + + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, + strength: float = 0.5, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Tuple[int, int] = (1024, 1024), + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Tuple[int, int] = (1024, 1024), + aesthetic_score: float = 6.0, + negative_aesthetic_score: float = 2.5, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + TODO + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + TODO + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + TODO + aesthetic_score (`float`, *optional*, defaults to 6.0): + TODO + negative_aesthetic_score (`float`, *optional*, defaults to 2.5): + TDOO + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Preprocess image + image = self.image_processor.preprocess(image) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 6. Prepare latent variables + latents = self.prepare_latents( + image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + add_text_embeds = pooled_prompt_embeds + + if self.unet.add_embedding.linear_1.in_features == (1280 + 5 * 256): + # refiner + add_time_ids = torch.tensor([list(original_size + crops_coords_top_left + (aesthetic_score,))], dtype=torch.long) + neg_add_time_ids = torch.tensor([list(original_size + crops_coords_top_left + (negative_aesthetic_score,))], dtype=torch.long) + elif self.unet.add_embedding.linear_1.in_features == (1280 + 6 * 256): + # SD-XL Base + add_time_ids = torch.tensor([list(original_size + crops_coords_top_left + target_size)], dtype=torch.long) + neg_add_time_ids = add_time_ids.clone() + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_time_ids, neg_add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device) + + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # make sure the VAE is in float32 mode, as it overflows in float16 + self.vae.to(dtype=torch.float32) + + use_torch_2_0_or_xformers = self.vae.decoder.mid_block.attentions[0].processor in [ + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + ] + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if not use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(latents.dtype) + self.vae.decoder.conv_in.to(latents.dtype) + self.vae.decoder.mid_block.to(latents.dtype) + else: + latents = latents.float() + + if not output_type == "latent": + # CHECK there is problem here (PVP) + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + #image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + has_nsfw_concept = None + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionXLPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) From 13107bbf653c90350e8d12506d2530d383b1f5e4 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 26 Jun 2023 15:18:14 +0000 Subject: [PATCH 14/14] Fix batch size --- .../pipeline_stable_diffusion_xl.py | 10 +++++----- .../pipeline_stable_diffusion_xl_img2img.py | 4 +++- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 8602d489f1..b9d3b9e2a5 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -339,13 +339,11 @@ class StableDiffusionXLPipeline(DiffusionPipeline): text_input_ids.to(device), output_hidden_states=True, ) + # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] - prompt_embeds = prompt_embeds.hidden_states[-2] - prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device) - bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -398,7 +396,6 @@ class StableDiffusionXLPipeline(DiffusionPipeline): ) # We are only ALWAYS interested in the pooled output of the final text encoder negative_pooled_prompt_embeds = negative_prompt_embeds[0] - negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] if do_classifier_free_guidance: @@ -421,6 +418,9 @@ class StableDiffusionXLPipeline(DiffusionPipeline): prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(bs_embed * num_images_per_prompt, -1) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(bs_embed * num_images_per_prompt, -1) + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds def run_safety_checker(self, image, device, dtype): @@ -697,7 +697,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline): prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) - add_time_ids = add_time_ids.to(device) + add_time_ids = add_time_ids.to(device).repeat(num_images_per_prompt, 1) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 54754a1ca3..4c6862391e 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -422,6 +422,8 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline): prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(bs_embed * num_images_per_prompt, -1) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(bs_embed * num_images_per_prompt, -1) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds @@ -749,7 +751,7 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline): prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) - add_time_ids = add_time_ids.to(device) + add_time_ids = add_time_ids.to(device).repeat(num_images_per_prompt, 1) num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: