From fdb05f54ef61007a8f77ad8537b292176d13fafb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81lvaro=20Somoza?= Date: Sun, 12 May 2024 23:10:29 -0400 Subject: [PATCH] Official callbacks (#7761) --- docs/source/en/using-diffusers/callback.md | 67 +++++++- src/diffusers/callbacks.py | 156 ++++++++++++++++++ .../controlnet/pipeline_controlnet.py | 18 +- .../controlnet/pipeline_controlnet_img2img.py | 18 +- .../controlnet/pipeline_controlnet_inpaint.py | 18 +- .../pipeline_controlnet_inpaint_sd_xl.py | 18 +- .../controlnet/pipeline_controlnet_sd_xl.py | 18 +- .../pipeline_controlnet_sd_xl_img2img.py | 18 +- .../controlnet_xs/pipeline_controlnet_xs.py | 18 +- .../pipeline_controlnet_xs_sd_xl.py | 18 +- .../pipeline_stable_diffusion.py | 19 ++- .../pipeline_stable_diffusion_img2img.py | 18 +- .../pipeline_stable_diffusion_inpaint.py | 18 +- ...eline_stable_diffusion_instruct_pix2pix.py | 18 +- .../pipeline_stable_diffusion_xl.py | 18 +- .../pipeline_stable_diffusion_xl_img2img.py | 18 +- .../pipeline_stable_diffusion_xl_inpaint.py | 18 +- 17 files changed, 400 insertions(+), 94 deletions(-) create mode 100644 src/diffusers/callbacks.py diff --git a/docs/source/en/using-diffusers/callback.md b/docs/source/en/using-diffusers/callback.md index 3f3e8dae9f..7445513dbf 100644 --- a/docs/source/en/using-diffusers/callback.md +++ b/docs/source/en/using-diffusers/callback.md @@ -19,13 +19,74 @@ The denoising loop of a pipeline can be modified with custom defined functions u This guide will demonstrate how callbacks work by a few features you can implement with them. +## Official callbacks + +We provide a list of callbacks you can plug into an existing pipeline and modify the denoising loop. This is the current list of official callbacks: + +- `SDCFGCutoffCallback`: Disables the CFG after a certain number of steps for all SD 1.5 pipelines, including text-to-image, image-to-image, inpaint, and controlnet. +- `SDXLCFGCutoffCallback`: Disables the CFG after a certain number of steps for all SDXL pipelines, including text-to-image, image-to-image, inpaint, and controlnet. +- `IPAdapterScaleCutoffCallback`: Disables the IP Adapter after a certain number of steps for all pipelines supporting IP-Adapter. + +> [!TIP] +> If you want to add a new official callback, feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) or [submit a PR](https://huggingface.co/docs/diffusers/main/en/conceptual/contribution#how-to-open-a-pr). + +To set up a callback, you need to specify the number of denoising steps after which the callback comes into effect. You can do so by using either one of these two arguments + +- `cutoff_step_ratio`: Float number with the ratio of the steps. +- `cutoff_step_index`: Integer number with the exact number of the step. + +```python +import torch + +from diffusers import DPMSolverMultistepScheduler, StableDiffusionXLPipeline +from diffusers.callbacks import SDXLCFGCutoffCallback + + +callback = SDXLCFGCutoffCallback(cutoff_step_ratio=0.4) +# can also be used with cutoff_step_index +# callback = SDXLCFGCutoffCallback(cutoff_step_ratio=None, cutoff_step_index=10) + +pipeline = StableDiffusionXLPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + torch_dtype=torch.float16, + variant="fp16", +).to("cuda") +pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, use_karras_sigmas=True) + +prompt = "a sports car at the road, best quality, high quality, high detail, 8k resolution" + +generator = torch.Generator(device="cpu").manual_seed(2628670641) + +out = pipeline( + prompt=prompt, + negative_prompt="", + guidance_scale=6.5, + num_inference_steps=25, + generator=generator, + callback_on_step_end=callback, +) + +out.images[0].save("official_callback.png") +``` + +
+
+ generated image of a sports car at the road +
without SDXLCFGCutoffCallback
+
+
+ generated image of a a sports car at the road with cfg callback +
with SDXLCFGCutoffCallback
+
+
+ ## Dynamic classifier-free guidance Dynamic classifier-free guidance (CFG) is a feature that allows you to disable CFG after a certain number of inference steps which can help you save compute with minimal cost to performance. The callback function for this should have the following arguments: -* `pipeline` (or the pipeline instance) provides access to important properties such as `num_timesteps` and `guidance_scale`. You can modify these properties by updating the underlying attributes. For this example, you'll disable CFG by setting `pipeline._guidance_scale=0.0`. -* `step_index` and `timestep` tell you where you are in the denoising loop. Use `step_index` to turn off CFG after reaching 40% of `num_timesteps`. -* `callback_kwargs` is a dict that contains tensor variables you can modify during the denoising loop. It only includes variables specified in the `callback_on_step_end_tensor_inputs` argument, which is passed to the pipeline's `__call__` method. Different pipelines may use different sets of variables, so please check a pipeline's `_callback_tensor_inputs` attribute for the list of variables you can modify. Some common variables include `latents` and `prompt_embeds`. For this function, change the batch size of `prompt_embeds` after setting `guidance_scale=0.0` in order for it to work properly. +- `pipeline` (or the pipeline instance) provides access to important properties such as `num_timesteps` and `guidance_scale`. You can modify these properties by updating the underlying attributes. For this example, you'll disable CFG by setting `pipeline._guidance_scale=0.0`. +- `step_index` and `timestep` tell you where you are in the denoising loop. Use `step_index` to turn off CFG after reaching 40% of `num_timesteps`. +- `callback_kwargs` is a dict that contains tensor variables you can modify during the denoising loop. It only includes variables specified in the `callback_on_step_end_tensor_inputs` argument, which is passed to the pipeline's `__call__` method. Different pipelines may use different sets of variables, so please check a pipeline's `_callback_tensor_inputs` attribute for the list of variables you can modify. Some common variables include `latents` and `prompt_embeds`. For this function, change the batch size of `prompt_embeds` after setting `guidance_scale=0.0` in order for it to work properly. Your callback function should look something like this: diff --git a/src/diffusers/callbacks.py b/src/diffusers/callbacks.py new file mode 100644 index 0000000000..38542407e3 --- /dev/null +++ b/src/diffusers/callbacks.py @@ -0,0 +1,156 @@ +from typing import Any, Dict, List + +from .configuration_utils import ConfigMixin, register_to_config +from .utils import CONFIG_NAME + + +class PipelineCallback(ConfigMixin): + """ + Base class for all the official callbacks used in a pipeline. This class provides a structure for implementing + custom callbacks and ensures that all callbacks have a consistent interface. + + Please implement the following: + `tensor_inputs`: This should return a list of tensor inputs specific to your callback. You will only be able to + include + variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. + `callback_fn`: This method defines the core functionality of your callback. + """ + + config_name = CONFIG_NAME + + @register_to_config + def __init__(self, cutoff_step_ratio=1.0, cutoff_step_index=None): + super().__init__() + + if (cutoff_step_ratio is None and cutoff_step_index is None) or ( + cutoff_step_ratio is not None and cutoff_step_index is not None + ): + raise ValueError("Either cutoff_step_ratio or cutoff_step_index should be provided, not both or none.") + + if cutoff_step_ratio is not None and ( + not isinstance(cutoff_step_ratio, float) or not (0.0 <= cutoff_step_ratio <= 1.0) + ): + raise ValueError("cutoff_step_ratio must be a float between 0.0 and 1.0.") + + @property + def tensor_inputs(self) -> List[str]: + raise NotImplementedError(f"You need to set the attribute `tensor_inputs` for {self.__class__}") + + def callback_fn(self, pipeline, step_index, timesteps, callback_kwargs) -> Dict[str, Any]: + raise NotImplementedError(f"You need to implement the method `callback_fn` for {self.__class__}") + + def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]: + return self.callback_fn(pipeline, step_index, timestep, callback_kwargs) + + +class MultiPipelineCallbacks: + """ + This class is designed to handle multiple pipeline callbacks. It accepts a list of PipelineCallback objects and + provides a unified interface for calling all of them. + """ + + def __init__(self, callbacks: List[PipelineCallback]): + self.callbacks = callbacks + + @property + def tensor_inputs(self) -> List[str]: + return [input for callback in self.callbacks for input in callback.tensor_inputs] + + def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]: + """ + Calls all the callbacks in order with the given arguments and returns the final callback_kwargs. + """ + for callback in self.callbacks: + callback_kwargs = callback(pipeline, step_index, timestep, callback_kwargs) + + return callback_kwargs + + +class SDCFGCutoffCallback(PipelineCallback): + """ + Callback function for Stable Diffusion Pipelines. After certain number of steps (set by `cutoff_step_ratio` or + `cutoff_step_index`), this callback will disable the CFG. + + Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step. + """ + + tensor_inputs = ["prompt_embeds"] + + def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]: + cutoff_step_ratio = self.config.cutoff_step_ratio + cutoff_step_index = self.config.cutoff_step_index + + # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio + cutoff_step = ( + cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio) + ) + + if step_index == cutoff_step: + prompt_embeds = callback_kwargs[self.tensor_inputs[0]] + prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens. + + pipeline._guidance_scale = 0.0 + + callback_kwargs[self.tensor_inputs[0]] = prompt_embeds + return callback_kwargs + + +class SDXLCFGCutoffCallback(PipelineCallback): + """ + Callback function for Stable Diffusion XL Pipelines. After certain number of steps (set by `cutoff_step_ratio` or + `cutoff_step_index`), this callback will disable the CFG. + + Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step. + """ + + tensor_inputs = ["prompt_embeds", "add_text_embeds", "add_time_ids"] + + def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]: + cutoff_step_ratio = self.config.cutoff_step_ratio + cutoff_step_index = self.config.cutoff_step_index + + # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio + cutoff_step = ( + cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio) + ) + + if step_index == cutoff_step: + prompt_embeds = callback_kwargs[self.tensor_inputs[0]] + prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens. + + add_text_embeds = callback_kwargs[self.tensor_inputs[1]] + add_text_embeds = add_text_embeds[-1:] # "-1" denotes the embeddings for conditional pooled text tokens + + add_time_ids = callback_kwargs[self.tensor_inputs[2]] + add_time_ids = add_time_ids[-1:] # "-1" denotes the embeddings for conditional added time vector + + pipeline._guidance_scale = 0.0 + + callback_kwargs[self.tensor_inputs[0]] = prompt_embeds + callback_kwargs[self.tensor_inputs[1]] = add_text_embeds + callback_kwargs[self.tensor_inputs[2]] = add_time_ids + return callback_kwargs + + +class IPAdapterScaleCutoffCallback(PipelineCallback): + """ + Callback function for any pipeline that inherits `IPAdapterMixin`. After certain number of steps (set by + `cutoff_step_ratio` or `cutoff_step_index`), this callback will set the IP Adapter scale to `0.0`. + + Note: This callback mutates the IP Adapter attention processors by setting the scale to 0.0 after the cutoff step. + """ + + tensor_inputs = [] + + def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]: + cutoff_step_ratio = self.config.cutoff_step_ratio + cutoff_step_index = self.config.cutoff_step_index + + # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio + cutoff_step = ( + cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio) + ) + + if step_index == cutoff_step: + pipeline.set_ip_adapter_scale(0.0) + return callback_kwargs diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index e64dcdc554..cf979c352c 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -22,6 +22,7 @@ import torch import torch.nn.functional as F from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel @@ -926,7 +927,9 @@ class StableDiffusionControlNetPipeline( control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, clip_skip: Optional[int] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): @@ -1019,11 +1022,11 @@ class StableDiffusionControlNetPipeline( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the @@ -1055,6 +1058,9 @@ class StableDiffusionControlNetPipeline( "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # align format for control guidance diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py index 2e44efa78b..e6f1a06bdd 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py @@ -21,6 +21,7 @@ import torch import torch.nn.functional as F from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel @@ -917,7 +918,9 @@ class StableDiffusionControlNetImg2ImgPipeline( control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, clip_skip: Optional[int] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): @@ -1004,11 +1007,11 @@ class StableDiffusionControlNetImg2ImgPipeline( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the @@ -1040,6 +1043,9 @@ class StableDiffusionControlNetImg2ImgPipeline( "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # align format for control guidance diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index cdc34819d5..d29e3ac8f9 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -23,6 +23,7 @@ import torch import torch.nn.functional as F from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel @@ -1134,7 +1135,9 @@ class StableDiffusionControlNetInpaintPipeline( control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, clip_skip: Optional[int] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): @@ -1239,11 +1242,11 @@ class StableDiffusionControlNetInpaintPipeline( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the @@ -1275,6 +1278,9 @@ class StableDiffusionControlNetInpaintPipeline( "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # align format for control guidance diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index 3cfdefa9d4..125bb5a75d 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -27,6 +27,7 @@ from transformers import ( CLIPVisionModelWithProjection, ) +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import ( FromSingleFileMixin, @@ -1178,7 +1179,9 @@ class StableDiffusionXLControlNetInpaintPipeline( aesthetic_score: float = 6.0, negative_aesthetic_score: float = 2.5, clip_skip: Optional[int] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): @@ -1317,11 +1320,11 @@ class StableDiffusionXLControlNetInpaintPipeline( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the @@ -1351,6 +1354,9 @@ class StableDiffusionXLControlNetInpaintPipeline( "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # align format for control guidance diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 763188f347..95f55c3b86 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -30,6 +30,7 @@ from transformers import ( from diffusers.utils.import_utils import is_invisible_watermark_available +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import ( FromSingleFileMixin, @@ -1031,7 +1032,9 @@ class StableDiffusionXLControlNetPipeline( negative_crops_coords_top_left: Tuple[int, int] = (0, 0), negative_target_size: Optional[Tuple[int, int]] = None, clip_skip: Optional[int] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): @@ -1169,11 +1172,11 @@ class StableDiffusionXLControlNetPipeline( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the @@ -1203,6 +1206,9 @@ class StableDiffusionXLControlNetPipeline( "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # align format for control guidance diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index dbd406d928..107a569b01 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -30,6 +30,7 @@ from transformers import ( from diffusers.utils.import_utils import is_invisible_watermark_available +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import ( FromSingleFileMixin, @@ -1105,7 +1106,9 @@ class StableDiffusionXLControlNetImg2ImgPipeline( aesthetic_score: float = 6.0, negative_aesthetic_score: float = 2.5, clip_skip: Optional[int] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): @@ -1254,11 +1257,11 @@ class StableDiffusionXLControlNetImg2ImgPipeline( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the @@ -1288,6 +1291,9 @@ class StableDiffusionXLControlNetImg2ImgPipeline( "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # align format for control guidance diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py index 50cd24e4fa..3675a99ba6 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py @@ -21,6 +21,7 @@ import torch import torch.nn.functional as F from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel @@ -648,7 +649,9 @@ class StableDiffusionControlNetXSPipeline( control_guidance_start: float = 0.0, control_guidance_end: float = 1.0, clip_skip: Optional[int] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], ): r""" @@ -715,11 +718,11 @@ class StableDiffusionControlNetXSPipeline( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the @@ -734,6 +737,9 @@ class StableDiffusionControlNetXSPipeline( "not-safe-for-work" (nsfw) content. """ + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + unet = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet # 1. Check inputs. Raise error if not correct diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py index e572412f6e..1cd3f90e4e 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py @@ -28,6 +28,7 @@ from transformers import ( from diffusers.utils.import_utils import is_invisible_watermark_available +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel @@ -739,7 +740,9 @@ class StableDiffusionXLControlNetXSPipeline( negative_crops_coords_top_left: Tuple[int, int] = (0, 0), negative_target_size: Optional[Tuple[int, int]] = None, clip_skip: Optional[int] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], ): r""" @@ -851,11 +854,11 @@ class StableDiffusionXLControlNetXSPipeline( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the @@ -869,6 +872,9 @@ class StableDiffusionXLControlNetXSPipeline( returned, otherwise a `tuple` is returned containing the output images. """ + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + unet = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet # 1. Check inputs. Raise error if not correct diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index e8ab72421d..087ce151f0 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import inspect from typing import Any, Callable, Dict, List, Optional, Union @@ -19,6 +18,7 @@ import torch from packaging import version from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin @@ -775,7 +775,9 @@ class StableDiffusionPipeline( cross_attention_kwargs: Optional[Dict[str, Any]] = None, guidance_rescale: float = 0.0, clip_skip: Optional[int] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): @@ -845,11 +847,11 @@ class StableDiffusionPipeline( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the @@ -881,6 +883,9 @@ class StableDiffusionPipeline( "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index f2a5de8154..d806f230a0 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -21,6 +21,7 @@ import torch from packaging import version from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin @@ -862,7 +863,9 @@ class StableDiffusionImg2ImgPipeline( return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: int = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): @@ -932,11 +935,11 @@ class StableDiffusionImg2ImgPipeline( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the @@ -967,6 +970,9 @@ class StableDiffusionImg2ImgPipeline( "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 71dec964fd..37a3b32994 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -21,6 +21,7 @@ import torch from packaging import version from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin @@ -1014,7 +1015,9 @@ class StableDiffusionInpaintPipeline( return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: int = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): @@ -1107,11 +1110,11 @@ class StableDiffusionInpaintPipeline( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the @@ -1171,6 +1174,9 @@ class StableDiffusionInpaintPipeline( "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py index b2b2b14009..1443c8b0af 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py @@ -20,6 +20,7 @@ import PIL.Image import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel @@ -175,7 +176,9 @@ class StableDiffusionInstructPix2PixPipeline( ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): @@ -227,11 +230,11 @@ class StableDiffusionInstructPix2PixPipeline( return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the @@ -290,6 +293,9 @@ class StableDiffusionInstructPix2PixPipeline( "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + # 0. Check inputs self.check_inputs( prompt, diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 52d0b07fb3..2568150fa5 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -24,6 +24,7 @@ from transformers import ( CLIPVisionModelWithProjection, ) +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import ( FromSingleFileMixin, @@ -861,7 +862,9 @@ class StableDiffusionXLPipeline( negative_crops_coords_top_left: Tuple[int, int] = (0, 0), negative_target_size: Optional[Tuple[int, int]] = None, clip_skip: Optional[int] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): @@ -992,11 +995,11 @@ class StableDiffusionXLPipeline( as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the @@ -1026,6 +1029,9 @@ class StableDiffusionXLPipeline( "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + # 0. Default height and width to unet height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index b8698a0083..838489dca7 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -25,6 +25,7 @@ from transformers import ( CLIPVisionModelWithProjection, ) +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import ( FromSingleFileMixin, @@ -1008,7 +1009,9 @@ class StableDiffusionXLImg2ImgPipeline( aesthetic_score: float = 6.0, negative_aesthetic_score: float = 2.5, clip_skip: Optional[int] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): @@ -1157,11 +1160,11 @@ class StableDiffusionXLImg2ImgPipeline( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the @@ -1191,6 +1194,9 @@ class StableDiffusionXLImg2ImgPipeline( "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 38f5cec931..631e309993 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -26,6 +26,7 @@ from transformers import ( CLIPVisionModelWithProjection, ) +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import ( FromSingleFileMixin, @@ -1243,7 +1244,9 @@ class StableDiffusionXLInpaintPipeline( aesthetic_score: float = 6.0, negative_aesthetic_score: float = 2.5, clip_skip: Optional[int] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): @@ -1411,11 +1414,11 @@ class StableDiffusionXLInpaintPipeline( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the @@ -1445,6 +1448,9 @@ class StableDiffusionXLInpaintPipeline( "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor