From f706729d3cefc9ab02f74c09dee6b655a24ec750 Mon Sep 17 00:00:00 2001 From: anton-l Date: Mon, 21 Nov 2022 23:56:35 +0100 Subject: [PATCH] text unet end to end --- ...onvert_versatile_diffusion_to_diffusers.py | 52 +- src/diffusers/__init__.py | 1 + src/diffusers/models/unet_2d.py | 2 +- src/diffusers/models/unet_2d_condition.py | 2 +- src/diffusers/pipelines/__init__.py | 1 + .../pipelines/versatile_diffusion/__init__.py | 3 + .../versatile_diffusion/modeling_text_unet.py | 210 +++++--- .../pipeline_versatile_diffusion.py | 3 +- ...eline_versatile_diffusion_image_to_text.py | 461 ++++++++++++++++++ ...ine_versatile_diffusion_image_variation.py | 13 +- ...eline_versatile_diffusion_text_to_image.py | 9 +- .../dummy_torch_and_transformers_objects.py | 15 + .../test_versatile_diffusion_image_to_text.py | 56 +++ 13 files changed, 728 insertions(+), 100 deletions(-) create mode 100644 tests/pipelines/versatile_diffusion/test_versatile_diffusion_image_to_text.py diff --git a/scripts/convert_versatile_diffusion_to_diffusers.py b/scripts/convert_versatile_diffusion_to_diffusers.py index ca54f05f05..efa01a73b8 100644 --- a/scripts/convert_versatile_diffusion_to_diffusers.py +++ b/scripts/convert_versatile_diffusion_to_diffusers.py @@ -31,13 +31,14 @@ from diffusers import ( UNet2DConditionModel, VersatileDiffusionPipeline, ) +from diffusers.pipelines.versatile_diffusion.modeling_text_unet import UNetFlatConditionModel +from diffusers.pipelines.versatile_diffusion.modeling_gpt2_optimus import GPT2OptimusForLatentConnector from transformers import ( CLIPFeatureExtractor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection, ) -from diffusers.pipelines.versatile_diffusion.modeling_text_unet import UNetMultiDimConditionModel SCHEDULER_CONFIG = Namespace( @@ -241,7 +242,7 @@ def assign_to_checkpoint( # proj_attn.weight has to be converted from conv 1D to linear if "proj_attn.weight" in new_path: checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] - else: + elif path["old"] in old_checkpoint: checkpoint[new_path] = old_checkpoint[path["old"]] @@ -306,14 +307,14 @@ def create_text_unet_diffusers_config(unet_params): down_block_types = [] resolution = 1 for i in range(len(block_out_channels)): - block_type = "CrossAttnDownBlockMultiDim" if unet_params.with_attn[i] else "DownBlockMultiDim" + block_type = "CrossAttnDownBlockFlat" if unet_params.with_attn[i] else "DownBlockFlat" down_block_types.append(block_type) if i != len(block_out_channels) - 1: resolution *= 2 up_block_types = [] for i in range(len(block_out_channels)): - block_type = "CrossAttnUpBlockMultiDim" if unet_params.with_attn[-i - 1] else "UpBlockMultiDim" + block_type = "CrossAttnUpBlockFlat" if unet_params.with_attn[-i - 1] else "UpBlockFlat" up_block_types.append(block_type) resolution //= 2 @@ -322,8 +323,8 @@ def create_text_unet_diffusers_config(unet_params): config = dict( sample_size=None, - in_channels=unet_params.input_channels, - out_channels=unet_params.output_channels, + in_channels=(unet_params.input_channels, 1, 1), + out_channels=(unet_params.output_channels, 1, 1), down_block_types=tuple(down_block_types), up_block_types=tuple(up_block_types), block_out_channels=tuple(block_out_channels), @@ -450,6 +451,17 @@ def convert_vd_unet_checkpoint(checkpoint, config, unet_key, extract_ema=False): new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( f"input_blocks.{i}.0.op.bias" ) + elif f"input_blocks.{i}.0.weight" in unet_state_dict: + # text_unet uses linear layers in place of downsamplers + shape = unet_state_dict[f"input_blocks.{i}.0.weight"].shape + if shape[0] != shape[1]: + continue + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.0.bias" + ) paths = renew_resnet_paths(resnets) meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} @@ -512,10 +524,34 @@ def convert_vd_unet_checkpoint(checkpoint, config, unet_key, extract_ema=False): new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ f"output_blocks.{i}.{index}.conv.bias" ] - # Clear attentions as they have been attributed above. if len(attentions) == 2: attentions = [] + elif f"output_blocks.{i}.1.weight" in unet_state_dict: + # text_unet uses linear layers in place of upsamplers + shape = unet_state_dict[f"output_blocks.{i}.1.weight"].shape + if shape[0] != shape[1]: + continue + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.weight"] = unet_state_dict.pop( + f"output_blocks.{i}.1.weight" + ) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.bias"] = unet_state_dict.pop( + f"output_blocks.{i}.1.bias" + ) + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + elif f"output_blocks.{i}.2.weight" in unet_state_dict: + # text_unet uses linear layers in place of upsamplers + shape = unet_state_dict[f"output_blocks.{i}.2.weight"].shape + if shape[0] != shape[1]: + continue + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.weight"] = unet_state_dict.pop( + f"output_blocks.{i}.2.weight" + ) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.bias"] = unet_state_dict.pop( + f"output_blocks.{i}.2.bias" + ) if len(attentions): paths = renew_attention_paths(attentions) @@ -727,7 +763,7 @@ if __name__ == "__main__": converted_text_unet_checkpoint = convert_vd_unet_checkpoint( checkpoint, text_unet_config, unet_key="model.diffusion_model.unet_text.", extract_ema=args.extract_ema ) - text_unet = UNetMultiDimConditionModel(**text_unet_config) + text_unet = UNetFlatConditionModel(**text_unet_config) text_unet.load_state_dict(converted_text_unet_checkpoint) # Convert the VAE model. diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index b6bd0790e0..bedf36d516 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -76,6 +76,7 @@ if is_torch_available() and is_transformers_available(): VersatileDiffusionImageVariationPipeline, VersatileDiffusionPipeline, VersatileDiffusionTextToImagePipeline, + VersatileDiffusionImageToTextPipeline, VQDiffusionPipeline, ) else: diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py index 0432405760..cd7767d10e 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unet_2d.py @@ -175,7 +175,7 @@ class UNet2DModel(ModelMixin, ConfigMixin): num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32) self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps) self.conv_act = nn.SiLU() - self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) def forward( self, diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index c3f2fb87b6..5a02a3ba1e 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -201,7 +201,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): # out self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) self.conv_act = nn.SiLU() - self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) def set_attention_slice(self, slice_size): if slice_size is not None and self.config.attention_head_dim % slice_size != 0: diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 60cde85f79..a87a94a9a5 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -28,6 +28,7 @@ if is_torch_available() and is_transformers_available(): VersatileDiffusionImageVariationPipeline, VersatileDiffusionPipeline, VersatileDiffusionTextToImagePipeline, + VersatileDiffusionImageToTextPipeline, ) from .vq_diffusion import VQDiffusionPipeline diff --git a/src/diffusers/pipelines/versatile_diffusion/__init__.py b/src/diffusers/pipelines/versatile_diffusion/__init__.py index 60257f2728..58822a8f03 100644 --- a/src/diffusers/pipelines/versatile_diffusion/__init__.py +++ b/src/diffusers/pipelines/versatile_diffusion/__init__.py @@ -3,6 +3,9 @@ from ...utils import is_torch_available, is_transformers_available if is_transformers_available() and is_torch_available(): from .modeling_gpt2_optimus import GPT2OptimusForLatentConnector + from .modeling_text_unet import UNetFlatConditionModel + from .pipeline_versatile_diffusion import VersatileDiffusionPipeline from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline + from .pipeline_versatile_diffusion_image_to_text import VersatileDiffusionImageToTextPipeline diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 5c8aecfeb1..5fddb3dca9 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -6,8 +6,8 @@ import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config from ...modeling_utils import ModelMixin -from ...models.embeddings import TimestepEmbedding, Timesteps from ...models.attention import Transformer2DModel +from ...models.embeddings import TimestepEmbedding, Timesteps from ...models.unet_2d_condition import UNet2DConditionOutput from ...utils import logging @@ -15,7 +15,7 @@ from ...utils import logging logger = logging.get_logger(__name__) # pylint: disable=invalid-name -def get_down_block_multi_dim( +def get_down_block( down_block_type, num_layers, in_channels, @@ -23,38 +23,45 @@ def get_down_block_multi_dim( temb_channels, add_downsample, resnet_eps, + resnet_act_fn, attn_num_head_channels, resnet_groups=None, cross_attention_dim=None, + downsample_padding=None, ): down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type - if down_block_type == "DownBlockMultiDim": - return DownBlockMultiDim( + if down_block_type == "DownBlockFlat": + return DownBlockFlat( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, + downsample_padding=downsample_padding, ) - elif down_block_type == "CrossAttnDownBlockMultiDim": + elif down_block_type == "CrossAttnDownBlockFlat": if cross_attention_dim is None: - raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMultiDim") - return CrossAttnDownBlockMultiDim( + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockFlat") + return CrossAttnDownBlockFlat( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, + downsample_padding=downsample_padding, cross_attention_dim=cross_attention_dim, attn_num_head_channels=attn_num_head_channels, ) raise ValueError(f"{down_block_type} is not supported.") -def get_up_block_multi_dim( + +def get_up_block( up_block_type, num_layers, in_channels, @@ -63,13 +70,14 @@ def get_up_block_multi_dim( temb_channels, add_upsample, resnet_eps, + resnet_act_fn, attn_num_head_channels, resnet_groups=None, cross_attention_dim=None, ): up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type - if up_block_type == "UpBlockMultiDim": - return UpBlockMultiDim( + if up_block_type == "UpBlockFlat": + return UpBlockFlat( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, @@ -77,12 +85,13 @@ def get_up_block_multi_dim( temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, ) - elif up_block_type == "CrossAttnUpBlockMultiDim": + elif up_block_type == "CrossAttnUpBlockFlat": if cross_attention_dim is None: - raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMultiDim") - return CrossAttnUpBlockMultiDim( + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockFlat") + return CrossAttnUpBlockFlat( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, @@ -90,6 +99,7 @@ def get_up_block_multi_dim( temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, cross_attention_dim=cross_attention_dim, attn_num_head_channels=attn_num_head_channels, @@ -97,11 +107,11 @@ def get_up_block_multi_dim( raise ValueError(f"{up_block_type} is not supported.") -# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->LMSDiscrete -class UNetMultiDimConditionModel(ModelMixin, ConfigMixin): +# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel with UNet2DConditionModel->UNetFlatConditionModel, nn.Conv2d->LinearMultiDim, Block2D->BlockFlat +class UNetFlatConditionModel(ModelMixin, ConfigMixin): r""" - UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep - and returns sample shaped output. + UNetFlatConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a + timestep and returns sample shaped output. This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library implements for all the models (such as downloading or saving, etc.) @@ -114,9 +124,9 @@ class UNetMultiDimConditionModel(ModelMixin, ConfigMixin): flip_sin_to_cos (`bool`, *optional*, defaults to `True`): Whether to flip the sin to cos in the time embedding. freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. - down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "DownBlockFlat")`): The tuple of downsample blocks to use. - up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`): + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat",)`): The tuple of upsample blocks to use. block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. @@ -142,19 +152,18 @@ class UNetMultiDimConditionModel(ModelMixin, ConfigMixin): flip_sin_to_cos: bool = True, freq_shift: int = 0, down_block_types: Tuple[str] = ( - "CrossAttnDownBlockMultiDim", - "CrossAttnDownBlockMultiDim", - "CrossAttnDownBlockMultiDim", - "DownBlockMultiDim", + "CrossAttnDownBlockFlat", + "CrossAttnDownBlockFlat", + "CrossAttnDownBlockFlat", + "DownBlockFlat", ), up_block_types: Tuple[str] = ( - "UpBlockMultiDim", - "CrossAttnUpBlockMultiDim", - "CrossAttnUpBlockMultiDim", - "CrossAttnUpBlockMultiDim", + "UpBlockFlat", + "CrossAttnUpBlockFlat", + "CrossAttnUpBlockFlat", + "CrossAttnUpBlockFlat", ), block_out_channels: Tuple[int] = (320, 640, 1280, 1280), - block_second_dim: Tuple[int] = (4, 4, 4, 4), layers_per_block: int = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, @@ -170,7 +179,7 @@ class UNetMultiDimConditionModel(ModelMixin, ConfigMixin): time_embed_dim = block_out_channels[0] * 4 # input - self.conv_in = LinearMultiDim([in_channels, 1, 1], [block_out_channels[0], block_second_dim[0], 1]) + self.conv_in = LinearMultiDim(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) # time self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) @@ -187,25 +196,26 @@ class UNetMultiDimConditionModel(ModelMixin, ConfigMixin): for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] - second_dim = block_second_dim[i] is_final_block = i == len(block_out_channels) - 1 - down_block = get_down_block_multi_dim( + down_block = get_down_block( down_block_type, num_layers=layers_per_block, in_channels=input_channel, - out_channels=[output_channel, second_dim, 1], + out_channels=output_channel, temb_channels=time_embed_dim, add_downsample=not is_final_block, resnet_eps=norm_eps, + resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, attn_num_head_channels=attention_head_dim, + downsample_padding=downsample_padding, ) self.down_blocks.append(down_block) # mid - self.mid_block = UNetMidBlockMultiDimCrossAttn( + self.mid_block = UNetMidBlockFlatCrossAttn( in_channels=block_out_channels[-1], temb_channels=time_embed_dim, resnet_eps=norm_eps, @@ -237,7 +247,7 @@ class UNetMultiDimConditionModel(ModelMixin, ConfigMixin): else: add_upsample = False - up_block = get_up_block_multi_dim( + up_block = get_up_block( up_block_type, num_layers=layers_per_block + 1, in_channels=input_channel, @@ -246,6 +256,7 @@ class UNetMultiDimConditionModel(ModelMixin, ConfigMixin): temb_channels=time_embed_dim, add_upsample=add_upsample, resnet_eps=norm_eps, + resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, attn_num_head_channels=attention_head_dim, @@ -256,7 +267,7 @@ class UNetMultiDimConditionModel(ModelMixin, ConfigMixin): # out self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) self.conv_act = nn.SiLU() - self.conv_out = LinearMultiDim(block_out_channels[0], [out_channels, 1, 1]) + self.conv_out = LinearMultiDim(block_out_channels[0], out_channels, kernel_size=3, padding=1) def set_attention_slice(self, slice_size): if slice_size is not None and self.config.attention_head_dim % slice_size != 0: @@ -292,9 +303,7 @@ class UNetMultiDimConditionModel(ModelMixin, ConfigMixin): block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) def _set_gradient_checkpointing(self, module, value=False): - if isinstance( - module, (CrossAttnDownBlockMultiDim, DownBlockMultiDim, CrossAttnUpBlockMultiDim, UpBlockMultiDim) - ): + if isinstance(module, (CrossAttnDownBlockFlat, DownBlockFlat, CrossAttnUpBlockFlat, UpBlockFlat)): module.gradient_checkpointing = value def forward( @@ -308,7 +317,8 @@ class UNetMultiDimConditionModel(ModelMixin, ConfigMixin): Args: sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps - encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states + encoder_hidden_states (`torch.FloatTensor`): + (batch_size, sequence_length, hidden_size) encoder hidden states return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. @@ -410,23 +420,25 @@ class UNetMultiDimConditionModel(ModelMixin, ConfigMixin): class LinearMultiDim(nn.Linear): - def __init__(self, in_features, out_features, second_dim=4, *args, **kwargs): + def __init__(self, in_features, out_features=None, second_dim=4, *args, **kwargs): in_features = [in_features, second_dim, 1] if isinstance(in_features, int) else list(in_features) + if out_features is None: + out_features = in_features out_features = [out_features, second_dim, 1] if isinstance(out_features, int) else list(out_features) self.in_features_multidim = in_features self.out_features_multidim = out_features super().__init__(np.array(in_features).prod(), np.array(out_features).prod()) - def forward(self, x): - shape = x.shape - n = len(self.in_features_multidim) - x = x.view(*shape[0:-n], self.in_features) - y = super().forward(x) - y = y.view(*shape[0:-n], *self.out_features_multidim) - return y + def forward(self, input_tensor, *args, **kwargs): + shape = input_tensor.shape + n_dim = len(self.in_features_multidim) + input_tensor = input_tensor.reshape(*shape[0:-n_dim], self.in_features) + output_tensor = super().forward(input_tensor) + output_tensor = output_tensor.view(*shape[0:-n_dim], *self.out_features_multidim) + return output_tensor -class ResnetBlockMultiDim(nn.Module): +class ResnetBlockFlat(nn.Module): def __init__( self, *, @@ -440,29 +452,31 @@ class ResnetBlockMultiDim(nn.Module): eps=1e-6, time_embedding_norm="default", use_in_shortcut=None, + second_dim=4, + **kwargs, ): super().__init__() self.pre_norm = pre_norm self.pre_norm = True - in_channels = [in_channels] if isinstance(in_channels, int) else list(in_channels) - in_channels_prod = np.array(in_channels).prod() + in_channels = [in_channels, second_dim, 1] if isinstance(in_channels, int) else list(in_channels) + self.in_channels_prod = np.array(in_channels).prod() self.channels_multidim = in_channels if out_channels is not None: - out_channels = [out_channels] if isinstance(out_channels, int) else list(out_channels) + out_channels = [out_channels, second_dim, 1] if isinstance(out_channels, int) else list(out_channels) out_channels_prod = np.array(out_channels).prod() self.out_channels_multidim = out_channels else: - out_channels_prod = in_channels_prod + out_channels_prod = self.in_channels_prod self.out_channels_multidim = self.channels_multidim self.time_embedding_norm = time_embedding_norm if groups_out is None: groups_out = groups - self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels_prod, eps=eps, affine=True) - self.conv1 = torch.nn.Conv2d(in_channels_prod, out_channels_prod, kernel_size=1, padding=0) + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=self.in_channels_prod, eps=eps, affine=True) + self.conv1 = torch.nn.Conv2d(self.in_channels_prod, out_channels_prod, kernel_size=1, padding=0) if temb_channels is not None: self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels_prod) @@ -475,15 +489,20 @@ class ResnetBlockMultiDim(nn.Module): self.nonlinearity = nn.SiLU() - self.use_in_shortcut = in_channels_prod != out_channels_prod if use_in_shortcut is None else use_in_shortcut + self.use_in_shortcut = self.in_channels_prod != out_channels_prod if use_in_shortcut is None else use_in_shortcut self.conv_shortcut = None if self.use_in_shortcut: self.conv_shortcut = torch.nn.Conv2d( - in_channels_prod, out_channels_prod, kernel_size=1, stride=1, padding=0 + self.in_channels_prod, out_channels_prod, kernel_size=1, stride=1, padding=0 ) def forward(self, input_tensor, temb): + shape = input_tensor.shape + n_dim = len(self.channels_multidim) + input_tensor = input_tensor.reshape(*shape[0:-n_dim], self.in_channels_prod, 1, 1) + input_tensor = input_tensor.view(-1, self.in_channels_prod, 1, 1) + hidden_states = input_tensor hidden_states = self.norm1(hidden_states) @@ -505,10 +524,15 @@ class ResnetBlockMultiDim(nn.Module): output_tensor = input_tensor + hidden_states + output_tensor = output_tensor.view(*shape[0:-n_dim], -1) + output_tensor = output_tensor.view(*shape[0:-n_dim], *self.out_channels_multidim) + + print("resblock.output_tensor", output_tensor.abs().sum()) return output_tensor -class DownBlockMultiDim(nn.Module): +# Copied from diffusers.models.unet_2d_blocks.DownBlock2D with DownBlock2D->DownBlockFlat, ResnetBlock2D->ResnetBlockFlat, Downsample2D->LinearMultiDim +class DownBlockFlat(nn.Module): def __init__( self, in_channels: int, @@ -518,9 +542,12 @@ class DownBlockMultiDim(nn.Module): num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, + output_scale_factor=1.0, add_downsample=True, + downsample_padding=1, ): super().__init__() resnets = [] @@ -528,7 +555,7 @@ class DownBlockMultiDim(nn.Module): for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( - ResnetBlockMultiDim( + ResnetBlockFlat( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -536,6 +563,8 @@ class DownBlockMultiDim(nn.Module): groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) @@ -543,7 +572,13 @@ class DownBlockMultiDim(nn.Module): self.resnets = nn.ModuleList(resnets) if add_downsample: - self.downsamplers = nn.ModuleList([LinearMultiDim(out_channels, out_channels)]) + self.downsamplers = nn.ModuleList( + [ + LinearMultiDim( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) else: self.downsamplers = None @@ -576,7 +611,8 @@ class DownBlockMultiDim(nn.Module): return hidden_states, output_states -class CrossAttnDownBlockMultiDim(nn.Module): +# Copied from diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D with CrossAttnDownBlock2D->CrossAttnDownBlockFlat, ResnetBlock2D->ResnetBlockFlat, Downsample2D->LinearMultiDim +class CrossAttnDownBlockFlat(nn.Module): def __init__( self, in_channels: int, @@ -586,11 +622,14 @@ class CrossAttnDownBlockMultiDim(nn.Module): num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, attn_num_head_channels=1, cross_attention_dim=1280, attention_type="default", + output_scale_factor=1.0, + downsample_padding=1, add_downsample=True, ): super().__init__() @@ -603,7 +642,7 @@ class CrossAttnDownBlockMultiDim(nn.Module): for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( - ResnetBlockMultiDim( + ResnetBlockFlat( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -611,14 +650,16 @@ class CrossAttnDownBlockMultiDim(nn.Module): groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) attentions.append( Transformer2DModel( attn_num_head_channels, - out_channels[0] // attn_num_head_channels, - in_channels=out_channels[0], + out_channels // attn_num_head_channels, + in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, @@ -628,7 +669,13 @@ class CrossAttnDownBlockMultiDim(nn.Module): self.resnets = nn.ModuleList(resnets) if add_downsample: - self.downsamplers = nn.ModuleList([LinearMultiDim(out_channels, out_channels)]) + self.downsamplers = nn.ModuleList( + [ + LinearMultiDim( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) else: self.downsamplers = None @@ -687,7 +734,8 @@ class CrossAttnDownBlockMultiDim(nn.Module): return hidden_states, output_states -class UpBlockMultiDim(nn.Module): +# Copied from diffusers.models.unet_2d_blocks.UpBlock2D with UpBlock2D->UpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim +class UpBlockFlat(nn.Module): def __init__( self, in_channels: int, @@ -698,8 +746,10 @@ class UpBlockMultiDim(nn.Module): num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, + output_scale_factor=1.0, add_upsample=True, ): super().__init__() @@ -710,7 +760,7 @@ class UpBlockMultiDim(nn.Module): resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( - ResnetBlockMultiDim( + ResnetBlockFlat( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -718,6 +768,8 @@ class UpBlockMultiDim(nn.Module): groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) @@ -725,7 +777,7 @@ class UpBlockMultiDim(nn.Module): self.resnets = nn.ModuleList(resnets) if add_upsample: - self.upsamplers = nn.ModuleList([LinearMultiDim(out_channels, out_channels)]) + self.upsamplers = nn.ModuleList([LinearMultiDim(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None @@ -757,7 +809,8 @@ class UpBlockMultiDim(nn.Module): return hidden_states -class CrossAttnUpBlockMultiDim(nn.Module): +# Copied from diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D with CrossAttnUpBlock2D->CrossAttnUpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim +class CrossAttnUpBlockFlat(nn.Module): def __init__( self, in_channels: int, @@ -768,11 +821,13 @@ class CrossAttnUpBlockMultiDim(nn.Module): num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, attn_num_head_channels=1, cross_attention_dim=1280, attention_type="default", + output_scale_factor=1.0, add_upsample=True, ): super().__init__() @@ -787,7 +842,7 @@ class CrossAttnUpBlockMultiDim(nn.Module): resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( - ResnetBlockMultiDim( + ResnetBlockFlat( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -795,6 +850,8 @@ class CrossAttnUpBlockMultiDim(nn.Module): groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) @@ -812,7 +869,7 @@ class CrossAttnUpBlockMultiDim(nn.Module): self.resnets = nn.ModuleList(resnets) if add_upsample: - self.upsamplers = nn.ModuleList([LinearMultiDim(out_channels, out_channels)]) + self.upsamplers = nn.ModuleList([LinearMultiDim(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None @@ -879,7 +936,8 @@ class CrossAttnUpBlockMultiDim(nn.Module): return hidden_states -class UNetMidBlockMultiDimCrossAttn(nn.Module): +# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2DCrossAttn with UNetMidBlock2DCrossAttn->UNetMidBlockFlatCrossAttn, ResnetBlock2D->ResnetBlockFlat +class UNetMidBlockFlatCrossAttn(nn.Module): def __init__( self, in_channels: int, @@ -888,10 +946,12 @@ class UNetMidBlockMultiDimCrossAttn(nn.Module): num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, attn_num_head_channels=1, attention_type="default", + output_scale_factor=1.0, cross_attention_dim=1280, **kwargs, ): @@ -903,7 +963,7 @@ class UNetMidBlockMultiDimCrossAttn(nn.Module): # there is always at least one resnet resnets = [ - ResnetBlockMultiDim( + ResnetBlockFlat( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, @@ -911,6 +971,8 @@ class UNetMidBlockMultiDimCrossAttn(nn.Module): groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ] @@ -928,7 +990,7 @@ class UNetMidBlockMultiDimCrossAttn(nn.Module): ) ) resnets.append( - ResnetBlockMultiDim( + ResnetBlockFlat( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, @@ -936,6 +998,8 @@ class UNetMidBlockMultiDimCrossAttn(nn.Module): groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py index 8b8b59bc26..89453edcb1 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py @@ -9,7 +9,8 @@ from ...models import AutoencoderKL, UNet2DConditionModel from ...pipeline_utils import DiffusionPipeline from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...utils import logging -from . import VersatileDiffusionImageVariationPipeline, VersatileDiffusionTextToImagePipeline +from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline +from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_to_text.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_to_text.py index e69de29bb2..5e49cab220 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_to_text.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_to_text.py @@ -0,0 +1,461 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Callable, List, Optional, Union + +import numpy as np +import torch +import torch.utils.checkpoint + +import PIL +from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection, GPT2Tokenizer + +from .modeling_text_unet import UNetFlatConditionModel +from .modeling_gpt2_optimus import GPT2OptimusForLatentConnector +from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.attention import Transformer2DModel +from ...pipeline_utils import DiffusionPipeline, BaseOutput +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...utils import is_accelerate_available, logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class TextPipelineOutput(BaseOutput): + """ + Output class for text generation pipelines. + + Args: + text (`List[str]` or `np.ndarray`) + List of generated text of length `batch_size` or a numpy array of tokens of shape `(batch_size, num_tokens)`. + """ + + text: Union[List[str], np.ndarray] + + +class VersatileDiffusionImageToTextPipeline(DiffusionPipeline): + r""" + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Parameters: + vqvae ([`VQModel`]): + Vector-quantized (VQ) Model to encode and decode images to and from latent representations. + bert ([`LDMBertModel`]): + Text-encoder model based on [BERT](https://huggingface.co/docs/transformers/model_doc/bert) architecture. + tokenizer (`transformers.BertTokenizer`): + Tokenizer of class + [BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + image_feature_extractor: CLIPFeatureExtractor + image_encoder: CLIPVisionModelWithProjection + image_unet: UNet2DConditionModel + text_unet: UNetFlatConditionModel + vae: AutoencoderKL + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] + + def __init__( + self, + image_feature_extractor: CLIPFeatureExtractor, + image_encoder: CLIPVisionModelWithProjection, + image_unet: UNet2DConditionModel, + text_unet: UNetFlatConditionModel, + vae: AutoencoderKL, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + ): + super().__init__() + self.register_modules( + image_feature_extractor=image_feature_extractor, + image_encoder=image_encoder, + image_unet=image_unet, + text_unet=text_unet, + vae=vae, + scheduler=scheduler, + ) + + self.text_vae_decoder = GPT2OptimusForLatentConnector.from_pretrained("fusing/gpt2_optimus") + self.text_vae_tokenizer = GPT2Tokenizer.from_pretrained("fusing/gpt2_optimus") + + def swap_unet_attention_blocks(self): + for name, module in self.image_unet.named_modules(): + if isinstance(module, Transformer2DModel): + parent_name, index = name.rsplit(".", 1) + index = int(index) + self.image_unet.get_submodule(parent_name)[index], self.text_unet.get_submodule(parent_name)[index] = ( + self.text_unet.get_submodule(parent_name)[index], + self.image_unet.get_submodule(parent_name)[index], + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention with unet->image_unet + def enable_xformers_memory_efficient_attention(self): + r""" + Enable memory efficient attention as implemented in xformers. + + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. + """ + self.image_unet.set_use_memory_efficient_attention_xformers(True) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention with unet->image_unet + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.image_unet.set_use_memory_efficient_attention_xformers(False) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing with unet->image_unet + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.image_unet.config.attention_head_dim // 2 + self.image_unet.set_attention_slice(slice_size) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.image_unet, self.text_unet, self.text_encoder, self.vae]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device with unet->image_unet + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.image_unet, "_hf_hook"): + return self.device + for module in self.image_unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + """ + + def normalize_embeddings(encoder_output): + embeds = self.image_encoder.vision_model.post_layernorm(encoder_output.last_hidden_state) + embeds = self.image_encoder.visual_projection(embeds) + embeds_pooled = embeds[:, 0:1] + embeds = embeds / torch.norm(embeds_pooled, dim=-1, keepdim=True) + return embeds + + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + # get prompt text embeddings + image_input = self.image_feature_extractor(images=prompt, return_tensors="pt") + image_embeddings = self.image_encoder(image_input.pixel_values.to(self.device)) + image_embeddings = normalize_embeddings(image_embeddings) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = image_embeddings.shape + image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) + image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_images: List[str] + if negative_prompt is None: + uncond_images = [np.zeros((512, 512, 3))] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, PIL.Image.Image): + uncond_images = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_images = negative_prompt + + uncond_images = self.image_feature_extractor(images=uncond_images, return_tensors="pt") + uncond_embeddings = self.image_encoder(uncond_images.pixel_values.to(self.device)) + uncond_embeddings = normalize_embeddings(uncond_embeddings) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and conditional embeddings into a single batch + # to avoid doing two forward passes + image_embeddings = torch.cat([uncond_embeddings, image_embeddings]) + + return image_embeddings + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + latents = latents.reshape(latents.shape[:-2]) + self.text_vae_decoder = self.text_vae_decoder.to(self._execution_device) + bos_token = self.text_vae_tokenizer.bos_token_id + output = self.text_vae_decoder.generate(bos_token_id=bos_token, past=latents) + return output + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs(self, image, callback_steps): + if not isinstance(image, PIL.Image.Image) and not isinstance(image, torch.Tensor): + raise ValueError(f"`image` has to be of type `PIL.Image.Image` or `torch.Tensor` but is {type(image)}") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, 1, 1) + if latents is None: + if device.type == "mps": + # randn does not work reproducibly on mps + latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + else: + latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor], + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "str", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`PIL.Image.Image`, `List[PIL.Image.Image]` or `torch.Tensor`): + The image prompt or prompts to guide the image generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + # 1. Check inputs. Raise error if not correct + self.check_inputs(image, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(image, PIL.Image.Image) else len(image) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + image_embeddings = self._encode_prompt( + image, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.text_unet.in_channels[0] + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + image_embeddings.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Swap the attention blocks between the image and text UNets + self.swap_unet_attention_blocks() + + # 8. Denoising loop + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.text_unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 9. Swap the attention blocks backs in case the UNets are reused in another pipeline + self.swap_unet_attention_blocks() + + # 10. Post-processing + text = self.decode_latents(latents) + + # 11. Convert to strings + if output_type == "str": + text = self.text_vae_tokenizer.decode(text) + + if not return_dict: + return (text,) + + return TextPipelineOutput(text=text) diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py index e37010d1f8..bf764f47ae 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py @@ -22,8 +22,7 @@ import torch.utils.checkpoint import PIL from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection -from ...models import AutoencoderKL, UNet2DConditionModel, VQModel -from ...models.attention import Transformer2DModel +from ...models import AutoencoderKL, UNet2DConditionModel from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...utils import is_accelerate_available, logging @@ -73,16 +72,6 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): scheduler=scheduler, ) - def swap_unet_attention_blocks(self): - for name, module in self.image_unet.named_modules(): - if isinstance(module, Transformer2DModel): - parent_name, index = name.rsplit(".", 1) - index = int(index) - self.image_unet.get_submodule(parent_name)[index], self.text_unet.get_submodule(parent_name)[index] = ( - self.text_unet.get_submodule(parent_name)[index], - self.image_unet.get_submodule(parent_name)[index], - ) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention with unet->image_unet def enable_xformers_memory_efficient_attention(self): r""" diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py index 991b58c357..d28c88cb29 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py @@ -20,7 +20,8 @@ import torch.utils.checkpoint from transformers import CLIPFeatureExtractor, CLIPTextModelWithProjection, CLIPTokenizer -from ...models import AutoencoderKL, UNet2DConditionModel, VQModel +from .modeling_text_unet import UNetFlatConditionModel +from ...models import UNet2DConditionModel, AutoencoderKL from ...models.attention import Transformer2DModel from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler @@ -52,7 +53,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): image_feature_extractor: CLIPFeatureExtractor text_encoder: CLIPTextModelWithProjection image_unet: UNet2DConditionModel - text_unet: UNet2DConditionModel + text_unet: UNetFlatConditionModel vae: AutoencoderKL scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] @@ -61,8 +62,8 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): tokenizer: CLIPTokenizer, text_encoder: CLIPTextModelWithProjection, image_unet: UNet2DConditionModel, - text_unet: UNet2DConditionModel, - vae: Union[VQModel, AutoencoderKL], + text_unet: UNetFlatConditionModel, + vae: AutoencoderKL, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], ): super().__init__() diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 2ad0ead440..7aa12e46b4 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -139,6 +139,21 @@ class VersatileDiffusionImageVariationPipeline(metaclass=DummyObject): requires_backends(cls, ["torch", "transformers"]) +class VersatileDiffusionPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class VersatileDiffusionTextToImagePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/versatile_diffusion/test_versatile_diffusion_image_to_text.py b/tests/pipelines/versatile_diffusion/test_versatile_diffusion_image_to_text.py new file mode 100644 index 0000000000..f8ec184c77 --- /dev/null +++ b/tests/pipelines/versatile_diffusion/test_versatile_diffusion_image_to_text.py @@ -0,0 +1,56 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch + +from diffusers import VersatileDiffusionImageToTextPipeline +from diffusers.utils.testing_utils import load_image, require_torch_gpu, slow, torch_device + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class VersatileDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pass + + +@slow +@require_torch_gpu +class VersatileDiffusionImageVariationPipelineIntegrationTests(unittest.TestCase): + def test_inference_image_to_text(self): + pipe = VersatileDiffusionImageToTextPipeline.from_pretrained("scripts/vd_official") + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + image_prompt = load_image( + "https://raw.githubusercontent.com/SHI-Labs/Versatile-Diffusion/master/assets/benz.jpg" + ) + generator = torch.Generator(device=torch_device).manual_seed(0) + tokens = pipe( + image=image_prompt, + generator=generator, + guidance_scale=7.5, + num_inference_steps=50, + output_type="numpy", + ).text + + assert tokens.shape == (1, 30) + expected_tokens = np.array([0, 1, 2, 3, 4, 5, 6, 7]) + assert self.assertItemsEqual(tokens[0] , expected_tokens)