mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Addition of new callbacks to controlnets (#5812)
* add new callbacks to src/diffusers/pipelines/controlnet/pipeline_controlnet.py * update callbacks * fix repeated kwarg * update --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -130,6 +130,7 @@ class StableDiffusionControlNetPipeline(
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -485,15 +486,21 @@ class StableDiffusionControlNetPipeline(
|
||||
controlnet_conditioning_scale=1.0,
|
||||
control_guidance_start=0.0,
|
||||
control_guidance_end=1.0,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
@@ -760,6 +767,10 @@ class StableDiffusionControlNetPipeline(
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def clip_skip(self):
|
||||
return self._clip_skip
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
@@ -767,6 +778,14 @@ class StableDiffusionControlNetPipeline(
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
|
||||
|
||||
@property
|
||||
def cross_attention_kwargs(self):
|
||||
return self._cross_attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
@@ -786,14 +805,15 @@ class StableDiffusionControlNetPipeline(
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
||||
guess_mode: bool = False,
|
||||
control_guidance_start: Union[float, List[float]] = 0.0,
|
||||
control_guidance_end: Union[float, List[float]] = 1.0,
|
||||
clip_skip: Optional[int] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
@@ -868,6 +888,15 @@ class StableDiffusionControlNetPipeline(
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeine class.
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -878,6 +907,23 @@ class StableDiffusionControlNetPipeline(
|
||||
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
||||
"not-safe-for-work" (nsfw) content.
|
||||
"""
|
||||
|
||||
callback = kwargs.pop("callback", None)
|
||||
callback_steps = kwargs.pop("callback_steps", None)
|
||||
|
||||
if callback is not None:
|
||||
deprecate(
|
||||
"callback",
|
||||
"1.0.0",
|
||||
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
||||
)
|
||||
if callback_steps is not None:
|
||||
deprecate(
|
||||
"callback_steps",
|
||||
"1.0.0",
|
||||
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
||||
)
|
||||
|
||||
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
||||
|
||||
# align format for control guidance
|
||||
@@ -903,9 +949,12 @@ class StableDiffusionControlNetPipeline(
|
||||
controlnet_conditioning_scale,
|
||||
control_guidance_start,
|
||||
control_guidance_end,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._clip_skip = clip_skip
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
@@ -929,7 +978,7 @@ class StableDiffusionControlNetPipeline(
|
||||
|
||||
# 3. Encode input prompt
|
||||
text_encoder_lora_scale = (
|
||||
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
||||
)
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
@@ -940,7 +989,7 @@ class StableDiffusionControlNetPipeline(
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
clip_skip=clip_skip,
|
||||
clip_skip=self.clip_skip,
|
||||
)
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
@@ -988,6 +1037,7 @@ class StableDiffusionControlNetPipeline(
|
||||
# 5. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 6. Prepare latent variables
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
@@ -1078,7 +1128,7 @@ class StableDiffusionControlNetPipeline(
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep_cond=timestep_cond,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
cross_attention_kwargs=self.cross_attention_kwargs,
|
||||
down_block_additional_residuals=down_block_res_samples,
|
||||
mid_block_additional_residual=mid_block_res_sample,
|
||||
return_dict=False,
|
||||
@@ -1087,11 +1137,21 @@ class StableDiffusionControlNetPipeline(
|
||||
# perform guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
@@ -164,6 +164,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -519,15 +520,21 @@ class StableDiffusionControlNetImg2ImgPipeline(
|
||||
controlnet_conditioning_scale=1.0,
|
||||
control_guidance_start=0.0,
|
||||
control_guidance_end=1.0,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
@@ -808,6 +815,29 @@ class StableDiffusionControlNetImg2ImgPipeline(
|
||||
"""Disables the FreeU mechanism if enabled."""
|
||||
self.unet.disable_freeu()
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def clip_skip(self):
|
||||
return self._clip_skip
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def cross_attention_kwargs(self):
|
||||
return self._cross_attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
@@ -829,14 +859,15 @@ class StableDiffusionControlNetImg2ImgPipeline(
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
controlnet_conditioning_scale: Union[float, List[float]] = 0.8,
|
||||
guess_mode: bool = False,
|
||||
control_guidance_start: Union[float, List[float]] = 0.0,
|
||||
control_guidance_end: Union[float, List[float]] = 1.0,
|
||||
clip_skip: Optional[int] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
@@ -892,12 +923,6 @@ class StableDiffusionControlNetImg2ImgPipeline(
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that calls every `callback_steps` steps during inference. The function is called with the
|
||||
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
||||
every step.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
||||
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
@@ -915,6 +940,15 @@ class StableDiffusionControlNetImg2ImgPipeline(
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeine class.
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -925,6 +959,23 @@ class StableDiffusionControlNetImg2ImgPipeline(
|
||||
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
||||
"not-safe-for-work" (nsfw) content.
|
||||
"""
|
||||
|
||||
callback = kwargs.pop("callback", None)
|
||||
callback_steps = kwargs.pop("callback_steps", None)
|
||||
|
||||
if callback is not None:
|
||||
deprecate(
|
||||
"callback",
|
||||
"1.0.0",
|
||||
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
||||
)
|
||||
if callback_steps is not None:
|
||||
deprecate(
|
||||
"callback_steps",
|
||||
"1.0.0",
|
||||
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
||||
)
|
||||
|
||||
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
||||
|
||||
# align format for control guidance
|
||||
@@ -950,8 +1001,13 @@ class StableDiffusionControlNetImg2ImgPipeline(
|
||||
controlnet_conditioning_scale,
|
||||
control_guidance_start,
|
||||
control_guidance_end,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._clip_skip = clip_skip
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
@@ -961,10 +1017,6 @@ class StableDiffusionControlNetImg2ImgPipeline(
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
|
||||
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
|
||||
@@ -978,23 +1030,23 @@ class StableDiffusionControlNetImg2ImgPipeline(
|
||||
|
||||
# 3. Encode input prompt
|
||||
text_encoder_lora_scale = (
|
||||
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
||||
)
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
self.do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
clip_skip=clip_skip,
|
||||
clip_skip=self.clip_skip,
|
||||
)
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
if do_classifier_free_guidance:
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
# 4. Prepare image
|
||||
@@ -1010,7 +1062,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=controlnet.dtype,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
guess_mode=guess_mode,
|
||||
)
|
||||
elif isinstance(controlnet, MultiControlNetModel):
|
||||
@@ -1025,7 +1077,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=controlnet.dtype,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
guess_mode=guess_mode,
|
||||
)
|
||||
|
||||
@@ -1039,6 +1091,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 6. Prepare latent variables
|
||||
latents = self.prepare_latents(
|
||||
@@ -1068,11 +1121,11 @@ class StableDiffusionControlNetImg2ImgPipeline(
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# controlnet(s) inference
|
||||
if guess_mode and do_classifier_free_guidance:
|
||||
if guess_mode and self.do_classifier_free_guidance:
|
||||
# Infer ControlNet only for the conditional batch.
|
||||
control_model_input = latents
|
||||
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
|
||||
@@ -1099,7 +1152,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
if guess_mode and do_classifier_free_guidance:
|
||||
if guess_mode and self.do_classifier_free_guidance:
|
||||
# Infered ControlNet only for the conditional batch.
|
||||
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
||||
# add 0 to the unconditional batch to keep it unchanged.
|
||||
@@ -1111,20 +1164,30 @@ class StableDiffusionControlNetImg2ImgPipeline(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
cross_attention_kwargs=self.cross_attention_kwargs,
|
||||
down_block_additional_residuals=down_block_res_samples,
|
||||
mid_block_additional_residual=mid_block_res_sample,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
@@ -286,6 +286,7 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -656,18 +657,24 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
controlnet_conditioning_scale=1.0,
|
||||
control_guidance_start=0.0,
|
||||
control_guidance_end=1.0,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if height is not None and height % 8 != 0 or width is not None and width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
@@ -999,6 +1006,29 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
"""Disables the FreeU mechanism if enabled."""
|
||||
self.unet.disable_freeu()
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def clip_skip(self):
|
||||
return self._clip_skip
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def cross_attention_kwargs(self):
|
||||
return self._cross_attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
@@ -1021,14 +1051,15 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
controlnet_conditioning_scale: Union[float, List[float]] = 0.5,
|
||||
guess_mode: bool = False,
|
||||
control_guidance_start: Union[float, List[float]] = 0.0,
|
||||
control_guidance_end: Union[float, List[float]] = 1.0,
|
||||
clip_skip: Optional[int] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
@@ -1101,12 +1132,6 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that calls every `callback_steps` steps during inference. The function is called with the
|
||||
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
||||
every step.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
||||
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
@@ -1124,6 +1149,15 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeine class.
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -1134,6 +1168,23 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
||||
"not-safe-for-work" (nsfw) content.
|
||||
"""
|
||||
|
||||
callback = kwargs.pop("callback", None)
|
||||
callback_steps = kwargs.pop("callback_steps", None)
|
||||
|
||||
if callback is not None:
|
||||
deprecate(
|
||||
"callback",
|
||||
"1.0.0",
|
||||
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
||||
)
|
||||
if callback_steps is not None:
|
||||
deprecate(
|
||||
"callback_steps",
|
||||
"1.0.0",
|
||||
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
||||
)
|
||||
|
||||
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
||||
|
||||
# align format for control guidance
|
||||
@@ -1161,8 +1212,13 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
controlnet_conditioning_scale,
|
||||
control_guidance_start,
|
||||
control_guidance_end,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._clip_skip = clip_skip
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
@@ -1172,10 +1228,6 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
|
||||
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
|
||||
@@ -1189,23 +1241,23 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
|
||||
# 3. Encode input prompt
|
||||
text_encoder_lora_scale = (
|
||||
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
||||
)
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
self.do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
clip_skip=clip_skip,
|
||||
clip_skip=self.clip_skip,
|
||||
)
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
if do_classifier_free_guidance:
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
# 4. Prepare image
|
||||
@@ -1218,7 +1270,7 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=controlnet.dtype,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
guess_mode=guess_mode,
|
||||
)
|
||||
elif isinstance(controlnet, MultiControlNetModel):
|
||||
@@ -1233,7 +1285,7 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=controlnet.dtype,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
guess_mode=guess_mode,
|
||||
)
|
||||
|
||||
@@ -1261,6 +1313,7 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
||||
# create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
|
||||
is_strength_max = strength == 1.0
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 6. Prepare latent variables
|
||||
num_channels_latents = self.vae.config.latent_channels
|
||||
@@ -1297,7 +1350,7 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
do_classifier_free_guidance,
|
||||
self.do_classifier_free_guidance,
|
||||
)
|
||||
|
||||
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
@@ -1317,11 +1370,11 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# controlnet(s) inference
|
||||
if guess_mode and do_classifier_free_guidance:
|
||||
if guess_mode and self.do_classifier_free_guidance:
|
||||
# Infer ControlNet only for the conditional batch.
|
||||
control_model_input = latents
|
||||
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
|
||||
@@ -1348,7 +1401,7 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
if guess_mode and do_classifier_free_guidance:
|
||||
if guess_mode and self.do_classifier_free_guidance:
|
||||
# Infered ControlNet only for the conditional batch.
|
||||
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
||||
# add 0 to the unconditional batch to keep it unchanged.
|
||||
@@ -1363,14 +1416,14 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
cross_attention_kwargs=self.cross_attention_kwargs,
|
||||
down_block_additional_residuals=down_block_res_samples,
|
||||
mid_block_additional_residual=mid_block_res_sample,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
@@ -1379,7 +1432,7 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
|
||||
if num_channels_unet == 4:
|
||||
init_latents_proper = image_latents
|
||||
if do_classifier_free_guidance:
|
||||
if self.do_classifier_free_guidance:
|
||||
init_mask, _ = mask.chunk(2)
|
||||
else:
|
||||
init_mask = mask
|
||||
@@ -1392,6 +1445,16 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
|
||||
latents = (1 - init_mask) * init_latents_proper + init_mask * latents
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
@@ -34,6 +34,7 @@ from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
deprecate,
|
||||
is_invisible_watermark_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
@@ -167,6 +168,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
|
||||
_optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -555,6 +557,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
controlnet_conditioning_scale=1.0,
|
||||
control_guidance_start=0.0,
|
||||
control_guidance_end=1.0,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if strength < 0 or strength > 1:
|
||||
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
||||
@@ -565,14 +568,20 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type"
|
||||
f" {type(num_inference_steps)}."
|
||||
)
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
|
||||
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
@@ -1008,6 +1017,29 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
"""Disables the FreeU mechanism if enabled."""
|
||||
self.unet.disable_freeu()
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def clip_skip(self):
|
||||
return self._clip_skip
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def cross_attention_kwargs(self):
|
||||
return self._cross_attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
@@ -1039,8 +1071,6 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
||||
guess_mode: bool = False,
|
||||
@@ -1053,6 +1083,9 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
aesthetic_score: float = 6.0,
|
||||
negative_aesthetic_score: float = 2.5,
|
||||
clip_skip: Optional[int] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@@ -1147,12 +1180,6 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
@@ -1182,6 +1209,15 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeine class.
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -1190,6 +1226,23 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
[`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
|
||||
`tuple. `tuple. When returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
callback = kwargs.pop("callback", None)
|
||||
callback_steps = kwargs.pop("callback_steps", None)
|
||||
|
||||
if callback is not None:
|
||||
deprecate(
|
||||
"callback",
|
||||
"1.0.0",
|
||||
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
||||
)
|
||||
if callback_steps is not None:
|
||||
deprecate(
|
||||
"callback_steps",
|
||||
"1.0.0",
|
||||
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
||||
)
|
||||
|
||||
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
||||
|
||||
# align format for control guidance
|
||||
@@ -1237,8 +1290,13 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
controlnet_conditioning_scale,
|
||||
control_guidance_start,
|
||||
control_guidance_end,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._clip_skip = clip_skip
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
@@ -1248,17 +1306,13 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
|
||||
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
|
||||
|
||||
# 3. Encode input prompt
|
||||
text_encoder_lora_scale = (
|
||||
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
||||
)
|
||||
|
||||
(
|
||||
@@ -1271,7 +1325,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
prompt_2=prompt_2,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
negative_prompt_2=negative_prompt_2,
|
||||
prompt_embeds=prompt_embeds,
|
||||
@@ -1279,7 +1333,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
clip_skip=clip_skip,
|
||||
clip_skip=self.clip_skip,
|
||||
)
|
||||
|
||||
# 4. set timesteps
|
||||
@@ -1300,6 +1354,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
||||
# create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
|
||||
is_strength_max = strength == 1.0
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 5. Preprocess mask and image - resizes image and mask w.r.t height and width
|
||||
# 5.1 Prepare init image
|
||||
@@ -1316,7 +1371,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=controlnet.dtype,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
guess_mode=guess_mode,
|
||||
)
|
||||
elif isinstance(controlnet, MultiControlNetModel):
|
||||
@@ -1331,7 +1386,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=controlnet.dtype,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
guess_mode=guess_mode,
|
||||
)
|
||||
|
||||
@@ -1385,7 +1440,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
do_classifier_free_guidance,
|
||||
self.do_classifier_free_guidance,
|
||||
)
|
||||
|
||||
# 8. Check that sizes of mask, masked image and latents match
|
||||
@@ -1446,7 +1501,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
)
|
||||
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
||||
add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
|
||||
@@ -1483,7 +1538,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
|
||||
# concat latents, mask, masked_image_latents in the channel dimension
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
@@ -1491,7 +1546,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
||||
|
||||
# controlnet(s) inference
|
||||
if guess_mode and do_classifier_free_guidance:
|
||||
if guess_mode and self.do_classifier_free_guidance:
|
||||
# Infer ControlNet only for the conditional batch.
|
||||
control_model_input = latents
|
||||
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
|
||||
@@ -1528,7 +1583,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
if guess_mode and do_classifier_free_guidance:
|
||||
if guess_mode and self.do_classifier_free_guidance:
|
||||
# Infered ControlNet only for the conditional batch.
|
||||
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
||||
# add 0 to the unconditional batch to keep it unchanged.
|
||||
@@ -1543,7 +1598,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
cross_attention_kwargs=self.cross_attention_kwargs,
|
||||
down_block_additional_residuals=down_block_res_samples,
|
||||
mid_block_additional_residual=mid_block_res_sample,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
@@ -1551,11 +1606,11 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||
if self.do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||
|
||||
@@ -1564,7 +1619,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
|
||||
if num_channels_unet == 4:
|
||||
init_latents_proper = image_latents
|
||||
if do_classifier_free_guidance:
|
||||
if self.do_classifier_free_guidance:
|
||||
init_mask, _ = mask.chunk(2)
|
||||
else:
|
||||
init_mask = mask
|
||||
@@ -1577,6 +1632,16 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
|
||||
latents = (1 - init_mask) * init_latents_proper + init_mask * latents
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
@@ -35,7 +35,14 @@ from ...models.attention_processor import (
|
||||
)
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import USE_PEFT_BACKEND, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
deprecate,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
||||
@@ -143,6 +150,7 @@ class StableDiffusionXLControlNetPipeline(
|
||||
# leave controlnet out on purpose because it iterates with unet
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
|
||||
_optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -487,15 +495,21 @@ class StableDiffusionXLControlNetPipeline(
|
||||
controlnet_conditioning_scale=1.0,
|
||||
control_guidance_start=0.0,
|
||||
control_guidance_end=1.0,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
@@ -825,6 +839,10 @@ class StableDiffusionXLControlNetPipeline(
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def clip_skip(self):
|
||||
return self._clip_skip
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
@@ -832,6 +850,14 @@ class StableDiffusionXLControlNetPipeline(
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
|
||||
|
||||
@property
|
||||
def cross_attention_kwargs(self):
|
||||
return self._cross_attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
@@ -855,8 +881,6 @@ class StableDiffusionXLControlNetPipeline(
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
||||
guess_mode: bool = False,
|
||||
@@ -869,6 +893,9 @@ class StableDiffusionXLControlNetPipeline(
|
||||
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
negative_target_size: Optional[Tuple[int, int]] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
@@ -937,12 +964,6 @@ class StableDiffusionXLControlNetPipeline(
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that calls every `callback_steps` steps during inference. The function is called with the
|
||||
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
||||
every step.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
||||
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
@@ -989,6 +1010,15 @@ class StableDiffusionXLControlNetPipeline(
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeine class.
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -997,6 +1027,23 @@ class StableDiffusionXLControlNetPipeline(
|
||||
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
||||
otherwise a `tuple` is returned containing the output images.
|
||||
"""
|
||||
|
||||
callback = kwargs.pop("callback", None)
|
||||
callback_steps = kwargs.pop("callback_steps", None)
|
||||
|
||||
if callback is not None:
|
||||
deprecate(
|
||||
"callback",
|
||||
"1.0.0",
|
||||
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
||||
)
|
||||
if callback_steps is not None:
|
||||
deprecate(
|
||||
"callback_steps",
|
||||
"1.0.0",
|
||||
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
||||
)
|
||||
|
||||
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
||||
|
||||
# align format for control guidance
|
||||
@@ -1026,9 +1073,12 @@ class StableDiffusionXLControlNetPipeline(
|
||||
controlnet_conditioning_scale,
|
||||
control_guidance_start,
|
||||
control_guidance_end,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._clip_skip = clip_skip
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
@@ -1052,7 +1102,7 @@ class StableDiffusionXLControlNetPipeline(
|
||||
|
||||
# 3. Encode input prompt
|
||||
text_encoder_lora_scale = (
|
||||
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
||||
)
|
||||
(
|
||||
prompt_embeds,
|
||||
@@ -1072,7 +1122,7 @@ class StableDiffusionXLControlNetPipeline(
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
clip_skip=clip_skip,
|
||||
clip_skip=self.clip_skip,
|
||||
)
|
||||
|
||||
# 4. Prepare image
|
||||
@@ -1115,6 +1165,7 @@ class StableDiffusionXLControlNetPipeline(
|
||||
# 5. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 6. Prepare latent variables
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
@@ -1254,7 +1305,7 @@ class StableDiffusionXLControlNetPipeline(
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep_cond=timestep_cond,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
cross_attention_kwargs=self.cross_attention_kwargs,
|
||||
down_block_additional_residuals=down_block_res_samples,
|
||||
mid_block_additional_residual=mid_block_res_sample,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
@@ -1269,6 +1320,16 @@ class StableDiffusionXLControlNetPipeline(
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
@@ -37,6 +37,7 @@ from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
deprecate,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
scale_lora_layers,
|
||||
@@ -195,6 +196,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
|
||||
_optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -543,6 +545,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
controlnet_conditioning_scale=1.0,
|
||||
control_guidance_start=0.0,
|
||||
control_guidance_end=1.0,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if strength < 0 or strength > 1:
|
||||
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
||||
@@ -553,14 +556,20 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type"
|
||||
f" {type(num_inference_steps)}."
|
||||
)
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
|
||||
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
@@ -951,6 +960,29 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
"""Disables the FreeU mechanism if enabled."""
|
||||
self.unet.disable_freeu()
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def clip_skip(self):
|
||||
return self._clip_skip
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def cross_attention_kwargs(self):
|
||||
return self._cross_attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
@@ -976,8 +1008,6 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
controlnet_conditioning_scale: Union[float, List[float]] = 0.8,
|
||||
guess_mode: bool = False,
|
||||
@@ -992,6 +1022,9 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
aesthetic_score: float = 6.0,
|
||||
negative_aesthetic_score: float = 2.5,
|
||||
clip_skip: Optional[int] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@@ -1077,12 +1110,6 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
@@ -1138,6 +1165,15 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeine class.
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -1146,6 +1182,23 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple`
|
||||
containing the output images.
|
||||
"""
|
||||
|
||||
callback = kwargs.pop("callback", None)
|
||||
callback_steps = kwargs.pop("callback_steps", None)
|
||||
|
||||
if callback is not None:
|
||||
deprecate(
|
||||
"callback",
|
||||
"1.0.0",
|
||||
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
||||
)
|
||||
if callback_steps is not None:
|
||||
deprecate(
|
||||
"callback_steps",
|
||||
"1.0.0",
|
||||
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
||||
)
|
||||
|
||||
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
||||
|
||||
# align format for control guidance
|
||||
@@ -1177,8 +1230,13 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
controlnet_conditioning_scale,
|
||||
control_guidance_start,
|
||||
control_guidance_end,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._clip_skip = clip_skip
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
@@ -1188,10 +1246,6 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
|
||||
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
|
||||
@@ -1205,7 +1259,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
|
||||
# 3. Encode input prompt
|
||||
text_encoder_lora_scale = (
|
||||
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
||||
)
|
||||
(
|
||||
prompt_embeds,
|
||||
@@ -1217,7 +1271,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
prompt_2,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
self.do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
negative_prompt_2,
|
||||
prompt_embeds=prompt_embeds,
|
||||
@@ -1225,7 +1279,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
clip_skip=clip_skip,
|
||||
clip_skip=self.clip_skip,
|
||||
)
|
||||
|
||||
# 4. Prepare image and controlnet_conditioning_image
|
||||
@@ -1240,7 +1294,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=controlnet.dtype,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
guess_mode=guess_mode,
|
||||
)
|
||||
height, width = control_image.shape[-2:]
|
||||
@@ -1256,7 +1310,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=controlnet.dtype,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
guess_mode=guess_mode,
|
||||
)
|
||||
|
||||
@@ -1271,6 +1325,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 6. Prepare latent variables
|
||||
latents = self.prepare_latents(
|
||||
@@ -1328,7 +1383,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
)
|
||||
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
||||
add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
|
||||
@@ -1343,13 +1398,13 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
||||
|
||||
# controlnet(s) inference
|
||||
if guess_mode and do_classifier_free_guidance:
|
||||
if guess_mode and self.do_classifier_free_guidance:
|
||||
# Infer ControlNet only for the conditional batch.
|
||||
control_model_input = latents
|
||||
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
|
||||
@@ -1382,7 +1437,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
if guess_mode and do_classifier_free_guidance:
|
||||
if guess_mode and self.do_classifier_free_guidance:
|
||||
# Infered ControlNet only for the conditional batch.
|
||||
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
||||
# add 0 to the unconditional batch to keep it unchanged.
|
||||
@@ -1394,7 +1449,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
cross_attention_kwargs=self.cross_attention_kwargs,
|
||||
down_block_additional_residuals=down_block_res_samples,
|
||||
mid_block_additional_residual=mid_block_res_sample,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
@@ -1402,13 +1457,23 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
Reference in New Issue
Block a user