1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

support controlnet union

This commit is contained in:
Aryan
2025-04-14 15:39:16 +02:00
parent 9da8a9d1d5
commit b81bd78bf9
6 changed files with 311 additions and 143 deletions

View File

@@ -132,6 +132,7 @@ except OptionalDependencyNotAvailable:
else:
_import_structure["guiders"].extend(
[
"AdaptiveProjectedGuidance",
"ClassifierFreeGuidance",
"SkipLayerGuidance",
]
@@ -721,6 +722,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .utils.dummy_pt_objects import * # noqa F403
else:
from .guiders import (
AdaptiveProjectedGuidance,
ClassifierFreeGuidance,
SkipLayerGuidance,
)

View File

@@ -18,6 +18,7 @@ from ..utils import is_torch_available
if is_torch_available():
from .adaptive_projected_guidance import AdaptiveProjectedGuidance
from .classifier_free_guidance import ClassifierFreeGuidance
from .skip_layer_guidance import SkipLayerGuidance

View File

@@ -0,0 +1,174 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, Union, Tuple, List
import torch
from .guider_utils import BaseGuidance, rescale_noise_cfg, _default_prepare_inputs
class AdaptiveProjectedGuidance(BaseGuidance):
"""
Adaptive Projected Guidance (APG): https://huggingface.co/papers/2410.02416
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
adaptive_projected_guidance_momentum (`float`, defaults to `None`):
The momentum parameter for the adaptive projected guidance. Disabled if set to `None`.
adaptive_projected_guidance_rescale (`float`, defaults to `15.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float`, defaults to `0.0`):
The fraction of the total number of denoising steps after which guidance starts.
stop (`float`, defaults to `1.0`):
The fraction of the total number of denoising steps after which guidance stops.
"""
_input_predictions = ["pred_cond", "pred_uncond"]
def __init__(
self,
guidance_scale: float = 7.5,
adaptive_projected_guidance_momentum: Optional[float] = None,
adaptive_projected_guidance_rescale: float = 15.0,
eta: float = 1.0,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
):
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
self.eta = eta
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
self.momentum_buffer = None
def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
if self._step == 0:
if self.adaptive_projected_guidance_momentum is not None:
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
return _default_prepare_inputs(denoiser, self.num_conditions, *args)
def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None:
self._num_outputs_prepared += 1
if self._num_outputs_prepared > self.num_conditions:
raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.")
key = self._input_predictions[self._num_outputs_prepared - 1]
self._preds[key] = pred
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
pred = None
if not self._is_cfg_enabled():
pred = pred_cond
else:
pred = normalized_guidance(
pred_cond,
pred_uncond,
self.guidance_scale,
self.momentum_buffer,
self.eta,
self.adaptive_projected_guidance_rescale,
self.use_original_formulation,
)
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred
@property
def is_conditional(self) -> bool:
return self._num_outputs_prepared == 0
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_cfg_enabled():
num_conditions += 1
return num_conditions
def _is_cfg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close
class MomentumBuffer:
def __init__(self, momentum: float):
self.momentum = momentum
self.running_average = 0
def update(self, update_value: torch.Tensor):
new_average = self.momentum * self.running_average
self.running_average = update_value + new_average
def normalized_guidance(
pred_cond: torch.Tensor,
pred_uncond: torch.Tensor,
guidance_scale: float,
momentum_buffer: Optional[MomentumBuffer] = None,
eta: float = 1.0,
norm_threshold: float = 0.0,
use_original_formulation: bool = False,
):
diff = pred_cond - pred_uncond
dim = [-i for i in range(1, len(diff.shape))]
if momentum_buffer is not None:
momentum_buffer.update(diff)
diff = momentum_buffer.running_average
if norm_threshold > 0:
ones = torch.ones_like(diff)
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
diff = diff * scale_factor
v0, v1 = diff.double(), pred_cond.double()
v1 = torch.nn.functional.normalize(v1, dim=dim)
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
v0_orthogonal = v0 - v0_parallel
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
normalized_update = diff_orthogonal + eta * diff_parallel
pred = pred_cond if use_original_formulation else pred_uncond
pred = pred + (guidance_scale - 1) * normalized_update
return pred

View File

@@ -106,12 +106,17 @@ class ClassifierFreeGuidance(BaseGuidance):
def _is_cfg_enabled(self) -> bool:
if not self._enabled:
return False
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close

View File

@@ -206,21 +206,31 @@ class SkipLayerGuidance(BaseGuidance):
def _is_cfg_enabled(self) -> bool:
if not self._enabled:
return False
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close
def _is_slg_enabled(self) -> bool:
if not self._enabled:
return False
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step < self._step < skip_stop_step
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step < self._step < skip_stop_step
is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0)
return is_within_range and not is_zero

