diff --git a/examples/community/README.md b/examples/community/README.md index 5cebc4f9f0..8afe8d42e3 100755 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -68,6 +68,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif | InstantID Pipeline | Stable Diffusion XL Pipeline that supports InstantID | [InstantID Pipeline](#instantid-pipeline) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/InstantX/InstantID) | [Haofan Wang](https://github.com/haofanwang) | | UFOGen Scheduler | Scheduler for UFOGen Model (compatible with Stable Diffusion pipelines) | [UFOGen Scheduler](#ufogen-scheduler) | - | [dg845](https://github.com/dg845) | | Stable Diffusion XL IPEX Pipeline | Accelerate Stable Diffusion XL inference pipeline with BF16/FP32 precision on Intel Xeon CPUs with [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [Stable Diffusion XL on IPEX](#stable-diffusion-xl-on-ipex) | - | [Dan Li](https://github.com/ustcuna/) | +| Stable Diffusion BoxDiff Pipeline | Training-free controlled generation with bounding boxes using [BoxDiff](https://github.com/showlab/BoxDiff) | [Stable Diffusion BoxDiff Pipeline](#stable-diffusion-boxdiff) | - | [Jingyang Zhang](https://github.com/zjysteven/) | To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly. @@ -1676,6 +1677,68 @@ image = pipe(prompt, image=input_image, strength=0.75,).images[0] image.save('tensorrt_img2img_new_zealand_hills.png') ``` +### Stable Diffusion BoxDiff +BoxDiff is a training-free method for controlled generation with bounding box coordinates. It shoud work with any Stable Diffusion model. Below shows an example with `stable-diffusion-2-1-base`. +```py +import torch +from PIL import Image, ImageDraw +from copy import deepcopy + +from examples.community.pipeline_stable_diffusion_boxdiff import StableDiffusionBoxDiffPipeline + +def draw_box_with_text(img, boxes, names): + colors = ["red", "olive", "blue", "green", "orange", "brown", "cyan", "purple"] + img_new = deepcopy(img) + draw = ImageDraw.Draw(img_new) + + W, H = img.size + for bid, box in enumerate(boxes): + draw.rectangle([box[0] * W, box[1] * H, box[2] * W, box[3] * H], outline=colors[bid % len(colors)], width=4) + draw.text((box[0] * W, box[1] * H), names[bid], fill=colors[bid % len(colors)]) + return img_new + +pipe = StableDiffusionBoxDiffPipeline.from_pretrained( + "stabilityai/stable-diffusion-2-1-base", + torch_dtype=torch.float16, +) +pipe.to("cuda") + +# example 1 +prompt = "as the aurora lights up the sky, a herd of reindeer leisurely wanders on the grassy meadow, admiring the breathtaking view, a serene lake quietly reflects the magnificent display, and in the distance, a snow-capped mountain stands majestically, fantasy, 8k, highly detailed" +phrases = [ + "aurora", + "reindeer", + "meadow", + "lake", + "mountain" +] +boxes = [[1,3,512,202], [75,344,421,495], [1,327,508,507], [2,217,507,341], [1,135,509,242]] + +# example 2 +# prompt = "A rabbit wearing sunglasses looks very proud" +# phrases = ["rabbit", "sunglasses"] +# boxes = [[67,87,366,512], [66,130,364,262]] + +boxes = [[x / 512 for x in box] for box in boxes] + +images = pipe( + prompt, + boxdiff_phrases=phrases, + boxdiff_boxes=boxes, + boxdiff_kwargs={ + "attention_res": 16, + "normalize_eot": True + }, + num_inference_steps=50, + guidance_scale=7.5, + generator=torch.manual_seed(42), + safety_checker=None +).images + +draw_box_with_text(images[0], boxes, phrases).save("output.png") +``` + + ### Stable Diffusion Reference This pipeline uses the Reference Control. Refer to the [sd-webui-controlnet discussion: Reference-only Control](https://github.com/Mikubill/sd-webui-controlnet/discussions/1236)[sd-webui-controlnet discussion: Reference-adain Control](https://github.com/Mikubill/sd-webui-controlnet/discussions/1280). diff --git a/examples/community/pipeline_stable_diffusion_boxdiff.py b/examples/community/pipeline_stable_diffusion_boxdiff.py new file mode 100644 index 0000000000..f825339441 --- /dev/null +++ b/examples/community/pipeline_stable_diffusion_boxdiff.py @@ -0,0 +1,1700 @@ +# Copyright 2024 Jingyang Zhang and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import inspect +import math +import numbers +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from diffusers.configuration_utils import FrozenDict +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor +from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from diffusers.models.attention_processor import Attention, FusedAttnProcessor2_0 +from diffusers.models.lora import adjust_lora_scale_text_encoder +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from diffusers.utils.torch_utils import randn_tensor + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionPipeline + + >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt).images[0] + ``` +""" + + +class GaussianSmoothing(nn.Module): + """ + Copied from official repo: https://github.com/showlab/BoxDiff/blob/master/utils/gaussian_smoothing.py + Apply gaussian smoothing on a + 1d, 2d or 3d tensor. Filtering is performed seperately for each channel + in the input using a depthwise convolution. + Arguments: + channels (int, sequence): Number of channels of the input tensors. Output will + have this number of channels as well. + kernel_size (int, sequence): Size of the gaussian kernel. + sigma (float, sequence): Standard deviation of the gaussian kernel. + dim (int, optional): The number of dimensions of the data. + Default value is 2 (spatial). + """ + + def __init__(self, channels, kernel_size, sigma, dim=2): + super(GaussianSmoothing, self).__init__() + if isinstance(kernel_size, numbers.Number): + kernel_size = [kernel_size] * dim + if isinstance(sigma, numbers.Number): + sigma = [sigma] * dim + + # The gaussian kernel is the product of the + # gaussian function of each dimension. + kernel = 1 + meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size]) + for size, std, mgrid in zip(kernel_size, sigma, meshgrids): + mean = (size - 1) / 2 + kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2)) + + # Make sure sum of values in gaussian kernel equals 1. + kernel = kernel / torch.sum(kernel) + + # Reshape to depthwise convolutional weight + kernel = kernel.view(1, 1, *kernel.size()) + kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) + + self.register_buffer("weight", kernel) + self.groups = channels + + if dim == 1: + self.conv = F.conv1d + elif dim == 2: + self.conv = F.conv2d + elif dim == 3: + self.conv = F.conv3d + else: + raise RuntimeError("Only 1, 2 and 3 dimensions are supported. Received {}.".format(dim)) + + def forward(self, input): + """ + Apply gaussian filter to input. + Arguments: + input (torch.Tensor): Input to apply gaussian filter on. + Returns: + filtered (torch.Tensor): Filtered output. + """ + return self.conv(input, weight=self.weight.to(input.dtype), groups=self.groups) + + +class AttendExciteCrossAttnProcessor: + def __init__(self, attnstore, place_in_unet): + super().__init__() + self.attnstore = attnstore + self.place_in_unet = place_in_unet + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size=1) + query = attn.to_q(hidden_states) + + is_cross = encoder_hidden_states is not None + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + self.attnstore(attention_probs, is_cross, self.place_in_unet) + + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +class AttentionControl(abc.ABC): + def step_callback(self, x_t): + return x_t + + def between_steps(self): + return + + # @property + # def num_uncond_att_layers(self): + # return 0 + + @abc.abstractmethod + def forward(self, attn, is_cross: bool, place_in_unet: str): + raise NotImplementedError + + def __call__(self, attn, is_cross: bool, place_in_unet: str): + if self.cur_att_layer >= self.num_uncond_att_layers: + self.forward(attn, is_cross, place_in_unet) + self.cur_att_layer += 1 + if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: + self.cur_att_layer = 0 + self.cur_step += 1 + self.between_steps() + + def reset(self): + self.cur_step = 0 + self.cur_att_layer = 0 + + def __init__(self): + self.cur_step = 0 + self.num_att_layers = -1 + self.cur_att_layer = 0 + + +class AttentionStore(AttentionControl): + @staticmethod + def get_empty_store(): + return {"down_cross": [], "mid_cross": [], "up_cross": [], "down_self": [], "mid_self": [], "up_self": []} + + def forward(self, attn, is_cross: bool, place_in_unet: str): + key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" + if attn.shape[1] <= 32**2: # avoid memory overhead + self.step_store[key].append(attn) + return attn + + def between_steps(self): + self.attention_store = self.step_store + if self.save_global_store: + with torch.no_grad(): + if len(self.global_store) == 0: + self.global_store = self.step_store + else: + for key in self.global_store: + for i in range(len(self.global_store[key])): + self.global_store[key][i] += self.step_store[key][i].detach() + self.step_store = self.get_empty_store() + self.step_store = self.get_empty_store() + + def get_average_attention(self): + average_attention = self.attention_store + return average_attention + + def get_average_global_attention(self): + average_attention = { + key: [item / self.cur_step for item in self.global_store[key]] for key in self.attention_store + } + return average_attention + + def reset(self): + super(AttentionStore, self).reset() + self.step_store = self.get_empty_store() + self.attention_store = {} + self.global_store = {} + + def __init__(self, save_global_store=False): + """ + Initialize an empty AttentionStore + :param step_index: used to visualize only a specific step in the diffusion process + """ + super(AttentionStore, self).__init__() + self.save_global_store = save_global_store + self.step_store = self.get_empty_store() + self.attention_store = {} + self.global_store = {} + self.curr_step_index = 0 + self.num_uncond_att_layers = 0 + + +def aggregate_attention( + attention_store: AttentionStore, res: int, from_where: List[str], is_cross: bool, select: int +) -> torch.Tensor: + """Aggregates the attention across the different layers and heads at the specified resolution.""" + out = [] + attention_maps = attention_store.get_average_attention() + + # for k, v in attention_maps.items(): + # for vv in v: + # print(vv.shape) + # exit() + + num_pixels = res**2 + for location in from_where: + for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]: + if item.shape[1] == num_pixels: + cross_maps = item.reshape(1, -1, res, res, item.shape[-1])[select] + out.append(cross_maps) + out = torch.cat(out, dim=0) + out = out.sum(0) / out.shape[0] + return out + + +def register_attention_control(model, controller): + attn_procs = {} + cross_att_count = 0 + for name in model.unet.attn_processors.keys(): + # cross_attention_dim = None if name.endswith("attn1.processor") else model.unet.config.cross_attention_dim + if name.startswith("mid_block"): + # hidden_size = model.unet.config.block_out_channels[-1] + place_in_unet = "mid" + elif name.startswith("up_blocks"): + # block_id = int(name[len("up_blocks.")]) + # hidden_size = list(reversed(model.unet.config.block_out_channels))[block_id] + place_in_unet = "up" + elif name.startswith("down_blocks"): + # block_id = int(name[len("down_blocks.")]) + # hidden_size = model.unet.config.block_out_channels[block_id] + place_in_unet = "down" + else: + continue + + cross_att_count += 1 + attn_procs[name] = AttendExciteCrossAttnProcessor(attnstore=controller, place_in_unet=place_in_unet) + model.unet.set_attn_processor(attn_procs) + controller.num_att_layers = cross_att_count + + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, + `timesteps` must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` + must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionBoxDiffPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin +): + r""" + Pipeline for text-to-image generation using Stable Diffusion with BoxDiff. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return text_inputs, prompt_embeds, negative_prompt_embeds + + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + boxdiff_phrases, + boxdiff_boxes, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if boxdiff_phrases is not None or boxdiff_boxes is not None: + if not (boxdiff_phrases is not None and boxdiff_boxes is not None): + raise ValueError("Either both `boxdiff_phrases` and `boxdiff_boxes` must be passed or none of them.") + + if not isinstance(boxdiff_phrases, list) or not isinstance(boxdiff_boxes, list): + raise ValueError("`boxdiff_phrases` and `boxdiff_boxes` must be lists.") + + if len(boxdiff_phrases) != len(boxdiff_boxes): + raise ValueError( + "`boxdiff_phrases` and `boxdiff_boxes` must have the same length," + f" got: `boxdiff_phrases` {len(boxdiff_phrases)} != `boxdiff_boxes`" + f" {len(boxdiff_boxes)}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): + r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497. + + The suffixes after the scaling factors represent the stages where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values + that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + if not hasattr(self, "unet"): + raise ValueError("The pipeline must have `unet` for using FreeU.") + self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) + + def disable_freeu(self): + """Disables the FreeU mechanism if enabled.""" + self.unet.disable_freeu() + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.fuse_qkv_projections + def fuse_qkv_projections(self, unet: bool = True, vae: bool = True): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, + key, value) are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + + Args: + unet (`bool`, defaults to `True`): To apply fusion on the UNet. + vae (`bool`, defaults to `True`): To apply fusion on the VAE. + """ + self.fusing_unet = False + self.fusing_vae = False + + if unet: + self.fusing_unet = True + self.unet.fuse_qkv_projections() + self.unet.set_attn_processor(FusedAttnProcessor2_0()) + + if vae: + if not isinstance(self.vae, AutoencoderKL): + raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.") + + self.fusing_vae = True + self.vae.fuse_qkv_projections() + self.vae.set_attn_processor(FusedAttnProcessor2_0()) + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.unfuse_qkv_projections + def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True): + """Disable QKV projection fusion if enabled. + + + + This API is 🧪 experimental. + + + + Args: + unet (`bool`, defaults to `True`): To apply fusion on the UNet. + vae (`bool`, defaults to `True`): To apply fusion on the VAE. + + """ + if unet: + if not self.fusing_unet: + logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.") + else: + self.unet.unfuse_qkv_projections() + self.fusing_unet = False + + if vae: + if not self.fusing_vae: + logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.") + else: + self.vae.unfuse_qkv_projections() + self.fusing_vae = False + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32): + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + timesteps (`torch.Tensor`): + generate embedding vectors at these timesteps + embedding_dim (`int`, *optional*, defaults to 512): + dimension of the embeddings to generate + dtype: + data type of the generated embeddings + + Returns: + `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + def _compute_max_attention_per_index( + self, + attention_maps: torch.Tensor, + indices_to_alter: List[int], + smooth_attentions: bool = False, + sigma: float = 0.5, + kernel_size: int = 3, + normalize_eot: bool = False, + bboxes: List[int] = None, + L: int = 1, + P: float = 0.2, + ) -> List[torch.Tensor]: + """Computes the maximum attention value for each of the tokens we wish to alter.""" + last_idx = -1 + if normalize_eot: + prompt = self.prompt + if isinstance(self.prompt, list): + prompt = self.prompt[0] + last_idx = len(self.tokenizer(prompt)["input_ids"]) - 1 + attention_for_text = attention_maps[:, :, 1:last_idx] + attention_for_text *= 100 + attention_for_text = torch.nn.functional.softmax(attention_for_text, dim=-1) + + # Shift indices since we removed the first token "1:last_idx" + indices_to_alter = [index - 1 for index in indices_to_alter] + + # Extract the maximum values + max_indices_list_fg = [] + max_indices_list_bg = [] + dist_x = [] + dist_y = [] + + cnt = 0 + for i in indices_to_alter: + image = attention_for_text[:, :, i] + + # TODO + # box = [max(round(b / (512 / image.shape[0])), 0) for b in bboxes[cnt]] + # x1, y1, x2, y2 = box + H, W = image.shape + x1 = min(max(round(bboxes[cnt][0] * W), 0), W) + y1 = min(max(round(bboxes[cnt][1] * H), 0), H) + x2 = min(max(round(bboxes[cnt][2] * W), 0), W) + y2 = min(max(round(bboxes[cnt][3] * H), 0), H) + box = [x1, y1, x2, y2] + cnt += 1 + + # coordinates to masks + obj_mask = torch.zeros_like(image) + ones_mask = torch.ones([y2 - y1, x2 - x1], dtype=obj_mask.dtype).to(obj_mask.device) + obj_mask[y1:y2, x1:x2] = ones_mask + bg_mask = 1 - obj_mask + + if smooth_attentions: + smoothing = GaussianSmoothing(channels=1, kernel_size=kernel_size, sigma=sigma, dim=2).to(image.device) + input = F.pad(image.unsqueeze(0).unsqueeze(0), (1, 1, 1, 1), mode="reflect") + image = smoothing(input).squeeze(0).squeeze(0) + + # Inner-Box constraint + k = (obj_mask.sum() * P).long() + max_indices_list_fg.append((image * obj_mask).reshape(-1).topk(k)[0].mean()) + + # Outer-Box constraint + k = (bg_mask.sum() * P).long() + max_indices_list_bg.append((image * bg_mask).reshape(-1).topk(k)[0].mean()) + + # Corner Constraint + gt_proj_x = torch.max(obj_mask, dim=0)[0] + gt_proj_y = torch.max(obj_mask, dim=1)[0] + corner_mask_x = torch.zeros_like(gt_proj_x) + corner_mask_y = torch.zeros_like(gt_proj_y) + + # create gt according to the number config.L + N = gt_proj_x.shape[0] + corner_mask_x[max(box[0] - L, 0) : min(box[0] + L + 1, N)] = 1.0 + corner_mask_x[max(box[2] - L, 0) : min(box[2] + L + 1, N)] = 1.0 + corner_mask_y[max(box[1] - L, 0) : min(box[1] + L + 1, N)] = 1.0 + corner_mask_y[max(box[3] - L, 0) : min(box[3] + L + 1, N)] = 1.0 + dist_x.append((F.l1_loss(image.max(dim=0)[0], gt_proj_x, reduction="none") * corner_mask_x).mean()) + dist_y.append((F.l1_loss(image.max(dim=1)[0], gt_proj_y, reduction="none") * corner_mask_y).mean()) + + return max_indices_list_fg, max_indices_list_bg, dist_x, dist_y + + def _aggregate_and_get_max_attention_per_token( + self, + attention_store: AttentionStore, + indices_to_alter: List[int], + attention_res: int = 16, + smooth_attentions: bool = False, + sigma: float = 0.5, + kernel_size: int = 3, + normalize_eot: bool = False, + bboxes: List[int] = None, + L: int = 1, + P: float = 0.2, + ): + """Aggregates the attention for each token and computes the max activation value for each token to alter.""" + attention_maps = aggregate_attention( + attention_store=attention_store, + res=attention_res, + from_where=("up", "down", "mid"), + is_cross=True, + select=0, + ) + max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y = self._compute_max_attention_per_index( + attention_maps=attention_maps, + indices_to_alter=indices_to_alter, + smooth_attentions=smooth_attentions, + sigma=sigma, + kernel_size=kernel_size, + normalize_eot=normalize_eot, + bboxes=bboxes, + L=L, + P=P, + ) + return max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y + + @staticmethod + def _compute_loss( + max_attention_per_index_fg: List[torch.Tensor], + max_attention_per_index_bg: List[torch.Tensor], + dist_x: List[torch.Tensor], + dist_y: List[torch.Tensor], + return_losses: bool = False, + ) -> torch.Tensor: + """Computes the attend-and-excite loss using the maximum attention value for each token.""" + losses_fg = [max(0, 1.0 - curr_max) for curr_max in max_attention_per_index_fg] + losses_bg = [max(0, curr_max) for curr_max in max_attention_per_index_bg] + loss = sum(losses_fg) + sum(losses_bg) + sum(dist_x) + sum(dist_y) + if return_losses: + return max(losses_fg), losses_fg + else: + return max(losses_fg), loss + + @staticmethod + def _update_latent(latents: torch.Tensor, loss: torch.Tensor, step_size: float) -> torch.Tensor: + """Update the latent according to the computed loss.""" + grad_cond = torch.autograd.grad(loss.requires_grad_(True), [latents], retain_graph=True)[0] + latents = latents - step_size * grad_cond + return latents + + def _perform_iterative_refinement_step( + self, + latents: torch.Tensor, + indices_to_alter: List[int], + loss_fg: torch.Tensor, + threshold: float, + text_embeddings: torch.Tensor, + text_input, + attention_store: AttentionStore, + step_size: float, + t: int, + attention_res: int = 16, + smooth_attentions: bool = True, + sigma: float = 0.5, + kernel_size: int = 3, + max_refinement_steps: int = 20, + normalize_eot: bool = False, + bboxes: List[int] = None, + L: int = 1, + P: float = 0.2, + ): + """ + Performs the iterative latent refinement introduced in the paper. Here, we continuously update the latent + code according to our loss objective until the given threshold is reached for all tokens. + """ + iteration = 0 + target_loss = max(0, 1.0 - threshold) + + while loss_fg > target_loss: + iteration += 1 + + latents = latents.clone().detach().requires_grad_(True) + # noise_pred_text = self.unet(latents, t, encoder_hidden_states=text_embeddings[1].unsqueeze(0)).sample + self.unet.zero_grad() + + # Get max activation value for each subject token + ( + max_attention_per_index_fg, + max_attention_per_index_bg, + dist_x, + dist_y, + ) = self._aggregate_and_get_max_attention_per_token( + attention_store=attention_store, + indices_to_alter=indices_to_alter, + attention_res=attention_res, + smooth_attentions=smooth_attentions, + sigma=sigma, + kernel_size=kernel_size, + normalize_eot=normalize_eot, + bboxes=bboxes, + L=L, + P=P, + ) + + loss_fg, losses_fg = self._compute_loss( + max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y, return_losses=True + ) + + if loss_fg != 0: + latents = self._update_latent(latents, loss_fg, step_size) + + # with torch.no_grad(): + # noise_pred_uncond = self.unet(latents, t, encoder_hidden_states=text_embeddings[0].unsqueeze(0)).sample + # noise_pred_text = self.unet(latents, t, encoder_hidden_states=text_embeddings[1].unsqueeze(0)).sample + + # try: + # low_token = np.argmax([l.item() if not isinstance(l, int) else l for l in losses_fg]) + # except Exception as e: + # print(e) # catch edge case :) + # low_token = np.argmax(losses_fg) + + # low_word = self.tokenizer.decode(text_input.input_ids[0][indices_to_alter[low_token]]) + # print(f'\t Try {iteration}. {low_word} has a max attention of {max_attention_per_index_fg[low_token]}') + + if iteration >= max_refinement_steps: + # print(f'\t Exceeded max number of iterations ({max_refinement_steps})! ' + # f'Finished with a max attention of {max_attention_per_index_fg[low_token]}') + break + + # Run one more time but don't compute gradients and update the latents. + # We just need to compute the new loss - the grad update will occur below + latents = latents.clone().detach().requires_grad_(True) + # noise_pred_text = self.unet(latents, t, encoder_hidden_states=text_embeddings[1].unsqueeze(0)).sample + self.unet.zero_grad() + + # Get max activation value for each subject token + ( + max_attention_per_index_fg, + max_attention_per_index_bg, + dist_x, + dist_y, + ) = self._aggregate_and_get_max_attention_per_token( + attention_store=attention_store, + indices_to_alter=indices_to_alter, + attention_res=attention_res, + smooth_attentions=smooth_attentions, + sigma=sigma, + kernel_size=kernel_size, + normalize_eot=normalize_eot, + bboxes=bboxes, + L=L, + P=P, + ) + loss_fg, losses_fg = self._compute_loss( + max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y, return_losses=True + ) + # print(f"\t Finished with loss of: {loss_fg}") + return loss_fg, latents, max_attention_per_index_fg + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + boxdiff_phrases: List[str] = None, + boxdiff_boxes: List[List[float]] = None, # TODO + boxdiff_kwargs: Optional[Dict[str, Any]] = { + "attention_res": 16, + "P": 0.2, + "L": 1, + "max_iter_to_alter": 25, + "loss_thresholds": {0: 0.05, 10: 0.5, 20: 0.8}, + "scale_factor": 20, + "scale_range": (1.0, 0.5), + "smooth_attentions": True, + "sigma": 0.5, + "kernel_size": 3, + "refine": False, + "normalize_eot": True, + }, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + + boxdiff_attention_res (`int`, *optional*, defaults to 16): + The resolution of the attention maps used for computing the BoxDiff loss. + boxdiff_P (`float`, *optional*, defaults to 0.2): + + boxdiff_L (`int`, *optional*, defaults to 1): + The number of pixels around the corner to be selected in BoxDiff loss. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + # -1. Register attention control (for BoxDiff) + attention_store = AttentionStore() + register_attention_control(self, attention_store) + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + # to deal with lora scaling and other possible forward hooks + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + boxdiff_phrases, + boxdiff_boxes, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + ) + self.prompt = prompt + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + text_inputs, prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + clip_skip=self.clip_skip, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + if ip_adapter_image is not None: + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) + if self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 6.1 Add image embeds for IP-Adapter + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + + # 6.2 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 6.3 Prepare BoxDiff inputs + # a) Indices to alter + input_ids = self.tokenizer(prompt)["input_ids"] + decoded = [self.tokenizer.decode([t]) for t in input_ids] + indices_to_alter = [] + bboxes = [] + for phrase, box in zip(boxdiff_phrases, boxdiff_boxes): + # it could happen that phrase does not correspond a single token? + if phrase not in decoded: + continue + indices_to_alter.append(decoded.index(phrase)) + bboxes.append(box) + + # b) A bunch of hyperparameters + attention_res = boxdiff_kwargs.get("attention_res", 16) + smooth_attentions = boxdiff_kwargs.get("smooth_attentions", True) + sigma = boxdiff_kwargs.get("sigma", 0.5) + kernel_size = boxdiff_kwargs.get("kernel_size", 3) + L = boxdiff_kwargs.get("L", 1) + P = boxdiff_kwargs.get("P", 0.2) + thresholds = boxdiff_kwargs.get("loss_thresholds", {0: 0.05, 10: 0.5, 20: 0.8}) + max_iter_to_alter = boxdiff_kwargs.get("max_iter_to_alter", len(self.scheduler.timesteps) + 1) + scale_factor = boxdiff_kwargs.get("scale_factor", 20) + refine = boxdiff_kwargs.get("refine", False) + normalize_eot = boxdiff_kwargs.get("normalize_eot", True) + + scale_range = boxdiff_kwargs.get("scale_range", (1.0, 0.5)) + scale_range = np.linspace(scale_range[0], scale_range[1], len(self.scheduler.timesteps)) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # BoxDiff optimization + with torch.enable_grad(): + latents = latents.clone().detach().requires_grad_(True) + + # Forward pass of denoising with text conditioning + noise_pred_text = self.unet( + latents, + t, + encoder_hidden_states=prompt_embeds[1].unsqueeze(0), + cross_attention_kwargs=cross_attention_kwargs, + ).sample + self.unet.zero_grad() + + # Get max activation value for each subject token + ( + max_attention_per_index_fg, + max_attention_per_index_bg, + dist_x, + dist_y, + ) = self._aggregate_and_get_max_attention_per_token( + attention_store=attention_store, + indices_to_alter=indices_to_alter, + attention_res=attention_res, + smooth_attentions=smooth_attentions, + sigma=sigma, + kernel_size=kernel_size, + normalize_eot=normalize_eot, + bboxes=bboxes, + L=L, + P=P, + ) + + loss_fg, loss = self._compute_loss( + max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y + ) + + # Refinement from attend-and-excite (not necessary) + if refine and i in thresholds.keys() and loss_fg > 1.0 - thresholds[i]: + del noise_pred_text + torch.cuda.empty_cache() + loss_fg, latents, max_attention_per_index_fg = self._perform_iterative_refinement_step( + latents=latents, + indices_to_alter=indices_to_alter, + loss_fg=loss_fg, + threshold=thresholds[i], + text_embeddings=prompt_embeds, + text_input=text_inputs, + attention_store=attention_store, + step_size=scale_factor * np.sqrt(scale_range[i]), + t=t, + attention_res=attention_res, + smooth_attentions=smooth_attentions, + sigma=sigma, + kernel_size=kernel_size, + normalize_eot=normalize_eot, + bboxes=bboxes, + L=L, + P=P, + ) + + # Perform gradient update + if i < max_iter_to_alter: + _, loss = self._compute_loss( + max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y + ) + if loss != 0: + latents = self._update_latent( + latents=latents, loss=loss, step_size=scale_factor * np.sqrt(scale_range[i]) + ) + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)