From 8ee23f22d145d7699b15620f1900d1bb17af2bfa Mon Sep 17 00:00:00 2001 From: Nan Liu <45443761+nanlliu@users.noreply.github.com> Date: Tue, 13 Dec 2022 18:30:18 -0600 Subject: [PATCH] update composable diffusion for an updated diffuser library --- examples/community/README.md | 52 +- .../community/composable_stable_diffusion.py | 652 ++++++++++++------ 2 files changed, 475 insertions(+), 229 deletions(-) diff --git a/examples/community/README.md b/examples/community/README.md index d95aff7d8a..ac4da49454 100644 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -355,43 +355,45 @@ out = pipe( import torch as th import numpy as np import torchvision.utils as tvu + from diffusers import DiffusionPipeline +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument("--prompt", type=str, default="mystical trees | A magical pond | dark", + help="use '|' as the delimiter to compose separate sentences.") +parser.add_argument("--steps", type=int, default=50) +parser.add_argument("--scale", type=float, default=7.5) +parser.add_argument("--weights", type=str, default="7.5 | 7.5 | -7.5") +parser.add_argument("--seed", type=int, default=2) +parser.add_argument("--model_path", type=str, default="CompVis/stable-diffusion-v1-4") +parser.add_argument("--num_images", type=int, default=1) +args = parser.parse_args() + has_cuda = th.cuda.is_available() device = th.device('cpu' if not has_cuda else 'cuda') +prompt = args.prompt +scale = args.scale +steps = args.steps + pipe = DiffusionPipeline.from_pretrained( - "CompVis/stable-diffusion-v1-4", - use_auth_token=True, + args.model_path, custom_pipeline="composable_stable_diffusion", ).to(device) - -def dummy(images, **kwargs): - return images, False - -pipe.safety_checker = dummy +pipe.safety_checker = None images = [] -generator = torch.Generator("cuda").manual_seed(0) +generator = th.Generator("cuda").manual_seed(args.seed) +for i in range(args.num_images): + image = pipe(prompt, guidance_scale=scale, num_inference_steps=steps, + weights=args.weights, generator=generator).images[0] + images.append(th.from_numpy(np.array(image)).permute(2, 0, 1) / 255.) +grid = tvu.make_grid(th.stack(images, dim=0), nrow=4, padding=0) +tvu.save_image(grid, f'{prompt}_{args.weights}' + '.png') -seed = 0 -prompt = "a forest | a camel" -weights = " 1 | 1" # Equal weight to each prompt. Can be negative - -images = [] -for i in range(4): - res = pipe( - prompt, - guidance_scale=7.5, - num_inference_steps=50, - weights=weights, - generator=generator) - image = res.images[0] - images.append(image) - -for i, img in enumerate(images): - img.save(f"./composable_diffusion/image_{i}.png") ``` ### Imagic Stable Diffusion diff --git a/examples/community/composable_stable_diffusion.py b/examples/community/composable_stable_diffusion.py index eb207e1bdd..8bf2953967 100644 --- a/examples/community/composable_stable_diffusion.py +++ b/examples/community/composable_stable_diffusion.py @@ -1,25 +1,52 @@ -""" - modified based on diffusion library from Huggingface: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py -""" +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import inspect -import warnings -from typing import List, Optional, Union +from typing import Callable, List, Optional, Union import torch -from diffusers.models import AutoencoderKL, UNet2DConditionModel -from diffusers.pipeline_utils import DiffusionPipeline -from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput -from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker -from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from diffusers.utils import is_accelerate_available +from packaging import version from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from ...configuration_utils import FrozenDict +from ...models import AutoencoderKL, UNet2DConditionModel +from ...pipeline_utils import DiffusionPipeline +from ...schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from ...utils import deprecate, logging +from . import StableDiffusionPipelineOutput +from .safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + class ComposableStableDiffusionPipeline(DiffusionPipeline): r""" Pipeline for text-to-image generation using Stable Diffusion. + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. @@ -35,11 +62,12 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offsensive or harmful. - Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPFeatureExtractor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ + _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, @@ -47,11 +75,84 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline): text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + ], safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, ): super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + self.register_modules( vae=vae, text_encoder=text_encoder, @@ -61,56 +162,265 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline): safety_checker=safety_checker, feature_extractor=feature_extractor, ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.register_to_config(requires_safety_checker=requires_safety_checker) - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + def enable_vae_slicing(self): r""" - Enable sliced attention computation. - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + # TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate + # fix by only offloading self.safety_checker for now + cpu_offload(self.safety_checker.vision_model, device) + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + r""" + Encodes the prompt into text encoder hidden states. + Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, - `attention_head_dim` must be a multiple of `slice_size`. + prompt (`str` or `list(int)`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). """ - if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 - self.unet.set_attention_slice(slice_size) + batch_size = len(prompt) if isinstance(prompt, list) else 1 - def disable_attention_slicing(self): - r""" - Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go - back to computing attention in one step. - """ - # set slice_size = `None` to disable `attention slicing` - self.enable_attention_slicing(None) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + text_embeddings = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + text_embeddings = text_embeddings[0] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + uncond_embeddings = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + uncond_embeddings = uncond_embeddings[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + return text_embeddings + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + else: + has_nsfw_concept = None + return image, has_nsfw_concept + + def decode_latents(self, latents): + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs(self, prompt, height, width, callback_steps): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if latents is None: + if device.type == "mps": + # randn does not work reproducibly on mps + latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + else: + latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], - height: Optional[int] = 512, - width: Optional[int] = 512, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 7.5, - eta: Optional[float] = 0.0, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, generator: Optional[torch.Generator] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, weights: Optional[str] = "", - **kwargs, ): r""" Function invoked when calling the pipeline for generation. + Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. - height (`int`, *optional*, defaults to 512): + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. - width (`int`, *optional*, defaults to 512): + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -121,6 +431,11 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline): Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. @@ -137,6 +452,13 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline): return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. @@ -144,186 +466,108 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline): list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor - if "torch_device" in kwargs: - device = kwargs.pop("torch_device") - warnings.warn( - "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0." - " Consider using `pipe.to(torch_device)` instead." - ) - - # Set device as before (to be removed in 0.3.0) - if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" - self.to(device) - - if isinstance(prompt, str): - batch_size = 1 - elif isinstance(prompt, list): - batch_size = len(prompt) - else: - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if "|" in prompt: - prompt = [x.strip() for x in prompt.split("|")] - print(f"composing {prompt}...") - - # get prompt text embeddings - text_input = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] - - if not weights: - # specify weights for prompts (excluding the unconditional score) - print("using equal weights for all prompts...") - pos_weights = torch.tensor( - [1 / (text_embeddings.shape[0] - 1)] * (text_embeddings.shape[0] - 1), device=self.device - ).reshape(-1, 1, 1, 1) - neg_weights = torch.tensor([1.0], device=self.device).reshape(-1, 1, 1, 1) - mask = torch.tensor([False] + [True] * pos_weights.shape[0], dtype=torch.bool) - else: - # set prompt weight for each - num_prompts = len(prompt) if isinstance(prompt, list) else 1 - weights = [float(w.strip()) for w in weights.split("|")] - if len(weights) < num_prompts: - weights.append(1.0) - weights = torch.tensor(weights, device=self.device) - assert len(weights) == text_embeddings.shape[0], "weights specified are not equal to the number of prompts" - pos_weights = [] - neg_weights = [] - mask = [] # first one is unconditional score - for w in weights: - if w > 0: - pos_weights.append(w) - mask.append(True) - else: - neg_weights.append(abs(w)) - mask.append(False) - # normalize the weights - pos_weights = torch.tensor(pos_weights, device=self.device).reshape(-1, 1, 1, 1) - pos_weights = pos_weights / pos_weights.sum() - neg_weights = torch.tensor(neg_weights, device=self.device).reshape(-1, 1, 1, 1) - neg_weights = neg_weights / neg_weights.sum() - mask = torch.tensor(mask, device=self.device, dtype=torch.bool) + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, callback_steps) + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: - max_length = text_input.input_ids.shape[-1] - if torch.all(mask): - # no negative prompts, so we use empty string as the negative prompt - uncond_input = self.tokenizer( - [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" - ) - uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + if '|' in prompt: + prompt = [x.strip() for x in prompt.split('|')] + print(f"composing {prompt}...") - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) - - # update negative weights - neg_weights = torch.tensor([1.0], device=self.device) - mask = torch.tensor([False] + mask.detach().tolist(), device=self.device, dtype=torch.bool) - - # get the initial random noise unless the user supplied it - - # Unlike in other pipelines, latents need to be generated in the target device - # for 1-to-1 results reproducibility with the CompVis implementation. - # However this currently doesn't work in `mps`. - latents_device = "cpu" if self.device.type == "mps" else self.device - latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) - if latents is None: - latents = torch.randn( - latents_shape, - generator=generator, - device=latents_device, - ) - else: - if latents.shape != latents_shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - latents = latents.to(self.device) - - # set timesteps - accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) - extra_set_kwargs = {} - if accepts_offset: - extra_set_kwargs["offset"] = 1 - - self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) - - # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas - if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = latents * self.scheduler.sigmas[0] - - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): - # expand the latents if we are doing classifier free guidance - latent_model_input = ( - torch.cat([latents] * text_embeddings.shape[0]) if do_classifier_free_guidance else latents - ) - if isinstance(self.scheduler, LMSDiscreteScheduler): - sigma = self.scheduler.sigmas[i] - # the model input needs to be scaled to match the continuous ODE formulation in K-LMS - latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) - - # reduce memory by predicting each score sequentially - noise_preds = [] - # predict the noise residual - for latent_in, text_embedding_in in zip( - torch.chunk(latent_model_input, chunks=latent_model_input.shape[0], dim=0), - torch.chunk(text_embeddings, chunks=text_embeddings.shape[0], dim=0), - ): - noise_preds.append(self.unet(latent_in, t, encoder_hidden_states=text_embedding_in).sample) - noise_preds = torch.cat(noise_preds, dim=0) - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond = (noise_preds[~mask] * neg_weights).sum(dim=0, keepdims=True) - noise_pred_text = (noise_preds[mask] * pos_weights).sum(dim=0, keepdims=True) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample + if not weights: + # specify weights for prompts (excluding the unconditional score) + print('using equal positive weights (conjunction) for all prompts...') + weights = torch.tensor([guidance_scale] * len(prompt), device=self.device).reshape(-1, 1, 1, 1) else: + # set prompt weight for each + num_prompts = len(prompt) if isinstance(prompt, list) else 1 + weights = [float(w.strip()) for w in weights.split("|")] + # guidance scale as the default + if len(weights) < num_prompts: + weights.append(guidance_scale) + else: + weights = weights[:num_prompts] + assert len(weights) == len(prompt), "weights specified are not equal to the number of prompts" + weights = torch.tensor(weights, device=self.device).reshape(-1, 1, 1, 1) + else: + weights = guidance_scale + + # 3. Encode input prompt + text_embeddings = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + text_embeddings.dtype, + device, + generator, + latents, + ) + + # copy the latents for each prompt we need to compose + if isinstance(prompt, list) and batch_size == 1: + latents = latents.repeat(len(prompt), 1, 1, 1) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + (weights * (noise_pred_text - noise_pred_uncond)).sum(dim=0, keepdims=True) + + # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - # scale and decode the image latents with vae - latents = 1 / 0.18215 * latents - image = self.vae.decode(latents).sample + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) - image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).numpy() + # 8. Post-processing + image = self.decode_latents(latents) - # run safety checker - safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) - image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) + # 9. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype) + # 10. Convert to PIL if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image, has_nsfw_concept) - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) \ No newline at end of file