mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add ability to mix usage of T2I-Adapter(s) and ControlNet(s). (#5362)
* Add ability to mix usage of T2I-Adapter(s) and ControlNet(s). Previously, UNet2DConditional implemnetation onloy allowed use of one or the other. Adds new forward() arg down_intrablock_additional_residuals specifically for T2I-Adapters. If down_intrablock_addtional_residuals is not used, maintains backward compatibility with prior usage of only T2I-Adapter or ControlNet but not both * Improving forward() arg docs in src/diffusers/models/unet_2d_condition.py Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com> * Add deprecation warning if down_block_additional_residues is used for T2I-Adapter (intrablock residuals) Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Oops my bad, fixing last commit. * Added import of diffusers utils.deprecate * Conform to max line length * Modifying T2I-Adapter pipelines to reflect change to UNet forward() arg for T2I-Adapter residuals. --------- Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -20,7 +20,7 @@ import torch.utils.checkpoint
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import UNet2DConditionLoadersMixin
|
||||
from ..utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..utils import USE_PEFT_BACKEND, BaseOutput, logging, deprecate, scale_lora_layers, unscale_lora_layers
|
||||
from .activations import get_activation
|
||||
from .attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
@@ -778,6 +778,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
||||
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
||||
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[UNet2DConditionOutput, Tuple]:
|
||||
@@ -822,6 +823,13 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
added_cond_kwargs: (`dict`, *optional*):
|
||||
A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
|
||||
are passed along to the UNet blocks.
|
||||
down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
|
||||
additional residuals to be added to UNet long skip connections from down blocks to up blocks
|
||||
for example from ControlNet side model(s)
|
||||
mid_block_additional_residual (`torch.Tensor`, *optional*):
|
||||
additional residual to be added to UNet mid block output, for example from ControlNet side model
|
||||
down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
|
||||
additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
|
||||
|
||||
Returns:
|
||||
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
||||
@@ -1000,15 +1008,28 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
scale_lora_layers(self, lora_scale)
|
||||
|
||||
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
|
||||
is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
|
||||
# using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
|
||||
is_adapter = down_intrablock_additional_residuals is not None
|
||||
# maintain backward compatibility for legacy usage, where
|
||||
# T2I-Adapter and ControlNet both use down_block_additional_residuals arg
|
||||
# but can only use one or the other
|
||||
if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
|
||||
deprecate("T2I should not use down_block_additional_residuals",
|
||||
"1.3.0",
|
||||
"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
|
||||
and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
|
||||
for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
|
||||
standard_warn=False)
|
||||
down_intrablock_additional_residuals = down_block_additional_residuals
|
||||
is_adapter = True
|
||||
|
||||
down_block_res_samples = (sample,)
|
||||
for downsample_block in self.down_blocks:
|
||||
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
||||
# For t2i-adapter CrossAttnDownBlock2D
|
||||
additional_residuals = {}
|
||||
if is_adapter and len(down_block_additional_residuals) > 0:
|
||||
additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
|
||||
if is_adapter and len(down_intrablock_additional_residuals) > 0:
|
||||
additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
|
||||
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample,
|
||||
@@ -1021,9 +1042,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)
|
||||
|
||||
if is_adapter and len(down_block_additional_residuals) > 0:
|
||||
sample += down_block_additional_residuals.pop(0)
|
||||
if is_adapter and len(down_intrablock_additional_residuals) > 0:
|
||||
sample += down_intrablock_additional_residuals.pop(0)
|
||||
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
@@ -1051,10 +1071,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
# To support T2I-Adapter-XL
|
||||
if (
|
||||
is_adapter
|
||||
and len(down_block_additional_residuals) > 0
|
||||
and sample.shape == down_block_additional_residuals[0].shape
|
||||
and len(down_intrablock_additional_residuals) > 0
|
||||
and sample.shape == down_intrablock_additional_residuals[0].shape
|
||||
):
|
||||
sample += down_block_additional_residuals.pop(0)
|
||||
sample += down_intrablock_additional_residuals.pop(0)
|
||||
|
||||
if is_controlnet:
|
||||
sample = sample + mid_block_additional_residual
|
||||
|
||||
@@ -813,7 +813,7 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline):
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
down_block_additional_residuals=[state.clone() for state in adapter_state],
|
||||
down_intrablock_additional_residuals=[state.clone() for state in adapter_state],
|
||||
).sample
|
||||
|
||||
# perform guidance
|
||||
|
||||
@@ -975,9 +975,9 @@ class StableDiffusionXLAdapterPipeline(
|
||||
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
||||
|
||||
if i < int(num_inference_steps * adapter_conditioning_factor):
|
||||
down_block_additional_residuals = [state.clone() for state in adapter_state]
|
||||
down_intrablock_additional_residuals = [state.clone() for state in adapter_state]
|
||||
else:
|
||||
down_block_additional_residuals = None
|
||||
down_intrablock_additional_residuals = None
|
||||
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
@@ -986,7 +986,7 @@ class StableDiffusionXLAdapterPipeline(
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
down_block_additional_residuals=down_block_additional_residuals,
|
||||
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
|
||||
Reference in New Issue
Block a user