mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Merge branch 'main' into modular-test
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile requirements.in -o requirements.txt
|
||||
aiohappyeyeballs==2.4.3
|
||||
aiohappyeyeballs==2.6.1
|
||||
# via aiohttp
|
||||
aiohttp==3.10.10
|
||||
aiohttp==3.12.14
|
||||
# via -r requirements.in
|
||||
aiosignal==1.3.1
|
||||
aiosignal==1.4.0
|
||||
# via aiohttp
|
||||
annotated-types==0.7.0
|
||||
# via pydantic
|
||||
@@ -29,7 +29,6 @@ filelock==3.16.1
|
||||
# huggingface-hub
|
||||
# torch
|
||||
# transformers
|
||||
# triton
|
||||
frozenlist==1.5.0
|
||||
# via
|
||||
# aiohttp
|
||||
@@ -111,7 +110,9 @@ prometheus-client==0.21.0
|
||||
prometheus-fastapi-instrumentator==7.0.0
|
||||
# via -r requirements.in
|
||||
propcache==0.2.0
|
||||
# via yarl
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
py-consul==1.5.3
|
||||
# via -r requirements.in
|
||||
pydantic==2.9.2
|
||||
@@ -155,7 +156,9 @@ triton==3.3.0
|
||||
# via torch
|
||||
typing-extensions==4.12.2
|
||||
# via
|
||||
# aiosignal
|
||||
# anyio
|
||||
# exceptiongroup
|
||||
# fastapi
|
||||
# huggingface-hub
|
||||
# multidict
|
||||
@@ -168,5 +171,5 @@ urllib3==2.5.0
|
||||
# via requests
|
||||
uvicorn==0.32.0
|
||||
# via -r requirements.in
|
||||
yarl==1.16.0
|
||||
yarl==1.18.3
|
||||
# via aiohttp
|
||||
|
||||
@@ -763,4 +763,7 @@ class LegacyConfigMixin(ConfigMixin):
|
||||
# resolve remapping
|
||||
remapped_class = _fetch_remapped_cls_from_config(config, cls)
|
||||
|
||||
return remapped_class.from_config(config, return_unused_kwargs, **kwargs)
|
||||
if remapped_class is cls:
|
||||
return super(LegacyConfigMixin, remapped_class).from_config(config, return_unused_kwargs, **kwargs)
|
||||
else:
|
||||
return remapped_class.from_config(config, return_unused_kwargs, **kwargs)
|
||||
|
||||
@@ -24,7 +24,7 @@ from typing_extensions import Self
|
||||
from .. import __version__
|
||||
from ..quantizers import DiffusersAutoQuantizer
|
||||
from ..utils import deprecate, is_accelerate_available, logging
|
||||
from ..utils.torch_utils import device_synchronize, empty_device_cache
|
||||
from ..utils.torch_utils import empty_device_cache
|
||||
from .single_file_utils import (
|
||||
SingleFileComponentError,
|
||||
convert_animatediff_checkpoint_to_diffusers,
|
||||
@@ -431,10 +431,7 @@ class FromOriginalModelMixin:
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
unexpected_keys=unexpected_keys,
|
||||
)
|
||||
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
|
||||
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
|
||||
empty_device_cache()
|
||||
device_synchronize()
|
||||
else:
|
||||
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ from ..utils import (
|
||||
)
|
||||
from ..utils.constants import DIFFUSERS_REQUEST_TIMEOUT
|
||||
from ..utils.hub_utils import _get_model_file
|
||||
from ..utils.torch_utils import device_synchronize, empty_device_cache
|
||||
from ..utils.torch_utils import empty_device_cache
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
@@ -1690,10 +1690,7 @@ def create_diffusers_clip_model_from_ldm(
|
||||
|
||||
if is_accelerate_available():
|
||||
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
|
||||
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
|
||||
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
|
||||
empty_device_cache()
|
||||
device_synchronize()
|
||||
else:
|
||||
model.load_state_dict(diffusers_format_checkpoint, strict=False)
|
||||
|
||||
@@ -2153,10 +2150,7 @@ def create_diffusers_t5_model_from_checkpoint(
|
||||
|
||||
if is_accelerate_available():
|
||||
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
|
||||
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
|
||||
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
|
||||
empty_device_cache()
|
||||
device_synchronize()
|
||||
else:
|
||||
model.load_state_dict(diffusers_format_checkpoint)
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ from ..models.embeddings import (
|
||||
)
|
||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
|
||||
from ..utils import is_accelerate_available, is_torch_version, logging
|
||||
from ..utils.torch_utils import device_synchronize, empty_device_cache
|
||||
from ..utils.torch_utils import empty_device_cache
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
@@ -82,7 +82,6 @@ class FluxTransformer2DLoadersMixin:
|
||||
device_map = {"": self.device}
|
||||
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
|
||||
empty_device_cache()
|
||||
device_synchronize()
|
||||
|
||||
return image_projection
|
||||
|
||||
@@ -158,7 +157,6 @@ class FluxTransformer2DLoadersMixin:
|
||||
key_id += 1
|
||||
|
||||
empty_device_cache()
|
||||
device_synchronize()
|
||||
|
||||
return attn_procs
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
|
||||
from ..models.embeddings import IPAdapterTimeImageProjection
|
||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
|
||||
from ..utils import is_accelerate_available, is_torch_version, logging
|
||||
from ..utils.torch_utils import device_synchronize, empty_device_cache
|
||||
from ..utils.torch_utils import empty_device_cache
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -82,7 +82,6 @@ class SD3Transformer2DLoadersMixin:
|
||||
)
|
||||
|
||||
empty_device_cache()
|
||||
device_synchronize()
|
||||
|
||||
return attn_procs
|
||||
|
||||
@@ -152,7 +151,6 @@ class SD3Transformer2DLoadersMixin:
|
||||
device_map = {"": self.device}
|
||||
load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype)
|
||||
empty_device_cache()
|
||||
device_synchronize()
|
||||
|
||||
return image_proj
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ from ..utils import (
|
||||
is_torch_version,
|
||||
logging,
|
||||
)
|
||||
from ..utils.torch_utils import device_synchronize, empty_device_cache
|
||||
from ..utils.torch_utils import empty_device_cache
|
||||
from .lora_base import _func_optionally_disable_offloading
|
||||
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
|
||||
from .utils import AttnProcsLayers
|
||||
@@ -755,7 +755,6 @@ class UNet2DConditionLoadersMixin:
|
||||
device_map = {"": self.device}
|
||||
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
|
||||
empty_device_cache()
|
||||
device_synchronize()
|
||||
|
||||
return image_projection
|
||||
|
||||
@@ -854,7 +853,6 @@ class UNet2DConditionLoadersMixin:
|
||||
key_id += 2
|
||||
|
||||
empty_device_cache()
|
||||
device_synchronize()
|
||||
|
||||
return attn_procs
|
||||
|
||||
|
||||
@@ -62,7 +62,7 @@ from ..utils.hub_utils import (
|
||||
load_or_create_model_card,
|
||||
populate_model_card,
|
||||
)
|
||||
from ..utils.torch_utils import device_synchronize, empty_device_cache
|
||||
from ..utils.torch_utils import empty_device_cache
|
||||
from .model_loading_utils import (
|
||||
_caching_allocator_warmup,
|
||||
_determine_device_map,
|
||||
@@ -1540,10 +1540,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict)
|
||||
error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)
|
||||
|
||||
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
|
||||
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
|
||||
empty_device_cache()
|
||||
device_synchronize()
|
||||
|
||||
if offload_index is not None and len(offload_index) > 0:
|
||||
save_offload_index(offload_index, offload_folder)
|
||||
@@ -1880,4 +1877,9 @@ class LegacyModelMixin(ModelMixin):
|
||||
# resolve remapping
|
||||
remapped_class = _fetch_remapped_cls_from_config(config, cls)
|
||||
|
||||
return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs_copy)
|
||||
if remapped_class is cls:
|
||||
return super(LegacyModelMixin, remapped_class).from_pretrained(
|
||||
pretrained_model_name_or_path, **kwargs_copy
|
||||
)
|
||||
else:
|
||||
return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs_copy)
|
||||
|
||||
@@ -19,7 +19,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import (
|
||||
CLIPImageProcessor,
|
||||
CLIPTextModel,
|
||||
@@ -38,7 +37,13 @@ from ...loaders import (
|
||||
StableDiffusionXLLoraLoaderMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
)
|
||||
from ...models import AutoencoderKL, ControlNetModel, ControlNetUnionModel, ImageProjection, UNet2DConditionModel
|
||||
from ...models import (
|
||||
AutoencoderKL,
|
||||
ControlNetUnionModel,
|
||||
ImageProjection,
|
||||
MultiControlNetUnionModel,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from ...models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
XFormersAttnProcessor,
|
||||
@@ -262,7 +267,9 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
tokenizer: CLIPTokenizer,
|
||||
tokenizer_2: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
controlnet: ControlNetUnionModel,
|
||||
controlnet: Union[
|
||||
ControlNetUnionModel, List[ControlNetUnionModel], Tuple[ControlNetUnionModel], MultiControlNetUnionModel
|
||||
],
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
requires_aesthetics_score: bool = False,
|
||||
force_zeros_for_empty_prompt: bool = True,
|
||||
@@ -272,8 +279,8 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if not isinstance(controlnet, ControlNetUnionModel):
|
||||
raise ValueError("Expected `controlnet` to be of type `ControlNetUnionModel`.")
|
||||
if isinstance(controlnet, (list, tuple)):
|
||||
controlnet = MultiControlNetUnionModel(controlnet)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
@@ -649,6 +656,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
controlnet_conditioning_scale=1.0,
|
||||
control_guidance_start=0.0,
|
||||
control_guidance_end=1.0,
|
||||
control_mode=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if strength < 0 or strength > 1:
|
||||
@@ -722,28 +730,44 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
||||
)
|
||||
|
||||
# `prompt` needs more sophisticated handling when there are multiple
|
||||
# conditionings.
|
||||
if isinstance(self.controlnet, MultiControlNetUnionModel):
|
||||
if isinstance(prompt, list):
|
||||
logger.warning(
|
||||
f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
|
||||
" prompts. The conditionings will be fixed across the prompts."
|
||||
)
|
||||
|
||||
# Check `image`
|
||||
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
|
||||
self.controlnet, torch._dynamo.eval_frame.OptimizedModule
|
||||
)
|
||||
if (
|
||||
isinstance(self.controlnet, ControlNetModel)
|
||||
or is_compiled
|
||||
and isinstance(self.controlnet._orig_mod, ControlNetModel)
|
||||
):
|
||||
self.check_image(image, prompt, prompt_embeds)
|
||||
elif (
|
||||
isinstance(self.controlnet, ControlNetUnionModel)
|
||||
or is_compiled
|
||||
and isinstance(self.controlnet._orig_mod, ControlNetUnionModel)
|
||||
):
|
||||
self.check_image(image, prompt, prompt_embeds)
|
||||
else:
|
||||
assert False
|
||||
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
||||
|
||||
if isinstance(controlnet, ControlNetUnionModel):
|
||||
for image_ in image:
|
||||
self.check_image(image_, prompt, prompt_embeds)
|
||||
elif isinstance(controlnet, MultiControlNetUnionModel):
|
||||
if not isinstance(image, list):
|
||||
raise TypeError("For multiple controlnets: `image` must be type `list`")
|
||||
elif not all(isinstance(i, list) for i in image):
|
||||
raise ValueError("For multiple controlnets: elements of `image` must be list of conditionings.")
|
||||
elif len(image) != len(self.controlnet.nets):
|
||||
raise ValueError(
|
||||
f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
|
||||
)
|
||||
|
||||
for images_ in image:
|
||||
for image_ in images_:
|
||||
self.check_image(image_, prompt, prompt_embeds)
|
||||
|
||||
if not isinstance(control_guidance_start, (tuple, list)):
|
||||
control_guidance_start = [control_guidance_start]
|
||||
|
||||
if isinstance(controlnet, MultiControlNetUnionModel):
|
||||
if len(control_guidance_start) != len(self.controlnet.nets):
|
||||
raise ValueError(
|
||||
f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
|
||||
)
|
||||
|
||||
if not isinstance(control_guidance_end, (tuple, list)):
|
||||
control_guidance_end = [control_guidance_end]
|
||||
|
||||
@@ -762,6 +786,15 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
if end > 1.0:
|
||||
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
|
||||
|
||||
# Check `control_mode`
|
||||
if isinstance(controlnet, ControlNetUnionModel):
|
||||
if max(control_mode) >= controlnet.config.num_control_type:
|
||||
raise ValueError(f"control_mode: must be lower than {controlnet.config.num_control_type}.")
|
||||
elif isinstance(controlnet, MultiControlNetUnionModel):
|
||||
for _control_mode, _controlnet in zip(control_mode, self.controlnet.nets):
|
||||
if max(_control_mode) >= _controlnet.config.num_control_type:
|
||||
raise ValueError(f"control_mode: must be lower than {_controlnet.config.num_control_type}.")
|
||||
|
||||
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
|
||||
raise ValueError(
|
||||
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
|
||||
@@ -1049,7 +1082,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
prompt: Union[str, List[str]] = None,
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
image: PipelineImageInput = None,
|
||||
control_image: PipelineImageInput = None,
|
||||
control_image: Union[PipelineImageInput, List[PipelineImageInput]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
strength: float = 0.8,
|
||||
@@ -1074,7 +1107,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
guess_mode: bool = False,
|
||||
control_guidance_start: Union[float, List[float]] = 0.0,
|
||||
control_guidance_end: Union[float, List[float]] = 1.0,
|
||||
control_mode: Optional[Union[int, List[int]]] = None,
|
||||
control_mode: Optional[Union[int, List[int], List[List[int]]]] = None,
|
||||
original_size: Tuple[int, int] = None,
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
target_size: Tuple[int, int] = None,
|
||||
@@ -1104,13 +1137,13 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
`List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
|
||||
The initial image will be used as the starting point for the image generation process. Can also accept
|
||||
image latents as `image`, if passing latents directly, it will not be encoded again.
|
||||
control_image (`PipelineImageInput`):
|
||||
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
|
||||
the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also
|
||||
be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
|
||||
and/or width are passed, `image` is resized according to them. If multiple ControlNets are specified in
|
||||
init, images must be passed as a list such that each element of the list can be correctly batched for
|
||||
input to a single controlnet.
|
||||
control_image (`PipelineImageInput` or `List[PipelineImageInput]`, *optional*):
|
||||
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
|
||||
specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
|
||||
as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
|
||||
width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
|
||||
images must be passed as a list such that each element of the list can be correctly batched for input
|
||||
to a single ControlNet.
|
||||
height (`int`, *optional*, defaults to the size of control_image):
|
||||
The height in pixels of the generated image. Anything below 512 pixels won't work well for
|
||||
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
||||
@@ -1184,16 +1217,21 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
|
||||
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
|
||||
to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
|
||||
corresponding scale as a list.
|
||||
The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
|
||||
to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
|
||||
the corresponding scale as a list.
|
||||
guess_mode (`bool`, *optional*, defaults to `False`):
|
||||
In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
|
||||
you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
|
||||
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
|
||||
The percentage of total steps at which the controlnet starts applying.
|
||||
The percentage of total steps at which the ControlNet starts applying.
|
||||
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
|
||||
The percentage of total steps at which the controlnet stops applying.
|
||||
The percentage of total steps at which the ControlNet stops applying.
|
||||
control_mode (`int` or `List[int]` or `List[List[int]], *optional*):
|
||||
The control condition types for the ControlNet. See the ControlNet's model card forinformation on the
|
||||
available control modes. If multiple ControlNets are specified in `init`, control_mode should be a list
|
||||
where each ControlNet should have its corresponding control mode list. Should reflect the order of
|
||||
conditions in control_image
|
||||
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
||||
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
|
||||
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
|
||||
@@ -1273,12 +1311,6 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
|
||||
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
||||
|
||||
# align format for control guidance
|
||||
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
||||
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
||||
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
||||
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
||||
|
||||
if not isinstance(control_image, list):
|
||||
control_image = [control_image]
|
||||
else:
|
||||
@@ -1287,37 +1319,56 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
if not isinstance(control_mode, list):
|
||||
control_mode = [control_mode]
|
||||
|
||||
if len(control_image) != len(control_mode):
|
||||
raise ValueError("Expected len(control_image) == len(control_type)")
|
||||
if isinstance(controlnet, MultiControlNetUnionModel):
|
||||
control_image = [[item] for item in control_image]
|
||||
control_mode = [[item] for item in control_mode]
|
||||
|
||||
num_control_type = controlnet.config.num_control_type
|
||||
|
||||
# 1. Check inputs
|
||||
control_type = [0 for _ in range(num_control_type)]
|
||||
for _image, control_idx in zip(control_image, control_mode):
|
||||
control_type[control_idx] = 1
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
prompt_2,
|
||||
_image,
|
||||
strength,
|
||||
num_inference_steps,
|
||||
callback_steps,
|
||||
negative_prompt,
|
||||
negative_prompt_2,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
ip_adapter_image,
|
||||
ip_adapter_image_embeds,
|
||||
controlnet_conditioning_scale,
|
||||
control_guidance_start,
|
||||
control_guidance_end,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
# align format for control guidance
|
||||
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
||||
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
||||
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
||||
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
||||
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
|
||||
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
|
||||
control_guidance_start, control_guidance_end = (
|
||||
mult * [control_guidance_start],
|
||||
mult * [control_guidance_end],
|
||||
)
|
||||
|
||||
control_type = torch.Tensor(control_type)
|
||||
if isinstance(controlnet_conditioning_scale, float):
|
||||
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
|
||||
controlnet_conditioning_scale = [controlnet_conditioning_scale] * mult
|
||||
|
||||
# 1. Check inputs
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
prompt_2,
|
||||
control_image,
|
||||
strength,
|
||||
num_inference_steps,
|
||||
callback_steps,
|
||||
negative_prompt,
|
||||
negative_prompt_2,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
ip_adapter_image,
|
||||
ip_adapter_image_embeds,
|
||||
controlnet_conditioning_scale,
|
||||
control_guidance_start,
|
||||
control_guidance_end,
|
||||
control_mode,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
if isinstance(controlnet, ControlNetUnionModel):
|
||||
control_type = torch.zeros(controlnet.config.num_control_type).scatter_(0, torch.tensor(control_mode), 1)
|
||||
elif isinstance(controlnet, MultiControlNetUnionModel):
|
||||
control_type = [
|
||||
torch.zeros(controlnet_.config.num_control_type).scatter_(0, torch.tensor(control_mode_), 1)
|
||||
for control_mode_, controlnet_ in zip(control_mode, self.controlnet.nets)
|
||||
]
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._clip_skip = clip_skip
|
||||
@@ -1334,7 +1385,11 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
global_pool_conditions = controlnet.config.global_pool_conditions
|
||||
global_pool_conditions = (
|
||||
controlnet.config.global_pool_conditions
|
||||
if isinstance(controlnet, ControlNetUnionModel)
|
||||
else controlnet.nets[0].config.global_pool_conditions
|
||||
)
|
||||
guess_mode = guess_mode or global_pool_conditions
|
||||
|
||||
# 3.1. Encode input prompt
|
||||
@@ -1372,22 +1427,55 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
self.do_classifier_free_guidance,
|
||||
)
|
||||
|
||||
# 4. Prepare image and controlnet_conditioning_image
|
||||
# 4.1 Prepare image
|
||||
image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
|
||||
|
||||
for idx, _ in enumerate(control_image):
|
||||
control_image[idx] = self.prepare_control_image(
|
||||
image=control_image[idx],
|
||||
width=width,
|
||||
height=height,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=controlnet.dtype,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
guess_mode=guess_mode,
|
||||
)
|
||||
height, width = control_image[idx].shape[-2:]
|
||||
# 4.2 Prepare control images
|
||||
if isinstance(controlnet, ControlNetUnionModel):
|
||||
control_images = []
|
||||
|
||||
for image_ in control_image:
|
||||
image_ = self.prepare_control_image(
|
||||
image=image_,
|
||||
width=width,
|
||||
height=height,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=controlnet.dtype,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
guess_mode=guess_mode,
|
||||
)
|
||||
|
||||
control_images.append(image_)
|
||||
|
||||
control_image = control_images
|
||||
height, width = control_image[0].shape[-2:]
|
||||
|
||||
elif isinstance(controlnet, MultiControlNetUnionModel):
|
||||
control_images = []
|
||||
|
||||
for control_image_ in control_image:
|
||||
images = []
|
||||
|
||||
for image_ in control_image_:
|
||||
image_ = self.prepare_control_image(
|
||||
image=image_,
|
||||
width=width,
|
||||
height=height,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=controlnet.dtype,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
guess_mode=guess_mode,
|
||||
)
|
||||
|
||||
images.append(image_)
|
||||
control_images.append(images)
|
||||
|
||||
control_image = control_images
|
||||
height, width = control_image[0][0].shape[-2:]
|
||||
|
||||
# 5. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
@@ -1414,10 +1502,11 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
# 7.1 Create tensor stating which controlnets to keep
|
||||
controlnet_keep = []
|
||||
for i in range(len(timesteps)):
|
||||
controlnet_keep.append(
|
||||
1.0
|
||||
- float(i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end)
|
||||
)
|
||||
keeps = [
|
||||
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
||||
for s, e in zip(control_guidance_start, control_guidance_end)
|
||||
]
|
||||
controlnet_keep.append(keeps)
|
||||
|
||||
# 7.2 Prepare added time ids & embeddings
|
||||
original_size = original_size or (height, width)
|
||||
@@ -1460,12 +1549,25 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
prompt_embeds = prompt_embeds.to(device)
|
||||
add_text_embeds = add_text_embeds.to(device)
|
||||
add_time_ids = add_time_ids.to(device)
|
||||
control_type = (
|
||||
control_type.reshape(1, -1)
|
||||
.to(device, dtype=prompt_embeds.dtype)
|
||||
.repeat(batch_size * num_images_per_prompt * 2, 1)
|
||||
|
||||
control_type_repeat_factor = (
|
||||
batch_size * num_images_per_prompt * (2 if self.do_classifier_free_guidance else 1)
|
||||
)
|
||||
|
||||
if isinstance(controlnet, ControlNetUnionModel):
|
||||
control_type = (
|
||||
control_type.reshape(1, -1)
|
||||
.to(self._execution_device, dtype=prompt_embeds.dtype)
|
||||
.repeat(control_type_repeat_factor, 1)
|
||||
)
|
||||
elif isinstance(controlnet, MultiControlNetUnionModel):
|
||||
control_type = [
|
||||
_control_type.reshape(1, -1)
|
||||
.to(self._execution_device, dtype=prompt_embeds.dtype)
|
||||
.repeat(control_type_repeat_factor, 1)
|
||||
for _control_type in control_type
|
||||
]
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
|
||||
@@ -383,7 +383,8 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
|
||||
# set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
latents = latents * np.float64(self.scheduler.init_noise_sigma)
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
|
||||
@@ -483,7 +483,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * np.float64(self.scheduler.init_noise_sigma)
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
|
||||
@@ -481,7 +481,7 @@ class OnnxStableDiffusionUpscalePipeline(DiffusionPipeline):
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# Scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * np.float64(self.scheduler.init_noise_sigma)
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
|
||||
# 5. Add noise to image
|
||||
noise_level = np.array([noise_level]).astype(np.int64)
|
||||
|
||||
@@ -155,7 +155,7 @@ class FluxPipelineFastTests(
|
||||
|
||||
# Outputs should be different here
|
||||
# For some reasons, they don't show large differences
|
||||
assert max_diff > 1e-6
|
||||
self.assertGreater(max_diff, 1e-6, "Outputs should be different for different prompts.")
|
||||
|
||||
def test_fused_qkv_projections(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
@@ -187,14 +187,17 @@ class FluxPipelineFastTests(
|
||||
image = pipe(**inputs).images
|
||||
image_slice_disabled = image[0, -3:, -3:, -1]
|
||||
|
||||
assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
|
||||
"Fusion of QKV projections shouldn't affect the outputs."
|
||||
self.assertTrue(
|
||||
np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3),
|
||||
("Fusion of QKV projections shouldn't affect the outputs."),
|
||||
)
|
||||
assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
|
||||
"Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
|
||||
self.assertTrue(
|
||||
np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3),
|
||||
("Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."),
|
||||
)
|
||||
assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
|
||||
"Original outputs should match when fused QKV projections are disabled."
|
||||
self.assertTrue(
|
||||
np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2),
|
||||
("Original outputs should match when fused QKV projections are disabled."),
|
||||
)
|
||||
|
||||
def test_flux_image_output_shape(self):
|
||||
@@ -209,7 +212,11 @@ class FluxPipelineFastTests(
|
||||
inputs.update({"height": height, "width": width})
|
||||
image = pipe(**inputs).images[0]
|
||||
output_height, output_width, _ = image.shape
|
||||
assert (output_height, output_width) == (expected_height, expected_width)
|
||||
self.assertEqual(
|
||||
(output_height, output_width),
|
||||
(expected_height, expected_width),
|
||||
f"Output shape {image.shape} does not match expected shape {(expected_height, expected_width)}",
|
||||
)
|
||||
|
||||
def test_flux_true_cfg(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
|
||||
@@ -220,7 +227,9 @@ class FluxPipelineFastTests(
|
||||
inputs["negative_prompt"] = "bad quality"
|
||||
inputs["true_cfg_scale"] = 2.0
|
||||
true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0]
|
||||
assert not np.allclose(no_true_cfg_out, true_cfg_out)
|
||||
self.assertFalse(
|
||||
np.allclose(no_true_cfg_out, true_cfg_out), "Outputs should be different when true_cfg_scale is set."
|
||||
)
|
||||
|
||||
|
||||
@nightly
|
||||
@@ -269,45 +278,17 @@ class FluxPipelineSlowTests(unittest.TestCase):
|
||||
|
||||
image = pipe(**inputs).images[0]
|
||||
image_slice = image[0, :10, :10]
|
||||
# fmt: off
|
||||
expected_slice = np.array(
|
||||
[
|
||||
0.3242,
|
||||
0.3203,
|
||||
0.3164,
|
||||
0.3164,
|
||||
0.3125,
|
||||
0.3125,
|
||||
0.3281,
|
||||
0.3242,
|
||||
0.3203,
|
||||
0.3301,
|
||||
0.3262,
|
||||
0.3242,
|
||||
0.3281,
|
||||
0.3242,
|
||||
0.3203,
|
||||
0.3262,
|
||||
0.3262,
|
||||
0.3164,
|
||||
0.3262,
|
||||
0.3281,
|
||||
0.3184,
|
||||
0.3281,
|
||||
0.3281,
|
||||
0.3203,
|
||||
0.3281,
|
||||
0.3281,
|
||||
0.3164,
|
||||
0.3320,
|
||||
0.3320,
|
||||
0.3203,
|
||||
],
|
||||
[0.3242, 0.3203, 0.3164, 0.3164, 0.3125, 0.3125, 0.3281, 0.3242, 0.3203, 0.3301, 0.3262, 0.3242, 0.3281, 0.3242, 0.3203, 0.3262, 0.3262, 0.3164, 0.3262, 0.3281, 0.3184, 0.3281, 0.3281, 0.3203, 0.3281, 0.3281, 0.3164, 0.3320, 0.3320, 0.3203],
|
||||
dtype=np.float32,
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
|
||||
|
||||
assert max_diff < 1e-4
|
||||
self.assertLess(
|
||||
max_diff, 1e-4, f"Image slice is different from expected slice: {image_slice} != {expected_slice}"
|
||||
)
|
||||
|
||||
|
||||
@slow
|
||||
@@ -377,42 +358,14 @@ class FluxIPAdapterPipelineSlowTests(unittest.TestCase):
|
||||
image = pipe(**inputs).images[0]
|
||||
image_slice = image[0, :10, :10]
|
||||
|
||||
# fmt: off
|
||||
expected_slice = np.array(
|
||||
[
|
||||
0.1855,
|
||||
0.1680,
|
||||
0.1406,
|
||||
0.1953,
|
||||
0.1699,
|
||||
0.1465,
|
||||
0.2012,
|
||||
0.1738,
|
||||
0.1484,
|
||||
0.2051,
|
||||
0.1797,
|
||||
0.1523,
|
||||
0.2012,
|
||||
0.1719,
|
||||
0.1445,
|
||||
0.2070,
|
||||
0.1777,
|
||||
0.1465,
|
||||
0.2090,
|
||||
0.1836,
|
||||
0.1484,
|
||||
0.2129,
|
||||
0.1875,
|
||||
0.1523,
|
||||
0.2090,
|
||||
0.1816,
|
||||
0.1484,
|
||||
0.2110,
|
||||
0.1836,
|
||||
0.1543,
|
||||
],
|
||||
[0.1855, 0.1680, 0.1406, 0.1953, 0.1699, 0.1465, 0.2012, 0.1738, 0.1484, 0.2051, 0.1797, 0.1523, 0.2012, 0.1719, 0.1445, 0.2070, 0.1777, 0.1465, 0.2090, 0.1836, 0.1484, 0.2129, 0.1875, 0.1523, 0.2090, 0.1816, 0.1484, 0.2110, 0.1836, 0.1543],
|
||||
dtype=np.float32,
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
|
||||
|
||||
assert max_diff < 1e-4, f"{image_slice} != {expected_slice}"
|
||||
self.assertLess(
|
||||
max_diff, 1e-4, f"Image slice is different from expected slice: {image_slice} != {expected_slice}"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user