mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
support controlnet union
This commit is contained in:
@@ -132,6 +132,7 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["guiders"].extend(
|
||||
[
|
||||
"AdaptiveProjectedGuidance",
|
||||
"ClassifierFreeGuidance",
|
||||
"SkipLayerGuidance",
|
||||
]
|
||||
@@ -721,6 +722,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .utils.dummy_pt_objects import * # noqa F403
|
||||
else:
|
||||
from .guiders import (
|
||||
AdaptiveProjectedGuidance,
|
||||
ClassifierFreeGuidance,
|
||||
SkipLayerGuidance,
|
||||
)
|
||||
|
||||
@@ -18,6 +18,7 @@ from ..utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .adaptive_projected_guidance import AdaptiveProjectedGuidance
|
||||
from .classifier_free_guidance import ClassifierFreeGuidance
|
||||
from .skip_layer_guidance import SkipLayerGuidance
|
||||
|
||||
|
||||
174
src/diffusers/guiders/adaptive_projected_guidance.py
Normal file
174
src/diffusers/guiders/adaptive_projected_guidance.py
Normal file
@@ -0,0 +1,174 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Optional, Union, Tuple, List
|
||||
|
||||
import torch
|
||||
|
||||
from .guider_utils import BaseGuidance, rescale_noise_cfg, _default_prepare_inputs
|
||||
|
||||
|
||||
class AdaptiveProjectedGuidance(BaseGuidance):
|
||||
"""
|
||||
Adaptive Projected Guidance (APG): https://huggingface.co/papers/2410.02416
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `7.5`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
adaptive_projected_guidance_momentum (`float`, defaults to `None`):
|
||||
The momentum parameter for the adaptive projected guidance. Disabled if set to `None`.
|
||||
adaptive_projected_guidance_rescale (`float`, defaults to `15.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
start (`float`, defaults to `0.0`):
|
||||
The fraction of the total number of denoising steps after which guidance starts.
|
||||
stop (`float`, defaults to `1.0`):
|
||||
The fraction of the total number of denoising steps after which guidance stops.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 7.5,
|
||||
adaptive_projected_guidance_momentum: Optional[float] = None,
|
||||
adaptive_projected_guidance_rescale: float = 15.0,
|
||||
eta: float = 1.0,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
|
||||
self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
|
||||
self.eta = eta
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
self.momentum_buffer = None
|
||||
|
||||
def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
|
||||
if self._step == 0:
|
||||
if self.adaptive_projected_guidance_momentum is not None:
|
||||
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
|
||||
return _default_prepare_inputs(denoiser, self.num_conditions, *args)
|
||||
|
||||
def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None:
|
||||
self._num_outputs_prepared += 1
|
||||
if self._num_outputs_prepared > self.num_conditions:
|
||||
raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.")
|
||||
key = self._input_predictions[self._num_outputs_prepared - 1]
|
||||
self._preds[key] = pred
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
pred = None
|
||||
|
||||
if not self._is_cfg_enabled():
|
||||
pred = pred_cond
|
||||
else:
|
||||
pred = normalized_guidance(
|
||||
pred_cond,
|
||||
pred_uncond,
|
||||
self.guidance_scale,
|
||||
self.momentum_buffer,
|
||||
self.eta,
|
||||
self.adaptive_projected_guidance_rescale,
|
||||
self.use_original_formulation,
|
||||
)
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return pred
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
return self._num_outputs_prepared == 0
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_cfg_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_cfg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
return is_within_range and not is_close
|
||||
|
||||
|
||||
class MomentumBuffer:
|
||||
def __init__(self, momentum: float):
|
||||
self.momentum = momentum
|
||||
self.running_average = 0
|
||||
|
||||
def update(self, update_value: torch.Tensor):
|
||||
new_average = self.momentum * self.running_average
|
||||
self.running_average = update_value + new_average
|
||||
|
||||
|
||||
def normalized_guidance(
|
||||
pred_cond: torch.Tensor,
|
||||
pred_uncond: torch.Tensor,
|
||||
guidance_scale: float,
|
||||
momentum_buffer: Optional[MomentumBuffer] = None,
|
||||
eta: float = 1.0,
|
||||
norm_threshold: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
):
|
||||
diff = pred_cond - pred_uncond
|
||||
dim = [-i for i in range(1, len(diff.shape))]
|
||||
if momentum_buffer is not None:
|
||||
momentum_buffer.update(diff)
|
||||
diff = momentum_buffer.running_average
|
||||
if norm_threshold > 0:
|
||||
ones = torch.ones_like(diff)
|
||||
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
|
||||
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
|
||||
diff = diff * scale_factor
|
||||
v0, v1 = diff.double(), pred_cond.double()
|
||||
v1 = torch.nn.functional.normalize(v1, dim=dim)
|
||||
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
|
||||
v0_orthogonal = v0 - v0_parallel
|
||||
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
|
||||
normalized_update = diff_orthogonal + eta * diff_parallel
|
||||
pred = pred_cond if use_original_formulation else pred_uncond
|
||||
pred = pred + (guidance_scale - 1) * normalized_update
|
||||
return pred
|
||||
@@ -106,12 +106,17 @@ class ClassifierFreeGuidance(BaseGuidance):
|
||||
def _is_cfg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
return is_within_range and not is_close
|
||||
|
||||
@@ -206,21 +206,31 @@ class SkipLayerGuidance(BaseGuidance):
|
||||
def _is_cfg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
return is_within_range and not is_close
|
||||
|
||||
def _is_slg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step < self._step < skip_stop_step
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step < self._step < skip_stop_step
|
||||
|
||||
is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0)
|
||||
|
||||
return is_within_range and not is_zero
|
||||
|
||||
@@ -194,11 +194,7 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock, ModularIPAdapterMixin):
|
||||
PipelineImageInput,
|
||||
required=True,
|
||||
description="The image(s) to be used as ip adapter"
|
||||
),
|
||||
InputParam(
|
||||
"guidance_scale",
|
||||
default=5.0,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@@ -236,11 +232,10 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock, ModularIPAdapterMixin):
|
||||
|
||||
# modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
|
||||
def prepare_ip_adapter_image_embeds(
|
||||
self, components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
|
||||
self, components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
|
||||
):
|
||||
image_embeds = []
|
||||
if do_classifier_free_guidance:
|
||||
negative_image_embeds = []
|
||||
negative_image_embeds = []
|
||||
if ip_adapter_image_embeds is None:
|
||||
if not isinstance(ip_adapter_image, list):
|
||||
ip_adapter_image = [ip_adapter_image]
|
||||
@@ -259,21 +254,18 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock, ModularIPAdapterMixin):
|
||||
)
|
||||
|
||||
image_embeds.append(single_image_embeds[None, :])
|
||||
if do_classifier_free_guidance:
|
||||
negative_image_embeds.append(single_negative_image_embeds[None, :])
|
||||
negative_image_embeds.append(single_negative_image_embeds[None, :])
|
||||
else:
|
||||
for single_image_embeds in ip_adapter_image_embeds:
|
||||
if do_classifier_free_guidance:
|
||||
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
|
||||
negative_image_embeds.append(single_negative_image_embeds)
|
||||
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
|
||||
negative_image_embeds.append(single_negative_image_embeds)
|
||||
image_embeds.append(single_image_embeds)
|
||||
|
||||
ip_adapter_image_embeds = []
|
||||
for i, single_image_embeds in enumerate(image_embeds):
|
||||
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
|
||||
if do_classifier_free_guidance:
|
||||
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
|
||||
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
|
||||
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
|
||||
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
|
||||
|
||||
single_image_embeds = single_image_embeds.to(device=device)
|
||||
ip_adapter_image_embeds.append(single_image_embeds)
|
||||
@@ -323,6 +315,7 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
ComponentSpec("text_encoder_2", CLIPTextModelWithProjection),
|
||||
ComponentSpec("tokenizer", CLIPTokenizer),
|
||||
ComponentSpec("tokenizer_2", CLIPTokenizer),
|
||||
ComponentSpec("guider", GuiderType, obj=ClassifierFreeGuidance()),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -337,7 +330,6 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
InputParam("negative_prompt"),
|
||||
InputParam("negative_prompt_2"),
|
||||
InputParam("cross_attention_kwargs"),
|
||||
InputParam("guidance_scale",default=5.0),
|
||||
InputParam("clip_skip"),
|
||||
]
|
||||
|
||||
@@ -601,10 +593,9 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
data = self.get_block_state(state)
|
||||
self.check_inputs(pipeline, data)
|
||||
|
||||
data.do_classifier_free_guidance = data.guidance_scale > 1.0
|
||||
data.do_classifier_free_guidance = pipeline.guider.num_conditions > 1
|
||||
data.device = pipeline._execution_device
|
||||
|
||||
|
||||
# Encode input prompt
|
||||
data.text_encoder_lora_scale = (
|
||||
data.cross_attention_kwargs.get("scale", None) if data.cross_attention_kwargs is not None else None
|
||||
@@ -1750,7 +1741,6 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
InputParam("crops_coords_top_left", default=(0, 0)),
|
||||
InputParam("negative_crops_coords_top_left", default=(0, 0)),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam("guidance_scale", required=True),
|
||||
InputParam("aesthetic_score", default=6.0),
|
||||
InputParam("negative_aesthetic_score", default=2.0),
|
||||
]
|
||||
@@ -1897,7 +1887,8 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
and pipeline.unet is not None
|
||||
and pipeline.unet.config.time_cond_proj_dim is not None
|
||||
):
|
||||
data.guidance_scale_tensor = torch.tensor(data.guidance_scale - 1).repeat(data.batch_size * data.num_images_per_prompt)
|
||||
# TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this!
|
||||
data.guidance_scale_tensor = torch.tensor(pipeline.guider.guidance_scale - 1).repeat(data.batch_size * data.num_images_per_prompt)
|
||||
data.timestep_cond = self.get_guidance_scale_embedding(
|
||||
data.guidance_scale_tensor, embedding_dim=pipeline.unet.config.time_cond_proj_dim
|
||||
).to(device=data.device, dtype=data.latents.dtype)
|
||||
@@ -1925,7 +1916,6 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
InputParam("crops_coords_top_left", default=(0, 0)),
|
||||
InputParam("negative_crops_coords_top_left", default=(0, 0)),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam("guidance_scale", default=5.0),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -2051,7 +2041,8 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
and pipeline.unet is not None
|
||||
and pipeline.unet.config.time_cond_proj_dim is not None
|
||||
):
|
||||
data.guidance_scale_tensor = torch.tensor(data.guidance_scale - 1).repeat(data.batch_size * data.num_images_per_prompt)
|
||||
# TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this!
|
||||
data.guidance_scale_tensor = torch.tensor(pipeline.guider.guidance_scale - 1).repeat(data.batch_size * data.num_images_per_prompt)
|
||||
data.timestep_cond = self.get_guidance_scale_embedding(
|
||||
data.guidance_scale_tensor, embedding_dim=pipeline.unet.config.time_cond_proj_dim
|
||||
).to(device=data.device, dtype=data.latents.dtype)
|
||||
@@ -2234,6 +2225,10 @@ class StableDiffusionXLDenoiseStep(PipelineBlock):
|
||||
|
||||
data.num_channels_unet = pipeline.unet.config.in_channels
|
||||
data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False
|
||||
if data.disable_guidance:
|
||||
pipeline.guider.force_disable()
|
||||
else:
|
||||
pipeline.guider.force_enable()
|
||||
|
||||
# Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta)
|
||||
@@ -2354,7 +2349,6 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
|
||||
InputParam("cross_attention_kwargs"),
|
||||
InputParam("generator"),
|
||||
InputParam("eta", default=0.0),
|
||||
InputParam("guider_kwargs"),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -2627,9 +2621,8 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
|
||||
data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False
|
||||
if data.disable_guidance:
|
||||
pipeline.guider.force_disable()
|
||||
|
||||
# (3) Prepare conditional inputs for controlnet using the guider
|
||||
data.controlnet_disable_guidance = True if data.disable_guidance or data.guess_mode else False
|
||||
else:
|
||||
pipeline.guider.force_enable()
|
||||
|
||||
# (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta)
|
||||
@@ -2764,7 +2757,6 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock):
|
||||
ComponentSpec("controlnet", ControlNetUnionModel),
|
||||
ComponentSpec("scheduler", KarrasDiffusionSchedulers),
|
||||
ComponentSpec("guider", GuiderType, obj=ClassifierFreeGuidance()),
|
||||
ComponentSpec("controlnet_guider", GuiderType, obj=ClassifierFreeGuidance()),
|
||||
ComponentSpec("control_image_processor", VaeImageProcessor, obj=VaeImageProcessor(do_convert_rgb=True, do_normalize=False)),
|
||||
]
|
||||
|
||||
@@ -2781,12 +2773,9 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock):
|
||||
InputParam("controlnet_conditioning_scale", default=1.0),
|
||||
InputParam("guess_mode", default=False),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam("guidance_scale", default=5.0),
|
||||
InputParam("guidance_rescale", default=0.0),
|
||||
InputParam("cross_attention_kwargs"),
|
||||
InputParam("generator"),
|
||||
InputParam("eta", default=0.0),
|
||||
InputParam("guider_kwargs")
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -3029,7 +3018,6 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock):
|
||||
crops_coords=data.crops_coords,
|
||||
)
|
||||
data.height, data.width = data.control_image[idx].shape[-2:]
|
||||
|
||||
|
||||
# (1.6)
|
||||
# controlnet_keep
|
||||
@@ -3043,80 +3031,48 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock):
|
||||
# (2) Prepare conditional inputs for unet using the guider
|
||||
# adding default guider arguments: disable_guidance, guidance_scale, guidance_rescale
|
||||
data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False
|
||||
data.guider_kwargs = data.guider_kwargs or {}
|
||||
data.guider_kwargs = {
|
||||
**data.guider_kwargs,
|
||||
"disable_guidance": data.disable_guidance,
|
||||
"guidance_scale": data.guidance_scale,
|
||||
"guidance_rescale": data.guidance_rescale,
|
||||
"batch_size": data.batch_size * data.num_images_per_prompt,
|
||||
}
|
||||
pipeline.guider.set_guider(pipeline, data.guider_kwargs)
|
||||
data.prompt_embeds = pipeline.guider.prepare_input(
|
||||
data.prompt_embeds,
|
||||
data.negative_prompt_embeds,
|
||||
)
|
||||
data.add_time_ids = pipeline.guider.prepare_input(
|
||||
data.add_time_ids,
|
||||
data.negative_add_time_ids,
|
||||
)
|
||||
data.pooled_prompt_embeds = pipeline.guider.prepare_input(
|
||||
data.pooled_prompt_embeds,
|
||||
data.negative_pooled_prompt_embeds,
|
||||
)
|
||||
|
||||
if data.num_channels_unet == 9:
|
||||
data.mask = pipeline.guider.prepare_input(data.mask, data.mask)
|
||||
data.masked_image_latents = pipeline.guider.prepare_input(data.masked_image_latents, data.masked_image_latents)
|
||||
|
||||
data.added_cond_kwargs = {
|
||||
"text_embeds": data.pooled_prompt_embeds,
|
||||
"time_ids": data.add_time_ids,
|
||||
}
|
||||
|
||||
if data.ip_adapter_embeds is not None:
|
||||
data.ip_adapter_embeds = pipeline.guider.prepare_input(data.ip_adapter_embeds, data.negative_ip_adapter_embeds)
|
||||
data.added_cond_kwargs["image_embeds"] = data.ip_adapter_embeds
|
||||
if data.disable_guidance:
|
||||
pipeline.guider.force_disable()
|
||||
else:
|
||||
pipeline.guider.force_enable()
|
||||
|
||||
# (3) Prepare conditional inputs for controlnet using the guider
|
||||
data.controlnet_disable_guidance = True if data.disable_guidance or data.guess_mode else False
|
||||
data.controlnet_guider_kwargs = data.guider_kwargs or {}
|
||||
data.controlnet_guider_kwargs = {
|
||||
**data.controlnet_guider_kwargs,
|
||||
"disable_guidance": data.controlnet_disable_guidance,
|
||||
"guidance_scale": data.guidance_scale,
|
||||
"guidance_rescale": data.guidance_rescale,
|
||||
"batch_size": data.batch_size * data.num_images_per_prompt,
|
||||
}
|
||||
pipeline.controlnet_guider.set_guider(pipeline, data.controlnet_guider_kwargs)
|
||||
data.controlnet_prompt_embeds = pipeline.controlnet_guider.prepare_input(data.prompt_embeds)
|
||||
data.controlnet_added_cond_kwargs = {
|
||||
"text_embeds": pipeline.controlnet_guider.prepare_input(data.pooled_prompt_embeds),
|
||||
"time_ids": pipeline.controlnet_guider.prepare_input(data.add_time_ids),
|
||||
}
|
||||
for idx, _ in enumerate(data.control_image):
|
||||
data.control_image[idx] = pipeline.controlnet_guider.prepare_input(data.control_image[idx], data.control_image[idx])
|
||||
# data.controlnet_prompt_embeds = pipeline.controlnet_guider.prepare_input(data.prompt_embeds)
|
||||
# data.controlnet_added_cond_kwargs = {
|
||||
# "text_embeds": pipeline.controlnet_guider.prepare_input(data.pooled_prompt_embeds),
|
||||
# "time_ids": pipeline.controlnet_guider.prepare_input(data.add_time_ids),
|
||||
# }
|
||||
|
||||
data.control_type = (
|
||||
data.control_type.reshape(1, -1)
|
||||
.to(data.device, dtype=data.prompt_embeds.dtype)
|
||||
)
|
||||
data.control_type = data.control_type.reshape(1, -1).to(data.device, dtype=data.prompt_embeds.dtype)
|
||||
repeat_by = data.batch_size * data.num_images_per_prompt // data.control_type.shape[0]
|
||||
data.control_type = data.control_type.repeat_interleave(repeat_by, dim=0)
|
||||
data.control_type = pipeline.controlnet_guider.prepare_input(data.control_type, data.control_type)
|
||||
|
||||
# (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta)
|
||||
data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0)
|
||||
|
||||
|
||||
with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(data.timesteps):
|
||||
# prepare latents for unet using the guider
|
||||
data.latent_model_input = pipeline.guider.prepare_input(data.latents, data.latents)
|
||||
pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t)
|
||||
|
||||
# prepare latents for controlnet using the guider
|
||||
data.control_model_input = pipeline.controlnet_guider.prepare_input(data.latents, data.latents)
|
||||
(
|
||||
latents,
|
||||
prompt_embeds,
|
||||
add_time_ids,
|
||||
pooled_prompt_embeds,
|
||||
mask,
|
||||
masked_image_latents,
|
||||
ip_adapter_embeds,
|
||||
) = pipeline.guider.prepare_inputs(
|
||||
pipeline.unet,
|
||||
data.latents,
|
||||
(data.prompt_embeds, data.negative_prompt_embeds),
|
||||
(data.add_time_ids, data.negative_add_time_ids),
|
||||
(data.pooled_prompt_embeds, data.negative_pooled_prompt_embeds),
|
||||
data.mask,
|
||||
data.masked_image_latents,
|
||||
(data.ip_adapter_embeds, data.negative_ip_adapter_embeds),
|
||||
)
|
||||
|
||||
if isinstance(data.controlnet_keep[i], list):
|
||||
data.cond_scale = [c * s for c, s in zip(data.controlnet_conditioning_scale, data.controlnet_keep[i])]
|
||||
@@ -3126,48 +3082,72 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock):
|
||||
data.controlnet_cond_scale = data.controlnet_cond_scale[0]
|
||||
data.cond_scale = data.controlnet_cond_scale * data.controlnet_keep[i]
|
||||
|
||||
data.down_block_res_samples, data.mid_block_res_sample = pipeline.controlnet(
|
||||
pipeline.scheduler.scale_model_input(data.control_model_input, t),
|
||||
t,
|
||||
encoder_hidden_states=data.controlnet_prompt_embeds,
|
||||
controlnet_cond=data.control_image,
|
||||
control_type=data.control_type,
|
||||
control_type_idx=data.control_mode,
|
||||
conditioning_scale=data.cond_scale,
|
||||
guess_mode=data.guess_mode,
|
||||
added_cond_kwargs=data.controlnet_added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
for batch_index, (
|
||||
latents_i, prompt_embeds_i, add_time_ids_i, pooled_prompt_embeds_i, mask_i, masked_image_latents_i, ip_adapter_embeds_i
|
||||
) in enumerate(zip(
|
||||
latents, prompt_embeds, add_time_ids, pooled_prompt_embeds, mask, masked_image_latents, ip_adapter_embeds
|
||||
)):
|
||||
latents_i = pipeline.scheduler.scale_model_input(latents_i, t)
|
||||
|
||||
# Prepare for inpainting
|
||||
if data.num_channels_unet == 9:
|
||||
latents_i = torch.cat([latents_i, mask_i, masked_image_latents_i], dim=1)
|
||||
|
||||
# Prepare additional conditionings
|
||||
data.added_cond_kwargs = {
|
||||
"text_embeds": pooled_prompt_embeds_i,
|
||||
"time_ids": add_time_ids_i,
|
||||
}
|
||||
if ip_adapter_embeds_i is not None:
|
||||
data.added_cond_kwargs["image_embeds"] = ip_adapter_embeds_i
|
||||
|
||||
# Prepare controlnet additional conditionings
|
||||
data.controlnet_added_cond_kwargs = {
|
||||
"text_embeds": pooled_prompt_embeds_i,
|
||||
"time_ids": add_time_ids_i,
|
||||
}
|
||||
|
||||
if pipeline.guider.is_conditional or not data.guess_mode:
|
||||
data.down_block_res_samples, data.mid_block_res_sample = pipeline.controlnet(
|
||||
latents_i,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds_i,
|
||||
controlnet_cond=data.control_image,
|
||||
control_type=data.control_type,
|
||||
control_type_idx=data.control_mode,
|
||||
conditioning_scale=data.cond_scale,
|
||||
guess_mode=data.guess_mode,
|
||||
added_cond_kwargs=data.controlnet_added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
elif pipeline.guider.is_unconditional and data.guess_mode:
|
||||
data.down_block_res_samples = [torch.zeros_like(d) for d in data.down_block_res_samples]
|
||||
data.mid_block_res_sample = torch.zeros_like(data.mid_block_res_sample)
|
||||
|
||||
# when we apply guidance for unet, but not for controlnet:
|
||||
# add 0 to the unconditional batch
|
||||
data.down_block_res_samples = pipeline.guider.prepare_input(
|
||||
data.down_block_res_samples, [torch.zeros_like(d) for d in data.down_block_res_samples]
|
||||
)
|
||||
data.mid_block_res_sample = pipeline.guider.prepare_input(
|
||||
data.mid_block_res_sample, torch.zeros_like(data.mid_block_res_sample)
|
||||
)
|
||||
if data.num_channels_unet == 9:
|
||||
latents_i = torch.cat([latents_i, mask_i, masked_image_latents_i], dim=1)
|
||||
|
||||
data.latent_model_input = pipeline.scheduler.scale_model_input(data.latent_model_input, t)
|
||||
if data.num_channels_unet == 9:
|
||||
data.latent_model_input = torch.cat([data.latent_model_input, data.mask, data.masked_image_latents], dim=1)
|
||||
data.noise_pred = pipeline.unet(
|
||||
latents_i,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds_i,
|
||||
timestep_cond=data.timestep_cond,
|
||||
cross_attention_kwargs=data.cross_attention_kwargs,
|
||||
added_cond_kwargs=data.added_cond_kwargs,
|
||||
down_block_additional_residuals=data.down_block_res_samples,
|
||||
mid_block_additional_residual=data.mid_block_res_sample,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
data.noise_pred = pipeline.guider.prepare_outputs(pipeline.unet, data.noise_pred)
|
||||
|
||||
# Perform guidance
|
||||
outputs = pipeline.guider.outputs
|
||||
data.noise_pred = pipeline.guider(**outputs)
|
||||
|
||||
data.noise_pred = pipeline.unet(
|
||||
data.latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=data.prompt_embeds,
|
||||
timestep_cond=data.timestep_cond,
|
||||
cross_attention_kwargs=data.cross_attention_kwargs,
|
||||
added_cond_kwargs=data.added_cond_kwargs,
|
||||
down_block_additional_residuals=data.down_block_res_samples,
|
||||
mid_block_additional_residual=data.mid_block_res_sample,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
# perform guidance
|
||||
data.noise_pred = pipeline.guider.apply_guidance(data.noise_pred, timestep=t, latents=data.latents)
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
# Perform scheduler step using the predicted output
|
||||
data.latents_dtype = data.latents.dtype
|
||||
data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
if data.latents.dtype != data.latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
@@ -3180,14 +3160,10 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock):
|
||||
data.init_latents_proper = pipeline.scheduler.add_noise(
|
||||
data.init_latents_proper, data.noise, torch.tensor([data.noise_timestep])
|
||||
)
|
||||
|
||||
data.latents = (1 - data.mask) * data.init_latents_proper + data.mask * data.latents
|
||||
|
||||
if i == len(data.timesteps) - 1 or ((i + 1) > data.num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
pipeline.guider.reset_guider(pipeline)
|
||||
pipeline.controlnet_guider.reset_guider(pipeline)
|
||||
|
||||
self.add_block_state(state, data)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user