diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index d8e274cb93..d3a61df3ba 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -132,6 +132,7 @@ except OptionalDependencyNotAvailable: else: _import_structure["guiders"].extend( [ + "AdaptiveProjectedGuidance", "ClassifierFreeGuidance", "SkipLayerGuidance", ] @@ -721,6 +722,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: from .utils.dummy_pt_objects import * # noqa F403 else: from .guiders import ( + AdaptiveProjectedGuidance, ClassifierFreeGuidance, SkipLayerGuidance, ) diff --git a/src/diffusers/guiders/__init__.py b/src/diffusers/guiders/__init__.py index adef65277b..e3c6494de0 100644 --- a/src/diffusers/guiders/__init__.py +++ b/src/diffusers/guiders/__init__.py @@ -18,6 +18,7 @@ from ..utils import is_torch_available if is_torch_available(): + from .adaptive_projected_guidance import AdaptiveProjectedGuidance from .classifier_free_guidance import ClassifierFreeGuidance from .skip_layer_guidance import SkipLayerGuidance diff --git a/src/diffusers/guiders/adaptive_projected_guidance.py b/src/diffusers/guiders/adaptive_projected_guidance.py new file mode 100644 index 0000000000..ab5175745b --- /dev/null +++ b/src/diffusers/guiders/adaptive_projected_guidance.py @@ -0,0 +1,174 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional, Union, Tuple, List + +import torch + +from .guider_utils import BaseGuidance, rescale_noise_cfg, _default_prepare_inputs + + +class AdaptiveProjectedGuidance(BaseGuidance): + """ + Adaptive Projected Guidance (APG): https://huggingface.co/papers/2410.02416 + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + adaptive_projected_guidance_momentum (`float`, defaults to `None`): + The momentum parameter for the adaptive projected guidance. Disabled if set to `None`. + adaptive_projected_guidance_rescale (`float`, defaults to `15.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.0`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `1.0`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond"] + + def __init__( + self, + guidance_scale: float = 7.5, + adaptive_projected_guidance_momentum: Optional[float] = None, + adaptive_projected_guidance_rescale: float = 15.0, + eta: float = 1.0, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum + self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale + self.eta = eta + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + self.momentum_buffer = None + + def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: + if self._step == 0: + if self.adaptive_projected_guidance_momentum is not None: + self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum) + return _default_prepare_inputs(denoiser, self.num_conditions, *args) + + def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None: + self._num_outputs_prepared += 1 + if self._num_outputs_prepared > self.num_conditions: + raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.") + key = self._input_predictions[self._num_outputs_prepared - 1] + self._preds[key] = pred + + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + pred = None + + if not self._is_cfg_enabled(): + pred = pred_cond + else: + pred = normalized_guidance( + pred_cond, + pred_uncond, + self.guidance_scale, + self.momentum_buffer, + self.eta, + self.adaptive_projected_guidance_rescale, + self.use_original_formulation, + ) + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred + + @property + def is_conditional(self) -> bool: + return self._num_outputs_prepared == 0 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_cfg_enabled(): + num_conditions += 1 + return num_conditions + + def _is_cfg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + + return is_within_range and not is_close + + +class MomentumBuffer: + def __init__(self, momentum: float): + self.momentum = momentum + self.running_average = 0 + + def update(self, update_value: torch.Tensor): + new_average = self.momentum * self.running_average + self.running_average = update_value + new_average + + +def normalized_guidance( + pred_cond: torch.Tensor, + pred_uncond: torch.Tensor, + guidance_scale: float, + momentum_buffer: Optional[MomentumBuffer] = None, + eta: float = 1.0, + norm_threshold: float = 0.0, + use_original_formulation: bool = False, +): + diff = pred_cond - pred_uncond + dim = [-i for i in range(1, len(diff.shape))] + if momentum_buffer is not None: + momentum_buffer.update(diff) + diff = momentum_buffer.running_average + if norm_threshold > 0: + ones = torch.ones_like(diff) + diff_norm = diff.norm(p=2, dim=dim, keepdim=True) + scale_factor = torch.minimum(ones, norm_threshold / diff_norm) + diff = diff * scale_factor + v0, v1 = diff.double(), pred_cond.double() + v1 = torch.nn.functional.normalize(v1, dim=dim) + v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1 + v0_orthogonal = v0 - v0_parallel + diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff) + normalized_update = diff_orthogonal + eta * diff_parallel + pred = pred_cond if use_original_formulation else pred_uncond + pred = pred + (guidance_scale - 1) * normalized_update + return pred diff --git a/src/diffusers/guiders/classifier_free_guidance.py b/src/diffusers/guiders/classifier_free_guidance.py index b3508307d4..3deacdfb28 100644 --- a/src/diffusers/guiders/classifier_free_guidance.py +++ b/src/diffusers/guiders/classifier_free_guidance.py @@ -106,12 +106,17 @@ class ClassifierFreeGuidance(BaseGuidance): def _is_cfg_enabled(self) -> bool: if not self._enabled: return False - skip_start_step = int(self._start * self._num_inference_steps) - skip_stop_step = int(self._stop * self._num_inference_steps) - is_within_range = skip_start_step <= self._step < skip_stop_step + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + is_close = False if self.use_original_formulation: is_close = math.isclose(self.guidance_scale, 0.0) else: is_close = math.isclose(self.guidance_scale, 1.0) + return is_within_range and not is_close diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py index 92ae7f8518..4abb93272e 100644 --- a/src/diffusers/guiders/skip_layer_guidance.py +++ b/src/diffusers/guiders/skip_layer_guidance.py @@ -206,21 +206,31 @@ class SkipLayerGuidance(BaseGuidance): def _is_cfg_enabled(self) -> bool: if not self._enabled: return False - skip_start_step = int(self._start * self._num_inference_steps) - skip_stop_step = int(self._stop * self._num_inference_steps) - is_within_range = skip_start_step <= self._step < skip_stop_step + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + is_close = False if self.use_original_formulation: is_close = math.isclose(self.guidance_scale, 0.0) else: is_close = math.isclose(self.guidance_scale, 1.0) + return is_within_range and not is_close def _is_slg_enabled(self) -> bool: if not self._enabled: return False - skip_start_step = int(self._start * self._num_inference_steps) - skip_stop_step = int(self._stop * self._num_inference_steps) - is_within_range = skip_start_step < self._step < skip_stop_step + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step < self._step < skip_stop_step + is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0) + return is_within_range and not is_zero diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 5df57c6c16..24d7e333c4 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -194,11 +194,7 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock, ModularIPAdapterMixin): PipelineImageInput, required=True, description="The image(s) to be used as ip adapter" - ), - InputParam( - "guidance_scale", - default=5.0, - ), + ) ] @@ -236,11 +232,10 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock, ModularIPAdapterMixin): # modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( - self, components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + self, components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt ): image_embeds = [] - if do_classifier_free_guidance: - negative_image_embeds = [] + negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] @@ -259,21 +254,18 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock, ModularIPAdapterMixin): ) image_embeds.append(single_image_embeds[None, :]) - if do_classifier_free_guidance: - negative_image_embeds.append(single_negative_image_embeds[None, :]) + negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: - if do_classifier_free_guidance: - single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - negative_image_embeds.append(single_negative_image_embeds) + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) - if do_classifier_free_guidance: - single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) @@ -323,6 +315,7 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock): ComponentSpec("text_encoder_2", CLIPTextModelWithProjection), ComponentSpec("tokenizer", CLIPTokenizer), ComponentSpec("tokenizer_2", CLIPTokenizer), + ComponentSpec("guider", GuiderType, obj=ClassifierFreeGuidance()), ] @property @@ -337,7 +330,6 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock): InputParam("negative_prompt"), InputParam("negative_prompt_2"), InputParam("cross_attention_kwargs"), - InputParam("guidance_scale",default=5.0), InputParam("clip_skip"), ] @@ -601,10 +593,9 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock): data = self.get_block_state(state) self.check_inputs(pipeline, data) - data.do_classifier_free_guidance = data.guidance_scale > 1.0 + data.do_classifier_free_guidance = pipeline.guider.num_conditions > 1 data.device = pipeline._execution_device - # Encode input prompt data.text_encoder_lora_scale = ( data.cross_attention_kwargs.get("scale", None) if data.cross_attention_kwargs is not None else None @@ -1750,7 +1741,6 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): InputParam("crops_coords_top_left", default=(0, 0)), InputParam("negative_crops_coords_top_left", default=(0, 0)), InputParam("num_images_per_prompt", default=1), - InputParam("guidance_scale", required=True), InputParam("aesthetic_score", default=6.0), InputParam("negative_aesthetic_score", default=2.0), ] @@ -1897,7 +1887,8 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): and pipeline.unet is not None and pipeline.unet.config.time_cond_proj_dim is not None ): - data.guidance_scale_tensor = torch.tensor(data.guidance_scale - 1).repeat(data.batch_size * data.num_images_per_prompt) + # TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this! + data.guidance_scale_tensor = torch.tensor(pipeline.guider.guidance_scale - 1).repeat(data.batch_size * data.num_images_per_prompt) data.timestep_cond = self.get_guidance_scale_embedding( data.guidance_scale_tensor, embedding_dim=pipeline.unet.config.time_cond_proj_dim ).to(device=data.device, dtype=data.latents.dtype) @@ -1925,7 +1916,6 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): InputParam("crops_coords_top_left", default=(0, 0)), InputParam("negative_crops_coords_top_left", default=(0, 0)), InputParam("num_images_per_prompt", default=1), - InputParam("guidance_scale", default=5.0), ] @property @@ -2051,7 +2041,8 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): and pipeline.unet is not None and pipeline.unet.config.time_cond_proj_dim is not None ): - data.guidance_scale_tensor = torch.tensor(data.guidance_scale - 1).repeat(data.batch_size * data.num_images_per_prompt) + # TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this! + data.guidance_scale_tensor = torch.tensor(pipeline.guider.guidance_scale - 1).repeat(data.batch_size * data.num_images_per_prompt) data.timestep_cond = self.get_guidance_scale_embedding( data.guidance_scale_tensor, embedding_dim=pipeline.unet.config.time_cond_proj_dim ).to(device=data.device, dtype=data.latents.dtype) @@ -2234,6 +2225,10 @@ class StableDiffusionXLDenoiseStep(PipelineBlock): data.num_channels_unet = pipeline.unet.config.in_channels data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False + if data.disable_guidance: + pipeline.guider.force_disable() + else: + pipeline.guider.force_enable() # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta) @@ -2354,7 +2349,6 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): InputParam("cross_attention_kwargs"), InputParam("generator"), InputParam("eta", default=0.0), - InputParam("guider_kwargs"), ] @property @@ -2627,9 +2621,8 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False if data.disable_guidance: pipeline.guider.force_disable() - - # (3) Prepare conditional inputs for controlnet using the guider - data.controlnet_disable_guidance = True if data.disable_guidance or data.guess_mode else False + else: + pipeline.guider.force_enable() # (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta) @@ -2764,7 +2757,6 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock): ComponentSpec("controlnet", ControlNetUnionModel), ComponentSpec("scheduler", KarrasDiffusionSchedulers), ComponentSpec("guider", GuiderType, obj=ClassifierFreeGuidance()), - ComponentSpec("controlnet_guider", GuiderType, obj=ClassifierFreeGuidance()), ComponentSpec("control_image_processor", VaeImageProcessor, obj=VaeImageProcessor(do_convert_rgb=True, do_normalize=False)), ] @@ -2781,12 +2773,9 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock): InputParam("controlnet_conditioning_scale", default=1.0), InputParam("guess_mode", default=False), InputParam("num_images_per_prompt", default=1), - InputParam("guidance_scale", default=5.0), - InputParam("guidance_rescale", default=0.0), InputParam("cross_attention_kwargs"), InputParam("generator"), InputParam("eta", default=0.0), - InputParam("guider_kwargs") ] @property @@ -3029,7 +3018,6 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock): crops_coords=data.crops_coords, ) data.height, data.width = data.control_image[idx].shape[-2:] - # (1.6) # controlnet_keep @@ -3043,80 +3031,48 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock): # (2) Prepare conditional inputs for unet using the guider # adding default guider arguments: disable_guidance, guidance_scale, guidance_rescale data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False - data.guider_kwargs = data.guider_kwargs or {} - data.guider_kwargs = { - **data.guider_kwargs, - "disable_guidance": data.disable_guidance, - "guidance_scale": data.guidance_scale, - "guidance_rescale": data.guidance_rescale, - "batch_size": data.batch_size * data.num_images_per_prompt, - } - pipeline.guider.set_guider(pipeline, data.guider_kwargs) - data.prompt_embeds = pipeline.guider.prepare_input( - data.prompt_embeds, - data.negative_prompt_embeds, - ) - data.add_time_ids = pipeline.guider.prepare_input( - data.add_time_ids, - data.negative_add_time_ids, - ) - data.pooled_prompt_embeds = pipeline.guider.prepare_input( - data.pooled_prompt_embeds, - data.negative_pooled_prompt_embeds, - ) - - if data.num_channels_unet == 9: - data.mask = pipeline.guider.prepare_input(data.mask, data.mask) - data.masked_image_latents = pipeline.guider.prepare_input(data.masked_image_latents, data.masked_image_latents) - - data.added_cond_kwargs = { - "text_embeds": data.pooled_prompt_embeds, - "time_ids": data.add_time_ids, - } - - if data.ip_adapter_embeds is not None: - data.ip_adapter_embeds = pipeline.guider.prepare_input(data.ip_adapter_embeds, data.negative_ip_adapter_embeds) - data.added_cond_kwargs["image_embeds"] = data.ip_adapter_embeds + if data.disable_guidance: + pipeline.guider.force_disable() + else: + pipeline.guider.force_enable() # (3) Prepare conditional inputs for controlnet using the guider - data.controlnet_disable_guidance = True if data.disable_guidance or data.guess_mode else False - data.controlnet_guider_kwargs = data.guider_kwargs or {} - data.controlnet_guider_kwargs = { - **data.controlnet_guider_kwargs, - "disable_guidance": data.controlnet_disable_guidance, - "guidance_scale": data.guidance_scale, - "guidance_rescale": data.guidance_rescale, - "batch_size": data.batch_size * data.num_images_per_prompt, - } - pipeline.controlnet_guider.set_guider(pipeline, data.controlnet_guider_kwargs) - data.controlnet_prompt_embeds = pipeline.controlnet_guider.prepare_input(data.prompt_embeds) - data.controlnet_added_cond_kwargs = { - "text_embeds": pipeline.controlnet_guider.prepare_input(data.pooled_prompt_embeds), - "time_ids": pipeline.controlnet_guider.prepare_input(data.add_time_ids), - } - for idx, _ in enumerate(data.control_image): - data.control_image[idx] = pipeline.controlnet_guider.prepare_input(data.control_image[idx], data.control_image[idx]) + # data.controlnet_prompt_embeds = pipeline.controlnet_guider.prepare_input(data.prompt_embeds) + # data.controlnet_added_cond_kwargs = { + # "text_embeds": pipeline.controlnet_guider.prepare_input(data.pooled_prompt_embeds), + # "time_ids": pipeline.controlnet_guider.prepare_input(data.add_time_ids), + # } - data.control_type = ( - data.control_type.reshape(1, -1) - .to(data.device, dtype=data.prompt_embeds.dtype) - ) + data.control_type = data.control_type.reshape(1, -1).to(data.device, dtype=data.prompt_embeds.dtype) repeat_by = data.batch_size * data.num_images_per_prompt // data.control_type.shape[0] data.control_type = data.control_type.repeat_interleave(repeat_by, dim=0) - data.control_type = pipeline.controlnet_guider.prepare_input(data.control_type, data.control_type) # (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta) data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0) - with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: for i, t in enumerate(data.timesteps): - # prepare latents for unet using the guider - data.latent_model_input = pipeline.guider.prepare_input(data.latents, data.latents) + pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) - # prepare latents for controlnet using the guider - data.control_model_input = pipeline.controlnet_guider.prepare_input(data.latents, data.latents) + ( + latents, + prompt_embeds, + add_time_ids, + pooled_prompt_embeds, + mask, + masked_image_latents, + ip_adapter_embeds, + ) = pipeline.guider.prepare_inputs( + pipeline.unet, + data.latents, + (data.prompt_embeds, data.negative_prompt_embeds), + (data.add_time_ids, data.negative_add_time_ids), + (data.pooled_prompt_embeds, data.negative_pooled_prompt_embeds), + data.mask, + data.masked_image_latents, + (data.ip_adapter_embeds, data.negative_ip_adapter_embeds), + ) if isinstance(data.controlnet_keep[i], list): data.cond_scale = [c * s for c, s in zip(data.controlnet_conditioning_scale, data.controlnet_keep[i])] @@ -3126,48 +3082,72 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock): data.controlnet_cond_scale = data.controlnet_cond_scale[0] data.cond_scale = data.controlnet_cond_scale * data.controlnet_keep[i] - data.down_block_res_samples, data.mid_block_res_sample = pipeline.controlnet( - pipeline.scheduler.scale_model_input(data.control_model_input, t), - t, - encoder_hidden_states=data.controlnet_prompt_embeds, - controlnet_cond=data.control_image, - control_type=data.control_type, - control_type_idx=data.control_mode, - conditioning_scale=data.cond_scale, - guess_mode=data.guess_mode, - added_cond_kwargs=data.controlnet_added_cond_kwargs, - return_dict=False, - ) + for batch_index, ( + latents_i, prompt_embeds_i, add_time_ids_i, pooled_prompt_embeds_i, mask_i, masked_image_latents_i, ip_adapter_embeds_i + ) in enumerate(zip( + latents, prompt_embeds, add_time_ids, pooled_prompt_embeds, mask, masked_image_latents, ip_adapter_embeds + )): + latents_i = pipeline.scheduler.scale_model_input(latents_i, t) + + # Prepare for inpainting + if data.num_channels_unet == 9: + latents_i = torch.cat([latents_i, mask_i, masked_image_latents_i], dim=1) + + # Prepare additional conditionings + data.added_cond_kwargs = { + "text_embeds": pooled_prompt_embeds_i, + "time_ids": add_time_ids_i, + } + if ip_adapter_embeds_i is not None: + data.added_cond_kwargs["image_embeds"] = ip_adapter_embeds_i + + # Prepare controlnet additional conditionings + data.controlnet_added_cond_kwargs = { + "text_embeds": pooled_prompt_embeds_i, + "time_ids": add_time_ids_i, + } + + if pipeline.guider.is_conditional or not data.guess_mode: + data.down_block_res_samples, data.mid_block_res_sample = pipeline.controlnet( + latents_i, + t, + encoder_hidden_states=prompt_embeds_i, + controlnet_cond=data.control_image, + control_type=data.control_type, + control_type_idx=data.control_mode, + conditioning_scale=data.cond_scale, + guess_mode=data.guess_mode, + added_cond_kwargs=data.controlnet_added_cond_kwargs, + return_dict=False, + ) + elif pipeline.guider.is_unconditional and data.guess_mode: + data.down_block_res_samples = [torch.zeros_like(d) for d in data.down_block_res_samples] + data.mid_block_res_sample = torch.zeros_like(data.mid_block_res_sample) - # when we apply guidance for unet, but not for controlnet: - # add 0 to the unconditional batch - data.down_block_res_samples = pipeline.guider.prepare_input( - data.down_block_res_samples, [torch.zeros_like(d) for d in data.down_block_res_samples] - ) - data.mid_block_res_sample = pipeline.guider.prepare_input( - data.mid_block_res_sample, torch.zeros_like(data.mid_block_res_sample) - ) + if data.num_channels_unet == 9: + latents_i = torch.cat([latents_i, mask_i, masked_image_latents_i], dim=1) - data.latent_model_input = pipeline.scheduler.scale_model_input(data.latent_model_input, t) - if data.num_channels_unet == 9: - data.latent_model_input = torch.cat([data.latent_model_input, data.mask, data.masked_image_latents], dim=1) + data.noise_pred = pipeline.unet( + latents_i, + t, + encoder_hidden_states=prompt_embeds_i, + timestep_cond=data.timestep_cond, + cross_attention_kwargs=data.cross_attention_kwargs, + added_cond_kwargs=data.added_cond_kwargs, + down_block_additional_residuals=data.down_block_res_samples, + mid_block_additional_residual=data.mid_block_res_sample, + return_dict=False, + )[0] + data.noise_pred = pipeline.guider.prepare_outputs(pipeline.unet, data.noise_pred) + + # Perform guidance + outputs = pipeline.guider.outputs + data.noise_pred = pipeline.guider(**outputs) - data.noise_pred = pipeline.unet( - data.latent_model_input, - t, - encoder_hidden_states=data.prompt_embeds, - timestep_cond=data.timestep_cond, - cross_attention_kwargs=data.cross_attention_kwargs, - added_cond_kwargs=data.added_cond_kwargs, - down_block_additional_residuals=data.down_block_res_samples, - mid_block_additional_residual=data.mid_block_res_sample, - return_dict=False, - )[0] - # perform guidance - data.noise_pred = pipeline.guider.apply_guidance(data.noise_pred, timestep=t, latents=data.latents) - # compute the previous noisy sample x_t -> x_t-1 + # Perform scheduler step using the predicted output data.latents_dtype = data.latents.dtype data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0] + if data.latents.dtype != data.latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 @@ -3180,14 +3160,10 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock): data.init_latents_proper = pipeline.scheduler.add_noise( data.init_latents_proper, data.noise, torch.tensor([data.noise_timestep]) ) - data.latents = (1 - data.mask) * data.init_latents_proper + data.mask * data.latents if i == len(data.timesteps) - 1 or ((i + 1) > data.num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): progress_bar.update() - - pipeline.guider.reset_guider(pipeline) - pipeline.controlnet_guider.reset_guider(pipeline) self.add_block_state(state, data)