mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
* Add support of Xlabs Controlnets --------- Co-authored-by: Anzhella Pankratova <son0shad@gmail.com>
This commit is contained in:
@@ -23,7 +23,7 @@ from ..loaders import PeftAdapterMixin
|
||||
from ..models.attention_processor import AttentionProcessor
|
||||
from ..models.modeling_utils import ModelMixin
|
||||
from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from .controlnet import BaseOutput, zero_module
|
||||
from .controlnet import BaseOutput, ControlNetConditioningEmbedding, zero_module
|
||||
from .embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
|
||||
from .modeling_outputs import Transformer2DModelOutput
|
||||
from .transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
|
||||
@@ -55,6 +55,7 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
guidance_embeds: bool = False,
|
||||
axes_dims_rope: List[int] = [16, 56, 56],
|
||||
num_mode: int = None,
|
||||
conditioning_embedding_channels: int = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = in_channels
|
||||
@@ -106,7 +107,14 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
if self.union:
|
||||
self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim)
|
||||
|
||||
self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))
|
||||
if conditioning_embedding_channels is not None:
|
||||
self.input_hint_block = ControlNetConditioningEmbedding(
|
||||
conditioning_embedding_channels=conditioning_embedding_channels, block_out_channels=(16, 16, 16, 16)
|
||||
)
|
||||
self.controlnet_x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
|
||||
else:
|
||||
self.input_hint_block = None
|
||||
self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@@ -269,6 +277,16 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
)
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
if self.input_hint_block is not None:
|
||||
controlnet_cond = self.input_hint_block(controlnet_cond)
|
||||
batch_size, channels, height_pw, width_pw = controlnet_cond.shape
|
||||
height = height_pw // self.config.patch_size
|
||||
width = width_pw // self.config.patch_size
|
||||
controlnet_cond = controlnet_cond.reshape(
|
||||
batch_size, channels, height, self.config.patch_size, width, self.config.patch_size
|
||||
)
|
||||
controlnet_cond = controlnet_cond.permute(0, 2, 4, 1, 3, 5)
|
||||
controlnet_cond = controlnet_cond.reshape(batch_size, height * width, -1)
|
||||
# add
|
||||
hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond)
|
||||
|
||||
|
||||
@@ -402,6 +402,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
controlnet_block_samples=None,
|
||||
controlnet_single_block_samples=None,
|
||||
return_dict: bool = True,
|
||||
controlnet_blocks_repeat: bool = False,
|
||||
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
||||
"""
|
||||
The [`FluxTransformer2DModel`] forward method.
|
||||
@@ -508,7 +509,13 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
if controlnet_block_samples is not None:
|
||||
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
|
||||
interval_control = int(np.ceil(interval_control))
|
||||
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
|
||||
# For Xlabs ControlNet.
|
||||
if controlnet_blocks_repeat:
|
||||
hidden_states = (
|
||||
hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
|
||||
)
|
||||
else:
|
||||
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
|
||||
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
|
||||
@@ -754,19 +754,22 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
)
|
||||
height, width = control_image.shape[-2:]
|
||||
|
||||
# vae encode
|
||||
control_image = self.vae.encode(control_image).latent_dist.sample()
|
||||
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||
# xlab controlnet has a input_hint_block and instantx controlnet does not
|
||||
controlnet_blocks_repeat = False if self.controlnet.input_hint_block is None else True
|
||||
if self.controlnet.input_hint_block is None:
|
||||
# vae encode
|
||||
control_image = self.vae.encode(control_image).latent_dist.sample()
|
||||
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||
|
||||
# pack
|
||||
height_control_image, width_control_image = control_image.shape[2:]
|
||||
control_image = self._pack_latents(
|
||||
control_image,
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height_control_image,
|
||||
width_control_image,
|
||||
)
|
||||
# pack
|
||||
height_control_image, width_control_image = control_image.shape[2:]
|
||||
control_image = self._pack_latents(
|
||||
control_image,
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height_control_image,
|
||||
width_control_image,
|
||||
)
|
||||
|
||||
# Here we ensure that `control_mode` has the same length as the control_image.
|
||||
if control_mode is not None:
|
||||
@@ -777,8 +780,9 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
|
||||
elif isinstance(self.controlnet, FluxMultiControlNetModel):
|
||||
control_images = []
|
||||
|
||||
for control_image_ in control_image:
|
||||
# xlab controlnet has a input_hint_block and instantx controlnet does not
|
||||
controlnet_blocks_repeat = False if self.controlnet.nets[0].input_hint_block is None else True
|
||||
for i, control_image_ in enumerate(control_image):
|
||||
control_image_ = self.prepare_image(
|
||||
image=control_image_,
|
||||
width=width,
|
||||
@@ -790,20 +794,20 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
)
|
||||
height, width = control_image_.shape[-2:]
|
||||
|
||||
# vae encode
|
||||
control_image_ = self.vae.encode(control_image_).latent_dist.sample()
|
||||
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||
|
||||
# pack
|
||||
height_control_image, width_control_image = control_image_.shape[2:]
|
||||
control_image_ = self._pack_latents(
|
||||
control_image_,
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height_control_image,
|
||||
width_control_image,
|
||||
)
|
||||
if self.controlnet.nets[0].input_hint_block is None:
|
||||
# vae encode
|
||||
control_image_ = self.vae.encode(control_image_).latent_dist.sample()
|
||||
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||
|
||||
# pack
|
||||
height_control_image, width_control_image = control_image_.shape[2:]
|
||||
control_image_ = self._pack_latents(
|
||||
control_image_,
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height_control_image,
|
||||
width_control_image,
|
||||
)
|
||||
control_images.append(control_image_)
|
||||
|
||||
control_image = control_images
|
||||
@@ -927,6 +931,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
img_ids=latent_image_ids,
|
||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
controlnet_blocks_repeat=controlnet_blocks_repeat,
|
||||
)[0]
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
|
||||
Reference in New Issue
Block a user