1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Add ability to mix usage of T2I-Adapter(s) and ControlNet(s). (#5362)

* Add ability to mix usage of T2I-Adapter(s) and ControlNet(s).
Previously, UNet2DConditional implemnetation onloy allowed use of one or the other.
Adds new forward() arg down_intrablock_additional_residuals specifically for T2I-Adapters. If down_intrablock_addtional_residuals is not used, maintains backward compatibility with prior usage of only T2I-Adapter or ControlNet but not both

* Improving forward() arg docs in src/diffusers/models/unet_2d_condition.py

Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>

* Add deprecation warning if down_block_additional_residues is used for T2I-Adapter (intrablock residuals)

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Oops my bad, fixing last commit.

* Added import of diffusers utils.deprecate

* Conform to max line length

* Modifying T2I-Adapter pipelines to reflect change to UNet forward() arg for T2I-Adapter residuals.

---------

Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Gregg Helt
2023-10-16 07:29:05 -07:00
committed by GitHub
parent cc12f3ec92
commit de12776b3a
3 changed files with 34 additions and 14 deletions

View File

@@ -20,7 +20,7 @@ import torch.utils.checkpoint
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import UNet2DConditionLoadersMixin
from ..utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
from ..utils import USE_PEFT_BACKEND, BaseOutput, logging, deprecate, scale_lora_layers, unscale_lora_layers
from .activations import get_activation
from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
@@ -778,6 +778,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
mid_block_additional_residual: Optional[torch.Tensor] = None,
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[UNet2DConditionOutput, Tuple]:
@@ -822,6 +823,13 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
added_cond_kwargs: (`dict`, *optional*):
A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
are passed along to the UNet blocks.
down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
additional residuals to be added to UNet long skip connections from down blocks to up blocks
for example from ControlNet side model(s)
mid_block_additional_residual (`torch.Tensor`, *optional*):
additional residual to be added to UNet mid block output, for example from ControlNet side model
down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
Returns:
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
@@ -1000,15 +1008,28 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
scale_lora_layers(self, lora_scale)
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
# using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
is_adapter = down_intrablock_additional_residuals is not None
# maintain backward compatibility for legacy usage, where
# T2I-Adapter and ControlNet both use down_block_additional_residuals arg
# but can only use one or the other
if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
deprecate("T2I should not use down_block_additional_residuals",
"1.3.0",
"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
standard_warn=False)
down_intrablock_additional_residuals = down_block_additional_residuals
is_adapter = True
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
# For t2i-adapter CrossAttnDownBlock2D
additional_residuals = {}
if is_adapter and len(down_block_additional_residuals) > 0:
additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
if is_adapter and len(down_intrablock_additional_residuals) > 0:
additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
sample, res_samples = downsample_block(
hidden_states=sample,
@@ -1021,9 +1042,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)
if is_adapter and len(down_block_additional_residuals) > 0:
sample += down_block_additional_residuals.pop(0)
if is_adapter and len(down_intrablock_additional_residuals) > 0:
sample += down_intrablock_additional_residuals.pop(0)
down_block_res_samples += res_samples
@@ -1051,10 +1071,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
# To support T2I-Adapter-XL
if (
is_adapter
and len(down_block_additional_residuals) > 0
and sample.shape == down_block_additional_residuals[0].shape
and len(down_intrablock_additional_residuals) > 0
and sample.shape == down_intrablock_additional_residuals[0].shape
):
sample += down_block_additional_residuals.pop(0)
sample += down_intrablock_additional_residuals.pop(0)
if is_controlnet:
sample = sample + mid_block_additional_residual

View File

@@ -813,7 +813,7 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline):
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=[state.clone() for state in adapter_state],
down_intrablock_additional_residuals=[state.clone() for state in adapter_state],
).sample
# perform guidance

View File

@@ -975,9 +975,9 @@ class StableDiffusionXLAdapterPipeline(
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
if i < int(num_inference_steps * adapter_conditioning_factor):
down_block_additional_residuals = [state.clone() for state in adapter_state]
down_intrablock_additional_residuals = [state.clone() for state in adapter_state]
else:
down_block_additional_residuals = None
down_intrablock_additional_residuals = None
noise_pred = self.unet(
latent_model_input,
@@ -986,7 +986,7 @@ class StableDiffusionXLAdapterPipeline(
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
down_block_additional_residuals=down_block_additional_residuals,
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
)[0]
# perform guidance