View File

@@ -194,11 +194,7 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock, ModularIPAdapterMixin):
PipelineImageInput,
required=True,
description="The image(s) to be used as ip adapter"
),
InputParam(
"guidance_scale",
default=5.0,
),
)
]
@@ -236,11 +232,10 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock, ModularIPAdapterMixin):
# modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(
self, components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
self, components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
):
image_embeds = []
if do_classifier_free_guidance:
negative_image_embeds = []
negative_image_embeds = []
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
@@ -259,21 +254,18 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock, ModularIPAdapterMixin):
)
image_embeds.append(single_image_embeds[None, :])
if do_classifier_free_guidance:
negative_image_embeds.append(single_negative_image_embeds[None, :])
negative_image_embeds.append(single_negative_image_embeds[None, :])
else:
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
negative_image_embeds.append(single_negative_image_embeds)
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
negative_image_embeds.append(single_negative_image_embeds)
image_embeds.append(single_image_embeds)
ip_adapter_image_embeds = []
for i, single_image_embeds in enumerate(image_embeds):
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
if do_classifier_free_guidance:
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
single_image_embeds = single_image_embeds.to(device=device)
ip_adapter_image_embeds.append(single_image_embeds)
@@ -323,6 +315,7 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
ComponentSpec("text_encoder_2", CLIPTextModelWithProjection),
ComponentSpec("tokenizer", CLIPTokenizer),
ComponentSpec("tokenizer_2", CLIPTokenizer),
ComponentSpec("guider", GuiderType, obj=ClassifierFreeGuidance()),
]
@property
@@ -337,7 +330,6 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
InputParam("negative_prompt"),
InputParam("negative_prompt_2"),
InputParam("cross_attention_kwargs"),
InputParam("guidance_scale",default=5.0),
InputParam("clip_skip"),
]
@@ -601,10 +593,9 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
data = self.get_block_state(state)
self.check_inputs(pipeline, data)
data.do_classifier_free_guidance = data.guidance_scale > 1.0
data.do_classifier_free_guidance = pipeline.guider.num_conditions > 1
data.device = pipeline._execution_device
# Encode input prompt
data.text_encoder_lora_scale = (
data.cross_attention_kwargs.get("scale", None) if data.cross_attention_kwargs is not None else None
@@ -1750,7 +1741,6 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
InputParam("crops_coords_top_left", default=(0, 0)),
InputParam("negative_crops_coords_top_left", default=(0, 0)),
InputParam("num_images_per_prompt", default=1),
InputParam("guidance_scale", required=True),
InputParam("aesthetic_score", default=6.0),
InputParam("negative_aesthetic_score", default=2.0),
]
@@ -1897,7 +1887,8 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
and pipeline.unet is not None
and pipeline.unet.config.time_cond_proj_dim is not None
):
data.guidance_scale_tensor = torch.tensor(data.guidance_scale - 1).repeat(data.batch_size * data.num_images_per_prompt)
# TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this!
data.guidance_scale_tensor = torch.tensor(pipeline.guider.guidance_scale - 1).repeat(data.batch_size * data.num_images_per_prompt)
data.timestep_cond = self.get_guidance_scale_embedding(
data.guidance_scale_tensor, embedding_dim=pipeline.unet.config.time_cond_proj_dim
).to(device=data.device, dtype=data.latents.dtype)
@@ -1925,7 +1916,6 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
InputParam("crops_coords_top_left", default=(0, 0)),
InputParam("negative_crops_coords_top_left", default=(0, 0)),
InputParam("num_images_per_prompt", default=1),
InputParam("guidance_scale", default=5.0),
]
@property
@@ -2051,7 +2041,8 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
and pipeline.unet is not None
and pipeline.unet.config.time_cond_proj_dim is not None
):
data.guidance_scale_tensor = torch.tensor(data.guidance_scale - 1).repeat(data.batch_size * data.num_images_per_prompt)
# TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this!
data.guidance_scale_tensor = torch.tensor(pipeline.guider.guidance_scale - 1).repeat(data.batch_size * data.num_images_per_prompt)
data.timestep_cond = self.get_guidance_scale_embedding(
data.guidance_scale_tensor, embedding_dim=pipeline.unet.config.time_cond_proj_dim
).to(device=data.device, dtype=data.latents.dtype)
@@ -2234,6 +2225,10 @@ class StableDiffusionXLDenoiseStep(PipelineBlock):
data.num_channels_unet = pipeline.unet.config.in_channels
data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False
if data.disable_guidance:
pipeline.guider.force_disable()
else:
pipeline.guider.force_enable()
# Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta)
@@ -2354,7 +2349,6 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
InputParam("cross_attention_kwargs"),
InputParam("generator"),
InputParam("eta", default=0.0),
InputParam("guider_kwargs"),
]
@property
@@ -2627,9 +2621,8 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False
if data.disable_guidance:
pipeline.guider.force_disable()
# (3) Prepare conditional inputs for controlnet using the guider
data.controlnet_disable_guidance = True if data.disable_guidance or data.guess_mode else False
else:
pipeline.guider.force_enable()
# (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta)
@@ -2764,7 +2757,6 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock):
ComponentSpec("controlnet", ControlNetUnionModel),
ComponentSpec("scheduler", KarrasDiffusionSchedulers),
ComponentSpec("guider", GuiderType, obj=ClassifierFreeGuidance()),
ComponentSpec("controlnet_guider", GuiderType, obj=ClassifierFreeGuidance()),
ComponentSpec("control_image_processor", VaeImageProcessor, obj=VaeImageProcessor(do_convert_rgb=True, do_normalize=False)),
]
@@ -2781,12 +2773,9 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock):
InputParam("controlnet_conditioning_scale", default=1.0),
InputParam("guess_mode", default=False),
InputParam("num_images_per_prompt", default=1),
InputParam("guidance_scale", default=5.0),
InputParam("guidance_rescale", default=0.0),
InputParam("cross_attention_kwargs"),
InputParam("generator"),
InputParam("eta", default=0.0),
InputParam("guider_kwargs")
]
@property
@@ -3029,7 +3018,6 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock):
crops_coords=data.crops_coords,
)
data.height, data.width = data.control_image[idx].shape[-2:]
# (1.6)
# controlnet_keep
@@ -3043,80 +3031,48 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock):
# (2) Prepare conditional inputs for unet using the guider
# adding default guider arguments: disable_guidance, guidance_scale, guidance_rescale
data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False
data.guider_kwargs = data.guider_kwargs or {}
data.guider_kwargs = {
**data.guider_kwargs,
"disable_guidance": data.disable_guidance,
"guidance_scale": data.guidance_scale,
"guidance_rescale": data.guidance_rescale,
"batch_size": data.batch_size * data.num_images_per_prompt,
}
pipeline.guider.set_guider(pipeline, data.guider_kwargs)
data.prompt_embeds = pipeline.guider.prepare_input(
data.prompt_embeds,
data.negative_prompt_embeds,
)
data.add_time_ids = pipeline.guider.prepare_input(
data.add_time_ids,
data.negative_add_time_ids,
)
data.pooled_prompt_embeds = pipeline.guider.prepare_input(
data.pooled_prompt_embeds,
data.negative_pooled_prompt_embeds,
)
if data.num_channels_unet == 9:
data.mask = pipeline.guider.prepare_input(data.mask, data.mask)
data.masked_image_latents = pipeline.guider.prepare_input(data.masked_image_latents, data.masked_image_latents)
data.added_cond_kwargs = {
"text_embeds": data.pooled_prompt_embeds,
"time_ids": data.add_time_ids,
}
if data.ip_adapter_embeds is not None:
data.ip_adapter_embeds = pipeline.guider.prepare_input(data.ip_adapter_embeds, data.negative_ip_adapter_embeds)
data.added_cond_kwargs["image_embeds"] = data.ip_adapter_embeds
if data.disable_guidance:
pipeline.guider.force_disable()
else:
pipeline.guider.force_enable()
# (3) Prepare conditional inputs for controlnet using the guider
data.controlnet_disable_guidance = True if data.disable_guidance or data.guess_mode else False
data.controlnet_guider_kwargs = data.guider_kwargs or {}
data.controlnet_guider_kwargs = {
**data.controlnet_guider_kwargs,
"disable_guidance": data.controlnet_disable_guidance,
"guidance_scale": data.guidance_scale,
"guidance_rescale": data.guidance_rescale,
"batch_size": data.batch_size * data.num_images_per_prompt,
}
pipeline.controlnet_guider.set_guider(pipeline, data.controlnet_guider_kwargs)
data.controlnet_prompt_embeds = pipeline.controlnet_guider.prepare_input(data.prompt_embeds)
data.controlnet_added_cond_kwargs = {
"text_embeds": pipeline.controlnet_guider.prepare_input(data.pooled_prompt_embeds),
"time_ids": pipeline.controlnet_guider.prepare_input(data.add_time_ids),
}
for idx, _ in enumerate(data.control_image):
data.control_image[idx] = pipeline.controlnet_guider.prepare_input(data.control_image[idx], data.control_image[idx])
# data.controlnet_prompt_embeds = pipeline.controlnet_guider.prepare_input(data.prompt_embeds)
# data.controlnet_added_cond_kwargs = {
# "text_embeds": pipeline.controlnet_guider.prepare_input(data.pooled_prompt_embeds),
# "time_ids": pipeline.controlnet_guider.prepare_input(data.add_time_ids),
# }
data.control_type = (
data.control_type.reshape(1, -1)
.to(data.device, dtype=data.prompt_embeds.dtype)
)
data.control_type = data.control_type.reshape(1, -1).to(data.device, dtype=data.prompt_embeds.dtype)
repeat_by = data.batch_size * data.num_images_per_prompt // data.control_type.shape[0]
data.control_type = data.control_type.repeat_interleave(repeat_by, dim=0)
data.control_type = pipeline.controlnet_guider.prepare_input(data.control_type, data.control_type)
# (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta)
data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0)
with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar:
for i, t in enumerate(data.timesteps):
# prepare latents for unet using the guider
data.latent_model_input = pipeline.guider.prepare_input(data.latents, data.latents)
pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t)
# prepare latents for controlnet using the guider
data.control_model_input = pipeline.controlnet_guider.prepare_input(data.latents, data.latents)
(
latents,
prompt_embeds,
add_time_ids,
pooled_prompt_embeds,
mask,
masked_image_latents,
ip_adapter_embeds,
) = pipeline.guider.prepare_inputs(
pipeline.unet,
data.latents,
(data.prompt_embeds, data.negative_prompt_embeds),
(data.add_time_ids, data.negative_add_time_ids),
(data.pooled_prompt_embeds, data.negative_pooled_prompt_embeds),
data.mask,
data.masked_image_latents,
(data.ip_adapter_embeds, data.negative_ip_adapter_embeds),
)
if isinstance(data.controlnet_keep[i], list):
data.cond_scale = [c * s for c, s in zip(data.controlnet_conditioning_scale, data.controlnet_keep[i])]
@@ -3126,48 +3082,72 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock):
data.controlnet_cond_scale = data.controlnet_cond_scale[0]
data.cond_scale = data.controlnet_cond_scale * data.controlnet_keep[i]
data.down_block_res_samples, data.mid_block_res_sample = pipeline.controlnet(
pipeline.scheduler.scale_model_input(data.control_model_input, t),
t,
encoder_hidden_states=data.controlnet_prompt_embeds,
controlnet_cond=data.control_image,
control_type=data.control_type,
control_type_idx=data.control_mode,
conditioning_scale=data.cond_scale,
guess_mode=data.guess_mode,
added_cond_kwargs=data.controlnet_added_cond_kwargs,
return_dict=False,
)
for batch_index, (
latents_i, prompt_embeds_i, add_time_ids_i, pooled_prompt_embeds_i, mask_i, masked_image_latents_i, ip_adapter_embeds_i
) in enumerate(zip(
latents, prompt_embeds, add_time_ids, pooled_prompt_embeds, mask, masked_image_latents, ip_adapter_embeds
)):
latents_i = pipeline.scheduler.scale_model_input(latents_i, t)
# Prepare for inpainting
if data.num_channels_unet == 9:
latents_i = torch.cat([latents_i, mask_i, masked_image_latents_i], dim=1)
# Prepare additional conditionings
data.added_cond_kwargs = {
"text_embeds": pooled_prompt_embeds_i,
"time_ids": add_time_ids_i,
}
if ip_adapter_embeds_i is not None:
data.added_cond_kwargs["image_embeds"] = ip_adapter_embeds_i
# Prepare controlnet additional conditionings
data.controlnet_added_cond_kwargs = {
"text_embeds": pooled_prompt_embeds_i,
"time_ids": add_time_ids_i,
}
if pipeline.guider.is_conditional or not data.guess_mode:
data.down_block_res_samples, data.mid_block_res_sample = pipeline.controlnet(
latents_i,
t,
encoder_hidden_states=prompt_embeds_i,
controlnet_cond=data.control_image,
control_type=data.control_type,
control_type_idx=data.control_mode,
conditioning_scale=data.cond_scale,
guess_mode=data.guess_mode,
added_cond_kwargs=data.controlnet_added_cond_kwargs,
return_dict=False,
)
elif pipeline.guider.is_unconditional and data.guess_mode:
data.down_block_res_samples = [torch.zeros_like(d) for d in data.down_block_res_samples]
data.mid_block_res_sample = torch.zeros_like(data.mid_block_res_sample)
# when we apply guidance for unet, but not for controlnet:
# add 0 to the unconditional batch
data.down_block_res_samples = pipeline.guider.prepare_input(
data.down_block_res_samples, [torch.zeros_like(d) for d in data.down_block_res_samples]
)
data.mid_block_res_sample = pipeline.guider.prepare_input(
data.mid_block_res_sample, torch.zeros_like(data.mid_block_res_sample)
)
if data.num_channels_unet == 9:
latents_i = torch.cat([latents_i, mask_i, masked_image_latents_i], dim=1)
data.latent_model_input = pipeline.scheduler.scale_model_input(data.latent_model_input, t)
if data.num_channels_unet == 9:
data.latent_model_input = torch.cat([data.latent_model_input, data.mask, data.masked_image_latents], dim=1)
data.noise_pred = pipeline.unet(
latents_i,
t,
encoder_hidden_states=prompt_embeds_i,
timestep_cond=data.timestep_cond,
cross_attention_kwargs=data.cross_attention_kwargs,
added_cond_kwargs=data.added_cond_kwargs,
down_block_additional_residuals=data.down_block_res_samples,
mid_block_additional_residual=data.mid_block_res_sample,
return_dict=False,
)[0]
data.noise_pred = pipeline.guider.prepare_outputs(pipeline.unet, data.noise_pred)
# Perform guidance
outputs = pipeline.guider.outputs
data.noise_pred = pipeline.guider(**outputs)
data.noise_pred = pipeline.unet(
data.latent_model_input,
t,
encoder_hidden_states=data.prompt_embeds,
timestep_cond=data.timestep_cond,
cross_attention_kwargs=data.cross_attention_kwargs,
added_cond_kwargs=data.added_cond_kwargs,
down_block_additional_residuals=data.down_block_res_samples,
mid_block_additional_residual=data.mid_block_res_sample,
return_dict=False,
)[0]
# perform guidance
data.noise_pred = pipeline.guider.apply_guidance(data.noise_pred, timestep=t, latents=data.latents)
# compute the previous noisy sample x_t -> x_t-1
# Perform scheduler step using the predicted output
data.latents_dtype = data.latents.dtype
data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0]
if data.latents.dtype != data.latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
@@ -3180,14 +3160,10 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock):
data.init_latents_proper = pipeline.scheduler.add_noise(
data.init_latents_proper, data.noise, torch.tensor([data.noise_timestep])
)
data.latents = (1 - data.mask) * data.init_latents_proper + data.mask * data.latents
if i == len(data.timesteps) - 1 or ((i + 1) > data.num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0):
progress_bar.update()
pipeline.guider.reset_guider(pipeline)
pipeline.controlnet_guider.reset_guider(pipeline)
self.add_block_state(state, data)