mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fixing implementation of ControlNet-XS (#6772)
* CheckIn - created DownSubBlocks * Added extra channels, implemented subblock fwd * Fixed connection sizes * checkin * Removed iter, next in forward * Models for SD21 & SDXL run through * Added back pipelines, cleared up connections * Cleaned up connection creation * added debug logs * updated logs * logs: added input loading * Update umer_debug_logger.py * log: Loading hint * Update umer_debug_logger.py * added logs * Changed debug logging * debug: added more logs * Fixed num_norm_groups * Debug: Logging all of SDXL input * Update umer_debug_logger.py * debug: updated logs * checkim * Readded tests * Removed debug logs * Fixed Slow Tests * Added value ckecks | Updated model_cpu_offload_seq * accelerate-offloading works ; fast tests work * Made unet & addon explicit in controlnet * Updated slow tests * Added dtype/device to ControlNetXS * Filled in test model paths * Added image_encoder/feature_extractor to XL pipe * Fixed fast tests * Added comments and docstrings * Fixed copies * Added docs ; Updates slow tests * Moved changes to UNetMidBlock2DCrossAttn * tiny cleanups * Removed stray prints * Removed ip adapters + freeU - Removed ip adapters + freeU as they don't make sense for ControlNet-XS - Fixed imports of UNet components * Fixed test_save_load_float16 * Make style, quality, fix-copies * Changed loading/saving API for ControlNetXS - Changed loading/saving API for ControlNetXS - other small fixes * Removed ControlNet-XS from research examples * Make style, quality, fix-copies * Small fixes - deleted ControlNetXSModel.init_original - added time_embedding_mix to StableDiffusionControlNetXSPipeline .from_pretrained / StableDiffusionXLControlNetXSPipeline.from_pretrained - fixed copy hints * checkin May 11 '23 * CheckIn Mar 12 '24 * Fixed tests for SD * Added tests for UNetControlNetXSModel * Fixed SDXL tests * cleanup * Delete Pipfile * CheckIn Mar 20 Started replacing sub blocks by `ControlNetXSCrossAttnDownBlock2D` and `ControlNetXSCrossAttnUplock2D` * check-in Mar 23 * checkin 24 Mar * Created init for UNetCnxs and CnxsAddon * CheckIn * Made from_modules, from_unet and no_control work * make style,quality,fix-copies & small changes * Fixed freezing * Added gradient ckpt'ing; fixed tests * Fix slow tests(+compile) ; clear naming confusion * Don't create UNet in init ; removed class_emb * Incorporated review feedback - Deleted get_base_pipeline / get_controlnet_addon for pipes - Pipes inherit from StableDiffusionXLPipeline - Made module dicts for cnxs-addon's down/mid/up classes - Added support for qkv fusion and freeU * Make style, quality, fix-copies * Implemented review feedback * Removed compatibility check for vae/ctrl embedding * make style, quality, fix-copies * Delete Pipfile * Integrated review feedback - Importing ControlNetConditioningEmbedding now - get_down/mid/up_block_addon now outside class - renamed `do_control` to `apply_control` * Reduced size of test tensors For this, added `norm_num_groups` as parameter everywhere * Renamed cnxs-`Addon` to cnxs-`Adapter` - `ControlNetXSAddon` -> `ControlNetXSAdapter` - `ControlNetXSAddonDownBlockComponents` -> `DownBlockControlNetXSAdapter`, and similarly for mid/up - `get_mid_block_addon` -> `get_mid_block_adapter`, and similarly for mid/up * Fixed save_pretrained/from_pretrained bug * Removed redundant code --------- Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
@@ -282,6 +282,10 @@
|
||||
title: ControlNet
|
||||
- local: api/pipelines/controlnet_sdxl
|
||||
title: ControlNet with Stable Diffusion XL
|
||||
- local: api/pipelines/controlnetxs
|
||||
title: ControlNet-XS
|
||||
- local: api/pipelines/controlnetxs_sdxl
|
||||
title: ControlNet-XS with Stable Diffusion XL
|
||||
- local: api/pipelines/dance_diffusion
|
||||
title: Dance Diffusion
|
||||
- local: api/pipelines/ddim
|
||||
|
||||
@@ -1,3 +1,15 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# ControlNet-XS
|
||||
|
||||
ControlNet-XS was introduced in [ControlNet-XS](https://vislearn.github.io/ControlNet-XS/) by Denis Zavadski and Carsten Rother. It is based on the observation that the control model in the [original ControlNet](https://huggingface.co/papers/2302.05543) can be made much smaller and still produce good results.
|
||||
@@ -12,5 +24,16 @@ Here's the overview from the [project page](https://vislearn.github.io/ControlNe
|
||||
|
||||
This model was contributed by [UmerHA](https://twitter.com/UmerHAdil). ❤️
|
||||
|
||||
<Tip>
|
||||
|
||||
> 🧠 Make sure to check out the Schedulers [guide](https://huggingface.co/docs/diffusers/main/en/using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
## StableDiffusionControlNetXSPipeline
|
||||
[[autodoc]] StableDiffusionControlNetXSPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## StableDiffusionPipelineOutput
|
||||
[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
|
||||
@@ -1,3 +1,15 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# ControlNet-XS with Stable Diffusion XL
|
||||
|
||||
ControlNet-XS was introduced in [ControlNet-XS](https://vislearn.github.io/ControlNet-XS/) by Denis Zavadski and Carsten Rother. It is based on the observation that the control model in the [original ControlNet](https://huggingface.co/papers/2302.05543) can be made much smaller and still produce good results.
|
||||
@@ -12,4 +24,22 @@ Here's the overview from the [project page](https://vislearn.github.io/ControlNe
|
||||
|
||||
This model was contributed by [UmerHA](https://twitter.com/UmerHAdil). ❤️
|
||||
|
||||
> 🧠 Make sure to check out the Schedulers [guide](https://huggingface.co/docs/diffusers/main/en/using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
<Tip warning={true}>
|
||||
|
||||
🧪 Many of the SDXL ControlNet checkpoints are experimental, and there is a lot of room for improvement. Feel free to open an [Issue](https://github.com/huggingface/diffusers/issues/new/choose) and leave us feedback on how we can improve!
|
||||
|
||||
</Tip>
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
## StableDiffusionXLControlNetXSPipeline
|
||||
[[autodoc]] StableDiffusionXLControlNetXSPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## StableDiffusionPipelineOutput
|
||||
[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,58 +0,0 @@
|
||||
# !pip install opencv-python transformers accelerate
|
||||
import argparse
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from controlnetxs import ControlNetXSModel
|
||||
from PIL import Image
|
||||
from pipeline_controlnet_xs import StableDiffusionControlNetXSPipeline
|
||||
|
||||
from diffusers.utils import load_image
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--prompt", type=str, default="aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
|
||||
)
|
||||
parser.add_argument("--negative_prompt", type=str, default="low quality, bad quality, sketches")
|
||||
parser.add_argument("--controlnet_conditioning_scale", type=float, default=0.7)
|
||||
parser.add_argument(
|
||||
"--image_path",
|
||||
type=str,
|
||||
default="https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png",
|
||||
)
|
||||
parser.add_argument("--num_inference_steps", type=int, default=50)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
prompt = args.prompt
|
||||
negative_prompt = args.negative_prompt
|
||||
# download an image
|
||||
image = load_image(args.image_path)
|
||||
|
||||
# initialize the models and pipeline
|
||||
controlnet_conditioning_scale = args.controlnet_conditioning_scale
|
||||
controlnet = ControlNetXSModel.from_pretrained("UmerHA/ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16)
|
||||
pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1", controlnet=controlnet, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
# get canny image
|
||||
image = np.array(image)
|
||||
image = cv2.Canny(image, 100, 200)
|
||||
image = image[:, :, None]
|
||||
image = np.concatenate([image, image, image], axis=2)
|
||||
canny_image = Image.fromarray(image)
|
||||
|
||||
num_inference_steps = args.num_inference_steps
|
||||
|
||||
# generate image
|
||||
image = pipe(
|
||||
prompt,
|
||||
controlnet_conditioning_scale=controlnet_conditioning_scale,
|
||||
image=canny_image,
|
||||
num_inference_steps=num_inference_steps,
|
||||
).images[0]
|
||||
image.save("cnxs_sd.canny.png")
|
||||
@@ -1,57 +0,0 @@
|
||||
# !pip install opencv-python transformers accelerate
|
||||
import argparse
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from controlnetxs import ControlNetXSModel
|
||||
from PIL import Image
|
||||
from pipeline_controlnet_xs import StableDiffusionControlNetXSPipeline
|
||||
|
||||
from diffusers.utils import load_image
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--prompt", type=str, default="aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
|
||||
)
|
||||
parser.add_argument("--negative_prompt", type=str, default="low quality, bad quality, sketches")
|
||||
parser.add_argument("--controlnet_conditioning_scale", type=float, default=0.7)
|
||||
parser.add_argument(
|
||||
"--image_path",
|
||||
type=str,
|
||||
default="https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png",
|
||||
)
|
||||
parser.add_argument("--num_inference_steps", type=int, default=50)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
prompt = args.prompt
|
||||
negative_prompt = args.negative_prompt
|
||||
# download an image
|
||||
image = load_image(args.image_path)
|
||||
# initialize the models and pipeline
|
||||
controlnet_conditioning_scale = args.controlnet_conditioning_scale
|
||||
controlnet = ControlNetXSModel.from_pretrained("UmerHA/ConrolNetXS-SDXL-canny", torch_dtype=torch.float16)
|
||||
pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
# get canny image
|
||||
image = np.array(image)
|
||||
image = cv2.Canny(image, 100, 200)
|
||||
image = image[:, :, None]
|
||||
image = np.concatenate([image, image, image], axis=2)
|
||||
canny_image = Image.fromarray(image)
|
||||
|
||||
num_inference_steps = args.num_inference_steps
|
||||
|
||||
# generate image
|
||||
image = pipe(
|
||||
prompt,
|
||||
controlnet_conditioning_scale=controlnet_conditioning_scale,
|
||||
image=canny_image,
|
||||
num_inference_steps=num_inference_steps,
|
||||
).images[0]
|
||||
image.save("cnxs_sdxl.canny.png")
|
||||
@@ -80,6 +80,7 @@ else:
|
||||
"AutoencoderTiny",
|
||||
"ConsistencyDecoderVAE",
|
||||
"ControlNetModel",
|
||||
"ControlNetXSAdapter",
|
||||
"I2VGenXLUNet",
|
||||
"Kandinsky3UNet",
|
||||
"ModelMixin",
|
||||
@@ -94,6 +95,7 @@ else:
|
||||
"UNet2DConditionModel",
|
||||
"UNet2DModel",
|
||||
"UNet3DConditionModel",
|
||||
"UNetControlNetXSModel",
|
||||
"UNetMotionModel",
|
||||
"UNetSpatioTemporalConditionModel",
|
||||
"UVit2DModel",
|
||||
@@ -270,6 +272,7 @@ else:
|
||||
"StableDiffusionControlNetImg2ImgPipeline",
|
||||
"StableDiffusionControlNetInpaintPipeline",
|
||||
"StableDiffusionControlNetPipeline",
|
||||
"StableDiffusionControlNetXSPipeline",
|
||||
"StableDiffusionDepth2ImgPipeline",
|
||||
"StableDiffusionDiffEditPipeline",
|
||||
"StableDiffusionGLIGENPipeline",
|
||||
@@ -293,6 +296,7 @@ else:
|
||||
"StableDiffusionXLControlNetImg2ImgPipeline",
|
||||
"StableDiffusionXLControlNetInpaintPipeline",
|
||||
"StableDiffusionXLControlNetPipeline",
|
||||
"StableDiffusionXLControlNetXSPipeline",
|
||||
"StableDiffusionXLImg2ImgPipeline",
|
||||
"StableDiffusionXLInpaintPipeline",
|
||||
"StableDiffusionXLInstructPix2PixPipeline",
|
||||
@@ -474,6 +478,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderTiny,
|
||||
ConsistencyDecoderVAE,
|
||||
ControlNetModel,
|
||||
ControlNetXSAdapter,
|
||||
I2VGenXLUNet,
|
||||
Kandinsky3UNet,
|
||||
ModelMixin,
|
||||
@@ -487,6 +492,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
UNet2DConditionModel,
|
||||
UNet2DModel,
|
||||
UNet3DConditionModel,
|
||||
UNetControlNetXSModel,
|
||||
UNetMotionModel,
|
||||
UNetSpatioTemporalConditionModel,
|
||||
UVit2DModel,
|
||||
@@ -642,6 +648,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionControlNetImg2ImgPipeline,
|
||||
StableDiffusionControlNetInpaintPipeline,
|
||||
StableDiffusionControlNetPipeline,
|
||||
StableDiffusionControlNetXSPipeline,
|
||||
StableDiffusionDepth2ImgPipeline,
|
||||
StableDiffusionDiffEditPipeline,
|
||||
StableDiffusionGLIGENPipeline,
|
||||
@@ -665,6 +672,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionXLControlNetImg2ImgPipeline,
|
||||
StableDiffusionXLControlNetInpaintPipeline,
|
||||
StableDiffusionXLControlNetPipeline,
|
||||
StableDiffusionXLControlNetXSPipeline,
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
StableDiffusionXLInpaintPipeline,
|
||||
StableDiffusionXLInstructPix2PixPipeline,
|
||||
|
||||
@@ -32,6 +32,7 @@ if is_torch_available():
|
||||
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
|
||||
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
|
||||
_import_structure["controlnet"] = ["ControlNetModel"]
|
||||
_import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
|
||||
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
|
||||
_import_structure["embeddings"] = ["ImageProjection"]
|
||||
_import_structure["modeling_utils"] = ["ModelMixin"]
|
||||
@@ -68,6 +69,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
ConsistencyDecoderVAE,
|
||||
)
|
||||
from .controlnet import ControlNetModel
|
||||
from .controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel
|
||||
from .embeddings import ImageProjection
|
||||
from .modeling_utils import ModelMixin
|
||||
from .transformers import (
|
||||
|
||||
1892
src/diffusers/models/controlnet_xs.py
Normal file
1892
src/diffusers/models/controlnet_xs.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -746,6 +746,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
self,
|
||||
in_channels: int,
|
||||
temb_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
||||
@@ -753,6 +754,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_groups_out: Optional[int] = None,
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads: int = 1,
|
||||
output_scale_factor: float = 1.0,
|
||||
@@ -764,6 +766,10 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
out_channels = out_channels or in_channels
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.has_cross_attention = True
|
||||
self.num_attention_heads = num_attention_heads
|
||||
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||
@@ -772,14 +778,17 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
if isinstance(transformer_layers_per_block, int):
|
||||
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
|
||||
|
||||
resnet_groups_out = resnet_groups_out or resnet_groups
|
||||
|
||||
# there is always at least one resnet
|
||||
resnets = [
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
groups_out=resnet_groups_out,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
@@ -794,11 +803,11 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
attentions.append(
|
||||
Transformer2DModel(
|
||||
num_attention_heads,
|
||||
in_channels // num_attention_heads,
|
||||
in_channels=in_channels,
|
||||
out_channels // num_attention_heads,
|
||||
in_channels=out_channels,
|
||||
num_layers=transformer_layers_per_block[i],
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
norm_num_groups=resnet_groups_out,
|
||||
use_linear_projection=use_linear_projection,
|
||||
upcast_attention=upcast_attention,
|
||||
attention_type=attention_type,
|
||||
@@ -808,8 +817,8 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
attentions.append(
|
||||
DualTransformer2DModel(
|
||||
num_attention_heads,
|
||||
in_channels // num_attention_heads,
|
||||
in_channels=in_channels,
|
||||
out_channels // num_attention_heads,
|
||||
in_channels=out_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
@@ -817,11 +826,11 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
)
|
||||
resnets.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
groups=resnet_groups_out,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
|
||||
@@ -134,6 +134,12 @@ else:
|
||||
"StableDiffusionXLControlNetPipeline",
|
||||
]
|
||||
)
|
||||
_import_structure["controlnet_xs"].extend(
|
||||
[
|
||||
"StableDiffusionControlNetXSPipeline",
|
||||
"StableDiffusionXLControlNetXSPipeline",
|
||||
]
|
||||
)
|
||||
_import_structure["deepfloyd_if"] = [
|
||||
"IFImg2ImgPipeline",
|
||||
"IFImg2ImgSuperResolutionPipeline",
|
||||
@@ -378,6 +384,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionXLControlNetInpaintPipeline,
|
||||
StableDiffusionXLControlNetPipeline,
|
||||
)
|
||||
from .controlnet_xs import (
|
||||
StableDiffusionControlNetXSPipeline,
|
||||
StableDiffusionXLControlNetXSPipeline,
|
||||
)
|
||||
from .deepfloyd_if import (
|
||||
IFImg2ImgPipeline,
|
||||
IFImg2ImgSuperResolutionPipeline,
|
||||
|
||||
68
src/diffusers/pipelines/controlnet_xs/__init__.py
Normal file
68
src/diffusers/pipelines/controlnet_xs/__init__.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_flax_available,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_controlnet_xs"] = ["StableDiffusionControlNetXSPipeline"]
|
||||
_import_structure["pipeline_controlnet_xs_sd_xl"] = ["StableDiffusionXLControlNetXSPipeline"]
|
||||
try:
|
||||
if not (is_transformers_available() and is_flax_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_flax_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects))
|
||||
else:
|
||||
pass # _import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"]
|
||||
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_controlnet_xs import StableDiffusionControlNetXSPipeline
|
||||
from .pipeline_controlnet_xs_sd_xl import StableDiffusionXLControlNetXSPipeline
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_flax_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_flax_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
pass # from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline
|
||||
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
@@ -19,30 +19,75 @@ import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from controlnetxs import ControlNetXSModel
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.utils import (
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
deprecate,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
|
||||
from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> # !pip install opencv-python transformers accelerate
|
||||
>>> from diffusers import StableDiffusionControlNetXSPipeline, ControlNetXSAdapter
|
||||
>>> from diffusers.utils import load_image
|
||||
>>> import numpy as np
|
||||
>>> import torch
|
||||
|
||||
>>> import cv2
|
||||
>>> from PIL import Image
|
||||
|
||||
>>> prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
|
||||
>>> negative_prompt = "low quality, bad quality, sketches"
|
||||
|
||||
>>> # download an image
|
||||
>>> image = load_image(
|
||||
... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png"
|
||||
... )
|
||||
|
||||
>>> # initialize the models and pipeline
|
||||
>>> controlnet_conditioning_scale = 0.5
|
||||
|
||||
>>> controlnet = ControlNetXSAdapter.from_pretrained(
|
||||
... "UmerHA/Testing-ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
|
||||
... "stabilityai/stable-diffusion-2-1-base", controlnet=controlnet, torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe.enable_model_cpu_offload()
|
||||
|
||||
>>> # get canny image
|
||||
>>> image = np.array(image)
|
||||
>>> image = cv2.Canny(image, 100, 200)
|
||||
>>> image = image[:, :, None]
|
||||
>>> image = np.concatenate([image, image, image], axis=2)
|
||||
>>> canny_image = Image.fromarray(image)
|
||||
>>> # generate image
|
||||
>>> image = pipe(
|
||||
... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image
|
||||
... ).images[0]
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class StableDiffusionControlNetXSPipeline(
|
||||
DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
|
||||
):
|
||||
@@ -56,7 +101,7 @@ class StableDiffusionControlNetXSPipeline(
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
- [`loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
@@ -66,9 +111,9 @@ class StableDiffusionControlNetXSPipeline(
|
||||
tokenizer ([`~transformers.CLIPTokenizer`]):
|
||||
A `CLIPTokenizer` to tokenize text.
|
||||
unet ([`UNet2DConditionModel`]):
|
||||
A `UNet2DConditionModel` to denoise the encoded image latents.
|
||||
controlnet ([`ControlNetXSModel`]):
|
||||
Provides additional conditioning to the `unet` during the denoising process.
|
||||
A [`UNet2DConditionModel`] used to create a UNetControlNetXSModel to denoise the encoded image latents.
|
||||
controlnet ([`ControlNetXSAdapter`]):
|
||||
A [`ControlNetXSAdapter`] to be used in combination with `unet` to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
@@ -80,17 +125,18 @@ class StableDiffusionControlNetXSPipeline(
|
||||
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae>controlnet"
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
controlnet: ControlNetXSModel,
|
||||
unet: Union[UNet2DConditionModel, UNetControlNetXSModel],
|
||||
controlnet: ControlNetXSAdapter,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
@@ -98,6 +144,9 @@ class StableDiffusionControlNetXSPipeline(
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if isinstance(unet, UNet2DConditionModel):
|
||||
unet = UNetControlNetXSModel.from_unet(unet, controlnet)
|
||||
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
@@ -114,14 +163,6 @@ class StableDiffusionControlNetXSPipeline(
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
vae_compatible, cnxs_condition_downsample_factor, vae_downsample_factor = controlnet._check_if_vae_compatible(
|
||||
vae
|
||||
)
|
||||
if not vae_compatible:
|
||||
raise ValueError(
|
||||
f"The downsampling factors of the VAE ({vae_downsample_factor}) and the conditioning part of ControlNetXS model {cnxs_condition_downsample_factor} need to be equal. Consider building the ControlNetXS model with different `conditioning_block_sizes`."
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
@@ -403,20 +444,19 @@ class StableDiffusionControlNetXSPipeline(
|
||||
self,
|
||||
prompt,
|
||||
image,
|
||||
callback_steps,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
controlnet_conditioning_scale=1.0,
|
||||
control_guidance_start=0.0,
|
||||
control_guidance_end=1.0,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
@@ -445,25 +485,16 @@ class StableDiffusionControlNetXSPipeline(
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
# Check `image`
|
||||
# Check `image` and `controlnet_conditioning_scale`
|
||||
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
|
||||
self.controlnet, torch._dynamo.eval_frame.OptimizedModule
|
||||
self.unet, torch._dynamo.eval_frame.OptimizedModule
|
||||
)
|
||||
if (
|
||||
isinstance(self.controlnet, ControlNetXSModel)
|
||||
isinstance(self.unet, UNetControlNetXSModel)
|
||||
or is_compiled
|
||||
and isinstance(self.controlnet._orig_mod, ControlNetXSModel)
|
||||
and isinstance(self.unet._orig_mod, UNetControlNetXSModel)
|
||||
):
|
||||
self.check_image(image, prompt, prompt_embeds)
|
||||
else:
|
||||
assert False
|
||||
|
||||
# Check `controlnet_conditioning_scale`
|
||||
if (
|
||||
isinstance(self.controlnet, ControlNetXSModel)
|
||||
or is_compiled
|
||||
and isinstance(self.controlnet._orig_mod, ControlNetXSModel)
|
||||
):
|
||||
if not isinstance(controlnet_conditioning_scale, float):
|
||||
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
|
||||
else:
|
||||
@@ -563,7 +594,33 @@ class StableDiffusionControlNetXSPipeline(
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_scale
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.clip_skip
|
||||
def clip_skip(self):
|
||||
return self._clip_skip
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.do_classifier_free_guidance
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.cross_attention_kwargs
|
||||
def cross_attention_kwargs(self):
|
||||
return self._cross_attention_kwargs
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.num_timesteps
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
@@ -581,13 +638,13 @@ class StableDiffusionControlNetXSPipeline(
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
||||
control_guidance_start: float = 0.0,
|
||||
control_guidance_end: float = 1.0,
|
||||
clip_skip: Optional[int] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
@@ -595,7 +652,7 @@ class StableDiffusionControlNetXSPipeline(
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
||||
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,
|
||||
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
|
||||
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
|
||||
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
|
||||
specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
|
||||
@@ -639,12 +696,6 @@ class StableDiffusionControlNetXSPipeline(
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that calls every `callback_steps` steps during inference. The function is called with the
|
||||
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
||||
every step.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
||||
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
@@ -659,7 +710,15 @@ class StableDiffusionControlNetXSPipeline(
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeine class.
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
@@ -669,21 +728,27 @@ class StableDiffusionControlNetXSPipeline(
|
||||
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
||||
"not-safe-for-work" (nsfw) content.
|
||||
"""
|
||||
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
||||
|
||||
unet = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
image,
|
||||
callback_steps,
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
controlnet_conditioning_scale,
|
||||
control_guidance_start,
|
||||
control_guidance_end,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._clip_skip = clip_skip
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
@@ -713,6 +778,7 @@ class StableDiffusionControlNetXSPipeline(
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
clip_skip=clip_skip,
|
||||
)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
@@ -720,27 +786,24 @@ class StableDiffusionControlNetXSPipeline(
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
# 4. Prepare image
|
||||
if isinstance(controlnet, ControlNetXSModel):
|
||||
image = self.prepare_image(
|
||||
image=image,
|
||||
width=width,
|
||||
height=height,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=controlnet.dtype,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
)
|
||||
height, width = image.shape[-2:]
|
||||
else:
|
||||
assert False
|
||||
image = self.prepare_image(
|
||||
image=image,
|
||||
width=width,
|
||||
height=height,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=unet.dtype,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
)
|
||||
height, width = image.shape[-2:]
|
||||
|
||||
# 5. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 6. Prepare latent variables
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
num_channels_latents = self.unet.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
@@ -757,42 +820,33 @@ class StableDiffusionControlNetXSPipeline(
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
is_unet_compiled = is_compiled_module(self.unet)
|
||||
is_controlnet_compiled = is_compiled_module(self.controlnet)
|
||||
self._num_timesteps = len(timesteps)
|
||||
is_controlnet_compiled = is_compiled_module(self.unet)
|
||||
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# Relevant thread:
|
||||
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
|
||||
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
|
||||
if is_controlnet_compiled and is_torch_higher_equal_2_1:
|
||||
torch._inductor.cudagraph_mark_step_begin()
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
dont_control = (
|
||||
i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end
|
||||
apply_control = (
|
||||
i / len(timesteps) >= control_guidance_start and (i + 1) / len(timesteps) <= control_guidance_end
|
||||
)
|
||||
if dont_control:
|
||||
noise_pred = self.unet(
|
||||
sample=latent_model_input,
|
||||
timestep=t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
return_dict=True,
|
||||
).sample
|
||||
else:
|
||||
noise_pred = self.controlnet(
|
||||
base_model=self.unet,
|
||||
sample=latent_model_input,
|
||||
timestep=t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
controlnet_cond=image,
|
||||
conditioning_scale=controlnet_conditioning_scale,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
return_dict=True,
|
||||
).sample
|
||||
noise_pred = self.unet(
|
||||
sample=latent_model_input,
|
||||
timestep=t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
controlnet_cond=image,
|
||||
conditioning_scale=controlnet_conditioning_scale,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
return_dict=True,
|
||||
apply_control=apply_control,
|
||||
).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
@@ -801,12 +855,18 @@ class StableDiffusionControlNetXSPipeline(
|
||||
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
# call the callback, if provided
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
# If we do sequential model offloading, let's offload unet and controlnet
|
||||
# manually for max memory savings
|
||||
@@ -19,41 +19,94 @@ import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
from transformers import (
|
||||
CLIPImageProcessor,
|
||||
CLIPTextModel,
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPTokenizer,
|
||||
)
|
||||
|
||||
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from diffusers.loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from diffusers.models import AutoencoderKL, ControlNetXSModel, UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import (
|
||||
from diffusers.utils.import_utils import is_invisible_watermark_available
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel
|
||||
from ...models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
LoRAAttnProcessor2_0,
|
||||
LoRAXFormersAttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.utils import (
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
deprecate,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from diffusers.utils.import_utils import is_invisible_watermark_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
|
||||
from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
||||
|
||||
|
||||
if is_invisible_watermark_available():
|
||||
from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
|
||||
from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> # !pip install opencv-python transformers accelerate
|
||||
>>> from diffusers import StableDiffusionXLControlNetXSPipeline, ControlNetXSAdapter, AutoencoderKL
|
||||
>>> from diffusers.utils import load_image
|
||||
>>> import numpy as np
|
||||
>>> import torch
|
||||
|
||||
>>> import cv2
|
||||
>>> from PIL import Image
|
||||
|
||||
>>> prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
|
||||
>>> negative_prompt = "low quality, bad quality, sketches"
|
||||
|
||||
>>> # download an image
|
||||
>>> image = load_image(
|
||||
... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png"
|
||||
... )
|
||||
|
||||
>>> # initialize the models and pipeline
|
||||
>>> controlnet_conditioning_scale = 0.5
|
||||
>>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
|
||||
>>> controlnet = ControlNetXSAdapter.from_pretrained(
|
||||
... "UmerHA/Testing-ConrolNetXS-SDXL-canny", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained(
|
||||
... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe.enable_model_cpu_offload()
|
||||
|
||||
>>> # get canny image
|
||||
>>> image = np.array(image)
|
||||
>>> image = cv2.Canny(image, 100, 200)
|
||||
>>> image = image[:, :, None]
|
||||
>>> image = np.concatenate([image, image, image], axis=2)
|
||||
>>> canny_image = Image.fromarray(image)
|
||||
|
||||
>>> # generate image
|
||||
>>> image = pipe(
|
||||
... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image
|
||||
... ).images[0]
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class StableDiffusionXLControlNetXSPipeline(
|
||||
DiffusionPipeline,
|
||||
StableDiffusionMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
StableDiffusionXLLoraLoaderMixin,
|
||||
FromSingleFileMixin,
|
||||
@@ -66,9 +119,8 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
- [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
@@ -83,9 +135,9 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
tokenizer_2 ([`~transformers.CLIPTokenizer`]):
|
||||
A `CLIPTokenizer` to tokenize text.
|
||||
unet ([`UNet2DConditionModel`]):
|
||||
A `UNet2DConditionModel` to denoise the encoded image latents.
|
||||
controlnet ([`ControlNetXSModel`]:
|
||||
Provides additional conditioning to the `unet` during the denoising process.
|
||||
A [`UNet2DConditionModel`] used to create a UNetControlNetXSModel to denoise the encoded image latents.
|
||||
controlnet ([`ControlNetXSAdapter`]):
|
||||
A [`ControlNetXSAdapter`] to be used in combination with `unet` to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
@@ -98,9 +150,15 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
watermarker is used.
|
||||
"""
|
||||
|
||||
# leave controlnet out on purpose because it iterates with unet
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae->controlnet"
|
||||
_optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
|
||||
_optional_components = [
|
||||
"tokenizer",
|
||||
"tokenizer_2",
|
||||
"text_encoder",
|
||||
"text_encoder_2",
|
||||
"feature_extractor",
|
||||
]
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -109,21 +167,17 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
text_encoder_2: CLIPTextModelWithProjection,
|
||||
tokenizer: CLIPTokenizer,
|
||||
tokenizer_2: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
controlnet: ControlNetXSModel,
|
||||
unet: Union[UNet2DConditionModel, UNetControlNetXSModel],
|
||||
controlnet: ControlNetXSAdapter,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
force_zeros_for_empty_prompt: bool = True,
|
||||
add_watermarker: Optional[bool] = None,
|
||||
feature_extractor: CLIPImageProcessor = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
vae_compatible, cnxs_condition_downsample_factor, vae_downsample_factor = controlnet._check_if_vae_compatible(
|
||||
vae
|
||||
)
|
||||
if not vae_compatible:
|
||||
raise ValueError(
|
||||
f"The downsampling factors of the VAE ({vae_downsample_factor}) and the conditioning part of ControlNetXS model {cnxs_condition_downsample_factor} need to be equal. Consider building the ControlNetXS model with different `conditioning_block_sizes`."
|
||||
)
|
||||
if isinstance(unet, UNet2DConditionModel):
|
||||
unet = UNetControlNetXSModel.from_unet(unet, controlnet)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
@@ -134,6 +188,7 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
unet=unet,
|
||||
controlnet=controlnet,
|
||||
scheduler=scheduler,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
|
||||
@@ -417,15 +472,21 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
controlnet_conditioning_scale=1.0,
|
||||
control_guidance_start=0.0,
|
||||
control_guidance_end=1.0,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
@@ -474,25 +535,16 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
||||
)
|
||||
|
||||
# Check `image`
|
||||
# Check `image` and ``controlnet_conditioning_scale``
|
||||
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
|
||||
self.controlnet, torch._dynamo.eval_frame.OptimizedModule
|
||||
self.unet, torch._dynamo.eval_frame.OptimizedModule
|
||||
)
|
||||
if (
|
||||
isinstance(self.controlnet, ControlNetXSModel)
|
||||
isinstance(self.unet, UNetControlNetXSModel)
|
||||
or is_compiled
|
||||
and isinstance(self.controlnet._orig_mod, ControlNetXSModel)
|
||||
and isinstance(self.unet._orig_mod, UNetControlNetXSModel)
|
||||
):
|
||||
self.check_image(image, prompt, prompt_embeds)
|
||||
else:
|
||||
assert False
|
||||
|
||||
# Check `controlnet_conditioning_scale`
|
||||
if (
|
||||
isinstance(self.controlnet, ControlNetXSModel)
|
||||
or is_compiled
|
||||
and isinstance(self.controlnet._orig_mod, ControlNetXSModel)
|
||||
):
|
||||
if not isinstance(controlnet_conditioning_scale, float):
|
||||
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
|
||||
else:
|
||||
@@ -593,7 +645,6 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids
|
||||
def _get_add_time_ids(
|
||||
self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
|
||||
):
|
||||
@@ -602,7 +653,7 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
passed_add_embed_dim = (
|
||||
self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
|
||||
)
|
||||
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
||||
expected_add_embed_dim = self.unet.base_add_embedding.linear_1.in_features
|
||||
|
||||
if expected_add_embed_dim != passed_add_embed_dim:
|
||||
raise ValueError(
|
||||
@@ -632,7 +683,33 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
self.vae.decoder.conv_in.to(dtype)
|
||||
self.vae.decoder.mid_block.to(dtype)
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_scale
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.clip_skip
|
||||
def clip_skip(self):
|
||||
return self._clip_skip
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.do_classifier_free_guidance
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.cross_attention_kwargs
|
||||
def cross_attention_kwargs(self):
|
||||
return self._cross_attention_kwargs
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.num_timesteps
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
@@ -654,8 +731,6 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
||||
control_guidance_start: float = 0.0,
|
||||
@@ -667,6 +742,9 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
negative_target_size: Optional[Tuple[int, int]] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
@@ -677,7 +755,7 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
prompt_2 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
||||
used in both text-encoders.
|
||||
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,
|
||||
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
|
||||
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
|
||||
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
|
||||
specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
|
||||
@@ -735,12 +813,6 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that calls every `callback_steps` steps during inference. The function is called with the
|
||||
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
||||
every step.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
||||
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
@@ -783,6 +855,15 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeine class.
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -791,7 +872,24 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] is
|
||||
returned, otherwise a `tuple` is returned containing the output images.
|
||||
"""
|
||||
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
||||
|
||||
callback = kwargs.pop("callback", None)
|
||||
callback_steps = kwargs.pop("callback_steps", None)
|
||||
|
||||
if callback is not None:
|
||||
deprecate(
|
||||
"callback",
|
||||
"1.0.0",
|
||||
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
||||
)
|
||||
if callback_steps is not None:
|
||||
deprecate(
|
||||
"callback_steps",
|
||||
"1.0.0",
|
||||
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
||||
)
|
||||
|
||||
unet = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
@@ -808,8 +906,14 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
controlnet_conditioning_scale,
|
||||
control_guidance_start,
|
||||
control_guidance_end,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._clip_skip = clip_skip
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
@@ -850,7 +954,7 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
)
|
||||
|
||||
# 4. Prepare image
|
||||
if isinstance(controlnet, ControlNetXSModel):
|
||||
if isinstance(unet, UNetControlNetXSModel):
|
||||
image = self.prepare_image(
|
||||
image=image,
|
||||
width=width,
|
||||
@@ -858,7 +962,7 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=controlnet.dtype,
|
||||
dtype=unet.dtype,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
)
|
||||
height, width = image.shape[-2:]
|
||||
@@ -870,7 +974,7 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 6. Prepare latent variables
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
num_channels_latents = self.unet.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
@@ -928,14 +1032,14 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
is_unet_compiled = is_compiled_module(self.unet)
|
||||
is_controlnet_compiled = is_compiled_module(self.controlnet)
|
||||
self._num_timesteps = len(timesteps)
|
||||
is_controlnet_compiled = is_compiled_module(self.unet)
|
||||
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# Relevant thread:
|
||||
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
|
||||
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
|
||||
if is_controlnet_compiled and is_torch_higher_equal_2_1:
|
||||
torch._inductor.cudagraph_mark_step_begin()
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
@@ -944,30 +1048,20 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
||||
|
||||
# predict the noise residual
|
||||
dont_control = (
|
||||
i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end
|
||||
apply_control = (
|
||||
i / len(timesteps) >= control_guidance_start and (i + 1) / len(timesteps) <= control_guidance_end
|
||||
)
|
||||
if dont_control:
|
||||
noise_pred = self.unet(
|
||||
sample=latent_model_input,
|
||||
timestep=t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=True,
|
||||
).sample
|
||||
else:
|
||||
noise_pred = self.controlnet(
|
||||
base_model=self.unet,
|
||||
sample=latent_model_input,
|
||||
timestep=t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
controlnet_cond=image,
|
||||
conditioning_scale=controlnet_conditioning_scale,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=True,
|
||||
).sample
|
||||
noise_pred = self.unet(
|
||||
sample=latent_model_input,
|
||||
timestep=t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
controlnet_cond=image,
|
||||
conditioning_scale=controlnet_conditioning_scale,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=True,
|
||||
apply_control=apply_control,
|
||||
).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
@@ -977,6 +1071,16 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
@@ -984,6 +1088,11 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
# manually for max memory savings
|
||||
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
|
||||
self.upcast_vae()
|
||||
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
|
||||
if not output_type == "latent":
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
||||
@@ -2238,6 +2238,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||
self,
|
||||
in_channels: int,
|
||||
temb_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
||||
@@ -2245,6 +2246,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_groups_out: Optional[int] = None,
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads: int = 1,
|
||||
output_scale_factor: float = 1.0,
|
||||
@@ -2256,6 +2258,10 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
out_channels = out_channels or in_channels
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.has_cross_attention = True
|
||||
self.num_attention_heads = num_attention_heads
|
||||
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||
@@ -2264,14 +2270,17 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||
if isinstance(transformer_layers_per_block, int):
|
||||
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
|
||||
|
||||
resnet_groups_out = resnet_groups_out or resnet_groups
|
||||
|
||||
# there is always at least one resnet
|
||||
resnets = [
|
||||
ResnetBlockFlat(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
groups_out=resnet_groups_out,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
@@ -2286,11 +2295,11 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||
attentions.append(
|
||||
Transformer2DModel(
|
||||
num_attention_heads,
|
||||
in_channels // num_attention_heads,
|
||||
in_channels=in_channels,
|
||||
out_channels // num_attention_heads,
|
||||
in_channels=out_channels,
|
||||
num_layers=transformer_layers_per_block[i],
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
norm_num_groups=resnet_groups_out,
|
||||
use_linear_projection=use_linear_projection,
|
||||
upcast_attention=upcast_attention,
|
||||
attention_type=attention_type,
|
||||
@@ -2300,8 +2309,8 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||
attentions.append(
|
||||
DualTransformer2DModel(
|
||||
num_attention_heads,
|
||||
in_channels // num_attention_heads,
|
||||
in_channels=in_channels,
|
||||
out_channels // num_attention_heads,
|
||||
in_channels=out_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
@@ -2309,11 +2318,11 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||
)
|
||||
resnets.append(
|
||||
ResnetBlockFlat(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
groups=resnet_groups_out,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
|
||||
@@ -92,6 +92,21 @@ class ControlNetModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class ControlNetXSAdapter(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class I2VGenXLUNet(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -287,6 +302,21 @@ class UNet3DConditionModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class UNetControlNetXSModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class UNetMotionModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -902,6 +902,21 @@ class StableDiffusionControlNetPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionControlNetXSPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionDepth2ImgPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
@@ -1247,6 +1262,21 @@ class StableDiffusionXLControlNetPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionXLControlNetXSPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionXLImg2ImgPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
352
tests/models/unets/test_models_unet_controlnetxs.py
Normal file
352
tests/models/unets/test_models_unet_controlnetxs.py
Normal file
@@ -0,0 +1,352 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from diffusers import ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNetControlNetXSModel
|
||||
main_input_name = "sample"
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_channels = 4
|
||||
sizes = (16, 16)
|
||||
conditioning_image_size = (3, 32, 32) # size of additional, unprocessed image for control-conditioning
|
||||
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor([10]).to(torch_device)
|
||||
encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)
|
||||
controlnet_cond = floats_tensor((batch_size, *conditioning_image_size)).to(torch_device)
|
||||
conditioning_scale = 1
|
||||
|
||||
return {
|
||||
"sample": noise,
|
||||
"timestep": time_step,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"controlnet_cond": controlnet_cond,
|
||||
"conditioning_scale": conditioning_scale,
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 16, 16)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"sample_size": 16,
|
||||
"down_block_types": ("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
"up_block_types": ("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
"block_out_channels": (4, 8),
|
||||
"cross_attention_dim": 8,
|
||||
"transformer_layers_per_block": 1,
|
||||
"num_attention_heads": 2,
|
||||
"norm_num_groups": 4,
|
||||
"upcast_attention": False,
|
||||
"ctrl_block_out_channels": [2, 4],
|
||||
"ctrl_num_attention_heads": 4,
|
||||
"ctrl_max_norm_num_groups": 2,
|
||||
"ctrl_conditioning_embedding_out_channels": (2, 2),
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def get_dummy_unet(self):
|
||||
"""For some tests we also need the underlying UNet. For these, we'll build the UNetControlNetXSModel from the UNet and ControlNetXS-Adapter"""
|
||||
return UNet2DConditionModel(
|
||||
block_out_channels=(4, 8),
|
||||
layers_per_block=2,
|
||||
sample_size=16,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=8,
|
||||
norm_num_groups=4,
|
||||
use_linear_projection=True,
|
||||
)
|
||||
|
||||
def get_dummy_controlnet_from_unet(self, unet, **kwargs):
|
||||
"""For some tests we also need the underlying ControlNetXS-Adapter. For these, we'll build the UNetControlNetXSModel from the UNet and ControlNetXS-Adapter"""
|
||||
# size_ratio and conditioning_embedding_out_channels chosen to keep model small
|
||||
return ControlNetXSAdapter.from_unet(unet, size_ratio=1, conditioning_embedding_out_channels=(2, 2), **kwargs)
|
||||
|
||||
def test_from_unet(self):
|
||||
unet = self.get_dummy_unet()
|
||||
controlnet = self.get_dummy_controlnet_from_unet(unet)
|
||||
|
||||
model = UNetControlNetXSModel.from_unet(unet, controlnet)
|
||||
model_state_dict = model.state_dict()
|
||||
|
||||
def assert_equal_weights(module, weight_dict_prefix):
|
||||
for param_name, param_value in module.named_parameters():
|
||||
assert torch.equal(model_state_dict[weight_dict_prefix + "." + param_name], param_value)
|
||||
|
||||
# # check unet
|
||||
# everything expect down,mid,up blocks
|
||||
modules_from_unet = [
|
||||
"time_embedding",
|
||||
"conv_in",
|
||||
"conv_norm_out",
|
||||
"conv_out",
|
||||
]
|
||||
for p in modules_from_unet:
|
||||
assert_equal_weights(getattr(unet, p), "base_" + p)
|
||||
optional_modules_from_unet = [
|
||||
"class_embedding",
|
||||
"add_time_proj",
|
||||
"add_embedding",
|
||||
]
|
||||
for p in optional_modules_from_unet:
|
||||
if hasattr(unet, p) and getattr(unet, p) is not None:
|
||||
assert_equal_weights(getattr(unet, p), "base_" + p)
|
||||
# down blocks
|
||||
assert len(unet.down_blocks) == len(model.down_blocks)
|
||||
for i, d in enumerate(unet.down_blocks):
|
||||
assert_equal_weights(d.resnets, f"down_blocks.{i}.base_resnets")
|
||||
if hasattr(d, "attentions"):
|
||||
assert_equal_weights(d.attentions, f"down_blocks.{i}.base_attentions")
|
||||
if hasattr(d, "downsamplers") and getattr(d, "downsamplers") is not None:
|
||||
assert_equal_weights(d.downsamplers[0], f"down_blocks.{i}.base_downsamplers")
|
||||
# mid block
|
||||
assert_equal_weights(unet.mid_block, "mid_block.base_midblock")
|
||||
# up blocks
|
||||
assert len(unet.up_blocks) == len(model.up_blocks)
|
||||
for i, u in enumerate(unet.up_blocks):
|
||||
assert_equal_weights(u.resnets, f"up_blocks.{i}.resnets")
|
||||
if hasattr(u, "attentions"):
|
||||
assert_equal_weights(u.attentions, f"up_blocks.{i}.attentions")
|
||||
if hasattr(u, "upsamplers") and getattr(u, "upsamplers") is not None:
|
||||
assert_equal_weights(u.upsamplers[0], f"up_blocks.{i}.upsamplers")
|
||||
|
||||
# # check controlnet
|
||||
# everything expect down,mid,up blocks
|
||||
modules_from_controlnet = {
|
||||
"controlnet_cond_embedding": "controlnet_cond_embedding",
|
||||
"conv_in": "ctrl_conv_in",
|
||||
"control_to_base_for_conv_in": "control_to_base_for_conv_in",
|
||||
}
|
||||
optional_modules_from_controlnet = {"time_embedding": "ctrl_time_embedding"}
|
||||
for name_in_controlnet, name_in_unetcnxs in modules_from_controlnet.items():
|
||||
assert_equal_weights(getattr(controlnet, name_in_controlnet), name_in_unetcnxs)
|
||||
|
||||
for name_in_controlnet, name_in_unetcnxs in optional_modules_from_controlnet.items():
|
||||
if hasattr(controlnet, name_in_controlnet) and getattr(controlnet, name_in_controlnet) is not None:
|
||||
assert_equal_weights(getattr(controlnet, name_in_controlnet), name_in_unetcnxs)
|
||||
# down blocks
|
||||
assert len(controlnet.down_blocks) == len(model.down_blocks)
|
||||
for i, d in enumerate(controlnet.down_blocks):
|
||||
assert_equal_weights(d.resnets, f"down_blocks.{i}.ctrl_resnets")
|
||||
assert_equal_weights(d.base_to_ctrl, f"down_blocks.{i}.base_to_ctrl")
|
||||
assert_equal_weights(d.ctrl_to_base, f"down_blocks.{i}.ctrl_to_base")
|
||||
if d.attentions is not None:
|
||||
assert_equal_weights(d.attentions, f"down_blocks.{i}.ctrl_attentions")
|
||||
if d.downsamplers is not None:
|
||||
assert_equal_weights(d.downsamplers, f"down_blocks.{i}.ctrl_downsamplers")
|
||||
# mid block
|
||||
assert_equal_weights(controlnet.mid_block.base_to_ctrl, "mid_block.base_to_ctrl")
|
||||
assert_equal_weights(controlnet.mid_block.midblock, "mid_block.ctrl_midblock")
|
||||
assert_equal_weights(controlnet.mid_block.ctrl_to_base, "mid_block.ctrl_to_base")
|
||||
# up blocks
|
||||
assert len(controlnet.up_connections) == len(model.up_blocks)
|
||||
for i, u in enumerate(controlnet.up_connections):
|
||||
assert_equal_weights(u.ctrl_to_base, f"up_blocks.{i}.ctrl_to_base")
|
||||
|
||||
def test_freeze_unet(self):
|
||||
def assert_frozen(module):
|
||||
for p in module.parameters():
|
||||
assert not p.requires_grad
|
||||
|
||||
def assert_unfrozen(module):
|
||||
for p in module.parameters():
|
||||
assert p.requires_grad
|
||||
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
model = UNetControlNetXSModel(**init_dict)
|
||||
model.freeze_unet_params()
|
||||
|
||||
# # check unet
|
||||
# everything expect down,mid,up blocks
|
||||
modules_from_unet = [
|
||||
model.base_time_embedding,
|
||||
model.base_conv_in,
|
||||
model.base_conv_norm_out,
|
||||
model.base_conv_out,
|
||||
]
|
||||
for m in modules_from_unet:
|
||||
assert_frozen(m)
|
||||
|
||||
optional_modules_from_unet = [
|
||||
model.base_add_time_proj,
|
||||
model.base_add_embedding,
|
||||
]
|
||||
for m in optional_modules_from_unet:
|
||||
if m is not None:
|
||||
assert_frozen(m)
|
||||
|
||||
# down blocks
|
||||
for i, d in enumerate(model.down_blocks):
|
||||
assert_frozen(d.base_resnets)
|
||||
if isinstance(d.base_attentions, nn.ModuleList): # attentions can be list of Nones
|
||||
assert_frozen(d.base_attentions)
|
||||
if d.base_downsamplers is not None:
|
||||
assert_frozen(d.base_downsamplers)
|
||||
|
||||
# mid block
|
||||
assert_frozen(model.mid_block.base_midblock)
|
||||
|
||||
# up blocks
|
||||
for i, u in enumerate(model.up_blocks):
|
||||
assert_frozen(u.resnets)
|
||||
if isinstance(u.attentions, nn.ModuleList): # attentions can be list of Nones
|
||||
assert_frozen(u.attentions)
|
||||
if u.upsamplers is not None:
|
||||
assert_frozen(u.upsamplers)
|
||||
|
||||
# # check controlnet
|
||||
# everything expect down,mid,up blocks
|
||||
modules_from_controlnet = [
|
||||
model.controlnet_cond_embedding,
|
||||
model.ctrl_conv_in,
|
||||
model.control_to_base_for_conv_in,
|
||||
]
|
||||
optional_modules_from_controlnet = [model.ctrl_time_embedding]
|
||||
|
||||
for m in modules_from_controlnet:
|
||||
assert_unfrozen(m)
|
||||
for m in optional_modules_from_controlnet:
|
||||
if m is not None:
|
||||
assert_unfrozen(m)
|
||||
|
||||
# down blocks
|
||||
for d in model.down_blocks:
|
||||
assert_unfrozen(d.ctrl_resnets)
|
||||
assert_unfrozen(d.base_to_ctrl)
|
||||
assert_unfrozen(d.ctrl_to_base)
|
||||
if isinstance(d.ctrl_attentions, nn.ModuleList): # attentions can be list of Nones
|
||||
assert_unfrozen(d.ctrl_attentions)
|
||||
if d.ctrl_downsamplers is not None:
|
||||
assert_unfrozen(d.ctrl_downsamplers)
|
||||
# mid block
|
||||
assert_unfrozen(model.mid_block.base_to_ctrl)
|
||||
assert_unfrozen(model.mid_block.ctrl_midblock)
|
||||
assert_unfrozen(model.mid_block.ctrl_to_base)
|
||||
# up blocks
|
||||
for u in model.up_blocks:
|
||||
assert_unfrozen(u.ctrl_to_base)
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
model_class_copy = copy.copy(UNetControlNetXSModel)
|
||||
|
||||
modules_with_gc_enabled = {}
|
||||
|
||||
# now monkey patch the following function:
|
||||
# def _set_gradient_checkpointing(self, module, value=False):
|
||||
# if hasattr(module, "gradient_checkpointing"):
|
||||
# module.gradient_checkpointing = value
|
||||
|
||||
def _set_gradient_checkpointing_new(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
modules_with_gc_enabled[module.__class__.__name__] = True
|
||||
|
||||
model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new
|
||||
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
model = model_class_copy(**init_dict)
|
||||
|
||||
model.enable_gradient_checkpointing()
|
||||
|
||||
EXPECTED_SET = {
|
||||
"Transformer2DModel",
|
||||
"UNetMidBlock2DCrossAttn",
|
||||
"ControlNetXSCrossAttnDownBlock2D",
|
||||
"ControlNetXSCrossAttnMidBlock2D",
|
||||
"ControlNetXSCrossAttnUpBlock2D",
|
||||
}
|
||||
|
||||
assert set(modules_with_gc_enabled.keys()) == EXPECTED_SET
|
||||
assert all(modules_with_gc_enabled.values()), "All modules should be enabled"
|
||||
|
||||
def test_forward_no_control(self):
|
||||
unet = self.get_dummy_unet()
|
||||
controlnet = self.get_dummy_controlnet_from_unet(unet)
|
||||
|
||||
model = UNetControlNetXSModel.from_unet(unet, controlnet)
|
||||
|
||||
unet = unet.to(torch_device)
|
||||
model = model.to(torch_device)
|
||||
|
||||
input_ = self.dummy_input
|
||||
|
||||
control_specific_input = ["controlnet_cond", "conditioning_scale"]
|
||||
input_for_unet = {k: v for k, v in input_.items() if k not in control_specific_input}
|
||||
|
||||
with torch.no_grad():
|
||||
unet_output = unet(**input_for_unet).sample.cpu()
|
||||
unet_controlnet_output = model(**input_, apply_control=False).sample.cpu()
|
||||
|
||||
assert np.abs(unet_output.flatten() - unet_controlnet_output.flatten()).max() < 3e-4
|
||||
|
||||
def test_time_embedding_mixing(self):
|
||||
unet = self.get_dummy_unet()
|
||||
controlnet = self.get_dummy_controlnet_from_unet(unet)
|
||||
controlnet_mix_time = self.get_dummy_controlnet_from_unet(
|
||||
unet, time_embedding_mix=0.5, learn_time_embedding=True
|
||||
)
|
||||
|
||||
model = UNetControlNetXSModel.from_unet(unet, controlnet)
|
||||
model_mix_time = UNetControlNetXSModel.from_unet(unet, controlnet_mix_time)
|
||||
|
||||
unet = unet.to(torch_device)
|
||||
model = model.to(torch_device)
|
||||
model_mix_time = model_mix_time.to(torch_device)
|
||||
|
||||
input_ = self.dummy_input
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**input_).sample
|
||||
output_mix_time = model_mix_time(**input_).sample
|
||||
|
||||
assert output.shape == output_mix_time.shape
|
||||
|
||||
def test_forward_with_norm_groups(self):
|
||||
# UNetControlNetXSModel currently only supports StableDiffusion and StableDiffusion-XL, both of which have norm_num_groups fixed at 32. So we don't need to test different values for norm_num_groups.
|
||||
pass
|
||||
0
tests/pipelines/controlnet_xs/__init__.py
Normal file
0
tests/pipelines/controlnet_xs/__init__.py
Normal file
366
tests/pipelines/controlnet_xs/test_controlnetxs.py
Normal file
366
tests/pipelines/controlnet_xs/test_controlnetxs.py
Normal file
@@ -0,0 +1,366 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import traceback
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AsymmetricAutoencoderKL,
|
||||
AutoencoderKL,
|
||||
AutoencoderTiny,
|
||||
ConsistencyDecoderVAE,
|
||||
ControlNetXSAdapter,
|
||||
DDIMScheduler,
|
||||
LCMScheduler,
|
||||
StableDiffusionControlNetXSPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
load_image,
|
||||
load_numpy,
|
||||
require_python39_or_higher,
|
||||
require_torch_2,
|
||||
require_torch_gpu,
|
||||
run_test_in_subprocess,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...models.autoencoders.test_models_vae import (
|
||||
get_asym_autoencoder_kl_config,
|
||||
get_autoencoder_kl_config,
|
||||
get_autoencoder_tiny_config,
|
||||
get_consistency_vae_config,
|
||||
)
|
||||
from ..pipeline_params import (
|
||||
IMAGE_TO_IMAGE_IMAGE_PARAMS,
|
||||
TEXT_TO_IMAGE_BATCH_PARAMS,
|
||||
TEXT_TO_IMAGE_IMAGE_PARAMS,
|
||||
TEXT_TO_IMAGE_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import (
|
||||
PipelineKarrasSchedulerTesterMixin,
|
||||
PipelineLatentTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
SDFunctionTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
# Will be run via run_test_in_subprocess
|
||||
def _test_stable_diffusion_compile(in_queue, out_queue, timeout):
|
||||
error = None
|
||||
try:
|
||||
_ = in_queue.get(timeout=timeout)
|
||||
|
||||
controlnet = ControlNetXSAdapter.from_pretrained(
|
||||
"UmerHA/Testing-ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16
|
||||
)
|
||||
pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
controlnet=controlnet,
|
||||
safety_checker=None,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
pipe.to("cuda")
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
pipe.unet.to(memory_format=torch.channels_last)
|
||||
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
prompt = "bird"
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
|
||||
).resize((512, 512))
|
||||
|
||||
output = pipe(prompt, image, num_inference_steps=10, generator=generator, output_type="np")
|
||||
image = output.images[0]
|
||||
|
||||
assert image.shape == (512, 512, 3)
|
||||
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny_out_full.npy"
|
||||
)
|
||||
expected_image = np.resize(expected_image, (512, 512, 3))
|
||||
|
||||
assert np.abs(expected_image - image).max() < 1.0
|
||||
|
||||
except Exception:
|
||||
error = f"{traceback.format_exc()}"
|
||||
|
||||
results = {"error": error}
|
||||
out_queue.put(results, timeout=timeout)
|
||||
out_queue.join()
|
||||
|
||||
|
||||
class ControlNetXSPipelineFastTests(
|
||||
PipelineLatentTesterMixin,
|
||||
PipelineKarrasSchedulerTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
SDFunctionTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
pipeline_class = StableDiffusionControlNetXSPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
|
||||
test_attention_slicing = False
|
||||
|
||||
def get_dummy_components(self, time_cond_proj_dim=None):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(4, 8),
|
||||
layers_per_block=2,
|
||||
sample_size=16,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=8,
|
||||
norm_num_groups=4,
|
||||
time_cond_proj_dim=time_cond_proj_dim,
|
||||
use_linear_projection=True,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
controlnet = ControlNetXSAdapter.from_unet(
|
||||
unet=unet,
|
||||
size_ratio=1,
|
||||
learn_time_embedding=True,
|
||||
conditioning_embedding_out_channels=(2, 2),
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=[4, 8],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
norm_num_groups=2,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=8,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
"controlnet": controlnet,
|
||||
"scheduler": scheduler,
|
||||
"vae": vae,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"safety_checker": None,
|
||||
"feature_extractor": None,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
controlnet_embedder_scale_factor = 2
|
||||
image = randn_tensor(
|
||||
(1, 3, 8 * controlnet_embedder_scale_factor, 8 * controlnet_embedder_scale_factor),
|
||||
generator=generator,
|
||||
device=torch.device(device),
|
||||
)
|
||||
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"output_type": "numpy",
|
||||
"image": image,
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_attention_forwardGenerator_pass(self):
|
||||
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
|
||||
|
||||
def test_controlnet_lcm(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
|
||||
components = self.get_dummy_components(time_cond_proj_dim=8)
|
||||
sd_pipe = StableDiffusionControlNetXSPipeline(**components)
|
||||
sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
output = sd_pipe(**inputs)
|
||||
image = output.images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 16, 16, 3)
|
||||
expected_slice = np.array([0.745, 0.753, 0.767, 0.543, 0.523, 0.502, 0.314, 0.521, 0.478])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_to_dtype(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# pipeline creates a new UNetControlNetXSModel under the hood. So we need to check the dtype from pipe.components
|
||||
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
|
||||
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
|
||||
|
||||
pipe.to(dtype=torch.float16)
|
||||
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
|
||||
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
|
||||
|
||||
def test_multi_vae(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
block_out_channels = pipe.vae.config.block_out_channels
|
||||
norm_num_groups = pipe.vae.config.norm_num_groups
|
||||
|
||||
vae_classes = [AutoencoderKL, AsymmetricAutoencoderKL, ConsistencyDecoderVAE, AutoencoderTiny]
|
||||
configs = [
|
||||
get_autoencoder_kl_config(block_out_channels, norm_num_groups),
|
||||
get_asym_autoencoder_kl_config(block_out_channels, norm_num_groups),
|
||||
get_consistency_vae_config(block_out_channels, norm_num_groups),
|
||||
get_autoencoder_tiny_config(block_out_channels),
|
||||
]
|
||||
|
||||
out_np = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0]
|
||||
|
||||
for vae_cls, config in zip(vae_classes, configs):
|
||||
vae = vae_cls(**config)
|
||||
vae = vae.to(torch_device)
|
||||
components["vae"] = vae
|
||||
vae_pipe = self.pipeline_class(**components)
|
||||
|
||||
# pipeline creates a new UNetControlNetXSModel under the hood, which aren't on device.
|
||||
# So we need to move the new pipe to device.
|
||||
vae_pipe.to(torch_device)
|
||||
vae_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
out_vae_np = vae_pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0]
|
||||
|
||||
assert out_vae_np.shape == out_np.shape
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class ControlNetXSPipelineSlowTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_canny(self):
|
||||
controlnet = ControlNetXSAdapter.from_pretrained(
|
||||
"UmerHA/Testing-ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16
|
||||
)
|
||||
pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1-base", controlnet=controlnet, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
prompt = "bird"
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
|
||||
)
|
||||
|
||||
output = pipe(prompt, image, generator=generator, output_type="np", num_inference_steps=3)
|
||||
|
||||
image = output.images[0]
|
||||
|
||||
assert image.shape == (768, 512, 3)
|
||||
|
||||
original_image = image[-3:, -3:, -1].flatten()
|
||||
expected_image = np.array([0.1963, 0.229, 0.2659, 0.2109, 0.2332, 0.2827, 0.2534, 0.2422, 0.2808])
|
||||
assert np.allclose(original_image, expected_image, atol=1e-04)
|
||||
|
||||
def test_depth(self):
|
||||
controlnet = ControlNetXSAdapter.from_pretrained(
|
||||
"UmerHA/Testing-ConrolNetXS-SD2.1-depth", torch_dtype=torch.float16
|
||||
)
|
||||
pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1-base", controlnet=controlnet, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
prompt = "Stormtrooper's lecture"
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/stormtrooper_depth.png"
|
||||
)
|
||||
|
||||
output = pipe(prompt, image, generator=generator, output_type="np", num_inference_steps=3)
|
||||
|
||||
image = output.images[0]
|
||||
|
||||
assert image.shape == (512, 512, 3)
|
||||
|
||||
original_image = image[-3:, -3:, -1].flatten()
|
||||
expected_image = np.array([0.4844, 0.4937, 0.4956, 0.4663, 0.5039, 0.5044, 0.4565, 0.4883, 0.4941])
|
||||
assert np.allclose(original_image, expected_image, atol=1e-04)
|
||||
|
||||
@require_python39_or_higher
|
||||
@require_torch_2
|
||||
def test_stable_diffusion_compile(self):
|
||||
run_test_in_subprocess(test_case=self, target_func=_test_stable_diffusion_compile, inputs=None)
|
||||
425
tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py
Normal file
425
tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py
Normal file
@@ -0,0 +1,425 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AsymmetricAutoencoderKL,
|
||||
AutoencoderKL,
|
||||
AutoencoderTiny,
|
||||
ConsistencyDecoderVAE,
|
||||
ControlNetXSAdapter,
|
||||
EulerDiscreteScheduler,
|
||||
StableDiffusionXLControlNetXSPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, load_image, require_torch_gpu, slow, torch_device
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...models.autoencoders.test_models_vae import (
|
||||
get_asym_autoencoder_kl_config,
|
||||
get_autoencoder_kl_config,
|
||||
get_autoencoder_tiny_config,
|
||||
get_consistency_vae_config,
|
||||
)
|
||||
from ..pipeline_params import (
|
||||
IMAGE_TO_IMAGE_IMAGE_PARAMS,
|
||||
TEXT_TO_IMAGE_BATCH_PARAMS,
|
||||
TEXT_TO_IMAGE_IMAGE_PARAMS,
|
||||
TEXT_TO_IMAGE_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import (
|
||||
PipelineKarrasSchedulerTesterMixin,
|
||||
PipelineLatentTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
SDXLOptionalComponentsTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class StableDiffusionXLControlNetXSPipelineFastTests(
|
||||
PipelineLatentTesterMixin,
|
||||
PipelineKarrasSchedulerTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
SDXLOptionalComponentsTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
pipeline_class = StableDiffusionXLControlNetXSPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
|
||||
test_attention_slicing = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(4, 8),
|
||||
layers_per_block=2,
|
||||
sample_size=16,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
use_linear_projection=True,
|
||||
norm_num_groups=4,
|
||||
# SD2-specific config below
|
||||
attention_head_dim=(2, 4),
|
||||
addition_embed_type="text_time",
|
||||
addition_time_embed_dim=8,
|
||||
transformer_layers_per_block=(1, 2),
|
||||
projection_class_embeddings_input_dim=56, # 6 * 8 (addition_time_embed_dim) + 8 (cross_attention_dim)
|
||||
cross_attention_dim=8,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
controlnet = ControlNetXSAdapter.from_unet(
|
||||
unet=unet,
|
||||
size_ratio=0.5,
|
||||
learn_time_embedding=True,
|
||||
conditioning_embedding_out_channels=(2, 2),
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
scheduler = EulerDiscreteScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
steps_offset=1,
|
||||
beta_schedule="scaled_linear",
|
||||
timestep_spacing="leading",
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=[4, 8],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
norm_num_groups=2,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=4,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
# SD2-specific config below
|
||||
hidden_act="gelu",
|
||||
projection_dim=8,
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
"controlnet": controlnet,
|
||||
"scheduler": scheduler,
|
||||
"vae": vae,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"text_encoder_2": text_encoder_2,
|
||||
"tokenizer_2": tokenizer_2,
|
||||
"feature_extractor": None,
|
||||
}
|
||||
return components
|
||||
|
||||
# copied from test_controlnet_sdxl.py
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
controlnet_embedder_scale_factor = 2
|
||||
image = randn_tensor(
|
||||
(1, 3, 8 * controlnet_embedder_scale_factor, 8 * controlnet_embedder_scale_factor),
|
||||
generator=generator,
|
||||
device=torch.device(device),
|
||||
)
|
||||
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"output_type": "np",
|
||||
"image": image,
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
# copied from test_controlnet_sdxl.py
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3)
|
||||
|
||||
# copied from test_controlnet_sdxl.py
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_attention_forwardGenerator_pass(self):
|
||||
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3)
|
||||
|
||||
# copied from test_controlnet_sdxl.py
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
|
||||
|
||||
# copied from test_controlnet_sdxl.py
|
||||
@require_torch_gpu
|
||||
def test_stable_diffusion_xl_offloads(self):
|
||||
pipes = []
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = self.pipeline_class(**components).to(torch_device)
|
||||
pipes.append(sd_pipe)
|
||||
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = self.pipeline_class(**components)
|
||||
sd_pipe.enable_model_cpu_offload()
|
||||
pipes.append(sd_pipe)
|
||||
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = self.pipeline_class(**components)
|
||||
sd_pipe.enable_sequential_cpu_offload()
|
||||
pipes.append(sd_pipe)
|
||||
|
||||
image_slices = []
|
||||
for pipe in pipes:
|
||||
pipe.unet.set_default_attn_processor()
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
image = pipe(**inputs).images
|
||||
|
||||
image_slices.append(image[0, -3:, -3:, -1].flatten())
|
||||
|
||||
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
|
||||
assert np.abs(image_slices[0] - image_slices[2]).max() < 1e-3
|
||||
|
||||
# copied from test_controlnet_sdxl.py
|
||||
def test_stable_diffusion_xl_multi_prompts(self):
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = self.pipeline_class(**components).to(torch_device)
|
||||
|
||||
# forward with single prompt
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output = sd_pipe(**inputs)
|
||||
image_slice_1 = output.images[0, -3:, -3:, -1]
|
||||
|
||||
# forward with same prompt duplicated
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["prompt_2"] = inputs["prompt"]
|
||||
output = sd_pipe(**inputs)
|
||||
image_slice_2 = output.images[0, -3:, -3:, -1]
|
||||
|
||||
# ensure the results are equal
|
||||
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
|
||||
|
||||
# forward with different prompt
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["prompt_2"] = "different prompt"
|
||||
output = sd_pipe(**inputs)
|
||||
image_slice_3 = output.images[0, -3:, -3:, -1]
|
||||
|
||||
# ensure the results are not equal
|
||||
assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4
|
||||
|
||||
# manually set a negative_prompt
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["negative_prompt"] = "negative prompt"
|
||||
output = sd_pipe(**inputs)
|
||||
image_slice_1 = output.images[0, -3:, -3:, -1]
|
||||
|
||||
# forward with same negative_prompt duplicated
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["negative_prompt"] = "negative prompt"
|
||||
inputs["negative_prompt_2"] = inputs["negative_prompt"]
|
||||
output = sd_pipe(**inputs)
|
||||
image_slice_2 = output.images[0, -3:, -3:, -1]
|
||||
|
||||
# ensure the results are equal
|
||||
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
|
||||
|
||||
# forward with different negative_prompt
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["negative_prompt"] = "negative prompt"
|
||||
inputs["negative_prompt_2"] = "different negative prompt"
|
||||
output = sd_pipe(**inputs)
|
||||
image_slice_3 = output.images[0, -3:, -3:, -1]
|
||||
|
||||
# ensure the results are not equal
|
||||
assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4
|
||||
|
||||
# copied from test_stable_diffusion_xl.py
|
||||
def test_stable_diffusion_xl_prompt_embeds(self):
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = self.pipeline_class(**components)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# forward without prompt embeds
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["prompt"] = 2 * [inputs["prompt"]]
|
||||
inputs["num_images_per_prompt"] = 2
|
||||
|
||||
output = sd_pipe(**inputs)
|
||||
image_slice_1 = output.images[0, -3:, -3:, -1]
|
||||
|
||||
# forward with prompt embeds
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
prompt = 2 * [inputs.pop("prompt")]
|
||||
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = sd_pipe.encode_prompt(prompt)
|
||||
|
||||
output = sd_pipe(
|
||||
**inputs,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
)
|
||||
image_slice_2 = output.images[0, -3:, -3:, -1]
|
||||
|
||||
# make sure that it's equal
|
||||
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1.1e-4
|
||||
|
||||
# copied from test_stable_diffusion_xl.py
|
||||
def test_save_load_optional_components(self):
|
||||
self._test_save_load_optional_components()
|
||||
|
||||
# copied from test_controlnetxs.py
|
||||
def test_to_dtype(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# pipeline creates a new UNetControlNetXSModel under the hood. So we need to check the dtype from pipe.components
|
||||
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
|
||||
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
|
||||
|
||||
pipe.to(dtype=torch.float16)
|
||||
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
|
||||
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
|
||||
|
||||
def test_multi_vae(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
block_out_channels = pipe.vae.config.block_out_channels
|
||||
norm_num_groups = pipe.vae.config.norm_num_groups
|
||||
|
||||
vae_classes = [AutoencoderKL, AsymmetricAutoencoderKL, ConsistencyDecoderVAE, AutoencoderTiny]
|
||||
configs = [
|
||||
get_autoencoder_kl_config(block_out_channels, norm_num_groups),
|
||||
get_asym_autoencoder_kl_config(block_out_channels, norm_num_groups),
|
||||
get_consistency_vae_config(block_out_channels, norm_num_groups),
|
||||
get_autoencoder_tiny_config(block_out_channels),
|
||||
]
|
||||
|
||||
out_np = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0]
|
||||
|
||||
for vae_cls, config in zip(vae_classes, configs):
|
||||
vae = vae_cls(**config)
|
||||
vae = vae.to(torch_device)
|
||||
components["vae"] = vae
|
||||
vae_pipe = self.pipeline_class(**components)
|
||||
|
||||
# pipeline creates a new UNetControlNetXSModel under the hood, which aren't on device.
|
||||
# So we need to move the new pipe to device.
|
||||
vae_pipe.to(torch_device)
|
||||
vae_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
out_vae_np = vae_pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0]
|
||||
|
||||
assert out_vae_np.shape == out_np.shape
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class StableDiffusionXLControlNetXSPipelineSlowTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_canny(self):
|
||||
controlnet = ControlNetXSAdapter.from_pretrained(
|
||||
"UmerHA/Testing-ConrolNetXS-SDXL-canny", torch_dtype=torch.float16
|
||||
)
|
||||
pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
prompt = "bird"
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
|
||||
)
|
||||
|
||||
images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images
|
||||
|
||||
assert images[0].shape == (768, 512, 3)
|
||||
|
||||
original_image = images[0, -3:, -3:, -1].flatten()
|
||||
expected_image = np.array([0.3202, 0.3151, 0.3328, 0.3172, 0.337, 0.3381, 0.3378, 0.3389, 0.3224])
|
||||
assert np.allclose(original_image, expected_image, atol=1e-04)
|
||||
|
||||
def test_depth(self):
|
||||
controlnet = ControlNetXSAdapter.from_pretrained(
|
||||
"UmerHA/Testing-ConrolNetXS-SDXL-depth", torch_dtype=torch.float16
|
||||
)
|
||||
pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
prompt = "Stormtrooper's lecture"
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/stormtrooper_depth.png"
|
||||
)
|
||||
|
||||
images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images
|
||||
|
||||
assert images[0].shape == (512, 512, 3)
|
||||
|
||||
original_image = images[0, -3:, -3:, -1].flatten()
|
||||
expected_image = np.array([0.5448, 0.5437, 0.5426, 0.5543, 0.553, 0.5475, 0.5595, 0.5602, 0.5529])
|
||||
assert np.allclose(original_image, expected_image, atol=1e-04)
|
||||
@@ -32,6 +32,7 @@ from diffusers import (
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.loaders import IPAdapterMixin
|
||||
from diffusers.models.attention_processor import AttnProcessor
|
||||
from diffusers.models.controlnet_xs import UNetControlNetXSModel
|
||||
from diffusers.models.unets.unet_3d_condition import UNet3DConditionModel
|
||||
from diffusers.models.unets.unet_i2vgen_xl import I2VGenXLUNet
|
||||
from diffusers.models.unets.unet_motion_model import UNetMotionModel
|
||||
@@ -1685,7 +1686,10 @@ class PipelineTesterMixin:
|
||||
self.assertTrue(hasattr(pipe, "vae") and isinstance(pipe.vae, (AutoencoderKL, AutoencoderTiny)))
|
||||
self.assertTrue(
|
||||
hasattr(pipe, "unet")
|
||||
and isinstance(pipe.unet, (UNet2DConditionModel, UNet3DConditionModel, I2VGenXLUNet, UNetMotionModel))
|
||||
and isinstance(
|
||||
pipe.unet,
|
||||
(UNet2DConditionModel, UNet3DConditionModel, I2VGenXLUNet, UNetMotionModel, UNetControlNetXSModel),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user