mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
seperate controlnet step into input + denoise
This commit is contained in:
@@ -2395,27 +2395,20 @@ class StableDiffusionXLDenoiseStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
|
||||
class StableDiffusionXLControlNetInputStep(PipelineBlock):
|
||||
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 7.5}),
|
||||
default_creation_method="from_config"),
|
||||
ComponentSpec("scheduler", EulerDiscreteScheduler),
|
||||
ComponentSpec("unet", UNet2DConditionModel),
|
||||
ComponentSpec("controlnet", ControlNetModel),
|
||||
ComponentSpec("control_image_processor", VaeImageProcessor, config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), default_creation_method="from_config"),
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process"
|
||||
return "step that prepare inputs for controlnet"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
@@ -2426,9 +2419,6 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
|
||||
InputParam("controlnet_conditioning_scale", default=1.0),
|
||||
InputParam("guess_mode", default=False),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam("cross_attention_kwargs"),
|
||||
InputParam("generator"),
|
||||
InputParam("eta", default=0.0),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -2452,12 +2442,246 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
|
||||
type_hint=torch.Tensor,
|
||||
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step."
|
||||
),
|
||||
InputParam(
|
||||
"crops_coords",
|
||||
type_hint=Optional[Tuple[int]],
|
||||
description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step."
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediates_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam("control_image", type_hint=torch.Tensor, description="The processed control image"),
|
||||
OutputParam("control_guidance_start", type_hint=List[float], description="The controlnet guidance start values"),
|
||||
OutputParam("control_guidance_end", type_hint=List[float], description="The controlnet guidance end values"),
|
||||
OutputParam("controlnet_conditioning_scale", type_hint=List[float], description="The controlnet conditioning scale values"),
|
||||
OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used"),
|
||||
OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"),
|
||||
]
|
||||
|
||||
|
||||
|
||||
# Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image
|
||||
# 1. return image without apply any guidance
|
||||
# 2. add crops_coords and resize_mode to preprocess()
|
||||
@staticmethod
|
||||
def prepare_control_image(
|
||||
components,
|
||||
image,
|
||||
width,
|
||||
height,
|
||||
batch_size,
|
||||
num_images_per_prompt,
|
||||
device,
|
||||
dtype,
|
||||
crops_coords=None,
|
||||
):
|
||||
if crops_coords is not None:
|
||||
image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32)
|
||||
else:
|
||||
image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
|
||||
|
||||
image_batch_size = image.shape[0]
|
||||
if image_batch_size == 1:
|
||||
repeat_by = batch_size
|
||||
else:
|
||||
# image batch size is the same as prompt batch size
|
||||
repeat_by = num_images_per_prompt
|
||||
|
||||
image = image.repeat_interleave(repeat_by, dim=0)
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
return image
|
||||
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState:
|
||||
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
# (1) prepare controlnet inputs
|
||||
block_state.device = components._execution_device
|
||||
block_state.height, block_state.width = block_state.latents.shape[-2:]
|
||||
block_state.height = block_state.height * components.vae_scale_factor
|
||||
block_state.width = block_state.width * components.vae_scale_factor
|
||||
|
||||
controlnet = unwrap_module(components.controlnet)
|
||||
|
||||
# (1.1)
|
||||
# control_guidance_start/control_guidance_end (align format)
|
||||
if not isinstance(block_state.control_guidance_start, list) and isinstance(block_state.control_guidance_end, list):
|
||||
block_state.control_guidance_start = len(block_state.control_guidance_end) * [block_state.control_guidance_start]
|
||||
elif not isinstance(block_state.control_guidance_end, list) and isinstance(block_state.control_guidance_start, list):
|
||||
block_state.control_guidance_end = len(block_state.control_guidance_start) * [block_state.control_guidance_end]
|
||||
elif not isinstance(block_state.control_guidance_start, list) and not isinstance(block_state.control_guidance_end, list):
|
||||
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
|
||||
block_state.control_guidance_start, block_state.control_guidance_end = (
|
||||
mult * [block_state.control_guidance_start],
|
||||
mult * [block_state.control_guidance_end],
|
||||
)
|
||||
|
||||
# (1.2)
|
||||
# controlnet_conditioning_scale (align format)
|
||||
if isinstance(controlnet, MultiControlNetModel) and isinstance(block_state.controlnet_conditioning_scale, float):
|
||||
block_state.controlnet_conditioning_scale = [block_state.controlnet_conditioning_scale] * len(controlnet.nets)
|
||||
|
||||
# (1.3)
|
||||
# global_pool_conditions
|
||||
block_state.global_pool_conditions = (
|
||||
controlnet.config.global_pool_conditions
|
||||
if isinstance(controlnet, ControlNetModel)
|
||||
else controlnet.nets[0].config.global_pool_conditions
|
||||
)
|
||||
# (1.4)
|
||||
# guess_mode
|
||||
block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions
|
||||
|
||||
# (1.5)
|
||||
# control_image
|
||||
if isinstance(controlnet, ControlNetModel):
|
||||
block_state.control_image = self.prepare_control_image(
|
||||
components,
|
||||
image=block_state.control_image,
|
||||
width=block_state.width,
|
||||
height=block_state.height,
|
||||
batch_size=block_state.batch_size * block_state.num_images_per_prompt,
|
||||
num_images_per_prompt=block_state.num_images_per_prompt,
|
||||
device=block_state.device,
|
||||
dtype=controlnet.dtype,
|
||||
crops_coords=block_state.crops_coords,
|
||||
)
|
||||
elif isinstance(controlnet, MultiControlNetModel):
|
||||
control_images = []
|
||||
|
||||
for control_image_ in block_state.control_image:
|
||||
control_image = self.prepare_control_image(
|
||||
components,
|
||||
image=control_image_,
|
||||
width=block_state.width,
|
||||
height=block_state.height,
|
||||
batch_size=block_state.batch_size * block_state.num_images_per_prompt,
|
||||
num_images_per_prompt=block_state.num_images_per_prompt,
|
||||
device=block_state.device,
|
||||
dtype=controlnet.dtype,
|
||||
crops_coords=block_state.crops_coords,
|
||||
)
|
||||
|
||||
control_images.append(control_image)
|
||||
|
||||
block_state.control_image = control_images
|
||||
else:
|
||||
assert False
|
||||
|
||||
# (1.6)
|
||||
# controlnet_keep
|
||||
block_state.controlnet_keep = []
|
||||
for i in range(len(block_state.timesteps)):
|
||||
keeps = [
|
||||
1.0 - float(i / len(block_state.timesteps) < s or (i + 1) / len(block_state.timesteps) > e)
|
||||
for s, e in zip(block_state.control_guidance_start, block_state.control_guidance_end)
|
||||
]
|
||||
block_state.controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
|
||||
|
||||
|
||||
|
||||
self.add_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
|
||||
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 7.5}),
|
||||
default_creation_method="from_config"),
|
||||
ComponentSpec("scheduler", EulerDiscreteScheduler),
|
||||
ComponentSpec("unet", UNet2DConditionModel),
|
||||
ComponentSpec("controlnet", ControlNetModel),
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam(
|
||||
"num_inference_steps",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step."
|
||||
),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam("cross_attention_kwargs"),
|
||||
InputParam("generator"),
|
||||
InputParam("eta", default=0.0),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediates_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"control_image",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step."
|
||||
),
|
||||
InputParam(
|
||||
"control_guidance_start",
|
||||
required=True,
|
||||
type_hint=float,
|
||||
description="The control guidance start value to use for the denoising process. Can be generated in prepare_controlnet_inputs step."
|
||||
),
|
||||
InputParam(
|
||||
"control_guidance_end",
|
||||
required=True,
|
||||
type_hint=float,
|
||||
description="The control guidance end value to use for the denoising process. Can be generated in prepare_controlnet_inputs step."
|
||||
),
|
||||
InputParam(
|
||||
"controlnet_conditioning_scale",
|
||||
required=True,
|
||||
type_hint=float,
|
||||
description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step."
|
||||
),
|
||||
InputParam(
|
||||
"guess_mode",
|
||||
required=True,
|
||||
type_hint=bool,
|
||||
description="The guess mode value to use for the denoising process. Can be generated in prepare_controlnet_inputs step."
|
||||
),
|
||||
InputParam(
|
||||
"controlnet_keep",
|
||||
required=True,
|
||||
type_hint=List[float],
|
||||
description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step."
|
||||
),
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step."
|
||||
),
|
||||
InputParam(
|
||||
"batch_size",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."
|
||||
),
|
||||
InputParam(
|
||||
"timesteps",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step."
|
||||
),
|
||||
InputParam(
|
||||
"prompt_embeds",
|
||||
required=True,
|
||||
@@ -2557,36 +2781,6 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
|
||||
" `components.unet` or your `mask_image` or `image` input."
|
||||
)
|
||||
|
||||
# Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image
|
||||
# 1. return image without apply any guidance
|
||||
# 2. add crops_coords and resize_mode to preprocess()
|
||||
@staticmethod
|
||||
def prepare_control_image(
|
||||
components,
|
||||
image,
|
||||
width,
|
||||
height,
|
||||
batch_size,
|
||||
num_images_per_prompt,
|
||||
device,
|
||||
dtype,
|
||||
crops_coords=None,
|
||||
):
|
||||
if crops_coords is not None:
|
||||
image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32)
|
||||
else:
|
||||
image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
|
||||
|
||||
image_batch_size = image.shape[0]
|
||||
if image_batch_size == 1:
|
||||
repeat_by = batch_size
|
||||
else:
|
||||
# image batch size is the same as prompt batch size
|
||||
repeat_by = num_images_per_prompt
|
||||
|
||||
image = image.repeat_interleave(repeat_by, dim=0)
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
return image
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components
|
||||
@staticmethod
|
||||
@@ -2613,103 +2807,19 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
|
||||
|
||||
block_state = self.get_block_state(state)
|
||||
self.check_inputs(components, block_state)
|
||||
|
||||
block_state.num_channels_unet = components.unet.config.in_channels
|
||||
|
||||
# (1) prepare controlnet inputs
|
||||
block_state.device = components._execution_device
|
||||
block_state.height, block_state.width = block_state.latents.shape[-2:]
|
||||
block_state.height = block_state.height * components.vae_scale_factor
|
||||
block_state.width = block_state.width * components.vae_scale_factor
|
||||
|
||||
controlnet = unwrap_module(components.controlnet)
|
||||
|
||||
# (1.1)
|
||||
# control_guidance_start/control_guidance_end (align format)
|
||||
if not isinstance(block_state.control_guidance_start, list) and isinstance(block_state.control_guidance_end, list):
|
||||
block_state.control_guidance_start = len(block_state.control_guidance_end) * [block_state.control_guidance_start]
|
||||
elif not isinstance(block_state.control_guidance_end, list) and isinstance(block_state.control_guidance_start, list):
|
||||
block_state.control_guidance_end = len(block_state.control_guidance_start) * [block_state.control_guidance_end]
|
||||
elif not isinstance(block_state.control_guidance_start, list) and not isinstance(block_state.control_guidance_end, list):
|
||||
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
|
||||
block_state.control_guidance_start, block_state.control_guidance_end = (
|
||||
mult * [block_state.control_guidance_start],
|
||||
mult * [block_state.control_guidance_end],
|
||||
)
|
||||
|
||||
# (1.2)
|
||||
# controlnet_conditioning_scale (align format)
|
||||
if isinstance(controlnet, MultiControlNetModel) and isinstance(block_state.controlnet_conditioning_scale, float):
|
||||
block_state.controlnet_conditioning_scale = [block_state.controlnet_conditioning_scale] * len(controlnet.nets)
|
||||
|
||||
# (1.3)
|
||||
# global_pool_conditions
|
||||
block_state.global_pool_conditions = (
|
||||
controlnet.config.global_pool_conditions
|
||||
if isinstance(controlnet, ControlNetModel)
|
||||
else controlnet.nets[0].config.global_pool_conditions
|
||||
)
|
||||
# (1.4)
|
||||
# guess_mode
|
||||
block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions
|
||||
|
||||
# (1.5)
|
||||
# control_image
|
||||
if isinstance(controlnet, ControlNetModel):
|
||||
block_state.control_image = self.prepare_control_image(
|
||||
components,
|
||||
image=block_state.control_image,
|
||||
width=block_state.width,
|
||||
height=block_state.height,
|
||||
batch_size=block_state.batch_size * block_state.num_images_per_prompt,
|
||||
num_images_per_prompt=block_state.num_images_per_prompt,
|
||||
device=block_state.device,
|
||||
dtype=controlnet.dtype,
|
||||
crops_coords=block_state.crops_coords,
|
||||
)
|
||||
elif isinstance(controlnet, MultiControlNetModel):
|
||||
control_images = []
|
||||
|
||||
for control_image_ in block_state.control_image:
|
||||
control_image = self.prepare_control_image(
|
||||
components,
|
||||
image=control_image_,
|
||||
width=block_state.width,
|
||||
height=block_state.height,
|
||||
batch_size=block_state.batch_size * block_state.num_images_per_prompt,
|
||||
num_images_per_prompt=block_state.num_images_per_prompt,
|
||||
device=block_state.device,
|
||||
dtype=controlnet.dtype,
|
||||
crops_coords=block_state.crops_coords,
|
||||
)
|
||||
|
||||
control_images.append(control_image)
|
||||
|
||||
block_state.control_image = control_images
|
||||
else:
|
||||
assert False
|
||||
|
||||
# (1.6)
|
||||
# controlnet_keep
|
||||
block_state.controlnet_keep = []
|
||||
for i in range(len(block_state.timesteps)):
|
||||
keeps = [
|
||||
1.0 - float(i / len(block_state.timesteps) < s or (i + 1) / len(block_state.timesteps) > e)
|
||||
for s, e in zip(block_state.control_guidance_start, block_state.control_guidance_end)
|
||||
]
|
||||
block_state.controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
|
||||
|
||||
# (2) Prepare conditional inputs for unet using the guider
|
||||
# Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
block_state.extra_step_kwargs = self.prepare_extra_step_kwargs(components, block_state.generator, block_state.eta)
|
||||
block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0)
|
||||
|
||||
# (1) setup guider
|
||||
# disable for LCMs
|
||||
block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False
|
||||
if block_state.disable_guidance:
|
||||
components.guider.disable()
|
||||
else:
|
||||
components.guider.enable()
|
||||
|
||||
# (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
block_state.extra_step_kwargs = self.prepare_extra_step_kwargs(components, block_state.generator, block_state.eta)
|
||||
block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0)
|
||||
|
||||
components.guider.set_input_fields(
|
||||
prompt_embeds=("prompt_embeds", "negative_prompt_embeds"),
|
||||
add_time_ids=("add_time_ids", "negative_add_time_ids"),
|
||||
@@ -2720,11 +2830,16 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
|
||||
# (5) Denoise loop
|
||||
with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(block_state.timesteps):
|
||||
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
|
||||
guider_data = components.guider.prepare_inputs(block_state)
|
||||
|
||||
# prepare latent input for unet
|
||||
block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t)
|
||||
# adjust latent input for inpainting
|
||||
block_state.num_channels_unet = components.unet.config.in_channels
|
||||
if block_state.num_channels_unet == 9:
|
||||
block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1)
|
||||
|
||||
|
||||
# cond_scale (controlnet input)
|
||||
if isinstance(block_state.controlnet_keep[i], list):
|
||||
block_state.cond_scale = [c * s for c, s in zip(block_state.controlnet_conditioning_scale, block_state.controlnet_keep[i])]
|
||||
else:
|
||||
@@ -2733,62 +2848,69 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
|
||||
block_state.controlnet_cond_scale = block_state.controlnet_cond_scale[0]
|
||||
block_state.cond_scale = block_state.controlnet_cond_scale * block_state.controlnet_keep[i]
|
||||
|
||||
for batch in guider_data:
|
||||
# default controlnet output/unet input for guess mode + conditional path
|
||||
block_state.down_block_res_samples_zeros = None
|
||||
block_state.mid_block_res_sample_zeros = None
|
||||
|
||||
# guided denoiser step
|
||||
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
|
||||
guider_state = components.guider.prepare_inputs(block_state)
|
||||
|
||||
for guider_state_batch in guider_state:
|
||||
components.guider.prepare_models(components.unet)
|
||||
|
||||
# Prepare additional conditionings
|
||||
batch.added_cond_kwargs = {
|
||||
"text_embeds": batch.pooled_prompt_embeds,
|
||||
"time_ids": batch.add_time_ids,
|
||||
guider_state_batch.added_cond_kwargs = {
|
||||
"text_embeds": guider_state_batch.pooled_prompt_embeds,
|
||||
"time_ids": guider_state_batch.add_time_ids,
|
||||
}
|
||||
if batch.ip_adapter_embeds is not None:
|
||||
batch.added_cond_kwargs["image_embeds"] = batch.ip_adapter_embeds
|
||||
if guider_state_batch.ip_adapter_embeds is not None:
|
||||
guider_state_batch.added_cond_kwargs["image_embeds"] = guider_state_batch.ip_adapter_embeds
|
||||
|
||||
# Prepare controlnet additional conditionings
|
||||
batch.controlnet_added_cond_kwargs = {
|
||||
"text_embeds": batch.pooled_prompt_embeds,
|
||||
"time_ids": batch.add_time_ids,
|
||||
guider_state_batch.controlnet_added_cond_kwargs = {
|
||||
"text_embeds": guider_state_batch.pooled_prompt_embeds,
|
||||
"time_ids": guider_state_batch.add_time_ids,
|
||||
}
|
||||
|
||||
# Will always be run atleast once with every guider
|
||||
if components.guider.is_conditional or not block_state.guess_mode:
|
||||
block_state.down_block_res_samples, block_state.mid_block_res_sample = components.controlnet(
|
||||
if block_state.guess_mode and not components.guider.is_conditional:
|
||||
# guider always run uncond batch first, so these tensors should be set already
|
||||
guider_state_batch.down_block_res_samples = block_state.down_block_res_samples_zeros
|
||||
guider_state_batch.mid_block_res_sample = block_state.mid_block_res_sample_zeros
|
||||
else:
|
||||
guider_state_batch.down_block_res_samples, guider_state_batch.mid_block_res_sample = components.controlnet(
|
||||
block_state.scaled_latents,
|
||||
t,
|
||||
encoder_hidden_states=batch.prompt_embeds,
|
||||
encoder_hidden_states=guider_state_batch.prompt_embeds,
|
||||
controlnet_cond=block_state.control_image,
|
||||
conditioning_scale=block_state.cond_scale,
|
||||
guess_mode=block_state.guess_mode,
|
||||
added_cond_kwargs=batch.controlnet_added_cond_kwargs,
|
||||
added_cond_kwargs=guider_state_batch.controlnet_added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
batch.down_block_res_samples = block_state.down_block_res_samples
|
||||
batch.mid_block_res_sample = block_state.mid_block_res_sample
|
||||
if block_state.down_block_res_samples_zeros is None:
|
||||
block_state.down_block_res_samples_zeros = [torch.zeros_like(d) for d in guider_state_batch.down_block_res_samples]
|
||||
if block_state.mid_block_res_sample_zeros is None:
|
||||
block_state.mid_block_res_sample_zeros = torch.zeros_like(guider_state_batch.mid_block_res_sample)
|
||||
|
||||
if components.guider.is_unconditional and block_state.guess_mode:
|
||||
batch.down_block_res_samples = [torch.zeros_like(d) for d in block_state.down_block_res_samples]
|
||||
batch.mid_block_res_sample = torch.zeros_like(block_state.mid_block_res_sample)
|
||||
|
||||
# Prepare for inpainting
|
||||
if block_state.num_channels_unet == 9:
|
||||
block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1)
|
||||
|
||||
batch.noise_pred = components.unet(
|
||||
|
||||
guider_state_batch.noise_pred = components.unet(
|
||||
block_state.scaled_latents,
|
||||
t,
|
||||
encoder_hidden_states=batch.prompt_embeds,
|
||||
encoder_hidden_states=guider_state_batch.prompt_embeds,
|
||||
timestep_cond=block_state.timestep_cond,
|
||||
cross_attention_kwargs=block_state.cross_attention_kwargs,
|
||||
added_cond_kwargs=batch.added_cond_kwargs,
|
||||
down_block_additional_residuals=batch.down_block_res_samples,
|
||||
mid_block_additional_residual=batch.mid_block_res_sample,
|
||||
added_cond_kwargs=guider_state_batch.added_cond_kwargs,
|
||||
down_block_additional_residuals=guider_state_batch.down_block_res_samples,
|
||||
mid_block_additional_residual=guider_state_batch.mid_block_res_sample,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
components.guider.cleanup_models(components.unet)
|
||||
|
||||
# Perform guidance
|
||||
block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_data)
|
||||
block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_state)
|
||||
|
||||
# Perform scheduler step using the predicted output
|
||||
block_state.latents_dtype = block_state.latents.dtype
|
||||
@@ -2799,6 +2921,7 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
block_state.latents = block_state.latents.to(block_state.latents_dtype)
|
||||
|
||||
# adjust latent for inpainting
|
||||
if block_state.num_channels_unet == 4 and block_state.mask is not None and block_state.image_latents is not None:
|
||||
block_state.init_latents_proper = block_state.image_latents
|
||||
if i < len(block_state.timesteps) - 1:
|
||||
@@ -3463,6 +3586,16 @@ class StableDiffusionXLInpaintBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
" - `StableDiffusionXLInpaintPrepareLatentsStep` is used to prepare the latents\n" + \
|
||||
" - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning"
|
||||
|
||||
class StableDiffusionXLControlNetStep(SequentialPipelineBlocks):
|
||||
block_classes = [StableDiffusionXLControlNetInputStep, StableDiffusionXLControlNetDenoiseStep]
|
||||
block_names = ["prepare_input", "denoise"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Controlnet step that denoise the latents.\n" + \
|
||||
"This is a sequential pipeline blocks:\n" + \
|
||||
" - `StableDiffusionXLControlNetInputStep` is used to prepare the inputs for the denoise step.\n" + \
|
||||
" - `StableDiffusionXLControlNetDenoiseStep` is used to denoise the latents."
|
||||
|
||||
class StableDiffusionXLAutoBeforeDenoiseStep(AutoPipelineBlocks):
|
||||
block_classes = [StableDiffusionXLInpaintBeforeDenoiseStep, StableDiffusionXLImg2ImgBeforeDenoiseStep, StableDiffusionXLBeforeDenoiseStep]
|
||||
@@ -3477,10 +3610,9 @@ class StableDiffusionXLAutoBeforeDenoiseStep(AutoPipelineBlocks):
|
||||
" - `StableDiffusionXLImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n" + \
|
||||
" - `StableDiffusionXLBeforeDenoiseStep` (text2img) is used when both `image_latents` and `mask` are not provided."
|
||||
|
||||
|
||||
# Denoise
|
||||
class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks):
|
||||
block_classes = [StableDiffusionXLControlNetUnionDenoiseStep, StableDiffusionXLControlNetDenoiseStep, StableDiffusionXLDenoiseStep]
|
||||
block_classes = [StableDiffusionXLControlNetUnionDenoiseStep, StableDiffusionXLControlNetStep, StableDiffusionXLDenoiseStep]
|
||||
block_names = ["controlnet_union", "controlnet", "unet"]
|
||||
block_trigger_inputs = ["control_mode", "control_image", None]
|
||||
|
||||
@@ -3489,7 +3621,7 @@ class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks):
|
||||
return "Denoise step that denoise the latents.\n" + \
|
||||
"This is an auto pipeline block that works for controlnet, controlnet_union and no controlnet.\n" + \
|
||||
" - `StableDiffusionXLControlNetUnionDenoiseStep` (controlnet_union) is used when both `control_mode` and `control_image` are provided.\n" + \
|
||||
" - `StableDiffusionXLControlNetDenoiseStep` (controlnet) is used when `control_image` is provided.\n" + \
|
||||
" - `StableDiffusionXLControlStep` (controlnet) is used when `control_image` is provided.\n" + \
|
||||
" - `StableDiffusionXLDenoiseStep` (unet only) is used when both `control_mode` and `control_image` are not provided."
|
||||
|
||||
# After denoise
|
||||
@@ -3597,7 +3729,7 @@ INPAINT_BLOCKS = OrderedDict([
|
||||
])
|
||||
|
||||
CONTROLNET_BLOCKS = OrderedDict([
|
||||
("denoise", StableDiffusionXLControlNetDenoiseStep),
|
||||
("denoise", StableDiffusionXLControlNetStep),
|
||||
])
|
||||
|
||||
CONTROLNET_UNION_BLOCKS = OrderedDict([
|
||||
|
||||
Reference in New Issue
Block a user