diff --git a/examples/server/requirements.txt b/examples/server/requirements.txt index 5cc6e2303b..b91a8861a0 100644 --- a/examples/server/requirements.txt +++ b/examples/server/requirements.txt @@ -1,10 +1,10 @@ # This file was autogenerated by uv via the following command: # uv pip compile requirements.in -o requirements.txt -aiohappyeyeballs==2.4.3 +aiohappyeyeballs==2.6.1 # via aiohttp -aiohttp==3.10.10 +aiohttp==3.12.14 # via -r requirements.in -aiosignal==1.3.1 +aiosignal==1.4.0 # via aiohttp annotated-types==0.7.0 # via pydantic @@ -29,7 +29,6 @@ filelock==3.16.1 # huggingface-hub # torch # transformers - # triton frozenlist==1.5.0 # via # aiohttp @@ -111,7 +110,9 @@ prometheus-client==0.21.0 prometheus-fastapi-instrumentator==7.0.0 # via -r requirements.in propcache==0.2.0 - # via yarl + # via + # aiohttp + # yarl py-consul==1.5.3 # via -r requirements.in pydantic==2.9.2 @@ -155,7 +156,9 @@ triton==3.3.0 # via torch typing-extensions==4.12.2 # via + # aiosignal # anyio + # exceptiongroup # fastapi # huggingface-hub # multidict @@ -168,5 +171,5 @@ urllib3==2.5.0 # via requests uvicorn==0.32.0 # via -r requirements.in -yarl==1.16.0 +yarl==1.18.3 # via aiohttp diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 048ddcae32..91efdb0396 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -763,4 +763,7 @@ class LegacyConfigMixin(ConfigMixin): # resolve remapping remapped_class = _fetch_remapped_cls_from_config(config, cls) - return remapped_class.from_config(config, return_unused_kwargs, **kwargs) + if remapped_class is cls: + return super(LegacyConfigMixin, remapped_class).from_config(config, return_unused_kwargs, **kwargs) + else: + return remapped_class.from_config(config, return_unused_kwargs, **kwargs) diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index b9b86cf480..76fefc1260 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -24,7 +24,7 @@ from typing_extensions import Self from .. import __version__ from ..quantizers import DiffusersAutoQuantizer from ..utils import deprecate, is_accelerate_available, logging -from ..utils.torch_utils import device_synchronize, empty_device_cache +from ..utils.torch_utils import empty_device_cache from .single_file_utils import ( SingleFileComponentError, convert_animatediff_checkpoint_to_diffusers, @@ -431,10 +431,7 @@ class FromOriginalModelMixin: keep_in_fp32_modules=keep_in_fp32_modules, unexpected_keys=unexpected_keys, ) - # Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is - # required because we move tensors with non_blocking=True, which is slightly faster for model loading. empty_device_cache() - device_synchronize() else: _, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 5fafcb02be..a804ea80a9 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -46,7 +46,7 @@ from ..utils import ( ) from ..utils.constants import DIFFUSERS_REQUEST_TIMEOUT from ..utils.hub_utils import _get_model_file -from ..utils.torch_utils import device_synchronize, empty_device_cache +from ..utils.torch_utils import empty_device_cache if is_transformers_available(): @@ -1690,10 +1690,7 @@ def create_diffusers_clip_model_from_ldm( if is_accelerate_available(): load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) - # Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is - # required because we move tensors with non_blocking=True, which is slightly faster for model loading. empty_device_cache() - device_synchronize() else: model.load_state_dict(diffusers_format_checkpoint, strict=False) @@ -2153,10 +2150,7 @@ def create_diffusers_t5_model_from_checkpoint( if is_accelerate_available(): load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) - # Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is - # required because we move tensors with non_blocking=True, which is slightly faster for model loading. empty_device_cache() - device_synchronize() else: model.load_state_dict(diffusers_format_checkpoint) diff --git a/src/diffusers/loaders/transformer_flux.py b/src/diffusers/loaders/transformer_flux.py index af03d09029..0de8095948 100644 --- a/src/diffusers/loaders/transformer_flux.py +++ b/src/diffusers/loaders/transformer_flux.py @@ -19,7 +19,7 @@ from ..models.embeddings import ( ) from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta from ..utils import is_accelerate_available, is_torch_version, logging -from ..utils.torch_utils import device_synchronize, empty_device_cache +from ..utils.torch_utils import empty_device_cache if is_accelerate_available(): @@ -82,7 +82,6 @@ class FluxTransformer2DLoadersMixin: device_map = {"": self.device} load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype) empty_device_cache() - device_synchronize() return image_projection @@ -158,7 +157,6 @@ class FluxTransformer2DLoadersMixin: key_id += 1 empty_device_cache() - device_synchronize() return attn_procs diff --git a/src/diffusers/loaders/transformer_sd3.py b/src/diffusers/loaders/transformer_sd3.py index 4421f46dfc..1bc3a9c7a8 100644 --- a/src/diffusers/loaders/transformer_sd3.py +++ b/src/diffusers/loaders/transformer_sd3.py @@ -18,7 +18,7 @@ from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0 from ..models.embeddings import IPAdapterTimeImageProjection from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta from ..utils import is_accelerate_available, is_torch_version, logging -from ..utils.torch_utils import device_synchronize, empty_device_cache +from ..utils.torch_utils import empty_device_cache logger = logging.get_logger(__name__) @@ -82,7 +82,6 @@ class SD3Transformer2DLoadersMixin: ) empty_device_cache() - device_synchronize() return attn_procs @@ -152,7 +151,6 @@ class SD3Transformer2DLoadersMixin: device_map = {"": self.device} load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype) empty_device_cache() - device_synchronize() return image_proj diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 89c6449ff5..1d698e5a8b 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -43,7 +43,7 @@ from ..utils import ( is_torch_version, logging, ) -from ..utils.torch_utils import device_synchronize, empty_device_cache +from ..utils.torch_utils import empty_device_cache from .lora_base import _func_optionally_disable_offloading from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME from .utils import AttnProcsLayers @@ -755,7 +755,6 @@ class UNet2DConditionLoadersMixin: device_map = {"": self.device} load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype) empty_device_cache() - device_synchronize() return image_projection @@ -854,7 +853,6 @@ class UNet2DConditionLoadersMixin: key_id += 2 empty_device_cache() - device_synchronize() return attn_procs diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index d7b2136b4a..3707f70b9d 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -62,7 +62,7 @@ from ..utils.hub_utils import ( load_or_create_model_card, populate_model_card, ) -from ..utils.torch_utils import device_synchronize, empty_device_cache +from ..utils.torch_utils import empty_device_cache from .model_loading_utils import ( _caching_allocator_warmup, _determine_device_map, @@ -1540,10 +1540,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict) error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers) - # Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is - # required because we move tensors with non_blocking=True, which is slightly faster for model loading. empty_device_cache() - device_synchronize() if offload_index is not None and len(offload_index) > 0: save_offload_index(offload_index, offload_folder) @@ -1880,4 +1877,9 @@ class LegacyModelMixin(ModelMixin): # resolve remapping remapped_class = _fetch_remapped_cls_from_config(config, cls) - return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs_copy) + if remapped_class is cls: + return super(LegacyModelMixin, remapped_class).from_pretrained( + pretrained_model_name_or_path, **kwargs_copy + ) + else: + return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs_copy) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py index 82ef4b6391..65e2fe6617 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py @@ -19,7 +19,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import PIL.Image import torch -import torch.nn.functional as F from transformers import ( CLIPImageProcessor, CLIPTextModel, @@ -38,7 +37,13 @@ from ...loaders import ( StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) -from ...models import AutoencoderKL, ControlNetModel, ControlNetUnionModel, ImageProjection, UNet2DConditionModel +from ...models import ( + AutoencoderKL, + ControlNetUnionModel, + ImageProjection, + MultiControlNetUnionModel, + UNet2DConditionModel, +) from ...models.attention_processor import ( AttnProcessor2_0, XFormersAttnProcessor, @@ -262,7 +267,9 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( tokenizer: CLIPTokenizer, tokenizer_2: CLIPTokenizer, unet: UNet2DConditionModel, - controlnet: ControlNetUnionModel, + controlnet: Union[ + ControlNetUnionModel, List[ControlNetUnionModel], Tuple[ControlNetUnionModel], MultiControlNetUnionModel + ], scheduler: KarrasDiffusionSchedulers, requires_aesthetics_score: bool = False, force_zeros_for_empty_prompt: bool = True, @@ -272,8 +279,8 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( ): super().__init__() - if not isinstance(controlnet, ControlNetUnionModel): - raise ValueError("Expected `controlnet` to be of type `ControlNetUnionModel`.") + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetUnionModel(controlnet) self.register_modules( vae=vae, @@ -649,6 +656,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( controlnet_conditioning_scale=1.0, control_guidance_start=0.0, control_guidance_end=1.0, + control_mode=None, callback_on_step_end_tensor_inputs=None, ): if strength < 0 or strength > 1: @@ -722,28 +730,44 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." ) + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, MultiControlNetUnionModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + # Check `image` - is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( - self.controlnet, torch._dynamo.eval_frame.OptimizedModule - ) - if ( - isinstance(self.controlnet, ControlNetModel) - or is_compiled - and isinstance(self.controlnet._orig_mod, ControlNetModel) - ): - self.check_image(image, prompt, prompt_embeds) - elif ( - isinstance(self.controlnet, ControlNetUnionModel) - or is_compiled - and isinstance(self.controlnet._orig_mod, ControlNetUnionModel) - ): - self.check_image(image, prompt, prompt_embeds) - else: - assert False + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + if isinstance(controlnet, ControlNetUnionModel): + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + elif isinstance(controlnet, MultiControlNetUnionModel): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + elif not all(isinstance(i, list) for i in image): + raise ValueError("For multiple controlnets: elements of `image` must be list of conditionings.") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + + for images_ in image: + for image_ in images_: + self.check_image(image_, prompt, prompt_embeds) if not isinstance(control_guidance_start, (tuple, list)): control_guidance_start = [control_guidance_start] + if isinstance(controlnet, MultiControlNetUnionModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + if not isinstance(control_guidance_end, (tuple, list)): control_guidance_end = [control_guidance_end] @@ -762,6 +786,15 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( if end > 1.0: raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + # Check `control_mode` + if isinstance(controlnet, ControlNetUnionModel): + if max(control_mode) >= controlnet.config.num_control_type: + raise ValueError(f"control_mode: must be lower than {controlnet.config.num_control_type}.") + elif isinstance(controlnet, MultiControlNetUnionModel): + for _control_mode, _controlnet in zip(control_mode, self.controlnet.nets): + if max(_control_mode) >= _controlnet.config.num_control_type: + raise ValueError(f"control_mode: must be lower than {_controlnet.config.num_control_type}.") + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: raise ValueError( "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." @@ -1049,7 +1082,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, image: PipelineImageInput = None, - control_image: PipelineImageInput = None, + control_image: Union[PipelineImageInput, List[PipelineImageInput]] = None, height: Optional[int] = None, width: Optional[int] = None, strength: float = 0.8, @@ -1074,7 +1107,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( guess_mode: bool = False, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, - control_mode: Optional[Union[int, List[int]]] = None, + control_mode: Optional[Union[int, List[int], List[List[int]]]] = None, original_size: Tuple[int, int] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Tuple[int, int] = None, @@ -1104,13 +1137,13 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): The initial image will be used as the starting point for the image generation process. Can also accept image latents as `image`, if passing latents directly, it will not be encoded again. - control_image (`PipelineImageInput`): - The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If - the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also - be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height - and/or width are passed, `image` is resized according to them. If multiple ControlNets are specified in - init, images must be passed as a list such that each element of the list can be correctly batched for - input to a single controlnet. + control_image (`PipelineImageInput` or `List[PipelineImageInput]`, *optional*): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. height (`int`, *optional*, defaults to the size of control_image): The height in pixels of the generated image. Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) @@ -1184,16 +1217,21 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): - The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added - to the residual in the original unet. If multiple ControlNets are specified in init, you can set the - corresponding scale as a list. + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. guess_mode (`bool`, *optional*, defaults to `False`): In this mode, the ControlNet encoder will try best to recognize the content of the input image even if you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): - The percentage of total steps at which the controlnet starts applying. + The percentage of total steps at which the ControlNet starts applying. control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): - The percentage of total steps at which the controlnet stops applying. + The percentage of total steps at which the ControlNet stops applying. + control_mode (`int` or `List[int]` or `List[List[int]], *optional*): + The control condition types for the ControlNet. See the ControlNet's model card forinformation on the + available control modes. If multiple ControlNets are specified in `init`, control_mode should be a list + where each ControlNet should have its corresponding control mode list. Should reflect the order of + conditions in control_image original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as @@ -1273,12 +1311,6 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet - # align format for control guidance - if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): - control_guidance_start = len(control_guidance_end) * [control_guidance_start] - elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): - control_guidance_end = len(control_guidance_start) * [control_guidance_end] - if not isinstance(control_image, list): control_image = [control_image] else: @@ -1287,37 +1319,56 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( if not isinstance(control_mode, list): control_mode = [control_mode] - if len(control_image) != len(control_mode): - raise ValueError("Expected len(control_image) == len(control_type)") + if isinstance(controlnet, MultiControlNetUnionModel): + control_image = [[item] for item in control_image] + control_mode = [[item] for item in control_mode] - num_control_type = controlnet.config.num_control_type - - # 1. Check inputs - control_type = [0 for _ in range(num_control_type)] - for _image, control_idx in zip(control_image, control_mode): - control_type[control_idx] = 1 - self.check_inputs( - prompt, - prompt_2, - _image, - strength, - num_inference_steps, - callback_steps, - negative_prompt, - negative_prompt_2, - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - ip_adapter_image, - ip_adapter_image_embeds, - controlnet_conditioning_scale, - control_guidance_start, - control_guidance_end, - callback_on_step_end_tensor_inputs, + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode) + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], ) - control_type = torch.Tensor(control_type) + if isinstance(controlnet_conditioning_scale, float): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode) + controlnet_conditioning_scale = [controlnet_conditioning_scale] * mult + + # 1. Check inputs + self.check_inputs( + prompt, + prompt_2, + control_image, + strength, + num_inference_steps, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + control_mode, + callback_on_step_end_tensor_inputs, + ) + + if isinstance(controlnet, ControlNetUnionModel): + control_type = torch.zeros(controlnet.config.num_control_type).scatter_(0, torch.tensor(control_mode), 1) + elif isinstance(controlnet, MultiControlNetUnionModel): + control_type = [ + torch.zeros(controlnet_.config.num_control_type).scatter_(0, torch.tensor(control_mode_), 1) + for control_mode_, controlnet_ in zip(control_mode, self.controlnet.nets) + ] self._guidance_scale = guidance_scale self._clip_skip = clip_skip @@ -1334,7 +1385,11 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( device = self._execution_device - global_pool_conditions = controlnet.config.global_pool_conditions + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetUnionModel) + else controlnet.nets[0].config.global_pool_conditions + ) guess_mode = guess_mode or global_pool_conditions # 3.1. Encode input prompt @@ -1372,22 +1427,55 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( self.do_classifier_free_guidance, ) - # 4. Prepare image and controlnet_conditioning_image + # 4.1 Prepare image image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) - for idx, _ in enumerate(control_image): - control_image[idx] = self.prepare_control_image( - image=control_image[idx], - width=width, - height=height, - batch_size=batch_size * num_images_per_prompt, - num_images_per_prompt=num_images_per_prompt, - device=device, - dtype=controlnet.dtype, - do_classifier_free_guidance=self.do_classifier_free_guidance, - guess_mode=guess_mode, - ) - height, width = control_image[idx].shape[-2:] + # 4.2 Prepare control images + if isinstance(controlnet, ControlNetUnionModel): + control_images = [] + + for image_ in control_image: + image_ = self.prepare_control_image( + image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + control_images.append(image_) + + control_image = control_images + height, width = control_image[0].shape[-2:] + + elif isinstance(controlnet, MultiControlNetUnionModel): + control_images = [] + + for control_image_ in control_image: + images = [] + + for image_ in control_image_: + image_ = self.prepare_control_image( + image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + images.append(image_) + control_images.append(images) + + control_image = control_images + height, width = control_image[0][0].shape[-2:] # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) @@ -1414,10 +1502,11 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( # 7.1 Create tensor stating which controlnets to keep controlnet_keep = [] for i in range(len(timesteps)): - controlnet_keep.append( - 1.0 - - float(i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end) - ) + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps) # 7.2 Prepare added time ids & embeddings original_size = original_size or (height, width) @@ -1460,12 +1549,25 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device) - control_type = ( - control_type.reshape(1, -1) - .to(device, dtype=prompt_embeds.dtype) - .repeat(batch_size * num_images_per_prompt * 2, 1) + + control_type_repeat_factor = ( + batch_size * num_images_per_prompt * (2 if self.do_classifier_free_guidance else 1) ) + if isinstance(controlnet, ControlNetUnionModel): + control_type = ( + control_type.reshape(1, -1) + .to(self._execution_device, dtype=prompt_embeds.dtype) + .repeat(control_type_repeat_factor, 1) + ) + elif isinstance(controlnet, MultiControlNetUnionModel): + control_type = [ + _control_type.reshape(1, -1) + .to(self._execution_device, dtype=prompt_embeds.dtype) + .repeat(control_type_repeat_factor, 1) + for _control_type in control_type + ] + # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py index bd8609a11a..06c2076816 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py @@ -383,7 +383,8 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): # set timesteps self.scheduler.set_timesteps(num_inference_steps) - latents = latents * np.float64(self.scheduler.init_noise_sigma) + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py index 6a952a7ae6..141d849ec3 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py @@ -483,7 +483,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): self.scheduler.set_timesteps(num_inference_steps) # scale the initial noise by the standard deviation required by the scheduler - latents = latents * np.float64(self.scheduler.init_noise_sigma) + latents = latents * self.scheduler.init_noise_sigma # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py index 3f10764dc7..882fa98b07 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py @@ -481,7 +481,7 @@ class OnnxStableDiffusionUpscalePipeline(DiffusionPipeline): timesteps = self.scheduler.timesteps # Scale the initial noise by the standard deviation required by the scheduler - latents = latents * np.float64(self.scheduler.init_noise_sigma) + latents = latents * self.scheduler.init_noise_sigma # 5. Add noise to image noise_level = np.array([noise_level]).astype(np.int64) diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py index 0df0e028ff..a848ec615e 100644 --- a/tests/pipelines/flux/test_pipeline_flux.py +++ b/tests/pipelines/flux/test_pipeline_flux.py @@ -155,7 +155,7 @@ class FluxPipelineFastTests( # Outputs should be different here # For some reasons, they don't show large differences - assert max_diff > 1e-6 + self.assertGreater(max_diff, 1e-6, "Outputs should be different for different prompts.") def test_fused_qkv_projections(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator @@ -187,14 +187,17 @@ class FluxPipelineFastTests( image = pipe(**inputs).images image_slice_disabled = image[0, -3:, -3:, -1] - assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), ( - "Fusion of QKV projections shouldn't affect the outputs." + self.assertTrue( + np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), + ("Fusion of QKV projections shouldn't affect the outputs."), ) - assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), ( - "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." + self.assertTrue( + np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), + ("Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."), ) - assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), ( - "Original outputs should match when fused QKV projections are disabled." + self.assertTrue( + np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), + ("Original outputs should match when fused QKV projections are disabled."), ) def test_flux_image_output_shape(self): @@ -209,7 +212,11 @@ class FluxPipelineFastTests( inputs.update({"height": height, "width": width}) image = pipe(**inputs).images[0] output_height, output_width, _ = image.shape - assert (output_height, output_width) == (expected_height, expected_width) + self.assertEqual( + (output_height, output_width), + (expected_height, expected_width), + f"Output shape {image.shape} does not match expected shape {(expected_height, expected_width)}", + ) def test_flux_true_cfg(self): pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) @@ -220,7 +227,9 @@ class FluxPipelineFastTests( inputs["negative_prompt"] = "bad quality" inputs["true_cfg_scale"] = 2.0 true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0] - assert not np.allclose(no_true_cfg_out, true_cfg_out) + self.assertFalse( + np.allclose(no_true_cfg_out, true_cfg_out), "Outputs should be different when true_cfg_scale is set." + ) @nightly @@ -269,45 +278,17 @@ class FluxPipelineSlowTests(unittest.TestCase): image = pipe(**inputs).images[0] image_slice = image[0, :10, :10] + # fmt: off expected_slice = np.array( - [ - 0.3242, - 0.3203, - 0.3164, - 0.3164, - 0.3125, - 0.3125, - 0.3281, - 0.3242, - 0.3203, - 0.3301, - 0.3262, - 0.3242, - 0.3281, - 0.3242, - 0.3203, - 0.3262, - 0.3262, - 0.3164, - 0.3262, - 0.3281, - 0.3184, - 0.3281, - 0.3281, - 0.3203, - 0.3281, - 0.3281, - 0.3164, - 0.3320, - 0.3320, - 0.3203, - ], + [0.3242, 0.3203, 0.3164, 0.3164, 0.3125, 0.3125, 0.3281, 0.3242, 0.3203, 0.3301, 0.3262, 0.3242, 0.3281, 0.3242, 0.3203, 0.3262, 0.3262, 0.3164, 0.3262, 0.3281, 0.3184, 0.3281, 0.3281, 0.3203, 0.3281, 0.3281, 0.3164, 0.3320, 0.3320, 0.3203], dtype=np.float32, ) + # fmt: on max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten()) - - assert max_diff < 1e-4 + self.assertLess( + max_diff, 1e-4, f"Image slice is different from expected slice: {image_slice} != {expected_slice}" + ) @slow @@ -377,42 +358,14 @@ class FluxIPAdapterPipelineSlowTests(unittest.TestCase): image = pipe(**inputs).images[0] image_slice = image[0, :10, :10] + # fmt: off expected_slice = np.array( - [ - 0.1855, - 0.1680, - 0.1406, - 0.1953, - 0.1699, - 0.1465, - 0.2012, - 0.1738, - 0.1484, - 0.2051, - 0.1797, - 0.1523, - 0.2012, - 0.1719, - 0.1445, - 0.2070, - 0.1777, - 0.1465, - 0.2090, - 0.1836, - 0.1484, - 0.2129, - 0.1875, - 0.1523, - 0.2090, - 0.1816, - 0.1484, - 0.2110, - 0.1836, - 0.1543, - ], + [0.1855, 0.1680, 0.1406, 0.1953, 0.1699, 0.1465, 0.2012, 0.1738, 0.1484, 0.2051, 0.1797, 0.1523, 0.2012, 0.1719, 0.1445, 0.2070, 0.1777, 0.1465, 0.2090, 0.1836, 0.1484, 0.2129, 0.1875, 0.1523, 0.2090, 0.1816, 0.1484, 0.2110, 0.1836, 0.1543], dtype=np.float32, ) + # fmt: on max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten()) - - assert max_diff < 1e-4, f"{image_slice} != {expected_slice}" + self.assertLess( + max_diff, 1e-4, f"Image slice is different from expected slice: {image_slice} != {expected_slice}" + )