mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
support sdxl controlnet
This commit is contained in:
@@ -92,6 +92,10 @@ class ClassifierFreeGuidance(BaseGuidance):
|
||||
|
||||
return pred
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
return self._num_outputs_prepared == 0
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
@@ -100,6 +104,8 @@ class ClassifierFreeGuidance(BaseGuidance):
|
||||
return num_conditions
|
||||
|
||||
def _is_cfg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
@@ -39,6 +39,7 @@ class BaseGuidance:
|
||||
self._timestep: torch.LongTensor = None
|
||||
self._preds: Dict[str, torch.Tensor] = {}
|
||||
self._num_outputs_prepared: int = 0
|
||||
self._enabled = True
|
||||
|
||||
if not (0.0 <= start < 1.0):
|
||||
raise ValueError(
|
||||
@@ -54,6 +55,12 @@ class BaseGuidance:
|
||||
"`_input_predictions` must be a list of required prediction names for the guidance technique."
|
||||
)
|
||||
|
||||
def force_disable(self):
|
||||
self._enabled = False
|
||||
|
||||
def force_enable(self):
|
||||
self._enabled = True
|
||||
|
||||
def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None:
|
||||
self._step = step
|
||||
self._num_inference_steps = num_inference_steps
|
||||
@@ -62,10 +69,10 @@ class BaseGuidance:
|
||||
self._num_outputs_prepared = 0
|
||||
|
||||
def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
|
||||
raise NotImplementedError("GuidanceMixin::prepare_inputs must be implemented in subclasses.")
|
||||
raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.")
|
||||
|
||||
def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None:
|
||||
raise NotImplementedError("GuidanceMixin::prepare_outputs must be implemented in subclasses.")
|
||||
raise NotImplementedError("BaseGuidance::prepare_outputs must be implemented in subclasses.")
|
||||
|
||||
def __call__(self, **kwargs) -> Any:
|
||||
if len(kwargs) != self.num_conditions:
|
||||
@@ -75,11 +82,19 @@ class BaseGuidance:
|
||||
return self.forward(**kwargs)
|
||||
|
||||
def forward(self, *args, **kwargs) -> Any:
|
||||
raise NotImplementedError("GuidanceMixin::forward must be implemented in subclasses.")
|
||||
raise NotImplementedError("BaseGuidance::forward must be implemented in subclasses.")
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
raise NotImplementedError("BaseGuidance::is_conditional must be implemented in subclasses.")
|
||||
|
||||
@property
|
||||
def is_unconditional(self) -> bool:
|
||||
return not self.is_conditional
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
raise NotImplementedError("GuidanceMixin::num_conditions must be implemented in subclasses.")
|
||||
raise NotImplementedError("BaseGuidance::num_conditions must be implemented in subclasses.")
|
||||
|
||||
@property
|
||||
def outputs(self) -> Dict[str, torch.Tensor]:
|
||||
@@ -114,7 +129,7 @@ def _default_prepare_inputs(denoiser: torch.nn.Module, num_conditions: int, *arg
|
||||
"""
|
||||
Prepares the inputs for the denoiser by ensuring that the conditional and unconditional inputs are correctly
|
||||
prepared based on required number of conditions. This function is used in the `prepare_inputs` method of the
|
||||
`GuidanceMixin` class.
|
||||
`BaseGuidance` class.
|
||||
|
||||
Either tensors or tuples/lists of tensors can be provided. If a tuple/list is provided, it should contain two elements:
|
||||
- The first element is the conditional input.
|
||||
|
||||
@@ -189,6 +189,10 @@ class SkipLayerGuidance(BaseGuidance):
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return pred
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
return self._num_outputs_prepared == 0 or self._num_outputs_prepared == 2
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
@@ -200,6 +204,8 @@ class SkipLayerGuidance(BaseGuidance):
|
||||
return num_conditions
|
||||
|
||||
def _is_cfg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
@@ -211,6 +217,8 @@ class SkipLayerGuidance(BaseGuidance):
|
||||
return is_within_range and not is_close
|
||||
|
||||
def _is_slg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step < self._step < skip_stop_step
|
||||
|
||||
@@ -2263,21 +2263,9 @@ class StableDiffusionXLDenoiseStep(PipelineBlock):
|
||||
)
|
||||
|
||||
for batch_index, (
|
||||
latents_i,
|
||||
prompt_embeds_i,
|
||||
add_time_ids_i,
|
||||
pooled_prompt_embeds_i,
|
||||
mask_i,
|
||||
masked_image_latents_i,
|
||||
ip_adapter_embeds_i,
|
||||
latents_i, prompt_embeds_i, add_time_ids_i, pooled_prompt_embeds_i, mask_i, masked_image_latents_i, ip_adapter_embeds_i,
|
||||
) in enumerate(zip(
|
||||
latents,
|
||||
prompt_embeds,
|
||||
add_time_ids,
|
||||
pooled_prompt_embeds,
|
||||
mask,
|
||||
masked_image_latents,
|
||||
ip_adapter_embeds
|
||||
latents, prompt_embeds, add_time_ids, pooled_prompt_embeds, mask, masked_image_latents, ip_adapter_embeds
|
||||
)):
|
||||
latents_i = pipeline.scheduler.scale_model_input(latents_i, t)
|
||||
|
||||
@@ -2285,6 +2273,7 @@ class StableDiffusionXLDenoiseStep(PipelineBlock):
|
||||
if data.num_channels_unet == 9:
|
||||
latents_i = torch.cat([latents_i, mask_i, masked_image_latents_i], dim=1)
|
||||
|
||||
# Prepare additional conditionings
|
||||
data.added_cond_kwargs = {
|
||||
"text_embeds": pooled_prompt_embeds_i,
|
||||
"time_ids": add_time_ids_i,
|
||||
@@ -2292,7 +2281,7 @@ class StableDiffusionXLDenoiseStep(PipelineBlock):
|
||||
if ip_adapter_embeds_i is not None:
|
||||
data.added_cond_kwargs["image_embeds"] = ip_adapter_embeds_i
|
||||
|
||||
# predict the noise residual
|
||||
# Predict the noise residual
|
||||
data.noise_pred = pipeline.unet(
|
||||
latents_i,
|
||||
t,
|
||||
@@ -2347,7 +2336,6 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
|
||||
ComponentSpec("unet", UNet2DConditionModel),
|
||||
ComponentSpec("controlnet", ControlNetModel),
|
||||
ComponentSpec("control_image_processor", VaeImageProcessor, obj=VaeImageProcessor(do_convert_rgb=True, do_normalize=False)),
|
||||
ComponentSpec("controlnet_guider", GuiderType, obj=ClassifierFreeGuidance()),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -2363,8 +2351,6 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
|
||||
InputParam("controlnet_conditioning_scale", default=1.0),
|
||||
InputParam("guess_mode", default=False),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam("guidance_scale", default=5.0),
|
||||
InputParam("guidance_rescale", default=0.0),
|
||||
InputParam("cross_attention_kwargs"),
|
||||
InputParam("generator"),
|
||||
InputParam("eta", default=0.0),
|
||||
@@ -2515,8 +2501,8 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
|
||||
image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32)
|
||||
else:
|
||||
image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
|
||||
|
||||
image_batch_size = image.shape[0]
|
||||
|
||||
if image_batch_size == 1:
|
||||
repeat_by = batch_size
|
||||
else:
|
||||
@@ -2524,9 +2510,7 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
|
||||
repeat_by = num_images_per_prompt
|
||||
|
||||
image = image.repeat_interleave(repeat_by, dim=0)
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
|
||||
return image
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components
|
||||
@@ -2557,9 +2541,7 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
|
||||
data.num_channels_unet = pipeline.unet.config.in_channels
|
||||
|
||||
# (1) prepare controlnet inputs
|
||||
|
||||
data.device = pipeline._execution_device
|
||||
|
||||
data.height, data.width = data.latents.shape[-2:]
|
||||
data.height = data.height * pipeline.vae_scale_factor
|
||||
data.width = data.width * pipeline.vae_scale_factor
|
||||
@@ -2642,59 +2624,12 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
|
||||
data.controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
|
||||
|
||||
# (2) Prepare conditional inputs for unet using the guider
|
||||
# adding default guider arguments: disable_guidance, guidance_scale, guidance_rescale
|
||||
data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False
|
||||
data.guider_kwargs = data.guider_kwargs or {}
|
||||
data.guider_kwargs = {
|
||||
**data.guider_kwargs,
|
||||
"disable_guidance": data.disable_guidance,
|
||||
"guidance_scale": data.guidance_scale,
|
||||
"guidance_rescale": data.guidance_rescale,
|
||||
"batch_size": data.batch_size * data.num_images_per_prompt,
|
||||
}
|
||||
pipeline.guider.set_guider(pipeline, data.guider_kwargs)
|
||||
data.prompt_embeds = pipeline.guider.prepare_input(
|
||||
data.prompt_embeds,
|
||||
data.negative_prompt_embeds,
|
||||
)
|
||||
data.add_time_ids = pipeline.guider.prepare_input(
|
||||
data.add_time_ids,
|
||||
data.negative_add_time_ids,
|
||||
)
|
||||
data.pooled_prompt_embeds = pipeline.guider.prepare_input(
|
||||
data.pooled_prompt_embeds,
|
||||
data.negative_pooled_prompt_embeds,
|
||||
)
|
||||
if data.num_channels_unet == 9:
|
||||
data.mask = pipeline.guider.prepare_input(data.mask, data.mask)
|
||||
data.masked_image_latents = pipeline.guider.prepare_input(data.masked_image_latents, data.masked_image_latents)
|
||||
|
||||
data.added_cond_kwargs = {
|
||||
"text_embeds": data.pooled_prompt_embeds,
|
||||
"time_ids": data.add_time_ids,
|
||||
}
|
||||
|
||||
if data.ip_adapter_embeds is not None:
|
||||
data.ip_adapter_embeds = pipeline.guider.prepare_input(data.ip_adapter_embeds, data.negative_ip_adapter_embeds)
|
||||
data.added_cond_kwargs["image_embeds"] = data.ip_adapter_embeds
|
||||
if data.disable_guidance:
|
||||
pipeline.guider.force_disable()
|
||||
|
||||
# (3) Prepare conditional inputs for controlnet using the guider
|
||||
data.controlnet_disable_guidance = True if data.disable_guidance or data.guess_mode else False
|
||||
data.controlnet_guider_kwargs = data.guider_kwargs or {}
|
||||
data.controlnet_guider_kwargs = {
|
||||
**data.controlnet_guider_kwargs,
|
||||
"disable_guidance": data.controlnet_disable_guidance,
|
||||
"guidance_scale": data.guidance_scale,
|
||||
"guidance_rescale": data.guidance_rescale,
|
||||
"batch_size": data.batch_size * data.num_images_per_prompt,
|
||||
}
|
||||
pipeline.controlnet_guider.set_guider(pipeline, data.controlnet_guider_kwargs)
|
||||
data.controlnet_prompt_embeds = pipeline.controlnet_guider.prepare_input(data.prompt_embeds)
|
||||
data.controlnet_added_cond_kwargs = {
|
||||
"text_embeds": pipeline.controlnet_guider.prepare_input(data.pooled_prompt_embeds),
|
||||
"time_ids": pipeline.controlnet_guider.prepare_input(data.add_time_ids),
|
||||
}
|
||||
data.control_image = pipeline.controlnet_guider.prepare_input(data.control_image, data.control_image)
|
||||
|
||||
# (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta)
|
||||
@@ -2703,11 +2638,26 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
|
||||
# (5) Denoise loop
|
||||
with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(data.timesteps):
|
||||
# prepare latents for unet using the guider
|
||||
data.latent_model_input = pipeline.guider.prepare_input(data.latents, data.latents)
|
||||
pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t)
|
||||
|
||||
# prepare latents for controlnet using the guider
|
||||
data.control_model_input = pipeline.controlnet_guider.prepare_input(data.latents, data.latents)
|
||||
(
|
||||
latents,
|
||||
prompt_embeds,
|
||||
add_time_ids,
|
||||
pooled_prompt_embeds,
|
||||
mask,
|
||||
masked_image_latents,
|
||||
ip_adapter_embeds,
|
||||
) = pipeline.guider.prepare_inputs(
|
||||
pipeline.unet,
|
||||
data.latents,
|
||||
(data.prompt_embeds, data.negative_prompt_embeds),
|
||||
(data.add_time_ids, data.negative_add_time_ids),
|
||||
(data.pooled_prompt_embeds, data.negative_pooled_prompt_embeds),
|
||||
data.mask,
|
||||
data.masked_image_latents,
|
||||
(data.ip_adapter_embeds, data.negative_ip_adapter_embeds),
|
||||
)
|
||||
|
||||
if isinstance(data.controlnet_keep[i], list):
|
||||
data.cond_scale = [c * s for c, s in zip(data.controlnet_conditioning_scale, data.controlnet_keep[i])]
|
||||
@@ -2717,51 +2667,74 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
|
||||
data.controlnet_cond_scale = data.controlnet_cond_scale[0]
|
||||
data.cond_scale = data.controlnet_cond_scale * data.controlnet_keep[i]
|
||||
|
||||
data.down_block_res_samples, data.mid_block_res_sample = pipeline.controlnet(
|
||||
pipeline.scheduler.scale_model_input(data.control_model_input, t),
|
||||
t,
|
||||
encoder_hidden_states=data.controlnet_prompt_embeds,
|
||||
controlnet_cond=data.control_image,
|
||||
conditioning_scale=data.cond_scale,
|
||||
guess_mode=data.guess_mode,
|
||||
added_cond_kwargs=data.controlnet_added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
for batch_index, (
|
||||
latents_i, prompt_embeds_i, add_time_ids_i, pooled_prompt_embeds_i, mask_i, masked_image_latents_i, ip_adapter_embeds_i
|
||||
) in enumerate(zip(
|
||||
latents, prompt_embeds, add_time_ids, pooled_prompt_embeds, mask, masked_image_latents, ip_adapter_embeds
|
||||
)):
|
||||
latents_i = pipeline.scheduler.scale_model_input(latents_i, t)
|
||||
|
||||
# Prepare for inpainting
|
||||
if data.num_channels_unet == 9:
|
||||
latents_i = torch.cat([latents_i, mask_i, masked_image_latents_i], dim=1)
|
||||
|
||||
# Prepare additional conditionings
|
||||
data.added_cond_kwargs = {
|
||||
"text_embeds": pooled_prompt_embeds_i,
|
||||
"time_ids": add_time_ids_i,
|
||||
}
|
||||
if ip_adapter_embeds_i is not None:
|
||||
data.added_cond_kwargs["image_embeds"] = ip_adapter_embeds_i
|
||||
|
||||
# Prepare controlnet additional conditionings
|
||||
data.controlnet_added_cond_kwargs = {
|
||||
"text_embeds": pooled_prompt_embeds_i,
|
||||
"time_ids": add_time_ids_i,
|
||||
}
|
||||
|
||||
# when we apply guidance for unet, but not for controlnet:
|
||||
# add 0 to the unconditional batch
|
||||
data.down_block_res_samples = pipeline.guider.prepare_input(
|
||||
data.down_block_res_samples, [torch.zeros_like(d) for d in data.down_block_res_samples]
|
||||
)
|
||||
data.mid_block_res_sample = pipeline.guider.prepare_input(
|
||||
data.mid_block_res_sample, torch.zeros_like(data.mid_block_res_sample)
|
||||
)
|
||||
if pipeline.guider.is_conditional or not data.guess_mode:
|
||||
data.down_block_res_samples, data.mid_block_res_sample = pipeline.controlnet(
|
||||
latents_i,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds_i,
|
||||
controlnet_cond=data.control_image,
|
||||
conditioning_scale=data.cond_scale,
|
||||
guess_mode=data.guess_mode,
|
||||
added_cond_kwargs=data.controlnet_added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
elif pipeline.guider.is_unconditional and data.guess_mode:
|
||||
data.down_block_res_samples = [torch.zeros_like(d) for d in data.down_block_res_samples]
|
||||
data.mid_block_res_sample = torch.zeros_like(data.mid_block_res_sample)
|
||||
|
||||
data.latent_model_input = pipeline.scheduler.scale_model_input(data.latent_model_input, t)
|
||||
if data.num_channels_unet == 9:
|
||||
data.latent_model_input = torch.cat([data.latent_model_input, data.mask, data.masked_image_latents], dim=1)
|
||||
if data.num_channels_unet == 9:
|
||||
latents_i = torch.cat([latents_i, mask_i, masked_image_latents_i], dim=1)
|
||||
|
||||
data.noise_pred = pipeline.unet(
|
||||
data.latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=data.prompt_embeds,
|
||||
timestep_cond=data.timestep_cond,
|
||||
cross_attention_kwargs=data.cross_attention_kwargs,
|
||||
added_cond_kwargs=data.added_cond_kwargs,
|
||||
down_block_additional_residuals=data.down_block_res_samples,
|
||||
mid_block_additional_residual=data.mid_block_res_sample,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
# perform guidance
|
||||
data.noise_pred = pipeline.guider.apply_guidance(data.noise_pred, timestep=t, latents=data.latents)
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
data.noise_pred = pipeline.unet(
|
||||
latents_i,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds_i,
|
||||
timestep_cond=data.timestep_cond,
|
||||
cross_attention_kwargs=data.cross_attention_kwargs,
|
||||
added_cond_kwargs=data.added_cond_kwargs,
|
||||
down_block_additional_residuals=data.down_block_res_samples,
|
||||
mid_block_additional_residual=data.mid_block_res_sample,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
data.noise_pred = pipeline.guider.prepare_outputs(pipeline.unet, data.noise_pred)
|
||||
|
||||
# Perform guidance
|
||||
outputs = pipeline.guider.outputs
|
||||
data.noise_pred = pipeline.guider(**outputs)
|
||||
|
||||
# Perform scheduler step using the predicted output
|
||||
data.latents_dtype = data.latents.dtype
|
||||
data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
if data.latents.dtype != data.latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
data.latents = data.latents.to(data.latents_dtype)
|
||||
|
||||
|
||||
if data.num_channels_unet == 4 and data.mask is not None and data.image_latents is not None:
|
||||
data.init_latents_proper = data.image_latents
|
||||
@@ -2775,9 +2748,6 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
|
||||
|
||||
if i == len(data.timesteps) - 1 or ((i + 1) > data.num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
pipeline.guider.reset_guider(pipeline)
|
||||
pipeline.controlnet_guider.reset_guider(pipeline)
|
||||
|
||||
self.add_block_state(state, data)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user