diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 06421305c3..0ce2e04ad9 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -20,7 +20,7 @@ import torch.utils.checkpoint from ..configuration_utils import ConfigMixin, register_to_config from ..loaders import UNet2DConditionLoadersMixin -from ..utils import USE_PEFT_BACKEND, BaseOutput, logging, deprecate, scale_lora_layers, unscale_lora_layers +from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers from .activations import get_activation from .attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, @@ -824,8 +824,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that are passed along to the UNet blocks. down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*): - additional residuals to be added to UNet long skip connections from down blocks to up blocks - for example from ControlNet side model(s) + additional residuals to be added to UNet long skip connections from down blocks to up blocks for + example from ControlNet side model(s) mid_block_additional_residual (`torch.Tensor`, *optional*): additional residual to be added to UNet mid block output, for example from ControlNet side model down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): @@ -1014,12 +1014,14 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) # T2I-Adapter and ControlNet both use down_block_additional_residuals arg # but can only use one or the other if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: - deprecate("T2I should not use down_block_additional_residuals", - "1.3.0", - "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ + deprecate( + "T2I should not use down_block_additional_residuals", + "1.3.0", + "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", - standard_warn=False) + standard_warn=False, + ) down_intrablock_additional_residuals = down_block_additional_residuals is_adapter = True diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 2ed3deeb12..a70903b4bd 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -987,6 +987,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, + down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNet2DConditionOutput, Tuple]: @@ -1031,6 +1032,13 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): added_cond_kwargs: (`dict`, *optional*): A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that are passed along to the UNet blocks. + down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*): + additional residuals to be added to UNet long skip connections from down blocks to up blocks for + example from ControlNet side model(s) + mid_block_additional_residual (`torch.Tensor`, *optional*): + additional residual to be added to UNet mid block output, for example from ControlNet side model + down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): + additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) Returns: [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: @@ -1216,15 +1224,31 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): scale_lora_layers(self, lora_scale) is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None - is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None + # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets + is_adapter = down_intrablock_additional_residuals is not None + # maintain backward compatibility for legacy usage, where + # T2I-Adapter and ControlNet both use down_block_additional_residuals arg + # but can only use one or the other + if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: + deprecate( + "T2I should not use down_block_additional_residuals", + "1.3.0", + "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated " + " and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only" + " be used for ControlNet. Please make sure use" + " `down_intrablock_additional_residuals` instead. ", + standard_warn=False, + ) + down_intrablock_additional_residuals = down_block_additional_residuals + is_adapter = True down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: # For t2i-adapter CrossAttnDownBlockFlat additional_residuals = {} - if is_adapter and len(down_block_additional_residuals) > 0: - additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0) + if is_adapter and len(down_intrablock_additional_residuals) > 0: + additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) sample, res_samples = downsample_block( hidden_states=sample, @@ -1237,9 +1261,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale) - - if is_adapter and len(down_block_additional_residuals) > 0: - sample += down_block_additional_residuals.pop(0) + if is_adapter and len(down_intrablock_additional_residuals) > 0: + sample += down_intrablock_additional_residuals.pop(0) down_block_res_samples += res_samples @@ -1267,10 +1290,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): # To support T2I-Adapter-XL if ( is_adapter - and len(down_block_additional_residuals) > 0 - and sample.shape == down_block_additional_residuals[0].shape + and len(down_intrablock_additional_residuals) > 0 + and sample.shape == down_intrablock_additional_residuals[0].shape ): - sample += down_block_additional_residuals.pop(0) + sample += down_intrablock_additional_residuals.pop(0) if is_controlnet: sample = sample + mid_block_additional_residual