1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

refactor controlnet union

This commit is contained in:
yiyixuxu
2025-05-04 22:17:25 +02:00
parent efd70b7838
commit 43ac1ff7e7

View File

@@ -2613,12 +2613,6 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam(
"num_inference_steps",
required=True,
type_hint=int,
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step."
),
InputParam("num_images_per_prompt", default=1),
InputParam("cross_attention_kwargs"),
InputParam("generator"),
@@ -2755,6 +2749,12 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
type_hint=Optional[torch.Tensor],
description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step."
),
InputParam(
"num_inference_steps",
required=True,
type_hint=int,
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step."
),
]
@property
@@ -2940,25 +2940,198 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
return components, state
class StableDiffusionXLControlNetUnionInputStep(PipelineBlock):
model_name = "stable-diffusion-xl"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("controlnet", ControlNetUnionModel),
ComponentSpec("control_image_processor", VaeImageProcessor, config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), default_creation_method="from_config"),
]
@property
def description(self) -> str:
return "step that prepares inputs for the ControlNetUnion model"
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("control_image", required=True),
InputParam("control_mode", default=[0]),
InputParam("control_guidance_start", default=0.0),
InputParam("control_guidance_end", default=1.0),
InputParam("controlnet_conditioning_scale", default=1.0),
InputParam("guess_mode", default=False),
InputParam("num_images_per_prompt", default=1),
]
@property
def intermediates_inputs(self) -> List[InputParam]:
return [
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step."
),
InputParam(
"batch_size",
required=True,
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."
),
InputParam(
"dtype",
required=True,
type_hint=torch.dtype,
description="The dtype of model tensor inputs. Can be generated in input step."
),
InputParam(
"timesteps",
required=True,
type_hint=torch.Tensor,
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step."
),
InputParam(
"crops_coords",
type_hint=Optional[Tuple[int]],
description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step."
),
]
@property
def intermediates_outputs(self) -> List[OutputParam]:
return [
OutputParam("control_image", type_hint=List[torch.Tensor], description="The processed control images"),
OutputParam("control_mode", type_hint=List[int], description="The control mode indices"),
OutputParam("control_type", type_hint=torch.Tensor, description="The control type tensor that specifies which control type is active"),
OutputParam("control_guidance_start", type_hint=float, description="The controlnet guidance start value"),
OutputParam("control_guidance_end", type_hint=float, description="The controlnet guidance end value"),
OutputParam("controlnet_conditioning_scale", type_hint=float, description="The controlnet conditioning scale value"),
OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used"),
OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"),
]
# Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image
# 1. return image without apply any guidance
# 2. add crops_coords and resize_mode to preprocess()
@staticmethod
def prepare_control_image(
components,
image,
width,
height,
batch_size,
num_images_per_prompt,
device,
dtype,
crops_coords=None,
):
if crops_coords is not None:
image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32)
else:
image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
image_batch_size = image.shape[0]
if image_batch_size == 1:
repeat_by = batch_size
else:
# image batch size is the same as prompt batch size
repeat_by = num_images_per_prompt
image = image.repeat_interleave(repeat_by, dim=0)
image = image.to(device=device, dtype=dtype)
return image
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
controlnet = unwrap_module(components.controlnet)
device = block_state.device or components._execution_device
dtype = block_state.dtype or components.controlnet.dtype
block_state.height, block_state.width = block_state.latents.shape[-2:]
block_state.height = block_state.height * components.vae_scale_factor
block_state.width = block_state.width * components.vae_scale_factor
# control_guidance_start/control_guidance_end (align format)
if not isinstance(block_state.control_guidance_start, list) and isinstance(block_state.control_guidance_end, list):
block_state.control_guidance_start = len(block_state.control_guidance_end) * [block_state.control_guidance_start]
elif not isinstance(block_state.control_guidance_end, list) and isinstance(block_state.control_guidance_start, list):
block_state.control_guidance_end = len(block_state.control_guidance_start) * [block_state.control_guidance_end]
# guess_mode
block_state.global_pool_conditions = controlnet.config.global_pool_conditions
block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions
if not isinstance(block_state.control_image, list):
block_state.control_image = [block_state.control_image]
if not isinstance(block_state.control_mode, list):
block_state.control_mode = [block_state.control_mode]
if len(block_state.control_image) != len(block_state.control_mode):
raise ValueError("Expected len(control_image) == len(control_type)")
# control_type
block_state.num_control_type = controlnet.config.num_control_type
block_state.control_type = [0 for _ in range(block_state.num_control_type)]
for control_idx in block_state.control_mode:
block_state.control_type[control_idx] = 1
block_state.control_type = torch.Tensor(block_state.control_type)
block_state.control_type = block_state.control_type.reshape(1, -1).to(device, dtype=block_state.dtype)
repeat_by = block_state.batch_size * block_state.num_images_per_prompt // block_state.control_type.shape[0]
block_state.control_type = block_state.control_type.repeat_interleave(repeat_by, dim=0)
# prepare control_image
for idx, _ in enumerate(block_state.control_image):
block_state.control_image[idx] = self.prepare_control_image(
components,
image=block_state.control_image[idx],
width=block_state.width,
height=block_state.height,
batch_size=block_state.batch_size * block_state.num_images_per_prompt,
num_images_per_prompt=block_state.num_images_per_prompt,
device=device,
dtype=dtype,
crops_coords=block_state.crops_coords,
)
block_state.height, block_state.width = block_state.control_image[idx].shape[-2:]
# controlnet_keep
block_state.controlnet_keep = []
for i in range(len(block_state.timesteps)):
block_state.controlnet_keep.append(
1.0
- float(i / len(block_state.timesteps) < block_state.control_guidance_start or (i + 1) / len(block_state.timesteps) > block_state.control_guidance_end)
)
self.add_block_state(state, block_state)
return components, state
class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock):
model_name = "stable-diffusion-xl"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("unet", UNet2DConditionModel),
ComponentSpec("controlnet", ControlNetUnionModel),
ComponentSpec("scheduler", EulerDiscreteScheduler),
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 7.5}),
default_creation_method="from_config"),
ComponentSpec(
"control_image_processor",
VaeImageProcessor,
config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}),
default_creation_method="from_config"),
ComponentSpec("scheduler", EulerDiscreteScheduler),
ComponentSpec("unet", UNet2DConditionModel),
ComponentSpec("controlnet", ControlNetUnionModel),
]
@property
@@ -2967,12 +3140,6 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock):
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("control_image", required=True),
InputParam("control_guidance_start", default=0.0),
InputParam("control_guidance_end", default=1.0),
InputParam("control_mode", required=True),
InputParam("controlnet_conditioning_scale", default=1.0),
InputParam("guess_mode", default=False),
InputParam("num_images_per_prompt", default=1),
InputParam("cross_attention_kwargs"),
InputParam("generator"),
@@ -2983,15 +3150,75 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock):
def intermediates_inputs(self) -> List[str]:
return [
InputParam(
"latents",
"control_image",
required=True,
type_hint=List[torch.Tensor],
description="The control images to use for conditioning. Can be generated in prepare controlnet inputs step."
),
InputParam(
"control_mode",
required=True,
type_hint=List[int],
description="The control mode indices. Can be generated in prepare controlnet inputs step."
),
InputParam(
"control_type",
required=True,
type_hint=torch.Tensor,
description="The control type tensor that specifies which control type is active. Can be generated in prepare controlnet inputs step."
),
InputParam(
"num_control_type",
required=True,
type_hint=int,
description="The number of control types available. Can be generated in prepare controlnet inputs step."
),
InputParam(
"control_guidance_start",
required=True,
type_hint=float,
description="The control guidance start value. Can be generated in prepare controlnet inputs step."
),
InputParam(
"control_guidance_end",
required=True,
type_hint=float,
description="The control guidance end value. Can be generated in prepare controlnet inputs step."
),
InputParam(
"controlnet_conditioning_scale",
required=True,
type_hint=float,
description="The controlnet conditioning scale value. Can be generated in prepare controlnet inputs step."
),
InputParam(
"guess_mode",
required=True,
type_hint=bool,
description="Whether guess mode is used. Can be generated in prepare controlnet inputs step."
),
InputParam(
"global_pool_conditions",
required=True,
type_hint=bool,
description="Whether global pool conditions are used. Can be generated in prepare controlnet inputs step."
),
InputParam(
"controlnet_keep",
required=True,
type_hint=List[float],
description="The controlnet keep values. Can be generated in prepare controlnet inputs step."
),
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step."
),
InputParam(
"batch_size",
required=True,
type_hint=int,
required=True,
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."
),
InputParam(
@@ -3045,23 +3272,23 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock):
description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step."
),
InputParam(
"mask",
type_hint=Optional[torch.Tensor],
"mask",
type_hint=Optional[torch.Tensor],
description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step."
),
InputParam(
"masked_image_latents",
type_hint=Optional[torch.Tensor],
"masked_image_latents",
type_hint=Optional[torch.Tensor],
description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step."
),
InputParam(
"noise",
type_hint=Optional[torch.Tensor],
"noise",
type_hint=Optional[torch.Tensor],
description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step."
),
InputParam(
"image_latents",
type_hint=Optional[torch.Tensor],
"image_latents",
type_hint=Optional[torch.Tensor],
description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step."
),
InputParam(
@@ -3070,19 +3297,19 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock):
description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode or prepare_latent step."
),
InputParam(
"ip_adapter_embeds",
type_hint=Optional[torch.Tensor],
"ip_adapter_embeds",
type_hint=Optional[torch.Tensor],
description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step."
),
InputParam(
"negative_ip_adapter_embeds",
type_hint=Optional[torch.Tensor],
"negative_ip_adapter_embeds",
type_hint=Optional[torch.Tensor],
description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step."
),
]
@property
def intermediates_outputs(self) -> List[str]:
def intermediates_outputs(self) -> List[OutputParam]:
return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")]
@staticmethod
@@ -3105,39 +3332,7 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock):
" `components.unet` or your `mask_image` or `image` input."
)
# Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image
# 1. return image without apply any guidance
# 2. add crops_coords and resize_mode to preprocess()
@staticmethod
def prepare_control_image(
components,
image,
width,
height,
batch_size,
num_images_per_prompt,
device,
dtype,
crops_coords=None,
):
if crops_coords is not None:
image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32)
else:
image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
image_batch_size = image.shape[0]
if image_batch_size == 1:
repeat_by = batch_size
else:
# image batch size is the same as prompt batch size
repeat_by = num_images_per_prompt
image = image.repeat_interleave(repeat_by, dim=0)
image = image.to(device=device, dtype=dtype)
return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components
@staticmethod
@@ -3164,85 +3359,20 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock):
self.check_inputs(components, block_state)
block_state.num_channels_unet = components.unet.config.in_channels
# (1) prepare controlnet inputs
block_state.device = components._execution_device
block_state.height, block_state.width = block_state.latents.shape[-2:]
block_state.height = block_state.height * components.vae_scale_factor
block_state.width = block_state.width * components.vae_scale_factor
controlnet = unwrap_module(components.controlnet)
# Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
block_state.extra_step_kwargs = self.prepare_extra_step_kwargs(components, block_state.generator, block_state.eta)
block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0)
# (1.1)
# control guidance
if not isinstance(block_state.control_guidance_start, list) and isinstance(block_state.control_guidance_end, list):
block_state.control_guidance_start = len(block_state.control_guidance_end) * [block_state.control_guidance_start]
elif not isinstance(block_state.control_guidance_end, list) and isinstance(block_state.control_guidance_start, list):
block_state.control_guidance_end = len(block_state.control_guidance_start) * [block_state.control_guidance_end]
# (1.2)
# global_pool_conditions & guess_mode
block_state.global_pool_conditions = controlnet.config.global_pool_conditions
block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions
# (1.3)
# control_type
block_state.num_control_type = controlnet.config.num_control_type
# (1.4)
# control_type
if not isinstance(block_state.control_image, list):
block_state.control_image = [block_state.control_image]
if not isinstance(block_state.control_mode, list):
block_state.control_mode = [block_state.control_mode]
if len(block_state.control_image) != len(block_state.control_mode):
raise ValueError("Expected len(control_image) == len(control_type)")
block_state.control_type = [0 for _ in range(block_state.num_control_type)]
for control_idx in block_state.control_mode:
block_state.control_type[control_idx] = 1
block_state.control_type = torch.Tensor(block_state.control_type)
# (1.5)
# prepare control_image
for idx, _ in enumerate(block_state.control_image):
block_state.control_image[idx] = self.prepare_control_image(
components,
image=block_state.control_image[idx],
width=block_state.width,
height=block_state.height,
batch_size=block_state.batch_size * block_state.num_images_per_prompt,
num_images_per_prompt=block_state.num_images_per_prompt,
device=block_state.device,
dtype=controlnet.dtype,
crops_coords=block_state.crops_coords,
)
block_state.height, block_state.width = block_state.control_image[idx].shape[-2:]
# (1.6)
# controlnet_keep
block_state.controlnet_keep = []
for i in range(len(block_state.timesteps)):
block_state.controlnet_keep.append(
1.0
- float(i / len(block_state.timesteps) < block_state.control_guidance_start or (i + 1) / len(block_state.timesteps) > block_state.control_guidance_end)
)
# (2) Prepare conditional inputs for unet using the guider
# adding default guider arguments: disable_guidance, guidance_scale, guidance_rescale
# Setup guider
# disable for LCMs
block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False
if block_state.disable_guidance:
components.guider.disable()
else:
components.guider.enable()
block_state.control_type = block_state.control_type.reshape(1, -1).to(block_state.device, dtype=block_state.prompt_embeds.dtype)
repeat_by = block_state.batch_size * block_state.num_images_per_prompt // block_state.control_type.shape[0]
block_state.control_type = block_state.control_type.repeat_interleave(repeat_by, dim=0)
# (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
block_state.extra_step_kwargs = self.prepare_extra_step_kwargs(components, block_state.generator, block_state.eta)
block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0)
@@ -3612,7 +3742,7 @@ class StableDiffusionXLAutoBeforeDenoiseStep(AutoPipelineBlocks):
# Denoise
class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks):
block_classes = [StableDiffusionXLControlNetUnionDenoiseStep, StableDiffusionXLControlNetStep, StableDiffusionXLDenoiseStep]
block_classes = [StableDiffusionXLControlNetUnionStep, StableDiffusionXLControlNetStep, StableDiffusionXLDenoiseStep]
block_names = ["controlnet_union", "controlnet", "unet"]
block_trigger_inputs = ["control_mode", "control_image", None]
@@ -3620,8 +3750,8 @@ class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks):
def description(self):
return "Denoise step that denoise the latents.\n" + \
"This is an auto pipeline block that works for controlnet, controlnet_union and no controlnet.\n" + \
" - `StableDiffusionXLControlNetUnionDenoiseStep` (controlnet_union) is used when both `control_mode` and `control_image` are provided.\n" + \
" - `StableDiffusionXLControlStep` (controlnet) is used when `control_image` is provided.\n" + \
" - `StableDiffusionXLControlNetUnionStep` (controlnet_union) is used when both `control_mode` and `control_image` are provided.\n" + \
" - `StableDiffusionXLControlNetStep` (controlnet) is used when `control_image` is provided.\n" + \
" - `StableDiffusionXLDenoiseStep` (unet only) is used when both `control_mode` and `control_image` are not provided."
# After denoise
@@ -3733,7 +3863,7 @@ CONTROLNET_BLOCKS = OrderedDict([
])
CONTROLNET_UNION_BLOCKS = OrderedDict([
("denoise", StableDiffusionXLControlNetUnionDenoiseStep),
("denoise", StableDiffusionXLControlNetUnionStep),
])
IP_ADAPTER_BLOCKS = OrderedDict([
@@ -3865,3 +3995,15 @@ SDXL_INTERMEDIATE_OUTPUTS_SCHEMA = {
SDXL_OUTPUTS_SCHEMA = {
"images": OutputParam("images", type_hint=Union[Tuple[Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]]], StableDiffusionXLPipelineOutput], description="The final generated images")
}
class StableDiffusionXLControlNetUnionStep(SequentialPipelineBlocks):
block_classes = [StableDiffusionXLControlNetUnionInputStep, StableDiffusionXLControlNetUnionDenoiseStep]
block_names = ["prepare_input", "denoise"]
@property
def description(self):
return "ControlNetUnion step that denoises the latents.\n" + \
"This is a sequential pipeline blocks:\n" + \
" - `StableDiffusionXLControlNetUnionInputStep` is used to prepare the inputs for the denoise step.\n" + \
" - `StableDiffusionXLControlNetUnionDenoiseStep` is used to denoise the latents using the ControlNetUnion model."