diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index be9203b4d6..729bce548e 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -22,7 +22,7 @@ from torch import nn from ..configuration_utils import ConfigMixin, register_to_config from ..modeling_utils import ModelMixin from ..models.embeddings import ImagePositionalEmbeddings -from ..utils import BaseOutput +from ..utils import CONFIG_NAME, BaseOutput from ..utils.import_utils import is_xformers_available @@ -666,3 +666,120 @@ class AdaLayerNorm(nn.Module): scale, shift = torch.chunk(emb, 2) x = self.norm(x) * (1 + scale) + shift return x + + +class DualTransformer2DModel(nn.Module): + """ + Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + Pass if the input is continuous. The number of channels in the input and output. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of context dimensions to use. + sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. + Note that this is fixed at training time as it is used for learning a number of position embeddings. See + `ImagePositionalEmbeddings`. + num_vector_embeds (`int`, *optional*): + Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. + The number of diffusion steps used during training. Note that this is fixed at training time as it is used + to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for + up to but not more than steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the TransformerBlocks' attention should contain a bias parameter. + """ + + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + ): + super().__init__() + self.transformers = nn.ModuleList( + [ + Transformer2DModel( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + in_channels=in_channels, + num_layers=num_layers, + dropout=dropout, + norm_num_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attention_bias=attention_bias, + sample_size=sample_size, + num_vector_embeds=num_vector_embeds, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + ) + for _ in range(2) + ] + ) + + # The ratio of transformer1 to transformer2's output states to be combined during inference + self.mix_ratio = 0.5 + + # The shape of `encoder_hidden_states` is expected to be + # `(batch_size, num_condition_tokens[0]+num_condition_tokens[1], num_features)` + self.num_condition_tokens = (77, 257) + + def forward(self, hidden_states, encoder_hidden_states, timestep=None, return_dict: bool = True): + """ + Args: + hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. + When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input + hidden_states + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, context dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.long`, *optional*): + Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`] + if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample + tensor. + """ + input_states = hidden_states + + encoded_states = [] + tokens_start = 0 + for i in range(2): + # for each of the two transformers, pass the corresponding condition tokens + condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.num_condition_tokens[i]] + encoded_state = self.transformers[i](input_states, condition_state, timestep, return_dict)[0] + encoded_states.append(encoded_state - input_states) + tokens_start += self.num_condition_tokens[i] + + output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio) + output_states = output_states + input_states + + if not return_dict: + return (output_states,) + + return Transformer2DModelOutput(sample=output_states) + + def _set_attention_slice(self, slice_size): + for transformer in self.transformers: + transformer._set_attention_slice(slice_size) + + def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + for transformer in self.transformers: + transformer._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index c645f9f607..0988cbb0ab 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -15,7 +15,7 @@ import numpy as np import torch from torch import nn -from .attention import AttentionBlock, Transformer2DModel +from .attention import AttentionBlock, Transformer2DModel, DualTransformer2DModel from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py index a6e12dfc17..93ac157b2c 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py @@ -17,7 +17,6 @@ from typing import Callable, List, Optional, Union import numpy as np import torch -import torch.nn as nn import torch.utils.checkpoint import PIL @@ -29,7 +28,7 @@ from transformers import ( ) from ...models import AutoencoderKL, UNet2DConditionModel -from ...models.attention import Transformer2DModel, Transformer2DModelOutput +from ...models.attention import DualTransformer2DModel, Transformer2DModel from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...utils import is_accelerate_available, logging @@ -94,12 +93,32 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): if isinstance(module, Transformer2DModel): parent_name, index = name.rsplit(".", 1) index = int(index) + image_transformer = self.image_unet.get_submodule(parent_name)[index] text_transformer = self.text_unet.get_submodule(parent_name)[index] + config = image_transformer.config dual_transformer = DualTransformer2DModel( - image_transformer, text_transformer, mix_ratio=mix_ratio, condition_types=condition_types + num_attention_heads=config.num_attention_heads, + attention_head_dim=config.attention_head_dim, + in_channels=config.in_channels, + num_layers=config.num_layers, + dropout=config.dropout, + norm_num_groups=config.norm_num_groups, + cross_attention_dim=config.cross_attention_dim, + attention_bias=config.attention_bias, + sample_size=config.sample_size, + num_vector_embeds=config.num_vector_embeds, + activation_fn=config.activation_fn, + num_embeds_ada_norm=config.num_embeds_ada_norm, ) + for i, type in enumerate(condition_types): + if type == "image": + dual_transformer.transformers[i] = image_transformer + else: + dual_transformer.transformers[i] = text_transformer + + dual_transformer.mix_ratio = mix_ratio self.image_unet.get_submodule(parent_name)[index] = dual_transformer def remove_dual_attention(self): @@ -107,7 +126,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): if isinstance(module, DualTransformer2DModel): parent_name, index = name.rsplit(".", 1) index = int(index) - self.image_unet.get_submodule(parent_name)[index] = module.image_transformer + self.image_unet.get_submodule(parent_name)[index] = module.transformers[0] # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention with unet->image_unet def enable_xformers_memory_efficient_attention(self): @@ -412,6 +431,11 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): latents = latents * self.scheduler.init_noise_sigma return latents + def set_mix_ratio(self, mix_ratio): + for name, module in self.image_unet.named_modules(): + if isinstance(module, DualTransformer2DModel): + module.mix_ratio = mix_ratio + @torch.no_grad() def __call__( self, @@ -539,6 +563,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): # 7. Combine the attention blocks of the image and text UNets self.convert_to_dual_attention(prompt_mix_ratio, prompt_types) + self.set_mix_ratio(prompt_mix_ratio) # 8. Denoising loop for i, t in enumerate(self.progress_bar(timesteps)): @@ -575,36 +600,3 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): return (image,) return ImagePipelineOutput(images=image) - - -class DualTransformer2DModel(nn.Module): - def __init__(self, image_transformer, text_transformer, mix_ratio=0.5, condition_types=("text", "image")): - super().__init__() - self.image_transformer = image_transformer - self.text_transformer = text_transformer - self.mix_ratio = mix_ratio - self.condition_types = condition_types - - def forward(self, input_states, encoder_hidden_states, timestep=None, return_dict: bool = True): - if self.condition_types[0] == "text": - condition_states = [encoder_hidden_states[:, :77], encoder_hidden_states[:, 77:]] - else: - condition_states = [encoder_hidden_states[:, :257], encoder_hidden_states[:, 257:]] - - encoded_states = [] - for i in range(2): - if self.condition_types[i] == "text": - text_output = self.text_transformer(input_states, condition_states[i], timestep, return_dict) - encoded_states.append(text_output[0]) - else: - image_output = self.image_transformer(input_states, condition_states[i], timestep, return_dict) - encoded_states.append(image_output[0]) - encoded_states[i] = encoded_states[i] - input_states - - output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio) - output_states = output_states + input_states - - if not return_dict: - return (output_states,) - - return Transformer2DModelOutput(sample=output_states) diff --git a/tests/pipelines/versatile_diffusion/test_versatile_diffusion_dual_guided.py b/tests/pipelines/versatile_diffusion/test_versatile_diffusion_dual_guided.py index 57f0b55446..568f674338 100644 --- a/tests/pipelines/versatile_diffusion/test_versatile_diffusion_dual_guided.py +++ b/tests/pipelines/versatile_diffusion/test_versatile_diffusion_dual_guided.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import gc +import tempfile import unittest import numpy as np @@ -34,6 +36,47 @@ class VersatileDiffusionDualGuidedPipelineFastTests(PipelineTesterMixin, unittes @slow @require_torch_gpu class VersatileDiffusionDualGuidedPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_from_pretrained_save_pretrained(self): + pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained("diffusers/vd-official-test") + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device=torch_device).manual_seed(0) + image = pipe( + first_prompt="first prompt", + second_prompt="second prompt", + prompt_mix_ratio=0.75, + generator=generator, + guidance_scale=7.5, + num_inference_steps=2, + output_type="numpy", + ).images + + with tempfile.TemporaryDirectory() as tmpdirname: + pipe.save_pretrained(tmpdirname) + pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained(tmpdirname) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator = generator.manual_seed(0) + new_image = pipe( + first_prompt="first prompt", + second_prompt="second prompt", + prompt_mix_ratio=0.75, + generator=generator, + guidance_scale=7.5, + num_inference_steps=2, + output_type="numpy", + ).images + + assert np.abs(image - new_image).sum() < 1e-5, "Models don't have the same forward pass" + def test_inference_image_variations(self): pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained("diffusers/vd-official-test") pipe.to(torch_device)