1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

support sdxl controlnet

This commit is contained in:
Aryan
2025-04-14 14:04:04 +02:00
parent 0c4c1a8430
commit 9da8a9d1d5
4 changed files with 119 additions and 120 deletions

View File

@@ -92,6 +92,10 @@ class ClassifierFreeGuidance(BaseGuidance):
return pred
@property
def is_conditional(self) -> bool:
return self._num_outputs_prepared == 0
@property
def num_conditions(self) -> int:
num_conditions = 1
@@ -100,6 +104,8 @@ class ClassifierFreeGuidance(BaseGuidance):
return num_conditions
def _is_cfg_enabled(self) -> bool:
if not self._enabled:
return False
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step

View File

@@ -39,6 +39,7 @@ class BaseGuidance:
self._timestep: torch.LongTensor = None
self._preds: Dict[str, torch.Tensor] = {}
self._num_outputs_prepared: int = 0
self._enabled = True
if not (0.0 <= start < 1.0):
raise ValueError(
@@ -54,6 +55,12 @@ class BaseGuidance:
"`_input_predictions` must be a list of required prediction names for the guidance technique."
)
def force_disable(self):
self._enabled = False
def force_enable(self):
self._enabled = True
def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None:
self._step = step
self._num_inference_steps = num_inference_steps
@@ -62,10 +69,10 @@ class BaseGuidance:
self._num_outputs_prepared = 0
def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
raise NotImplementedError("GuidanceMixin::prepare_inputs must be implemented in subclasses.")
raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.")
def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None:
raise NotImplementedError("GuidanceMixin::prepare_outputs must be implemented in subclasses.")
raise NotImplementedError("BaseGuidance::prepare_outputs must be implemented in subclasses.")
def __call__(self, **kwargs) -> Any:
if len(kwargs) != self.num_conditions:
@@ -75,11 +82,19 @@ class BaseGuidance:
return self.forward(**kwargs)
def forward(self, *args, **kwargs) -> Any:
raise NotImplementedError("GuidanceMixin::forward must be implemented in subclasses.")
raise NotImplementedError("BaseGuidance::forward must be implemented in subclasses.")
@property
def is_conditional(self) -> bool:
raise NotImplementedError("BaseGuidance::is_conditional must be implemented in subclasses.")
@property
def is_unconditional(self) -> bool:
return not self.is_conditional
@property
def num_conditions(self) -> int:
raise NotImplementedError("GuidanceMixin::num_conditions must be implemented in subclasses.")
raise NotImplementedError("BaseGuidance::num_conditions must be implemented in subclasses.")
@property
def outputs(self) -> Dict[str, torch.Tensor]:
@@ -114,7 +129,7 @@ def _default_prepare_inputs(denoiser: torch.nn.Module, num_conditions: int, *arg
"""
Prepares the inputs for the denoiser by ensuring that the conditional and unconditional inputs are correctly
prepared based on required number of conditions. This function is used in the `prepare_inputs` method of the
`GuidanceMixin` class.
`BaseGuidance` class.
Either tensors or tuples/lists of tensors can be provided. If a tuple/list is provided, it should contain two elements:
- The first element is the conditional input.

View File

@@ -189,6 +189,10 @@ class SkipLayerGuidance(BaseGuidance):
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred
@property
def is_conditional(self) -> bool:
return self._num_outputs_prepared == 0 or self._num_outputs_prepared == 2
@property
def num_conditions(self) -> int:
@@ -200,6 +204,8 @@ class SkipLayerGuidance(BaseGuidance):
return num_conditions
def _is_cfg_enabled(self) -> bool:
if not self._enabled:
return False
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
@@ -211,6 +217,8 @@ class SkipLayerGuidance(BaseGuidance):
return is_within_range and not is_close
def _is_slg_enabled(self) -> bool:
if not self._enabled:
return False
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step < self._step < skip_stop_step

View File

@@ -2263,21 +2263,9 @@ class StableDiffusionXLDenoiseStep(PipelineBlock):
)
for batch_index, (
latents_i,
prompt_embeds_i,
add_time_ids_i,
pooled_prompt_embeds_i,
mask_i,
masked_image_latents_i,
ip_adapter_embeds_i,
latents_i, prompt_embeds_i, add_time_ids_i, pooled_prompt_embeds_i, mask_i, masked_image_latents_i, ip_adapter_embeds_i,
) in enumerate(zip(
latents,
prompt_embeds,
add_time_ids,
pooled_prompt_embeds,
mask,
masked_image_latents,
ip_adapter_embeds
latents, prompt_embeds, add_time_ids, pooled_prompt_embeds, mask, masked_image_latents, ip_adapter_embeds
)):
latents_i = pipeline.scheduler.scale_model_input(latents_i, t)
@@ -2285,6 +2273,7 @@ class StableDiffusionXLDenoiseStep(PipelineBlock):
if data.num_channels_unet == 9:
latents_i = torch.cat([latents_i, mask_i, masked_image_latents_i], dim=1)
# Prepare additional conditionings
data.added_cond_kwargs = {
"text_embeds": pooled_prompt_embeds_i,
"time_ids": add_time_ids_i,
@@ -2292,7 +2281,7 @@ class StableDiffusionXLDenoiseStep(PipelineBlock):
if ip_adapter_embeds_i is not None:
data.added_cond_kwargs["image_embeds"] = ip_adapter_embeds_i
# predict the noise residual
# Predict the noise residual
data.noise_pred = pipeline.unet(
latents_i,
t,
@@ -2347,7 +2336,6 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
ComponentSpec("unet", UNet2DConditionModel),
ComponentSpec("controlnet", ControlNetModel),
ComponentSpec("control_image_processor", VaeImageProcessor, obj=VaeImageProcessor(do_convert_rgb=True, do_normalize=False)),
ComponentSpec("controlnet_guider", GuiderType, obj=ClassifierFreeGuidance()),
]
@property
@@ -2363,8 +2351,6 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
InputParam("controlnet_conditioning_scale", default=1.0),
InputParam("guess_mode", default=False),
InputParam("num_images_per_prompt", default=1),
InputParam("guidance_scale", default=5.0),
InputParam("guidance_rescale", default=0.0),
InputParam("cross_attention_kwargs"),
InputParam("generator"),
InputParam("eta", default=0.0),
@@ -2515,8 +2501,8 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32)
else:
image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
image_batch_size = image.shape[0]
if image_batch_size == 1:
repeat_by = batch_size
else:
@@ -2524,9 +2510,7 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
repeat_by = num_images_per_prompt
image = image.repeat_interleave(repeat_by, dim=0)
image = image.to(device=device, dtype=dtype)
return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components
@@ -2557,9 +2541,7 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
data.num_channels_unet = pipeline.unet.config.in_channels
# (1) prepare controlnet inputs
data.device = pipeline._execution_device
data.height, data.width = data.latents.shape[-2:]
data.height = data.height * pipeline.vae_scale_factor
data.width = data.width * pipeline.vae_scale_factor
@@ -2642,59 +2624,12 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
data.controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
# (2) Prepare conditional inputs for unet using the guider
# adding default guider arguments: disable_guidance, guidance_scale, guidance_rescale
data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False
data.guider_kwargs = data.guider_kwargs or {}
data.guider_kwargs = {
**data.guider_kwargs,
"disable_guidance": data.disable_guidance,
"guidance_scale": data.guidance_scale,
"guidance_rescale": data.guidance_rescale,
"batch_size": data.batch_size * data.num_images_per_prompt,
}
pipeline.guider.set_guider(pipeline, data.guider_kwargs)
data.prompt_embeds = pipeline.guider.prepare_input(
data.prompt_embeds,
data.negative_prompt_embeds,
)
data.add_time_ids = pipeline.guider.prepare_input(
data.add_time_ids,
data.negative_add_time_ids,
)
data.pooled_prompt_embeds = pipeline.guider.prepare_input(
data.pooled_prompt_embeds,
data.negative_pooled_prompt_embeds,
)
if data.num_channels_unet == 9:
data.mask = pipeline.guider.prepare_input(data.mask, data.mask)
data.masked_image_latents = pipeline.guider.prepare_input(data.masked_image_latents, data.masked_image_latents)
data.added_cond_kwargs = {
"text_embeds": data.pooled_prompt_embeds,
"time_ids": data.add_time_ids,
}
if data.ip_adapter_embeds is not None:
data.ip_adapter_embeds = pipeline.guider.prepare_input(data.ip_adapter_embeds, data.negative_ip_adapter_embeds)
data.added_cond_kwargs["image_embeds"] = data.ip_adapter_embeds
if data.disable_guidance:
pipeline.guider.force_disable()
# (3) Prepare conditional inputs for controlnet using the guider
data.controlnet_disable_guidance = True if data.disable_guidance or data.guess_mode else False
data.controlnet_guider_kwargs = data.guider_kwargs or {}
data.controlnet_guider_kwargs = {
**data.controlnet_guider_kwargs,
"disable_guidance": data.controlnet_disable_guidance,
"guidance_scale": data.guidance_scale,
"guidance_rescale": data.guidance_rescale,
"batch_size": data.batch_size * data.num_images_per_prompt,
}
pipeline.controlnet_guider.set_guider(pipeline, data.controlnet_guider_kwargs)
data.controlnet_prompt_embeds = pipeline.controlnet_guider.prepare_input(data.prompt_embeds)
data.controlnet_added_cond_kwargs = {
"text_embeds": pipeline.controlnet_guider.prepare_input(data.pooled_prompt_embeds),
"time_ids": pipeline.controlnet_guider.prepare_input(data.add_time_ids),
}
data.control_image = pipeline.controlnet_guider.prepare_input(data.control_image, data.control_image)
# (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta)
@@ -2703,11 +2638,26 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
# (5) Denoise loop
with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar:
for i, t in enumerate(data.timesteps):
# prepare latents for unet using the guider
data.latent_model_input = pipeline.guider.prepare_input(data.latents, data.latents)
pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t)
# prepare latents for controlnet using the guider
data.control_model_input = pipeline.controlnet_guider.prepare_input(data.latents, data.latents)
(
latents,
prompt_embeds,
add_time_ids,
pooled_prompt_embeds,
mask,
masked_image_latents,
ip_adapter_embeds,
) = pipeline.guider.prepare_inputs(
pipeline.unet,
data.latents,
(data.prompt_embeds, data.negative_prompt_embeds),
(data.add_time_ids, data.negative_add_time_ids),
(data.pooled_prompt_embeds, data.negative_pooled_prompt_embeds),
data.mask,
data.masked_image_latents,
(data.ip_adapter_embeds, data.negative_ip_adapter_embeds),
)
if isinstance(data.controlnet_keep[i], list):
data.cond_scale = [c * s for c, s in zip(data.controlnet_conditioning_scale, data.controlnet_keep[i])]
@@ -2717,51 +2667,74 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
data.controlnet_cond_scale = data.controlnet_cond_scale[0]
data.cond_scale = data.controlnet_cond_scale * data.controlnet_keep[i]
data.down_block_res_samples, data.mid_block_res_sample = pipeline.controlnet(
pipeline.scheduler.scale_model_input(data.control_model_input, t),
t,
encoder_hidden_states=data.controlnet_prompt_embeds,
controlnet_cond=data.control_image,
conditioning_scale=data.cond_scale,
guess_mode=data.guess_mode,
added_cond_kwargs=data.controlnet_added_cond_kwargs,
return_dict=False,
)
for batch_index, (
latents_i, prompt_embeds_i, add_time_ids_i, pooled_prompt_embeds_i, mask_i, masked_image_latents_i, ip_adapter_embeds_i
) in enumerate(zip(
latents, prompt_embeds, add_time_ids, pooled_prompt_embeds, mask, masked_image_latents, ip_adapter_embeds
)):
latents_i = pipeline.scheduler.scale_model_input(latents_i, t)
# Prepare for inpainting
if data.num_channels_unet == 9:
latents_i = torch.cat([latents_i, mask_i, masked_image_latents_i], dim=1)
# Prepare additional conditionings
data.added_cond_kwargs = {
"text_embeds": pooled_prompt_embeds_i,
"time_ids": add_time_ids_i,
}
if ip_adapter_embeds_i is not None:
data.added_cond_kwargs["image_embeds"] = ip_adapter_embeds_i
# Prepare controlnet additional conditionings
data.controlnet_added_cond_kwargs = {
"text_embeds": pooled_prompt_embeds_i,
"time_ids": add_time_ids_i,
}
# when we apply guidance for unet, but not for controlnet:
# add 0 to the unconditional batch
data.down_block_res_samples = pipeline.guider.prepare_input(
data.down_block_res_samples, [torch.zeros_like(d) for d in data.down_block_res_samples]
)
data.mid_block_res_sample = pipeline.guider.prepare_input(
data.mid_block_res_sample, torch.zeros_like(data.mid_block_res_sample)
)
if pipeline.guider.is_conditional or not data.guess_mode:
data.down_block_res_samples, data.mid_block_res_sample = pipeline.controlnet(
latents_i,
t,
encoder_hidden_states=prompt_embeds_i,
controlnet_cond=data.control_image,
conditioning_scale=data.cond_scale,
guess_mode=data.guess_mode,
added_cond_kwargs=data.controlnet_added_cond_kwargs,
return_dict=False,
)
elif pipeline.guider.is_unconditional and data.guess_mode:
data.down_block_res_samples = [torch.zeros_like(d) for d in data.down_block_res_samples]
data.mid_block_res_sample = torch.zeros_like(data.mid_block_res_sample)
data.latent_model_input = pipeline.scheduler.scale_model_input(data.latent_model_input, t)
if data.num_channels_unet == 9:
data.latent_model_input = torch.cat([data.latent_model_input, data.mask, data.masked_image_latents], dim=1)
if data.num_channels_unet == 9:
latents_i = torch.cat([latents_i, mask_i, masked_image_latents_i], dim=1)
data.noise_pred = pipeline.unet(
data.latent_model_input,
t,
encoder_hidden_states=data.prompt_embeds,
timestep_cond=data.timestep_cond,
cross_attention_kwargs=data.cross_attention_kwargs,
added_cond_kwargs=data.added_cond_kwargs,
down_block_additional_residuals=data.down_block_res_samples,
mid_block_additional_residual=data.mid_block_res_sample,
return_dict=False,
)[0]
# perform guidance
data.noise_pred = pipeline.guider.apply_guidance(data.noise_pred, timestep=t, latents=data.latents)
# compute the previous noisy sample x_t -> x_t-1
data.noise_pred = pipeline.unet(
latents_i,
t,
encoder_hidden_states=prompt_embeds_i,
timestep_cond=data.timestep_cond,
cross_attention_kwargs=data.cross_attention_kwargs,
added_cond_kwargs=data.added_cond_kwargs,
down_block_additional_residuals=data.down_block_res_samples,
mid_block_additional_residual=data.mid_block_res_sample,
return_dict=False,
)[0]
data.noise_pred = pipeline.guider.prepare_outputs(pipeline.unet, data.noise_pred)
# Perform guidance
outputs = pipeline.guider.outputs
data.noise_pred = pipeline.guider(**outputs)
# Perform scheduler step using the predicted output
data.latents_dtype = data.latents.dtype
data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0]
if data.latents.dtype != data.latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
data.latents = data.latents.to(data.latents_dtype)
if data.num_channels_unet == 4 and data.mask is not None and data.image_latents is not None:
data.init_latents_proper = data.image_latents
@@ -2775,9 +2748,6 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
if i == len(data.timesteps) - 1 or ((i + 1) > data.num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0):
progress_bar.update()
pipeline.guider.reset_guider(pipeline)
pipeline.controlnet_guider.reset_guider(pipeline)
self.add_block_state(state, data)