1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

fix controlnet module refactor (#9968)

* fix
This commit is contained in:
YiYi Xu
2024-11-20 13:11:39 -10:00
committed by GitHub
parent 3139d39fa7
commit e564abe292
22 changed files with 272 additions and 58 deletions

View File

@@ -107,6 +107,7 @@ else:
"ModelMixin",
"MotionAdapter",
"MultiAdapter",
"MultiControlNetModel",
"PixArtTransformer2DModel",
"PriorTransformer",
"SD3ControlNetModel",
@@ -592,6 +593,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ModelMixin,
MotionAdapter,
MultiAdapter,
MultiControlNetModel,
PixArtTransformer2DModel,
PriorTransformer,
SD3ControlNetModel,

View File

@@ -11,9 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
from ..utils import deprecate
from .controlnets.controlnet import ( # noqa
BaseOutput,
ControlNetConditioningEmbedding,
ControlNetModel,
ControlNetOutput,
@@ -24,19 +25,91 @@ from .controlnets.controlnet import ( # noqa
class ControlNetOutput(ControlNetOutput):
def __init__(self, *args, **kwargs):
deprecation_message = "Importing `ControlNetOutput` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetOutput`, instead."
deprecate("ControlNetOutput", "0.34", deprecation_message)
deprecate("diffusers.models.controlnet.ControlNetOutput", "0.34", deprecation_message)
super().__init__(*args, **kwargs)
class ControlNetModel(ControlNetModel):
def __init__(self, *args, **kwargs):
def __init__(
self,
in_channels: int = 4,
conditioning_channels: int = 3,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
act_fn: str = "silu",
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
cross_attention_dim: int = 1280,
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None,
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
use_linear_projection: bool = False,
class_embed_type: Optional[str] = None,
addition_embed_type: Optional[str] = None,
addition_time_embed_dim: Optional[int] = None,
num_class_embeds: Optional[int] = None,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
projection_class_embeddings_input_dim: Optional[int] = None,
controlnet_conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
global_pool_conditions: bool = False,
addition_embed_type_num_heads: int = 64,
):
deprecation_message = "Importing `ControlNetModel` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetModel`, instead."
deprecate("ControlNetModel", "0.34", deprecation_message)
super().__init__(*args, **kwargs)
deprecate("diffusers.models.controlnet.ControlNetModel", "0.34", deprecation_message)
super().__init__(
in_channels=in_channels,
conditioning_channels=conditioning_channels,
flip_sin_to_cos=flip_sin_to_cos,
freq_shift=freq_shift,
down_block_types=down_block_types,
mid_block_type=mid_block_type,
only_cross_attention=only_cross_attention,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
downsample_padding=downsample_padding,
mid_block_scale_factor=mid_block_scale_factor,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
cross_attention_dim=cross_attention_dim,
transformer_layers_per_block=transformer_layers_per_block,
encoder_hid_dim=encoder_hid_dim,
encoder_hid_dim_type=encoder_hid_dim_type,
attention_head_dim=attention_head_dim,
num_attention_heads=num_attention_heads,
use_linear_projection=use_linear_projection,
class_embed_type=class_embed_type,
addition_embed_type=addition_embed_type,
addition_time_embed_dim=addition_time_embed_dim,
num_class_embeds=num_class_embeds,
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
global_pool_conditions=global_pool_conditions,
addition_embed_type_num_heads=addition_embed_type_num_heads,
)
class ControlNetConditioningEmbedding(ControlNetConditioningEmbedding):
def __init__(self, *args, **kwargs):
deprecation_message = "Importing `ControlNetConditioningEmbedding` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetConditioningEmbedding`, instead."
deprecate("ControlNetConditioningEmbedding", "0.34", deprecation_message)
deprecate("diffusers.models.controlnet.ControlNetConditioningEmbedding", "0.34", deprecation_message)
super().__init__(*args, **kwargs)

View File

@@ -13,6 +13,8 @@
# limitations under the License.
from typing import List
from ..utils import deprecate, logging
from .controlnets.controlnet_flux import FluxControlNetModel, FluxControlNetOutput, FluxMultiControlNetModel
@@ -23,19 +25,46 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class FluxControlNetOutput(FluxControlNetOutput):
def __init__(self, *args, **kwargs):
deprecation_message = "Importing `FluxControlNetOutput` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxControlNetOutput`, instead."
deprecate("FluxControlNetOutput", "0.34", deprecation_message)
deprecate("diffusers.models.controlnet_flux.FluxControlNetOutput", "0.34", deprecation_message)
super().__init__(*args, **kwargs)
class FluxControlNetModel(FluxControlNetModel):
def __init__(self, *args, **kwargs):
def __init__(
self,
patch_size: int = 1,
in_channels: int = 64,
num_layers: int = 19,
num_single_layers: int = 38,
attention_head_dim: int = 128,
num_attention_heads: int = 24,
joint_attention_dim: int = 4096,
pooled_projection_dim: int = 768,
guidance_embeds: bool = False,
axes_dims_rope: List[int] = [16, 56, 56],
num_mode: int = None,
conditioning_embedding_channels: int = None,
):
deprecation_message = "Importing `FluxControlNetModel` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxControlNetModel`, instead."
deprecate("FluxControlNetModel", "0.34", deprecation_message)
super().__init__(*args, **kwargs)
deprecate("diffusers.models.controlnet_flux.FluxControlNetModel", "0.34", deprecation_message)
super().__init__(
patch_size=patch_size,
in_channels=in_channels,
num_layers=num_layers,
num_single_layers=num_single_layers,
attention_head_dim=attention_head_dim,
num_attention_heads=num_attention_heads,
joint_attention_dim=joint_attention_dim,
pooled_projection_dim=pooled_projection_dim,
guidance_embeds=guidance_embeds,
axes_dims_rope=axes_dims_rope,
num_mode=num_mode,
conditioning_embedding_channels=conditioning_embedding_channels,
)
class FluxMultiControlNetModel(FluxMultiControlNetModel):
def __init__(self, *args, **kwargs):
deprecation_message = "Importing `FluxMultiControlNetModel` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxMultiControlNetModel`, instead."
deprecate("FluxMultiControlNetModel", "0.34", deprecation_message)
deprecate("diffusers.models.controlnet_flux.FluxMultiControlNetModel", "0.34", deprecation_message)
super().__init__(*args, **kwargs)

View File

@@ -23,19 +23,46 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class SD3ControlNetOutput(SD3ControlNetOutput):
def __init__(self, *args, **kwargs):
deprecation_message = "Importing `SD3ControlNetOutput` from `diffusers.models.controlnet_sd3` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sd3 import SD3ControlNetOutput`, instead."
deprecate("SD3ControlNetOutput", "0.34", deprecation_message)
deprecate("diffusers.models.controlnet_sd3.SD3ControlNetOutput", "0.34", deprecation_message)
super().__init__(*args, **kwargs)
class SD3ControlNetModel(SD3ControlNetModel):
def __init__(self, *args, **kwargs):
def __init__(
self,
sample_size: int = 128,
patch_size: int = 2,
in_channels: int = 16,
num_layers: int = 18,
attention_head_dim: int = 64,
num_attention_heads: int = 18,
joint_attention_dim: int = 4096,
caption_projection_dim: int = 1152,
pooled_projection_dim: int = 2048,
out_channels: int = 16,
pos_embed_max_size: int = 96,
extra_conditioning_channels: int = 0,
):
deprecation_message = "Importing `SD3ControlNetModel` from `diffusers.models.controlnet_sd3` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sd3 import SD3ControlNetModel`, instead."
deprecate("SD3ControlNetModel", "0.34", deprecation_message)
super().__init__(*args, **kwargs)
deprecate("diffusers.models.controlnet_sd3.SD3ControlNetModel", "0.34", deprecation_message)
super().__init__(
sample_size=sample_size,
patch_size=patch_size,
in_channels=in_channels,
num_layers=num_layers,
attention_head_dim=attention_head_dim,
num_attention_heads=num_attention_heads,
joint_attention_dim=joint_attention_dim,
caption_projection_dim=caption_projection_dim,
pooled_projection_dim=pooled_projection_dim,
out_channels=out_channels,
pos_embed_max_size=pos_embed_max_size,
extra_conditioning_channels=extra_conditioning_channels,
)
class SD3MultiControlNetModel(SD3MultiControlNetModel):
def __init__(self, *args, **kwargs):
deprecation_message = "Importing `SD3MultiControlNetModel` from `diffusers.models.controlnet_sd3` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sd3 import SD3MultiControlNetModel`, instead."
deprecate("SD3MultiControlNetModel", "0.34", deprecation_message)
deprecate("diffusers.models.controlnet_sd3.SD3MultiControlNetModel", "0.34", deprecation_message)
super().__init__(*args, **kwargs)

View File

@@ -13,6 +13,8 @@
# limitations under the License.
from typing import Optional, Tuple, Union
from ..utils import deprecate, logging
from .controlnets.controlnet_sparsectrl import ( # noqa
SparseControlNetConditioningEmbedding,
@@ -28,19 +30,87 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class SparseControlNetOutput(SparseControlNetOutput):
def __init__(self, *args, **kwargs):
deprecation_message = "Importing `SparseControlNetOutput` from `diffusers.models.controlnet_sparsectrl` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sparsectrl import SparseControlNetOutput`, instead."
deprecate("SparseControlNetOutput", "0.34", deprecation_message)
deprecate("diffusers.models.controlnet_sparsectrl.SparseControlNetOutput", "0.34", deprecation_message)
super().__init__(*args, **kwargs)
class SparseControlNetConditioningEmbedding(SparseControlNetConditioningEmbedding):
def __init__(self, *args, **kwargs):
deprecation_message = "Importing `SparseControlNetConditioningEmbedding` from `diffusers.models.controlnet_sparsectrl` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sparsectrl import SparseControlNetConditioningEmbedding`, instead."
deprecate("SparseControlNetConditioningEmbedding", "0.34", deprecation_message)
deprecate(
"diffusers.models.controlnet_sparsectrl.SparseControlNetConditioningEmbedding", "0.34", deprecation_message
)
super().__init__(*args, **kwargs)
class SparseControlNetModel(SparseControlNetModel):
def __init__(self, *args, **kwargs):
def __init__(
self,
in_channels: int = 4,
conditioning_channels: int = 4,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlockMotion",
"CrossAttnDownBlockMotion",
"CrossAttnDownBlockMotion",
"DownBlockMotion",
),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
act_fn: str = "silu",
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
cross_attention_dim: int = 768,
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
transformer_layers_per_mid_block: Optional[Union[int, Tuple[int]]] = None,
temporal_transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
use_linear_projection: bool = False,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
global_pool_conditions: bool = False,
controlnet_conditioning_channel_order: str = "rgb",
motion_max_seq_length: int = 32,
motion_num_attention_heads: int = 8,
concat_conditioning_mask: bool = True,
use_simplified_condition_embedding: bool = True,
):
deprecation_message = "Importing `SparseControlNetModel` from `diffusers.models.controlnet_sparsectrl` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sparsectrl import SparseControlNetModel`, instead."
deprecate("SparseControlNetModel", "0.34", deprecation_message)
super().__init__(*args, **kwargs)
deprecate("diffusers.models.controlnet_sparsectrl.SparseControlNetModel", "0.34", deprecation_message)
super().__init__(
in_channels=in_channels,
conditioning_channels=conditioning_channels,
flip_sin_to_cos=flip_sin_to_cos,
freq_shift=freq_shift,
down_block_types=down_block_types,
only_cross_attention=only_cross_attention,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
downsample_padding=downsample_padding,
mid_block_scale_factor=mid_block_scale_factor,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
cross_attention_dim=cross_attention_dim,
transformer_layers_per_block=transformer_layers_per_block,
transformer_layers_per_mid_block=transformer_layers_per_mid_block,
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block,
attention_head_dim=attention_head_dim,
num_attention_heads=num_attention_heads,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
global_pool_conditions=global_pool_conditions,
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
motion_max_seq_length=motion_max_seq_length,
motion_num_attention_heads=motion_num_attention_heads,
concat_conditioning_mask=concat_conditioning_mask,
use_simplified_condition_embedding=use_simplified_condition_embedding,
)

View File

@@ -22,8 +22,8 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...models.attention_processor import AttentionProcessor
from ...models.modeling_utils import ModelMixin
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ..controlnet import BaseOutput, ControlNetConditioningEmbedding, zero_module
from ...utils import USE_PEFT_BACKEND, BaseOutput, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ..controlnets.controlnet import ControlNetConditioningEmbedding, zero_module
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
from ..modeling_outputs import Transformer2DModelOutput
from ..transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
@@ -192,13 +192,13 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
num_attention_heads: int = 24,
load_weights_from_transformer=True,
):
config = transformer.config
config = dict(transformer.config)
config["num_layers"] = num_layers
config["num_single_layers"] = num_single_layers
config["attention_head_dim"] = attention_head_dim
config["num_attention_heads"] = num_attention_heads
controlnet = cls(**config)
controlnet = cls.from_config(config)
if load_weights_from_transformer:
controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())

View File

@@ -18,7 +18,7 @@ import torch
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
from ...utils import BaseOutput, logging
from ..attention_processor import AttentionProcessor
from ..embeddings import (
HunyuanCombinedTimestepTextSizeStyleEmbedding,
@@ -27,7 +27,7 @@ from ..embeddings import (
)
from ..modeling_utils import ModelMixin
from ..transformers.hunyuan_transformer_2d import HunyuanDiTBlock
from .controlnet import BaseOutput, Tuple, zero_module
from .controlnet import Tuple, zero_module
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

View File

@@ -82,7 +82,7 @@ class MultiControlNetModel(ModelMixin):
):
"""
Save a model and its configuration file to a directory, so that it can be re-loaded using the
`[`~pipelines.controlnet.MultiControlNetModel.from_pretrained`]` class method.
`[`~models.controlnets.multicontrolnet.MultiControlNetModel.from_pretrained`]` class method.
Arguments:
save_directory (`str` or `os.PathLike`):
@@ -128,7 +128,7 @@ class MultiControlNetModel(ModelMixin):
Parameters:
pretrained_model_path (`os.PathLike`):
A path to a *directory* containing model weights saved using
[`~diffusers.pipelines.controlnet.MultiControlNetModel.save_pretrained`], e.g.,
[`~models.controlnets.multicontrolnet.MultiControlNetModel.save_pretrained`], e.g.,
`./my_model_directory/controlnet`.
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype

View File

@@ -21,14 +21,20 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
from ...image_processor import PipelineImageInput
from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel, UNetMotionModel
from ...models import (
AutoencoderKL,
ControlNetModel,
ImageProjection,
MultiControlNetModel,
UNet2DConditionModel,
UNetMotionModel,
)
from ...models.lora import adjust_lora_scale_text_encoder
from ...models.unets.unet_motion_model import MotionAdapter
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ...video_processor import VideoProcessor
from ..controlnet.multicontrolnet import MultiControlNetModel
from ..free_init_utils import FreeInitMixin
from ..free_noise_utils import AnimateDiffFreeNoiseMixin
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin

View File

@@ -21,7 +21,14 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
from ...image_processor import PipelineImageInput
from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel, UNetMotionModel
from ...models import (
AutoencoderKL,
ControlNetModel,
ImageProjection,
MultiControlNetModel,
UNet2DConditionModel,
UNetMotionModel,
)
from ...models.lora import adjust_lora_scale_text_encoder
from ...models.unets.unet_motion_model import MotionAdapter
from ...schedulers import (
@@ -35,7 +42,6 @@ from ...schedulers import (
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ...video_processor import VideoProcessor
from ..controlnet.multicontrolnet import MultiControlNetModel
from ..free_init_utils import FreeInitMixin
from ..free_noise_utils import AnimateDiffFreeNoiseMixin
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin

View File

@@ -8,5 +8,5 @@ logger = logging.get_logger(__name__)
class MultiControlNetModel(MultiControlNetModel):
def __init__(self, *args, **kwargs):
deprecation_message = "Importing `MultiControlNetModel` from `diffusers.pipelines.controlnet.multicontrolnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.multicontrolnet import MultiControlNetModel`, instead."
deprecate("MultiControlNetModel", "0.34", deprecation_message)
deprecate("diffusers.pipelines.controlnet.multicontrolnet.MultiControlNetModel", "0.34", deprecation_message)
super().__init__(*args, **kwargs)

View File

@@ -25,7 +25,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
@@ -40,7 +40,6 @@ from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_ten
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from .multicontrolnet import MultiControlNetModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

View File

@@ -24,7 +24,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
@@ -39,7 +39,6 @@ from ...utils.torch_utils import is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from .multicontrolnet import MultiControlNetModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

View File

@@ -26,7 +26,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
@@ -41,7 +41,6 @@ from ...utils.torch_utils import is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from .multicontrolnet import MultiControlNetModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

View File

@@ -35,7 +35,7 @@ from ...loaders import (
StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin,
)
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel
from ...models.attention_processor import (
AttnProcessor2_0,
XFormersAttnProcessor,
@@ -54,7 +54,6 @@ from ...utils import (
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
from .multicontrolnet import MultiControlNetModel
if is_invisible_watermark_available():

View File

@@ -38,7 +38,7 @@ from ...loaders import (
StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin,
)
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel
from ...models.attention_processor import (
AttnProcessor2_0,
XFormersAttnProcessor,
@@ -61,8 +61,6 @@ from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutpu
if is_invisible_watermark_available():
from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
from .multicontrolnet import MultiControlNetModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

View File

@@ -38,7 +38,7 @@ from ...loaders import (
StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin,
)
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel
from ...models.attention_processor import (
AttnProcessor2_0,
XFormersAttnProcessor,
@@ -61,8 +61,6 @@ from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutpu
if is_invisible_watermark_available():
from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
from .multicontrolnet import MultiControlNetModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

View File

@@ -25,7 +25,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
@@ -36,7 +36,6 @@ from ...utils import (
unscale_lora_layers,
)
from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
from ..controlnet.multicontrolnet import MultiControlNetModel
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker

View File

@@ -26,7 +26,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
@@ -37,7 +37,6 @@ from ...utils import (
unscale_lora_layers,
)
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ..controlnet.multicontrolnet import MultiControlNetModel
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker

View File

@@ -38,7 +38,7 @@ from ...loaders import (
StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin,
)
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel
from ...models.attention_processor import (
AttnProcessor2_0,
XFormersAttnProcessor,
@@ -61,8 +61,6 @@ from .pag_utils import PAGMixin
if is_invisible_watermark_available():
from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
from ..controlnet.multicontrolnet import MultiControlNetModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

View File

@@ -38,7 +38,7 @@ from ...loaders import (
StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin,
)
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel
from ...models.attention_processor import (
AttnProcessor2_0,
XFormersAttnProcessor,
@@ -61,8 +61,6 @@ from .pag_utils import PAGMixin
if is_invisible_watermark_available():
from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
from ..controlnet.multicontrolnet import MultiControlNetModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

View File

@@ -452,6 +452,21 @@ class MultiAdapter(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class MultiControlNetModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class PixArtTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]