diff --git a/scripts/convert_versatile_diffusion_to_diffusers.py b/scripts/convert_versatile_diffusion_to_diffusers.py index abdcc495ce..ca54f05f05 100644 --- a/scripts/convert_versatile_diffusion_to_diffusers.py +++ b/scripts/convert_versatile_diffusion_to_diffusers.py @@ -37,6 +37,7 @@ from transformers import ( CLIPTokenizer, CLIPVisionModelWithProjection, ) +from diffusers.pipelines.versatile_diffusion.modeling_text_unet import UNetMultiDimConditionModel SCHEDULER_CONFIG = Namespace( @@ -256,9 +257,9 @@ def conv_attn_to_linear(checkpoint): checkpoint[key] = checkpoint[key][:, :, 0] -def create_unet_diffusers_config(unet_params): +def create_image_unet_diffusers_config(unet_params): """ - Creates a config for the diffusers based on the config of the LDM model. + Creates a config for the diffusers based on the config of the VD model. """ block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] @@ -295,6 +296,45 @@ def create_unet_diffusers_config(unet_params): return config +def create_text_unet_diffusers_config(unet_params): + """ + Creates a config for the diffusers based on the config of the VD model. + """ + + block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = "CrossAttnDownBlockMultiDim" if unet_params.with_attn[i] else "DownBlockMultiDim" + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = "CrossAttnUpBlockMultiDim" if unet_params.with_attn[-i - 1] else "UpBlockMultiDim" + up_block_types.append(block_type) + resolution //= 2 + + if not all(n == unet_params.num_noattn_blocks[0] for n in unet_params.num_noattn_blocks): + raise ValueError("Not all num_res_blocks are equal, which is not supported in this script.") + + config = dict( + sample_size=None, + in_channels=unet_params.input_channels, + out_channels=unet_params.output_channels, + down_block_types=tuple(down_block_types), + up_block_types=tuple(up_block_types), + block_out_channels=tuple(block_out_channels), + layers_per_block=unet_params.num_noattn_blocks[0], + cross_attention_dim=unet_params.context_dim, + attention_head_dim=unet_params.num_heads, + ) + + return config + + def create_vae_diffusers_config(vae_params): """ Creates a config for the diffusers based on the config of the VD model. @@ -674,7 +714,7 @@ if __name__ == "__main__": # Convert the UNet2DConditionModel models. if args.unet_checkpoint_path is not None: # image UNet - image_unet_config = create_unet_diffusers_config(IMAGE_UNET_CONFIG) + image_unet_config = create_image_unet_diffusers_config(IMAGE_UNET_CONFIG) checkpoint = torch.load(args.unet_checkpoint_path) converted_image_unet_checkpoint = convert_vd_unet_checkpoint( checkpoint, image_unet_config, unet_key="model.diffusion_model.unet_image.", extract_ema=args.extract_ema @@ -683,28 +723,12 @@ if __name__ == "__main__": image_unet.load_state_dict(converted_image_unet_checkpoint) # text UNet - text_unet_config = create_unet_diffusers_config(TEXT_UNET_CONFIG) + text_unet_config = create_text_unet_diffusers_config(TEXT_UNET_CONFIG) converted_text_unet_checkpoint = convert_vd_unet_checkpoint( checkpoint, text_unet_config, unet_key="model.diffusion_model.unet_text.", extract_ema=args.extract_ema ) - text_unet = UNet2DConditionModel(**text_unet_config) - # TEMP hack to skip converting the 1x1 blocks for the text unet - del converted_text_unet_checkpoint["conv_in.weight"] - del converted_text_unet_checkpoint["conv_in.bias"] - del converted_text_unet_checkpoint["conv_out.weight"] - for block in ["down_blocks", "mid_block", "up_blocks"]: - for i in range(4): - for j in range(3): - for module in ["time_emb_proj", "conv1", "norm1", "conv2", "norm2", "conv_shortcut"]: - for type in ["weight", "bias"]: - if block == "mid_block": - key = f"{block}.resnets.{j}.{module}.{type}" - else: - key = f"{block}.{i}.resnets.{j}.{module}.{type}" - if key in converted_text_unet_checkpoint: - del converted_text_unet_checkpoint[key] - # END TEMP hack - text_unet.load_state_dict(converted_text_unet_checkpoint, strict=False) + text_unet = UNetMultiDimConditionModel(**text_unet_config) + text_unet.load_state_dict(converted_text_unet_checkpoint) # Convert the VAE model. if args.vae_checkpoint_path is not None: diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py new file mode 100644 index 0000000000..5c8aecfeb1 --- /dev/null +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -0,0 +1,971 @@ +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...modeling_utils import ModelMixin +from ...models.embeddings import TimestepEmbedding, Timesteps +from ...models.attention import Transformer2DModel +from ...models.unet_2d_condition import UNet2DConditionOutput +from ...utils import logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def get_down_block_multi_dim( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, +): + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownBlockMultiDim": + return DownBlockMultiDim( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_groups=resnet_groups, + ) + elif down_block_type == "CrossAttnDownBlockMultiDim": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMultiDim") + return CrossAttnDownBlockMultiDim( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + ) + raise ValueError(f"{down_block_type} is not supported.") + +def get_up_block_multi_dim( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, +): + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpBlockMultiDim": + return UpBlockMultiDim( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_groups=resnet_groups, + ) + elif up_block_type == "CrossAttnUpBlockMultiDim": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMultiDim") + return CrossAttnUpBlockMultiDim( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + ) + raise ValueError(f"{up_block_type} is not supported.") + + +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->LMSDiscrete +class UNetMultiDimConditionModel(ModelMixin, ConfigMixin): + r""" + UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep + and returns sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the models (such as downloading or saving, etc.) + + Parameters: + sample_size (`int`, *optional*): The size of the input sample. + in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + flip_sin_to_cos (`bool`, *optional*, defaults to `True`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`): + The tuple of upsample blocks to use. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlockMultiDim", + "CrossAttnDownBlockMultiDim", + "CrossAttnDownBlockMultiDim", + "DownBlockMultiDim", + ), + up_block_types: Tuple[str] = ( + "UpBlockMultiDim", + "CrossAttnUpBlockMultiDim", + "CrossAttnUpBlockMultiDim", + "CrossAttnUpBlockMultiDim", + ), + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + block_second_dim: Tuple[int] = (4, 4, 4, 4), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: int = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + attention_head_dim: int = 8, + ): + super().__init__() + + self.sample_size = sample_size + time_embed_dim = block_out_channels[0] * 4 + + # input + self.conv_in = LinearMultiDim([in_channels, 1, 1], [block_out_channels[0], block_second_dim[0], 1]) + + # time + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + second_dim = block_second_dim[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block_multi_dim( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=[output_channel, second_dim, 1], + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlockMultiDimCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift="default", + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim, + resnet_groups=norm_num_groups, + ) + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block_multi_dim( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) + self.conv_act = nn.SiLU() + self.conv_out = LinearMultiDim(block_out_channels[0], [out_channels, 1, 1]) + + def set_attention_slice(self, slice_size): + if slice_size is not None and self.config.attention_head_dim % slice_size != 0: + raise ValueError( + f"Make sure slice_size {slice_size} is a divisor of " + f"the number of heads used in cross_attention {self.config.attention_head_dim}" + ) + if slice_size is not None and slice_size > self.config.attention_head_dim: + raise ValueError( + f"Chunk_size {slice_size} has to be smaller or equal to " + f"the number of heads used in cross_attention {self.config.attention_head_dim}" + ) + + for block in self.down_blocks: + if hasattr(block, "attentions") and block.attentions is not None: + block.set_attention_slice(slice_size) + + self.mid_block.set_attention_slice(slice_size) + + for block in self.up_blocks: + if hasattr(block, "attentions") and block.attentions is not None: + block.set_attention_slice(slice_size) + + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + for block in self.down_blocks: + if hasattr(block, "attentions") and block.attentions is not None: + block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + + self.mid_block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + + for block in self.up_blocks: + if hasattr(block, "attentions") and block.attentions is not None: + block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance( + module, (CrossAttnDownBlockMultiDim, DownBlockMultiDim, CrossAttnUpBlockMultiDim, UpBlockMultiDim) + ): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + return_dict: bool = True, + ) -> Union[UNet2DConditionOutput, Tuple]: + r""" + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb) + + # 2. pre-process + sample = self.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 4. mid + sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + # 6. post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + + return UNet2DConditionOutput(sample=sample) + + +class LinearMultiDim(nn.Linear): + def __init__(self, in_features, out_features, second_dim=4, *args, **kwargs): + in_features = [in_features, second_dim, 1] if isinstance(in_features, int) else list(in_features) + out_features = [out_features, second_dim, 1] if isinstance(out_features, int) else list(out_features) + self.in_features_multidim = in_features + self.out_features_multidim = out_features + super().__init__(np.array(in_features).prod(), np.array(out_features).prod()) + + def forward(self, x): + shape = x.shape + n = len(self.in_features_multidim) + x = x.view(*shape[0:-n], self.in_features) + y = super().forward(x) + y = y.view(*shape[0:-n], *self.out_features_multidim) + return y + + +class ResnetBlockMultiDim(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + dropout=0.0, + temb_channels=512, + groups=32, + groups_out=None, + pre_norm=True, + eps=1e-6, + time_embedding_norm="default", + use_in_shortcut=None, + ): + super().__init__() + self.pre_norm = pre_norm + self.pre_norm = True + + in_channels = [in_channels] if isinstance(in_channels, int) else list(in_channels) + in_channels_prod = np.array(in_channels).prod() + self.channels_multidim = in_channels + + if out_channels is not None: + out_channels = [out_channels] if isinstance(out_channels, int) else list(out_channels) + out_channels_prod = np.array(out_channels).prod() + self.out_channels_multidim = out_channels + else: + out_channels_prod = in_channels_prod + self.out_channels_multidim = self.channels_multidim + self.time_embedding_norm = time_embedding_norm + + if groups_out is None: + groups_out = groups + + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels_prod, eps=eps, affine=True) + self.conv1 = torch.nn.Conv2d(in_channels_prod, out_channels_prod, kernel_size=1, padding=0) + + if temb_channels is not None: + self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels_prod) + else: + self.time_emb_proj = None + + self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels_prod, eps=eps, affine=True) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels_prod, out_channels_prod, kernel_size=1, padding=0) + + self.nonlinearity = nn.SiLU() + + self.use_in_shortcut = in_channels_prod != out_channels_prod if use_in_shortcut is None else use_in_shortcut + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = torch.nn.Conv2d( + in_channels_prod, out_channels_prod, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, input_tensor, temb): + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv1(hidden_states) + + if temb is not None: + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] + hidden_states = hidden_states + temb + + hidden_states = self.norm2(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = input_tensor + hidden_states + + return output_tensor + + +class DownBlockMultiDim(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + add_downsample=True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlockMultiDim( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList([LinearMultiDim(out_channels, out_channels)]) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None): + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnDownBlockMultiDim(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + attention_type="default", + add_downsample=True, + ): + super().__init__() + resnets = [] + attentions = [] + + self.attention_type = attention_type + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlockMultiDim( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + Transformer2DModel( + attn_num_head_channels, + out_channels[0] // attn_num_head_channels, + in_channels=out_channels[0], + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList([LinearMultiDim(out_channels, out_channels)]) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def set_attention_slice(self, slice_size): + if slice_size is not None and self.attn_num_head_channels % slice_size != 0: + raise ValueError( + f"Make sure slice_size {slice_size} is a divisor of " + f"the number of heads used in cross_attention {self.attn_num_head_channels}" + ) + if slice_size is not None and slice_size > self.attn_num_head_channels: + raise ValueError( + f"Chunk_size {slice_size} has to be smaller or equal to " + f"the number of heads used in cross_attention {self.attn_num_head_channels}" + ) + + for attn in self.attentions: + attn._set_attention_slice(slice_size) + + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + for attn in self.attentions: + attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class UpBlockMultiDim(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + add_upsample=True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlockMultiDim( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([LinearMultiDim(out_channels, out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class CrossAttnUpBlockMultiDim(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + attention_type="default", + add_upsample=True, + ): + super().__init__() + resnets = [] + attentions = [] + + self.attention_type = attention_type + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlockMultiDim( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + Transformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([LinearMultiDim(out_channels, out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def set_attention_slice(self, slice_size): + if slice_size is not None and self.attn_num_head_channels % slice_size != 0: + raise ValueError( + f"Make sure slice_size {slice_size} is a divisor of " + f"the number of heads used in cross_attention {self.attn_num_head_channels}" + ) + if slice_size is not None and slice_size > self.attn_num_head_channels: + raise ValueError( + f"Chunk_size {slice_size} has to be smaller or equal to " + f"the number of heads used in cross_attention {self.attn_num_head_channels}" + ) + + for attn in self.attentions: + attn._set_attention_slice(slice_size) + + self.gradient_checkpointing = False + + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + for attn in self.attentions: + attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + upsample_size=None, + ): + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UNetMidBlockMultiDimCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + attention_type="default", + cross_attention_dim=1280, + **kwargs, + ): + super().__init__() + + self.attention_type = attention_type + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlockMultiDim( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for _ in range(num_layers): + attentions.append( + Transformer2DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlockMultiDim( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def set_attention_slice(self, slice_size): + if slice_size is not None and self.attn_num_head_channels % slice_size != 0: + raise ValueError( + f"Make sure slice_size {slice_size} is a divisor of " + f"the number of heads used in cross_attention {self.attn_num_head_channels}" + ) + if slice_size is not None and slice_size > self.attn_num_head_channels: + raise ValueError( + f"Chunk_size {slice_size} has to be smaller or equal to " + f"the number of heads used in cross_attention {self.attn_num_head_channels}" + ) + + for attn in self.attentions: + attn._set_attention_slice(slice_size) + + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + for attn in self.attentions: + attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn(hidden_states, encoder_hidden_states).sample + hidden_states = resnet(hidden_states, temb) + + return hidden_states diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_to_text.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_to_text.py index 4d2441103b..e69de29bb2 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_to_text.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_to_text.py @@ -1,408 +0,0 @@ -from typing import Optional, Tuple, Union - -import numpy as np -import torch -import torch.nn as nn - -from ...configuration_utils import ConfigMixin, register_to_config -from ...modeling_utils import ModelMixin -from ...models.embeddings import TimestepEmbedding, Timesteps -from ...models.unet_2d_condition import UNet2DConditionOutput -from ...utils import logging - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -class UNetMultiDimConditionModel(ModelMixin, ConfigMixin): - r""" - UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep - and returns sample shaped output. - - This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library - implements for all the models (such as downloading or saving, etc.) - - Parameters: - sample_size (`int`, *optional*): The size of the input sample. - in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. - out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. - center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. - flip_sin_to_cos (`bool`, *optional*, defaults to `True`): - Whether to flip the sin to cos in the time embedding. - freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. - down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): - The tuple of downsample blocks to use. - up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`): - The tuple of upsample blocks to use. - block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): - The tuple of output channels for each block. - layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. - downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. - mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. - act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. - norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. - norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. - cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. - attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. - """ - - _supports_gradient_checkpointing = True - - @register_to_config - def __init__( - self, - sample_size: Optional[int] = None, - in_channels: int = 4, - out_channels: int = 4, - center_input_sample: bool = False, - flip_sin_to_cos: bool = True, - freq_shift: int = 0, - down_block_types: Tuple[str] = ( - "CrossAttnDownBlockMultiDim", - "CrossAttnDownBlockMultiDim", - "CrossAttnDownBlockMultiDim", - "DownBlockMultiDim", - ), - up_block_types: Tuple[str] = ( - "UpBlockMultiDim", - "CrossAttnUpBlockMultiDim", - "CrossAttnUpBlockMultiDim", - "CrossAttnUpBlockMultiDim", - ), - block_out_channels: Tuple[int] = (320, 640, 1280, 1280), - block_second_dim: Tuple[int] = (4, 4, 4, 4), - layers_per_block: int = 2, - downsample_padding: int = 1, - mid_block_scale_factor: float = 1, - act_fn: str = "silu", - norm_num_groups: int = 32, - norm_eps: float = 1e-5, - cross_attention_dim: int = 1280, - attention_head_dim: int = 8, - ): - super().__init__() - - self.sample_size = sample_size - time_embed_dim = block_out_channels[0] * 4 - - # input - self.conv_in = LinearMultiDim([in_channels, 1, 1], block_out_channels[0], kernel_size=3, padding=(1, 1)) - - # time - self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) - timestep_input_dim = block_out_channels[0] - - self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) - - self.down_blocks = nn.ModuleList([]) - self.mid_block = None - self.up_blocks = nn.ModuleList([]) - - # down - output_channel = block_out_channels[0] - for i, down_block_type in enumerate(down_block_types): - input_channel = output_channel - output_channel = block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 - - down_block = self.get_down_block( - down_block_type, - num_layers=layers_per_block, - in_channels=input_channel, - out_channels=output_channel, - temb_channels=time_embed_dim, - add_downsample=not is_final_block, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim, - downsample_padding=downsample_padding, - ) - self.down_blocks.append(down_block) - - # mid - self.mid_block = UNetMidBlockMultiDimCrossAttn( - in_channels=block_out_channels[-1], - temb_channels=time_embed_dim, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - output_scale_factor=mid_block_scale_factor, - resnet_time_scale_shift="default", - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim, - resnet_groups=norm_num_groups, - ) - - # count how many layers upsample the images - self.num_upsamplers = 0 - - # up - reversed_block_out_channels = list(reversed(block_out_channels)) - output_channel = reversed_block_out_channels[0] - for i, up_block_type in enumerate(up_block_types): - is_final_block = i == len(block_out_channels) - 1 - - prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i] - input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] - - # add upsample block for all BUT final layer - if not is_final_block: - add_upsample = True - self.num_upsamplers += 1 - else: - add_upsample = False - - up_block = self.get_up_block( - up_block_type, - num_layers=layers_per_block + 1, - in_channels=input_channel, - out_channels=output_channel, - prev_output_channel=prev_output_channel, - temb_channels=time_embed_dim, - add_upsample=add_upsample, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim, - ) - self.up_blocks.append(up_block) - prev_output_channel = output_channel - - # out - self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) - self.conv_act = nn.SiLU() - self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) - - def get_down_block( - down_block_type, - num_layers, - in_channels, - out_channels, - temb_channels, - add_downsample, - resnet_eps, - resnet_act_fn, - attn_num_head_channels, - resnet_groups=None, - cross_attention_dim=None, - downsample_padding=None, - ): - down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type - if down_block_type == "DownBlockMultiDim": - return DownBlockMultiDim( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - downsample_padding=downsample_padding, - ) - elif down_block_type == "CrossAttnDownBlockMultiDim": - if cross_attention_dim is None: - raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D") - return CrossAttnDownBlockMultiDim( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - downsample_padding=downsample_padding, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attn_num_head_channels, - ) - - def set_attention_slice(self, slice_size): - if slice_size is not None and self.config.attention_head_dim % slice_size != 0: - raise ValueError( - f"Make sure slice_size {slice_size} is a divisor of " - f"the number of heads used in cross_attention {self.config.attention_head_dim}" - ) - if slice_size is not None and slice_size > self.config.attention_head_dim: - raise ValueError( - f"Chunk_size {slice_size} has to be smaller or equal to " - f"the number of heads used in cross_attention {self.config.attention_head_dim}" - ) - - for block in self.down_blocks: - if hasattr(block, "attentions") and block.attentions is not None: - block.set_attention_slice(slice_size) - - self.mid_block.set_attention_slice(slice_size) - - for block in self.up_blocks: - if hasattr(block, "attentions") and block.attentions is not None: - block.set_attention_slice(slice_size) - - def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): - for block in self.down_blocks: - if hasattr(block, "attentions") and block.attentions is not None: - block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) - - self.mid_block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) - - for block in self.up_blocks: - if hasattr(block, "attentions") and block.attentions is not None: - block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance( - module, (CrossAttnDownBlockMultiDim, DownBlockMultiDim, CrossAttnUpBlockMultiDim, UpBlockMultiDim) - ): - module.gradient_checkpointing = value - - def forward( - self, - sample: torch.FloatTensor, - timestep: Union[torch.Tensor, float, int], - encoder_hidden_states: torch.Tensor, - return_dict: bool = True, - ) -> Union[UNet2DConditionOutput, Tuple]: - r""" - Args: - sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor - timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps - encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. - - Returns: - [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: - [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. - """ - # By default samples have to be AT least a multiple of the overall upsampling factor. - # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). - # However, the upsampling interpolation output size can be forced to fit any upsampling size - # on the fly if necessary. - default_overall_up_factor = 2**self.num_upsamplers - - # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` - forward_upsample_size = False - upsample_size = None - - if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): - logger.info("Forward upsample size to force interpolation output size.") - forward_upsample_size = True - - # 0. center input if necessary - if self.config.center_input_sample: - sample = 2 * sample - 1.0 - - # 1. time - timesteps = timestep - if not torch.is_tensor(timesteps): - # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can - timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) - elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps.expand(sample.shape[0]) - - t_emb = self.time_proj(timesteps) - - # timesteps does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=self.dtype) - emb = self.time_embedding(t_emb) - - # 2. pre-process - sample = self.conv_in(sample) - - # 3. down - down_block_res_samples = (sample,) - for downsample_block in self.down_blocks: - if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None: - sample, res_samples = downsample_block( - hidden_states=sample, - temb=emb, - encoder_hidden_states=encoder_hidden_states, - ) - else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb) - - down_block_res_samples += res_samples - - # 4. mid - sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) - - # 5. up - for i, upsample_block in enumerate(self.up_blocks): - is_final_block = i == len(self.up_blocks) - 1 - - res_samples = down_block_res_samples[-len(upsample_block.resnets) :] - down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] - - # if we have not reached the final block and need to forward the - # upsample size, we do it here - if not is_final_block and forward_upsample_size: - upsample_size = down_block_res_samples[-1].shape[2:] - - if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None: - sample = upsample_block( - hidden_states=sample, - temb=emb, - res_hidden_states_tuple=res_samples, - encoder_hidden_states=encoder_hidden_states, - upsample_size=upsample_size, - ) - else: - sample = upsample_block( - hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size - ) - # 6. post-process - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - - if not return_dict: - return (sample,) - - return UNet2DConditionOutput(sample=sample) - - -class LinearMultiDim(nn.Linear): - def __init__(self, in_features, out_features, *args, **kwargs): - in_features = [in_features] if isinstance(in_features, int) else list(in_features) - out_features = [out_features] if isinstance(out_features, int) else list(out_features) - self.in_features_multidim = in_features - self.out_features_multidim = out_features - super().__init__(np.array(in_features).prod(), np.array(out_features).prod(), *args, **kwargs) - - def forward(self, x): - shape = x.shape - n = len(self.in_features_multidim) - x = x.view(*shape[0:-n], self.in_features) - y = super().forward(x) - y = y.view(*shape[0:-n], *self.out_features_multidim) - return y - - -class UNetMidBlockMultiDimCrossAttn(nn.Module): - pass - - -class DownBlockMultiDim(nn.Module): - pass - - -class CrossAttnDownBlockMultiDim(nn.Module): - pass - - -class UpBlockMultiDim(nn.Module): - pass - - -class CrossAttnUpBlockMultiDim(nn.Module): - pass