mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Tencent Hunyuan Team] Add Hunyuan-DiT ControlNet Inference (#8694)
* add controlnet support --------- Co-authored-by: xingchaoliu <xingchaoliu@tencent.com> Co-authored-by: yiyixuxu <yixu310@gmail,com>
This commit is contained in:
@@ -257,6 +257,8 @@
|
||||
title: PriorTransformer
|
||||
- local: api/models/controlnet
|
||||
title: ControlNetModel
|
||||
- local: api/models/controlnet_hunyuandit
|
||||
title: HunyuanDiT2DControlNetModel
|
||||
- local: api/models/controlnet_sd3
|
||||
title: SD3ControlNetModel
|
||||
title: Models
|
||||
@@ -282,6 +284,8 @@
|
||||
title: Consistency Models
|
||||
- local: api/pipelines/controlnet
|
||||
title: ControlNet
|
||||
- local: api/pipelines/controlnet_hunyuandit
|
||||
title: ControlNet with Hunyuan-DiT
|
||||
- local: api/pipelines/controlnet_sd3
|
||||
title: ControlNet with Stable Diffusion 3
|
||||
- local: api/pipelines/controlnet_sdxl
|
||||
|
||||
37
docs/source/en/api/models/controlnet_hunyuandit.md
Normal file
37
docs/source/en/api/models/controlnet_hunyuandit.md
Normal file
@@ -0,0 +1,37 @@
|
||||
<!--Copyright 2024 The HuggingFace Team and Tencent Hunyuan Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# HunyuanDiT2DControlNetModel
|
||||
|
||||
HunyuanDiT2DControlNetModel is an implementation of ControlNet for [Hunyuan-DiT](https://arxiv.org/abs/2405.08748).
|
||||
|
||||
ControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, and Maneesh Agrawala.
|
||||
|
||||
With a ControlNet model, you can provide an additional control image to condition and control Hunyuan-DiT generation. For example, if you provide a depth map, the ControlNet model generates an image that'll preserve the spatial information from the depth map. It is a more flexible and accurate way to control the image generation process.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with "zero convolutions" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.*
|
||||
|
||||
This code is implemented by Tencent Hunyuan Team. You can find pre-trained checkpoints for Hunyuan-DiT ControlNets on [Tencent Hunyuan](https://huggingface.co/Tencent-Hunyuan).
|
||||
|
||||
## Example For Loading HunyuanDiT2DControlNetModel
|
||||
|
||||
```py
|
||||
from diffusers import HunyuanDiT2DControlNetModel
|
||||
import torch
|
||||
controlnet = HunyuanDiT2DControlNetModel.from_pretrained("Tencent-Hunyuan/HunyuanDiT-v1.1-ControlNet-Diffusers-Pose", torch_dtype=torch.float16)
|
||||
```
|
||||
|
||||
## HunyuanDiT2DControlNetModel
|
||||
|
||||
[[autodoc]] HunyuanDiT2DControlNetModel
|
||||
36
docs/source/en/api/pipelines/controlnet_hunyuandit.md
Normal file
36
docs/source/en/api/pipelines/controlnet_hunyuandit.md
Normal file
@@ -0,0 +1,36 @@
|
||||
<!--Copyright 2024 The HuggingFace Team and Tencent Hunyuan Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# ControlNet with Hunyuan-DiT
|
||||
|
||||
HunyuanDiTControlNetPipeline is an implementation of ControlNet for [Hunyuan-DiT](https://arxiv.org/abs/2405.08748).
|
||||
|
||||
ControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, and Maneesh Agrawala.
|
||||
|
||||
With a ControlNet model, you can provide an additional control image to condition and control Hunyuan-DiT generation. For example, if you provide a depth map, the ControlNet model generates an image that'll preserve the spatial information from the depth map. It is a more flexible and accurate way to control the image generation process.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with "zero convolutions" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.*
|
||||
|
||||
This code is implemented by Tencent Hunyuan Team. You can find pre-trained checkpoints for Hunyuan-DiT ControlNets on [Tencent Hunyuan](https://huggingface.co/Tencent-Hunyuan).
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
## HunyuanDiTControlNetPipeline
|
||||
[[autodoc]] HunyuanDiTControlNetPipeline
|
||||
- all
|
||||
- __call__
|
||||
@@ -1,4 +1,4 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
<!--Copyright 2024 The HuggingFace Team and Tencent Hunyuan Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
@@ -83,7 +83,9 @@ else:
|
||||
"ControlNetModel",
|
||||
"ControlNetXSAdapter",
|
||||
"DiTTransformer2DModel",
|
||||
"HunyuanDiT2DControlNetModel",
|
||||
"HunyuanDiT2DModel",
|
||||
"HunyuanDiT2DMultiControlNetModel",
|
||||
"I2VGenXLUNet",
|
||||
"Kandinsky3UNet",
|
||||
"ModelMixin",
|
||||
@@ -234,6 +236,7 @@ else:
|
||||
"BlipDiffusionPipeline",
|
||||
"CLIPImageProjection",
|
||||
"CycleDiffusionPipeline",
|
||||
"HunyuanDiTControlNetPipeline",
|
||||
"HunyuanDiTPipeline",
|
||||
"I2VGenXLPipeline",
|
||||
"IFImg2ImgPipeline",
|
||||
@@ -500,7 +503,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
ControlNetModel,
|
||||
ControlNetXSAdapter,
|
||||
DiTTransformer2DModel,
|
||||
HunyuanDiT2DControlNetModel,
|
||||
HunyuanDiT2DModel,
|
||||
HunyuanDiT2DMultiControlNetModel,
|
||||
I2VGenXLUNet,
|
||||
Kandinsky3UNet,
|
||||
ModelMixin,
|
||||
@@ -629,6 +634,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AudioLDMPipeline,
|
||||
CLIPImageProjection,
|
||||
CycleDiffusionPipeline,
|
||||
HunyuanDiTControlNetPipeline,
|
||||
HunyuanDiTPipeline,
|
||||
I2VGenXLPipeline,
|
||||
IFImg2ImgPipeline,
|
||||
|
||||
@@ -33,6 +33,7 @@ if is_torch_available():
|
||||
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
|
||||
_import_structure["autoencoders.vq_model"] = ["VQModel"]
|
||||
_import_structure["controlnet"] = ["ControlNetModel"]
|
||||
_import_structure["controlnet_hunyuan"] = ["HunyuanDiT2DControlNetModel", "HunyuanDiT2DMultiControlNetModel"]
|
||||
_import_structure["controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
|
||||
_import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
|
||||
_import_structure["embeddings"] = ["ImageProjection"]
|
||||
@@ -75,6 +76,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
VQModel,
|
||||
)
|
||||
from .controlnet import ControlNetModel
|
||||
from .controlnet_hunyuan import HunyuanDiT2DControlNetModel, HunyuanDiT2DMultiControlNetModel
|
||||
from .controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
|
||||
from .controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel
|
||||
from .embeddings import ImageProjection
|
||||
|
||||
399
src/diffusers/models/controlnet_hunyuan.py
Normal file
399
src/diffusers/models/controlnet_hunyuan.py
Normal file
@@ -0,0 +1,399 @@
|
||||
# Copyright 2024 HunyuanDiT Authors, Qixun Wang and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import logging
|
||||
from .attention_processor import AttentionProcessor
|
||||
from .controlnet import BaseOutput, Tuple, zero_module
|
||||
from .embeddings import (
|
||||
HunyuanCombinedTimestepTextSizeStyleEmbedding,
|
||||
PatchEmbed,
|
||||
PixArtAlphaTextProjection,
|
||||
)
|
||||
from .modeling_utils import ModelMixin
|
||||
from .transformers.hunyuan_transformer_2d import HunyuanDiTBlock
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclass
|
||||
class HunyuanControlNetOutput(BaseOutput):
|
||||
controlnet_block_samples: Tuple[torch.Tensor]
|
||||
|
||||
|
||||
class HunyuanDiT2DControlNetModel(ModelMixin, ConfigMixin):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
conditioning_channels: int = 3,
|
||||
num_attention_heads: int = 16,
|
||||
attention_head_dim: int = 88,
|
||||
in_channels: Optional[int] = None,
|
||||
patch_size: Optional[int] = None,
|
||||
activation_fn: str = "gelu-approximate",
|
||||
sample_size=32,
|
||||
hidden_size=1152,
|
||||
transformer_num_layers: int = 40,
|
||||
mlp_ratio: float = 4.0,
|
||||
cross_attention_dim: int = 1024,
|
||||
cross_attention_dim_t5: int = 2048,
|
||||
pooled_projection_dim: int = 1024,
|
||||
text_len: int = 77,
|
||||
text_len_t5: int = 256,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_attention_heads
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.text_embedder = PixArtAlphaTextProjection(
|
||||
in_features=cross_attention_dim_t5,
|
||||
hidden_size=cross_attention_dim_t5 * 4,
|
||||
out_features=cross_attention_dim,
|
||||
act_fn="silu_fp32",
|
||||
)
|
||||
|
||||
self.text_embedding_padding = nn.Parameter(
|
||||
torch.randn(text_len + text_len_t5, cross_attention_dim, dtype=torch.float32)
|
||||
)
|
||||
|
||||
self.pos_embed = PatchEmbed(
|
||||
height=sample_size,
|
||||
width=sample_size,
|
||||
in_channels=in_channels,
|
||||
embed_dim=hidden_size,
|
||||
patch_size=patch_size,
|
||||
pos_embed_type=None,
|
||||
)
|
||||
|
||||
self.time_extra_emb = HunyuanCombinedTimestepTextSizeStyleEmbedding(
|
||||
hidden_size,
|
||||
pooled_projection_dim=pooled_projection_dim,
|
||||
seq_len=text_len_t5,
|
||||
cross_attention_dim=cross_attention_dim_t5,
|
||||
)
|
||||
|
||||
# controlnet_blocks
|
||||
self.controlnet_blocks = nn.ModuleList([])
|
||||
|
||||
# HunyuanDiT Blocks
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
HunyuanDiTBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=self.config.num_attention_heads,
|
||||
activation_fn=activation_fn,
|
||||
ff_inner_dim=int(self.inner_dim * mlp_ratio),
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details.
|
||||
skip=False, # always False as it is the first half of the model
|
||||
)
|
||||
for layer in range(transformer_num_layers // 2 - 1)
|
||||
]
|
||||
)
|
||||
self.input_block = zero_module(nn.Linear(hidden_size, hidden_size))
|
||||
for _ in range(len(self.blocks)):
|
||||
controlnet_block = nn.Linear(hidden_size, hidden_size)
|
||||
controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_blocks.append(controlnet_block)
|
||||
|
||||
@property
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the
|
||||
corresponding cross attention processor. This is strongly recommended when setting trainable attention
|
||||
processors.
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
@classmethod
|
||||
def from_transformer(
|
||||
cls, transformer, conditioning_channels=3, transformer_num_layers=None, load_weights_from_transformer=True
|
||||
):
|
||||
config = transformer.config
|
||||
activation_fn = config.activation_fn
|
||||
attention_head_dim = config.attention_head_dim
|
||||
cross_attention_dim = config.cross_attention_dim
|
||||
cross_attention_dim_t5 = config.cross_attention_dim_t5
|
||||
hidden_size = config.hidden_size
|
||||
in_channels = config.in_channels
|
||||
mlp_ratio = config.mlp_ratio
|
||||
num_attention_heads = config.num_attention_heads
|
||||
patch_size = config.patch_size
|
||||
sample_size = config.sample_size
|
||||
text_len = config.text_len
|
||||
text_len_t5 = config.text_len_t5
|
||||
|
||||
conditioning_channels = conditioning_channels
|
||||
transformer_num_layers = transformer_num_layers or config.transformer_num_layers
|
||||
|
||||
controlnet = cls(
|
||||
conditioning_channels=conditioning_channels,
|
||||
transformer_num_layers=transformer_num_layers,
|
||||
activation_fn=activation_fn,
|
||||
attention_head_dim=attention_head_dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
cross_attention_dim_t5=cross_attention_dim_t5,
|
||||
hidden_size=hidden_size,
|
||||
in_channels=in_channels,
|
||||
mlp_ratio=mlp_ratio,
|
||||
num_attention_heads=num_attention_heads,
|
||||
patch_size=patch_size,
|
||||
sample_size=sample_size,
|
||||
text_len=text_len,
|
||||
text_len_t5=text_len_t5,
|
||||
)
|
||||
if load_weights_from_transformer:
|
||||
key = controlnet.load_state_dict(transformer.state_dict(), strict=False)
|
||||
logger.warning(f"controlnet load from Hunyuan-DiT. missing_keys: {key[0]}")
|
||||
return controlnet
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
timestep,
|
||||
controlnet_cond: torch.Tensor,
|
||||
conditioning_scale: float = 1.0,
|
||||
encoder_hidden_states=None,
|
||||
text_embedding_mask=None,
|
||||
encoder_hidden_states_t5=None,
|
||||
text_embedding_mask_t5=None,
|
||||
image_meta_size=None,
|
||||
style=None,
|
||||
image_rotary_emb=None,
|
||||
return_dict=True,
|
||||
):
|
||||
"""
|
||||
The [`HunyuanDiT2DControlNetModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.Tensor` of shape `(batch size, dim, height, width)`):
|
||||
The input tensor.
|
||||
timestep ( `torch.LongTensor`, *optional*):
|
||||
Used to indicate denoising step.
|
||||
controlnet_cond ( `torch.Tensor` ):
|
||||
The conditioning input to ControlNet.
|
||||
conditioning_scale ( `float` ):
|
||||
Indicate the conditioning scale.
|
||||
encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
||||
Conditional embeddings for cross attention layer. This is the output of `BertModel`.
|
||||
text_embedding_mask: torch.Tensor
|
||||
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
|
||||
of `BertModel`.
|
||||
encoder_hidden_states_t5 ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
||||
Conditional embeddings for cross attention layer. This is the output of T5 Text Encoder.
|
||||
text_embedding_mask_t5: torch.Tensor
|
||||
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
|
||||
of T5 Text Encoder.
|
||||
image_meta_size (torch.Tensor):
|
||||
Conditional embedding indicate the image sizes
|
||||
style: torch.Tensor:
|
||||
Conditional embedding indicate the style
|
||||
image_rotary_emb (`torch.Tensor`):
|
||||
The image rotary embeddings to apply on query and key tensors during attention calculation.
|
||||
return_dict: bool
|
||||
Whether to return a dictionary.
|
||||
"""
|
||||
|
||||
height, width = hidden_states.shape[-2:]
|
||||
|
||||
hidden_states = self.pos_embed(hidden_states) # b,c,H,W -> b, N, C
|
||||
|
||||
# 2. pre-process
|
||||
hidden_states = hidden_states + self.input_block(self.pos_embed(controlnet_cond))
|
||||
|
||||
temb = self.time_extra_emb(
|
||||
timestep, encoder_hidden_states_t5, image_meta_size, style, hidden_dtype=timestep.dtype
|
||||
) # [B, D]
|
||||
|
||||
# text projection
|
||||
batch_size, sequence_length, _ = encoder_hidden_states_t5.shape
|
||||
encoder_hidden_states_t5 = self.text_embedder(
|
||||
encoder_hidden_states_t5.view(-1, encoder_hidden_states_t5.shape[-1])
|
||||
)
|
||||
encoder_hidden_states_t5 = encoder_hidden_states_t5.view(batch_size, sequence_length, -1)
|
||||
|
||||
encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_t5], dim=1)
|
||||
text_embedding_mask = torch.cat([text_embedding_mask, text_embedding_mask_t5], dim=-1)
|
||||
text_embedding_mask = text_embedding_mask.unsqueeze(2).bool()
|
||||
|
||||
encoder_hidden_states = torch.where(text_embedding_mask, encoder_hidden_states, self.text_embedding_padding)
|
||||
|
||||
block_res_samples = ()
|
||||
for layer, block in enumerate(self.blocks):
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
temb=temb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
) # (N, L, D)
|
||||
|
||||
block_res_samples = block_res_samples + (hidden_states,)
|
||||
|
||||
controlnet_block_res_samples = ()
|
||||
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
|
||||
block_res_sample = controlnet_block(block_res_sample)
|
||||
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
|
||||
|
||||
# 6. scaling
|
||||
controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples]
|
||||
|
||||
if not return_dict:
|
||||
return (controlnet_block_res_samples,)
|
||||
|
||||
return HunyuanControlNetOutput(controlnet_block_samples=controlnet_block_res_samples)
|
||||
|
||||
|
||||
class HunyuanDiT2DMultiControlNetModel(ModelMixin):
|
||||
r"""
|
||||
`HunyuanDiT2DMultiControlNetModel` wrapper class for Multi-HunyuanDiT2DControlNetModel
|
||||
|
||||
This module is a wrapper for multiple instances of the `HunyuanDiT2DControlNetModel`. The `forward()` API is
|
||||
designed to be compatible with `HunyuanDiT2DControlNetModel`.
|
||||
|
||||
Args:
|
||||
controlnets (`List[HunyuanDiT2DControlNetModel]`):
|
||||
Provides additional conditioning to the unet during the denoising process. You must set multiple
|
||||
`HunyuanDiT2DControlNetModel` as a list.
|
||||
"""
|
||||
|
||||
def __init__(self, controlnets):
|
||||
super().__init__()
|
||||
self.nets = nn.ModuleList(controlnets)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
timestep,
|
||||
controlnet_cond: torch.Tensor,
|
||||
conditioning_scale: float = 1.0,
|
||||
encoder_hidden_states=None,
|
||||
text_embedding_mask=None,
|
||||
encoder_hidden_states_t5=None,
|
||||
text_embedding_mask_t5=None,
|
||||
image_meta_size=None,
|
||||
style=None,
|
||||
image_rotary_emb=None,
|
||||
return_dict=True,
|
||||
):
|
||||
"""
|
||||
The [`HunyuanDiT2DControlNetModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.Tensor` of shape `(batch size, dim, height, width)`):
|
||||
The input tensor.
|
||||
timestep ( `torch.LongTensor`, *optional*):
|
||||
Used to indicate denoising step.
|
||||
controlnet_cond ( `torch.Tensor` ):
|
||||
The conditioning input to ControlNet.
|
||||
conditioning_scale ( `float` ):
|
||||
Indicate the conditioning scale.
|
||||
encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
||||
Conditional embeddings for cross attention layer. This is the output of `BertModel`.
|
||||
text_embedding_mask: torch.Tensor
|
||||
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
|
||||
of `BertModel`.
|
||||
encoder_hidden_states_t5 ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
||||
Conditional embeddings for cross attention layer. This is the output of T5 Text Encoder.
|
||||
text_embedding_mask_t5: torch.Tensor
|
||||
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
|
||||
of T5 Text Encoder.
|
||||
image_meta_size (torch.Tensor):
|
||||
Conditional embedding indicate the image sizes
|
||||
style: torch.Tensor:
|
||||
Conditional embedding indicate the style
|
||||
image_rotary_emb (`torch.Tensor`):
|
||||
The image rotary embeddings to apply on query and key tensors during attention calculation.
|
||||
return_dict: bool
|
||||
Whether to return a dictionary.
|
||||
"""
|
||||
for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
|
||||
block_samples = controlnet(
|
||||
hidden_states=hidden_states,
|
||||
timestep=timestep,
|
||||
controlnet_cond=image,
|
||||
conditioning_scale=scale,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
text_embedding_mask=text_embedding_mask,
|
||||
encoder_hidden_states_t5=encoder_hidden_states_t5,
|
||||
text_embedding_mask_t5=text_embedding_mask_t5,
|
||||
image_meta_size=image_meta_size,
|
||||
style=style,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
# merge samples
|
||||
if i == 0:
|
||||
control_block_samples = block_samples
|
||||
else:
|
||||
control_block_samples = [
|
||||
control_block_sample + block_sample
|
||||
for control_block_sample, block_sample in zip(control_block_samples[0], block_samples[0])
|
||||
]
|
||||
control_block_samples = (control_block_samples,)
|
||||
|
||||
return control_block_samples
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2024 HunyuanDiT Authors and The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2024 HunyuanDiT Authors, Qixun Wang and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -437,6 +437,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
|
||||
image_meta_size=None,
|
||||
style=None,
|
||||
image_rotary_emb=None,
|
||||
controlnet_block_samples=None,
|
||||
return_dict=True,
|
||||
):
|
||||
"""
|
||||
@@ -491,7 +492,10 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
|
||||
skips = []
|
||||
for layer, block in enumerate(self.blocks):
|
||||
if layer > self.config.num_layers // 2:
|
||||
skip = skips.pop()
|
||||
if controlnet_block_samples is not None:
|
||||
skip = skips.pop() + controlnet_block_samples.pop()
|
||||
else:
|
||||
skip = skips.pop()
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
temb=temb,
|
||||
@@ -510,6 +514,9 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
|
||||
if layer < (self.config.num_layers // 2 - 1):
|
||||
skips.append(hidden_states)
|
||||
|
||||
if controlnet_block_samples is not None and len(controlnet_block_samples) != 0:
|
||||
raise ValueError("The number of controls is not equal to the number of skip connections.")
|
||||
|
||||
# final layer
|
||||
hidden_states = self.norm_out(hidden_states, temb.to(torch.float32))
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
@@ -20,6 +20,7 @@ from ..utils import (
|
||||
_dummy_objects = {}
|
||||
_import_structure = {
|
||||
"controlnet": [],
|
||||
"controlnet_hunyuandit": [],
|
||||
"controlnet_sd3": [],
|
||||
"controlnet_xs": [],
|
||||
"deprecated": [],
|
||||
@@ -152,6 +153,11 @@ else:
|
||||
"StableDiffusionXLControlNetXSPipeline",
|
||||
]
|
||||
)
|
||||
_import_structure["controlnet_hunyuandit"].extend(
|
||||
[
|
||||
"HunyuanDiTControlNetPipeline",
|
||||
]
|
||||
)
|
||||
_import_structure["controlnet_sd3"].extend(
|
||||
[
|
||||
"StableDiffusion3ControlNetPipeline",
|
||||
@@ -409,6 +415,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionXLControlNetInpaintPipeline,
|
||||
StableDiffusionXLControlNetPipeline,
|
||||
)
|
||||
from .controlnet_hunyuandit import (
|
||||
HunyuanDiTControlNetPipeline,
|
||||
)
|
||||
from .controlnet_sd3 import (
|
||||
StableDiffusion3ControlNetPipeline,
|
||||
)
|
||||
|
||||
48
src/diffusers/pipelines/controlnet_hunyuandit/__init__.py
Normal file
48
src/diffusers/pipelines/controlnet_hunyuandit/__init__.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_hunyuandit_controlnet"] = ["HunyuanDiTControlNetPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_hunyuandit_controlnet import HunyuanDiTControlNetPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -122,6 +122,21 @@ class DiTTransformer2DModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class HunyuanDiT2DControlNetModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class HunyuanDiT2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -137,6 +152,21 @@ class HunyuanDiT2DModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class HunyuanDiT2DMultiControlNetModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class I2VGenXLUNet(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -212,6 +212,21 @@ class CycleDiffusionPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class HunyuanDiTControlNetPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class HunyuanDiTPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
0
tests/pipelines/controlnet_hunyuandit/__init__.py
Normal file
0
tests/pipelines/controlnet_hunyuandit/__init__.py
Normal file
@@ -0,0 +1,350 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc and Tencent Hunyuan Team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, BertModel, T5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDPMScheduler,
|
||||
HunyuanDiT2DModel,
|
||||
HunyuanDiTControlNetPipeline,
|
||||
)
|
||||
from diffusers.models import HunyuanDiT2DControlNetModel, HunyuanDiT2DMultiControlNetModel
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class HunyuanDiTControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||
pipeline_class = HunyuanDiTControlNetPipeline
|
||||
params = frozenset(
|
||||
[
|
||||
"prompt",
|
||||
"height",
|
||||
"width",
|
||||
"guidance_scale",
|
||||
"negative_prompt",
|
||||
"prompt_embeds",
|
||||
"negative_prompt_embeds",
|
||||
]
|
||||
)
|
||||
batch_params = frozenset(["prompt", "negative_prompt"])
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
transformer = HunyuanDiT2DModel(
|
||||
sample_size=16,
|
||||
num_layers=4,
|
||||
patch_size=2,
|
||||
attention_head_dim=8,
|
||||
num_attention_heads=3,
|
||||
in_channels=4,
|
||||
cross_attention_dim=32,
|
||||
cross_attention_dim_t5=32,
|
||||
pooled_projection_dim=16,
|
||||
hidden_size=24,
|
||||
activation_fn="gelu-approximate",
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
controlnet = HunyuanDiT2DControlNetModel(
|
||||
sample_size=16,
|
||||
transformer_num_layers=4,
|
||||
patch_size=2,
|
||||
attention_head_dim=8,
|
||||
num_attention_heads=3,
|
||||
in_channels=4,
|
||||
cross_attention_dim=32,
|
||||
cross_attention_dim_t5=32,
|
||||
pooled_projection_dim=16,
|
||||
hidden_size=24,
|
||||
activation_fn="gelu-approximate",
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL()
|
||||
|
||||
scheduler = DDPMScheduler()
|
||||
text_encoder = BertModel.from_pretrained("hf-internal-testing/tiny-random-BertModel")
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BertModel")
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
components = {
|
||||
"transformer": transformer.eval(),
|
||||
"vae": vae.eval(),
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"text_encoder_2": text_encoder_2,
|
||||
"tokenizer_2": tokenizer_2,
|
||||
"safety_checker": None,
|
||||
"feature_extractor": None,
|
||||
"controlnet": controlnet,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device="cpu").manual_seed(seed)
|
||||
|
||||
control_image = randn_tensor(
|
||||
(1, 3, 16, 16),
|
||||
generator=generator,
|
||||
device=torch.device(device),
|
||||
dtype=torch.float16,
|
||||
)
|
||||
|
||||
controlnet_conditioning_scale = 0.5
|
||||
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 5.0,
|
||||
"output_type": "np",
|
||||
"control_image": control_image,
|
||||
"controlnet_conditioning_scale": controlnet_conditioning_scale,
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_controlnet_hunyuandit(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = HunyuanDiTControlNetPipeline(**components)
|
||||
pipe = pipe.to(torch_device, dtype=torch.float16)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output = pipe(**inputs)
|
||||
image = output.images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
assert image.shape == (1, 16, 16, 3)
|
||||
|
||||
expected_slice = np.array(
|
||||
[0.6953125, 0.89208984, 0.59375, 0.5078125, 0.5786133, 0.6035156, 0.5839844, 0.53564453, 0.52246094]
|
||||
)
|
||||
|
||||
assert (
|
||||
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
), f"Expected: {expected_slice}, got: {image_slice.flatten()}"
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(
|
||||
expected_max_diff=1e-3,
|
||||
)
|
||||
|
||||
def test_sequential_cpu_offload_forward_pass(self):
|
||||
# TODO(YiYi) need to fix later
|
||||
pass
|
||||
|
||||
def test_sequential_offload_forward_pass_twice(self):
|
||||
# TODO(YiYi) need to fix later
|
||||
pass
|
||||
|
||||
def test_save_load_optional_components(self):
|
||||
# TODO(YiYi) need to fix later
|
||||
pass
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class HunyuanDiTControlNetPipelineSlowTests(unittest.TestCase):
|
||||
pipeline_class = HunyuanDiTControlNetPipeline
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_canny(self):
|
||||
controlnet = HunyuanDiT2DControlNetModel.from_pretrained(
|
||||
"Tencent-Hunyuan/HunyuanDiT-v1.1-ControlNet-Diffusers-Canny", torch_dtype=torch.float16
|
||||
)
|
||||
pipe = HunyuanDiTControlNetPipeline.from_pretrained(
|
||||
"Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers", controlnet=controlnet, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
prompt = "At night, an ancient Chinese-style lion statue stands in front of the hotel, its eyes gleaming as if guarding the building. The background is the hotel entrance at night, with a close-up, eye-level, and centered composition. This photo presents a realistic photographic style, embodies Chinese sculpture culture, and reveals a mysterious atmosphere."
|
||||
n_prompt = ""
|
||||
control_image = load_image(
|
||||
"https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.1-ControlNet-Diffusers-Canny/resolve/main/canny.jpg?download=true"
|
||||
)
|
||||
|
||||
output = pipe(
|
||||
prompt,
|
||||
negative_prompt=n_prompt,
|
||||
control_image=control_image,
|
||||
controlnet_conditioning_scale=0.5,
|
||||
guidance_scale=5.0,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
generator=generator,
|
||||
)
|
||||
image = output.images[0]
|
||||
|
||||
assert image.shape == (1024, 1024, 3)
|
||||
|
||||
original_image = image[-3:, -3:, -1].flatten()
|
||||
|
||||
expected_image = np.array(
|
||||
[0.43652344, 0.4399414, 0.44921875, 0.45043945, 0.45703125, 0.44873047, 0.43579102, 0.44018555, 0.42578125]
|
||||
)
|
||||
|
||||
assert np.abs(original_image.flatten() - expected_image).max() < 1e-2
|
||||
|
||||
def test_pose(self):
|
||||
controlnet = HunyuanDiT2DControlNetModel.from_pretrained(
|
||||
"Tencent-Hunyuan/HunyuanDiT-v1.1-ControlNet-Diffusers-Pose", torch_dtype=torch.float16
|
||||
)
|
||||
pipe = HunyuanDiTControlNetPipeline.from_pretrained(
|
||||
"Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers", controlnet=controlnet, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
prompt = "An Asian woman, dressed in a green top, wearing a purple headscarf and a purple scarf, stands in front of a blackboard. The background is the blackboard. The photo is presented in a close-up, eye-level, and centered composition, adopting a realistic photographic style"
|
||||
n_prompt = ""
|
||||
control_image = load_image(
|
||||
"https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.1-ControlNet-Diffusers-Pose/resolve/main/pose.jpg?download=true"
|
||||
)
|
||||
|
||||
output = pipe(
|
||||
prompt,
|
||||
negative_prompt=n_prompt,
|
||||
control_image=control_image,
|
||||
controlnet_conditioning_scale=0.5,
|
||||
guidance_scale=5.0,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
generator=generator,
|
||||
)
|
||||
image = output.images[0]
|
||||
|
||||
assert image.shape == (1024, 1024, 3)
|
||||
|
||||
original_image = image[-3:, -3:, -1].flatten()
|
||||
|
||||
expected_image = np.array(
|
||||
[0.4091797, 0.4177246, 0.39526367, 0.4194336, 0.40356445, 0.3857422, 0.39208984, 0.40429688, 0.37451172]
|
||||
)
|
||||
|
||||
assert np.abs(original_image.flatten() - expected_image).max() < 1e-2
|
||||
|
||||
def test_depth(self):
|
||||
controlnet = HunyuanDiT2DControlNetModel.from_pretrained(
|
||||
"Tencent-Hunyuan/HunyuanDiT-v1.1-ControlNet-Diffusers-Depth", torch_dtype=torch.float16
|
||||
)
|
||||
pipe = HunyuanDiTControlNetPipeline.from_pretrained(
|
||||
"Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers", controlnet=controlnet, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
prompt = "In the dense forest, a black and white panda sits quietly in green trees and red flowers, surrounded by mountains, rivers, and the ocean. The background is the forest in a bright environment."
|
||||
n_prompt = ""
|
||||
control_image = load_image(
|
||||
"https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.1-ControlNet-Diffusers-Depth/resolve/main/depth.jpg?download=true"
|
||||
)
|
||||
|
||||
output = pipe(
|
||||
prompt,
|
||||
negative_prompt=n_prompt,
|
||||
control_image=control_image,
|
||||
controlnet_conditioning_scale=0.5,
|
||||
guidance_scale=5.0,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
generator=generator,
|
||||
)
|
||||
image = output.images[0]
|
||||
|
||||
assert image.shape == (1024, 1024, 3)
|
||||
|
||||
original_image = image[-3:, -3:, -1].flatten()
|
||||
|
||||
expected_image = np.array(
|
||||
[0.31982422, 0.32177734, 0.30126953, 0.3190918, 0.3100586, 0.31396484, 0.3232422, 0.33544922, 0.30810547]
|
||||
)
|
||||
|
||||
assert np.abs(original_image.flatten() - expected_image).max() < 1e-2
|
||||
|
||||
def test_multi_controlnet(self):
|
||||
controlnet = HunyuanDiT2DControlNetModel.from_pretrained(
|
||||
"Tencent-Hunyuan/HunyuanDiT-v1.1-ControlNet-Diffusers-Canny", torch_dtype=torch.float16
|
||||
)
|
||||
controlnet = HunyuanDiT2DMultiControlNetModel([controlnet, controlnet])
|
||||
|
||||
pipe = HunyuanDiTControlNetPipeline.from_pretrained(
|
||||
"Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers", controlnet=controlnet, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
prompt = "At night, an ancient Chinese-style lion statue stands in front of the hotel, its eyes gleaming as if guarding the building. The background is the hotel entrance at night, with a close-up, eye-level, and centered composition. This photo presents a realistic photographic style, embodies Chinese sculpture culture, and reveals a mysterious atmosphere."
|
||||
n_prompt = ""
|
||||
control_image = load_image(
|
||||
"https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.1-ControlNet-Diffusers-Canny/resolve/main/canny.jpg?download=true"
|
||||
)
|
||||
|
||||
output = pipe(
|
||||
prompt,
|
||||
negative_prompt=n_prompt,
|
||||
control_image=[control_image, control_image],
|
||||
controlnet_conditioning_scale=[0.25, 0.25],
|
||||
guidance_scale=5.0,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
generator=generator,
|
||||
)
|
||||
image = output.images[0]
|
||||
|
||||
assert image.shape == (1024, 1024, 3)
|
||||
|
||||
original_image = image[-3:, -3:, -1].flatten()
|
||||
expected_image = np.array(
|
||||
[0.43652344, 0.44018555, 0.4494629, 0.44995117, 0.45654297, 0.44848633, 0.43603516, 0.4404297, 0.42626953]
|
||||
)
|
||||
|
||||
assert np.abs(original_image.flatten() - expected_image).max() < 1e-2
|
||||
Reference in New Issue
Block a user