diff --git a/.github/workflows/pr_tests_gpu.yml b/.github/workflows/pr_tests_gpu.yml index fd0c76c2ff..bb74daad21 100644 --- a/.github/workflows/pr_tests_gpu.yml +++ b/.github/workflows/pr_tests_gpu.yml @@ -13,6 +13,7 @@ on: - "src/diffusers/loaders/peft.py" - "tests/pipelines/test_pipelines_common.py" - "tests/models/test_modeling_common.py" + - "examples/**/*.py" workflow_dispatch: concurrency: diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 965bb554a8..4fa0c906b5 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -58,6 +58,7 @@ from diffusers.training_utils import ( compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory, + offload_models, ) from diffusers.utils import ( check_min_version, @@ -1364,43 +1365,34 @@ def main(args): # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid # the redundant encoding. if not train_dataset.custom_instance_prompts: - if args.offload: - text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device) - ( - instance_prompt_hidden_states_t5, - instance_prompt_hidden_states_llama3, - instance_pooled_prompt_embeds, - _, - _, - _, - ) = compute_text_embeddings(args.instance_prompt, text_encoding_pipeline) - if args.offload: - text_encoding_pipeline = text_encoding_pipeline.to("cpu") + with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): + ( + instance_prompt_hidden_states_t5, + instance_prompt_hidden_states_llama3, + instance_pooled_prompt_embeds, + _, + _, + _, + ) = compute_text_embeddings(args.instance_prompt, text_encoding_pipeline) # Handle class prompt for prior-preservation. if args.with_prior_preservation: - if args.offload: - text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device) - (class_prompt_hidden_states_t5, class_prompt_hidden_states_llama3, class_pooled_prompt_embeds, _, _, _) = ( - compute_text_embeddings(args.class_prompt, text_encoding_pipeline) - ) - if args.offload: - text_encoding_pipeline = text_encoding_pipeline.to("cpu") + with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): + (class_prompt_hidden_states_t5, class_prompt_hidden_states_llama3, class_pooled_prompt_embeds, _, _, _) = ( + compute_text_embeddings(args.class_prompt, text_encoding_pipeline) + ) validation_embeddings = {} if args.validation_prompt is not None: - if args.offload: - text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device) - ( - validation_embeddings["prompt_embeds_t5"], - validation_embeddings["prompt_embeds_llama3"], - validation_embeddings["pooled_prompt_embeds"], - validation_embeddings["negative_prompt_embeds_t5"], - validation_embeddings["negative_prompt_embeds_llama3"], - validation_embeddings["negative_pooled_prompt_embeds"], - ) = compute_text_embeddings(args.validation_prompt, text_encoding_pipeline) - if args.offload: - text_encoding_pipeline = text_encoding_pipeline.to("cpu") + with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): + ( + validation_embeddings["prompt_embeds_t5"], + validation_embeddings["prompt_embeds_llama3"], + validation_embeddings["pooled_prompt_embeds"], + validation_embeddings["negative_prompt_embeds_t5"], + validation_embeddings["negative_prompt_embeds_llama3"], + validation_embeddings["negative_pooled_prompt_embeds"], + ) = compute_text_embeddings(args.validation_prompt, text_encoding_pipeline) # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), # pack the statically computed variables appropriately here. This is so that we don't @@ -1581,12 +1573,10 @@ def main(args): if args.cache_latents: model_input = latents_cache[step].sample() else: - if args.offload: - vae = vae.to(accelerator.device) - pixel_values = batch["pixel_values"].to(dtype=vae.dtype) + with offload_models(vae, device=accelerator.device, offload=args.offload): + pixel_values = batch["pixel_values"].to(dtype=vae.dtype) model_input = vae.encode(pixel_values).latent_dist.sample() - if args.offload: - vae = vae.to("cpu") + model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor model_input = model_input.to(dtype=weight_dtype) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 048ddcae32..91efdb0396 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -763,4 +763,7 @@ class LegacyConfigMixin(ConfigMixin): # resolve remapping remapped_class = _fetch_remapped_cls_from_config(config, cls) - return remapped_class.from_config(config, return_unused_kwargs, **kwargs) + if remapped_class is cls: + return super(LegacyConfigMixin, remapped_class).from_config(config, return_unused_kwargs, **kwargs) + else: + return remapped_class.from_config(config, return_unused_kwargs, **kwargs) diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index b9b86cf480..76fefc1260 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -24,7 +24,7 @@ from typing_extensions import Self from .. import __version__ from ..quantizers import DiffusersAutoQuantizer from ..utils import deprecate, is_accelerate_available, logging -from ..utils.torch_utils import device_synchronize, empty_device_cache +from ..utils.torch_utils import empty_device_cache from .single_file_utils import ( SingleFileComponentError, convert_animatediff_checkpoint_to_diffusers, @@ -431,10 +431,7 @@ class FromOriginalModelMixin: keep_in_fp32_modules=keep_in_fp32_modules, unexpected_keys=unexpected_keys, ) - # Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is - # required because we move tensors with non_blocking=True, which is slightly faster for model loading. empty_device_cache() - device_synchronize() else: _, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 5fafcb02be..a804ea80a9 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -46,7 +46,7 @@ from ..utils import ( ) from ..utils.constants import DIFFUSERS_REQUEST_TIMEOUT from ..utils.hub_utils import _get_model_file -from ..utils.torch_utils import device_synchronize, empty_device_cache +from ..utils.torch_utils import empty_device_cache if is_transformers_available(): @@ -1690,10 +1690,7 @@ def create_diffusers_clip_model_from_ldm( if is_accelerate_available(): load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) - # Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is - # required because we move tensors with non_blocking=True, which is slightly faster for model loading. empty_device_cache() - device_synchronize() else: model.load_state_dict(diffusers_format_checkpoint, strict=False) @@ -2153,10 +2150,7 @@ def create_diffusers_t5_model_from_checkpoint( if is_accelerate_available(): load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) - # Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is - # required because we move tensors with non_blocking=True, which is slightly faster for model loading. empty_device_cache() - device_synchronize() else: model.load_state_dict(diffusers_format_checkpoint) diff --git a/src/diffusers/loaders/transformer_flux.py b/src/diffusers/loaders/transformer_flux.py index 0873e8edd0..ced81960fa 100644 --- a/src/diffusers/loaders/transformer_flux.py +++ b/src/diffusers/loaders/transformer_flux.py @@ -19,7 +19,7 @@ from ..models.embeddings import ( ) from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta from ..utils import is_accelerate_available, is_torch_version, logging -from ..utils.torch_utils import device_synchronize, empty_device_cache +from ..utils.torch_utils import empty_device_cache if is_accelerate_available(): @@ -82,7 +82,6 @@ class FluxTransformer2DLoadersMixin: device_map = {"": self.device} load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype) empty_device_cache() - device_synchronize() return image_projection @@ -156,7 +155,6 @@ class FluxTransformer2DLoadersMixin: key_id += 1 empty_device_cache() - device_synchronize() return attn_procs diff --git a/src/diffusers/loaders/transformer_sd3.py b/src/diffusers/loaders/transformer_sd3.py index 4421f46dfc..1bc3a9c7a8 100644 --- a/src/diffusers/loaders/transformer_sd3.py +++ b/src/diffusers/loaders/transformer_sd3.py @@ -18,7 +18,7 @@ from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0 from ..models.embeddings import IPAdapterTimeImageProjection from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta from ..utils import is_accelerate_available, is_torch_version, logging -from ..utils.torch_utils import device_synchronize, empty_device_cache +from ..utils.torch_utils import empty_device_cache logger = logging.get_logger(__name__) @@ -82,7 +82,6 @@ class SD3Transformer2DLoadersMixin: ) empty_device_cache() - device_synchronize() return attn_procs @@ -152,7 +151,6 @@ class SD3Transformer2DLoadersMixin: device_map = {"": self.device} load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype) empty_device_cache() - device_synchronize() return image_proj diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 89c6449ff5..1d698e5a8b 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -43,7 +43,7 @@ from ..utils import ( is_torch_version, logging, ) -from ..utils.torch_utils import device_synchronize, empty_device_cache +from ..utils.torch_utils import empty_device_cache from .lora_base import _func_optionally_disable_offloading from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME from .utils import AttnProcsLayers @@ -755,7 +755,6 @@ class UNet2DConditionLoadersMixin: device_map = {"": self.device} load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype) empty_device_cache() - device_synchronize() return image_projection @@ -854,7 +853,6 @@ class UNet2DConditionLoadersMixin: key_id += 2 empty_device_cache() - device_synchronize() return attn_procs diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 4918fae91d..01ebb1a910 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -62,7 +62,7 @@ from ..utils.hub_utils import ( load_or_create_model_card, populate_model_card, ) -from ..utils.torch_utils import device_synchronize, empty_device_cache +from ..utils.torch_utils import empty_device_cache from .model_loading_utils import ( _caching_allocator_warmup, _determine_device_map, @@ -1590,10 +1590,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict) error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers) - # Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is - # required because we move tensors with non_blocking=True, which is slightly faster for model loading. empty_device_cache() - device_synchronize() if offload_index is not None and len(offload_index) > 0: save_offload_index(offload_index, offload_folder) @@ -1930,4 +1927,9 @@ class LegacyModelMixin(ModelMixin): # resolve remapping remapped_class = _fetch_remapped_cls_from_config(config, cls) - return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs_copy) + if remapped_class is cls: + return super(LegacyModelMixin, remapped_class).from_pretrained( + pretrained_model_name_or_path, **kwargs_copy + ) + else: + return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs_copy) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py index 82ef4b6391..65e2fe6617 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py @@ -19,7 +19,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import PIL.Image import torch -import torch.nn.functional as F from transformers import ( CLIPImageProcessor, CLIPTextModel, @@ -38,7 +37,13 @@ from ...loaders import ( StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) -from ...models import AutoencoderKL, ControlNetModel, ControlNetUnionModel, ImageProjection, UNet2DConditionModel +from ...models import ( + AutoencoderKL, + ControlNetUnionModel, + ImageProjection, + MultiControlNetUnionModel, + UNet2DConditionModel, +) from ...models.attention_processor import ( AttnProcessor2_0, XFormersAttnProcessor, @@ -262,7 +267,9 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( tokenizer: CLIPTokenizer, tokenizer_2: CLIPTokenizer, unet: UNet2DConditionModel, - controlnet: ControlNetUnionModel, + controlnet: Union[ + ControlNetUnionModel, List[ControlNetUnionModel], Tuple[ControlNetUnionModel], MultiControlNetUnionModel + ], scheduler: KarrasDiffusionSchedulers, requires_aesthetics_score: bool = False, force_zeros_for_empty_prompt: bool = True, @@ -272,8 +279,8 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( ): super().__init__() - if not isinstance(controlnet, ControlNetUnionModel): - raise ValueError("Expected `controlnet` to be of type `ControlNetUnionModel`.") + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetUnionModel(controlnet) self.register_modules( vae=vae, @@ -649,6 +656,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( controlnet_conditioning_scale=1.0, control_guidance_start=0.0, control_guidance_end=1.0, + control_mode=None, callback_on_step_end_tensor_inputs=None, ): if strength < 0 or strength > 1: @@ -722,28 +730,44 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." ) + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, MultiControlNetUnionModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + # Check `image` - is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( - self.controlnet, torch._dynamo.eval_frame.OptimizedModule - ) - if ( - isinstance(self.controlnet, ControlNetModel) - or is_compiled - and isinstance(self.controlnet._orig_mod, ControlNetModel) - ): - self.check_image(image, prompt, prompt_embeds) - elif ( - isinstance(self.controlnet, ControlNetUnionModel) - or is_compiled - and isinstance(self.controlnet._orig_mod, ControlNetUnionModel) - ): - self.check_image(image, prompt, prompt_embeds) - else: - assert False + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + if isinstance(controlnet, ControlNetUnionModel): + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + elif isinstance(controlnet, MultiControlNetUnionModel): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + elif not all(isinstance(i, list) for i in image): + raise ValueError("For multiple controlnets: elements of `image` must be list of conditionings.") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + + for images_ in image: + for image_ in images_: + self.check_image(image_, prompt, prompt_embeds) if not isinstance(control_guidance_start, (tuple, list)): control_guidance_start = [control_guidance_start] + if isinstance(controlnet, MultiControlNetUnionModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + if not isinstance(control_guidance_end, (tuple, list)): control_guidance_end = [control_guidance_end] @@ -762,6 +786,15 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( if end > 1.0: raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + # Check `control_mode` + if isinstance(controlnet, ControlNetUnionModel): + if max(control_mode) >= controlnet.config.num_control_type: + raise ValueError(f"control_mode: must be lower than {controlnet.config.num_control_type}.") + elif isinstance(controlnet, MultiControlNetUnionModel): + for _control_mode, _controlnet in zip(control_mode, self.controlnet.nets): + if max(_control_mode) >= _controlnet.config.num_control_type: + raise ValueError(f"control_mode: must be lower than {_controlnet.config.num_control_type}.") + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: raise ValueError( "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." @@ -1049,7 +1082,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, image: PipelineImageInput = None, - control_image: PipelineImageInput = None, + control_image: Union[PipelineImageInput, List[PipelineImageInput]] = None, height: Optional[int] = None, width: Optional[int] = None, strength: float = 0.8, @@ -1074,7 +1107,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( guess_mode: bool = False, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, - control_mode: Optional[Union[int, List[int]]] = None, + control_mode: Optional[Union[int, List[int], List[List[int]]]] = None, original_size: Tuple[int, int] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Tuple[int, int] = None, @@ -1104,13 +1137,13 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): The initial image will be used as the starting point for the image generation process. Can also accept image latents as `image`, if passing latents directly, it will not be encoded again. - control_image (`PipelineImageInput`): - The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If - the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also - be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height - and/or width are passed, `image` is resized according to them. If multiple ControlNets are specified in - init, images must be passed as a list such that each element of the list can be correctly batched for - input to a single controlnet. + control_image (`PipelineImageInput` or `List[PipelineImageInput]`, *optional*): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. height (`int`, *optional*, defaults to the size of control_image): The height in pixels of the generated image. Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) @@ -1184,16 +1217,21 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): - The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added - to the residual in the original unet. If multiple ControlNets are specified in init, you can set the - corresponding scale as a list. + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. guess_mode (`bool`, *optional*, defaults to `False`): In this mode, the ControlNet encoder will try best to recognize the content of the input image even if you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): - The percentage of total steps at which the controlnet starts applying. + The percentage of total steps at which the ControlNet starts applying. control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): - The percentage of total steps at which the controlnet stops applying. + The percentage of total steps at which the ControlNet stops applying. + control_mode (`int` or `List[int]` or `List[List[int]], *optional*): + The control condition types for the ControlNet. See the ControlNet's model card forinformation on the + available control modes. If multiple ControlNets are specified in `init`, control_mode should be a list + where each ControlNet should have its corresponding control mode list. Should reflect the order of + conditions in control_image original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as @@ -1273,12 +1311,6 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet - # align format for control guidance - if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): - control_guidance_start = len(control_guidance_end) * [control_guidance_start] - elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): - control_guidance_end = len(control_guidance_start) * [control_guidance_end] - if not isinstance(control_image, list): control_image = [control_image] else: @@ -1287,37 +1319,56 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( if not isinstance(control_mode, list): control_mode = [control_mode] - if len(control_image) != len(control_mode): - raise ValueError("Expected len(control_image) == len(control_type)") + if isinstance(controlnet, MultiControlNetUnionModel): + control_image = [[item] for item in control_image] + control_mode = [[item] for item in control_mode] - num_control_type = controlnet.config.num_control_type - - # 1. Check inputs - control_type = [0 for _ in range(num_control_type)] - for _image, control_idx in zip(control_image, control_mode): - control_type[control_idx] = 1 - self.check_inputs( - prompt, - prompt_2, - _image, - strength, - num_inference_steps, - callback_steps, - negative_prompt, - negative_prompt_2, - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - ip_adapter_image, - ip_adapter_image_embeds, - controlnet_conditioning_scale, - control_guidance_start, - control_guidance_end, - callback_on_step_end_tensor_inputs, + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode) + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], ) - control_type = torch.Tensor(control_type) + if isinstance(controlnet_conditioning_scale, float): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode) + controlnet_conditioning_scale = [controlnet_conditioning_scale] * mult + + # 1. Check inputs + self.check_inputs( + prompt, + prompt_2, + control_image, + strength, + num_inference_steps, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + control_mode, + callback_on_step_end_tensor_inputs, + ) + + if isinstance(controlnet, ControlNetUnionModel): + control_type = torch.zeros(controlnet.config.num_control_type).scatter_(0, torch.tensor(control_mode), 1) + elif isinstance(controlnet, MultiControlNetUnionModel): + control_type = [ + torch.zeros(controlnet_.config.num_control_type).scatter_(0, torch.tensor(control_mode_), 1) + for control_mode_, controlnet_ in zip(control_mode, self.controlnet.nets) + ] self._guidance_scale = guidance_scale self._clip_skip = clip_skip @@ -1334,7 +1385,11 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( device = self._execution_device - global_pool_conditions = controlnet.config.global_pool_conditions + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetUnionModel) + else controlnet.nets[0].config.global_pool_conditions + ) guess_mode = guess_mode or global_pool_conditions # 3.1. Encode input prompt @@ -1372,22 +1427,55 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( self.do_classifier_free_guidance, ) - # 4. Prepare image and controlnet_conditioning_image + # 4.1 Prepare image image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) - for idx, _ in enumerate(control_image): - control_image[idx] = self.prepare_control_image( - image=control_image[idx], - width=width, - height=height, - batch_size=batch_size * num_images_per_prompt, - num_images_per_prompt=num_images_per_prompt, - device=device, - dtype=controlnet.dtype, - do_classifier_free_guidance=self.do_classifier_free_guidance, - guess_mode=guess_mode, - ) - height, width = control_image[idx].shape[-2:] + # 4.2 Prepare control images + if isinstance(controlnet, ControlNetUnionModel): + control_images = [] + + for image_ in control_image: + image_ = self.prepare_control_image( + image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + control_images.append(image_) + + control_image = control_images + height, width = control_image[0].shape[-2:] + + elif isinstance(controlnet, MultiControlNetUnionModel): + control_images = [] + + for control_image_ in control_image: + images = [] + + for image_ in control_image_: + image_ = self.prepare_control_image( + image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + images.append(image_) + control_images.append(images) + + control_image = control_images + height, width = control_image[0][0].shape[-2:] # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) @@ -1414,10 +1502,11 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( # 7.1 Create tensor stating which controlnets to keep controlnet_keep = [] for i in range(len(timesteps)): - controlnet_keep.append( - 1.0 - - float(i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end) - ) + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps) # 7.2 Prepare added time ids & embeddings original_size = original_size or (height, width) @@ -1460,12 +1549,25 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device) - control_type = ( - control_type.reshape(1, -1) - .to(device, dtype=prompt_embeds.dtype) - .repeat(batch_size * num_images_per_prompt * 2, 1) + + control_type_repeat_factor = ( + batch_size * num_images_per_prompt * (2 if self.do_classifier_free_guidance else 1) ) + if isinstance(controlnet, ControlNetUnionModel): + control_type = ( + control_type.reshape(1, -1) + .to(self._execution_device, dtype=prompt_embeds.dtype) + .repeat(control_type_repeat_factor, 1) + ) + elif isinstance(controlnet, MultiControlNetUnionModel): + control_type = [ + _control_type.reshape(1, -1) + .to(self._execution_device, dtype=prompt_embeds.dtype) + .repeat(control_type_repeat_factor, 1) + for _control_type in control_type + ] + # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 073d94750a..6e6e5a4c7f 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -840,6 +840,8 @@ class FluxPipeline( # 5. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas: + sigmas = None image_seq_len = latents.shape[1] mu = calculate_shift( image_seq_len, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py index bd8609a11a..06c2076816 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py @@ -383,7 +383,8 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): # set timesteps self.scheduler.set_timesteps(num_inference_steps) - latents = latents * np.float64(self.scheduler.init_noise_sigma) + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py index 6a952a7ae6..141d849ec3 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py @@ -483,7 +483,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): self.scheduler.set_timesteps(num_inference_steps) # scale the initial noise by the standard deviation required by the scheduler - latents = latents * np.float64(self.scheduler.init_noise_sigma) + latents = latents * self.scheduler.init_noise_sigma # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py index 3f10764dc7..882fa98b07 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py @@ -481,7 +481,7 @@ class OnnxStableDiffusionUpscalePipeline(DiffusionPipeline): timesteps = self.scheduler.timesteps # Scale the initial noise by the standard deviation required by the scheduler - latents = latents * np.float64(self.scheduler.init_noise_sigma) + latents = latents * self.scheduler.init_noise_sigma # 5. Add noise to image noise_level = np.array([noise_level]).astype(np.int64) diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index 748a7e39c0..7d8685ba10 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -153,6 +153,8 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): flow_shift: Optional[float] = 1.0, timestep_spacing: str = "linspace", steps_offset: int = 0, + use_dynamic_shifting: bool = False, + time_shift_type: str = "exponential", ): if self.config.use_beta_sigmas and not is_scipy_available(): raise ImportError("Make sure to install scipy if you want to use beta sigmas.") @@ -232,7 +234,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): """ self._begin_index = begin_index - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + def set_timesteps( + self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None + ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -242,6 +246,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ + if mu is not None: + assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential" + self.config.flow_shift = np.exp(mu) # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = ( diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 1a648af5a0..d07ff8b200 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -230,6 +230,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): timestep_spacing: str = "linspace", steps_offset: int = 0, rescale_betas_zero_snr: bool = False, + use_dynamic_shifting: bool = False, + time_shift_type: str = "exponential", ): if self.config.use_beta_sigmas and not is_scipy_available(): raise ImportError("Make sure to install scipy if you want to use beta sigmas.") @@ -330,6 +332,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): self, num_inference_steps: int = None, device: Union[str, torch.device] = None, + mu: Optional[float] = None, timesteps: Optional[List[int]] = None, ): """ @@ -345,6 +348,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`, and `timestep_spacing` attribute will be ignored. """ + if mu is not None: + assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential" + self.config.flow_shift = np.exp(mu) if num_inference_steps is None and timesteps is None: raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.") if num_inference_steps is not None and timesteps is not None: diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index 9e3e830039..8663210a62 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -169,6 +169,8 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" lambda_min_clipped: float = -float("inf"), variance_type: Optional[str] = None, + use_dynamic_shifting: bool = False, + time_shift_type: str = "exponential", ): if self.config.use_beta_sigmas and not is_scipy_available(): raise ImportError("Make sure to install scipy if you want to use beta sigmas.") @@ -301,6 +303,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): self, num_inference_steps: int = None, device: Union[str, torch.device] = None, + mu: Optional[float] = None, timesteps: Optional[List[int]] = None, ): """ @@ -316,6 +319,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): timestep spacing strategy of equal spacing between timesteps schedule is used. If `timesteps` is passed, `num_inference_steps` must be `None`. """ + if mu is not None: + assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential" + self.config.flow_shift = np.exp(mu) if num_inference_steps is None and timesteps is None: raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.") if num_inference_steps is not None and timesteps is not None: diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index 8b1f699b10..1d2378fd4f 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -212,6 +212,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): steps_offset: int = 0, final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" rescale_betas_zero_snr: bool = False, + use_dynamic_shifting: bool = False, + time_shift_type: str = "exponential", ): if self.config.use_beta_sigmas and not is_scipy_available(): raise ImportError("Make sure to install scipy if you want to use beta sigmas.") @@ -298,7 +300,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): """ self._begin_index = begin_index - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + def set_timesteps( + self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None + ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -309,6 +313,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891 + if mu is not None: + assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential" + self.config.flow_shift = np.exp(mu) if self.config.timestep_spacing == "linspace": timesteps = ( np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 755ff81883..d33b80dba0 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -5,12 +5,14 @@ import math import random import re import warnings +from contextlib import contextmanager from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import numpy as np import torch from .models import UNet2DConditionModel +from .pipelines import DiffusionPipeline from .schedulers import SchedulerMixin from .utils import ( convert_state_dict_to_diffusers, @@ -318,6 +320,39 @@ def free_memory(): torch.xpu.empty_cache() +@contextmanager +def offload_models( + *modules: Union[torch.nn.Module, DiffusionPipeline], device: Union[str, torch.device], offload: bool = True +): + """ + Context manager that, if offload=True, moves each module to `device` on enter, then moves it back to its original + device on exit. + + Args: + device (`str` or `torch.Device`): Device to move the `modules` to. + offload (`bool`): Flag to enable offloading. + """ + if offload: + is_model = not any(isinstance(m, DiffusionPipeline) for m in modules) + # record where each module was + if is_model: + original_devices = [next(m.parameters()).device for m in modules] + else: + assert len(modules) == 1 + original_devices = modules[0].device + # move to target device + for m in modules: + m.to(device) + + try: + yield + finally: + if offload: + # move back to original devices + for m, orig_dev in zip(modules, original_devices): + m.to(orig_dev) + + def parse_buckets_string(buckets_str): """Parses a string defining buckets into a list of (height, width) tuples.""" if not buckets_str: