1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

add transformers depth

This commit is contained in:
Patrick von Platen
2023-06-23 12:16:51 +02:00
parent 57b8406ef0
commit 39b0b97aac
3 changed files with 36 additions and 9 deletions

View File

@@ -38,6 +38,7 @@ def get_down_block(
add_downsample,
resnet_eps,
resnet_act_fn,
num_transformer_blocks=1,
num_attention_heads=None,
resnet_groups=None,
cross_attention_dim=None,
@@ -106,6 +107,7 @@ def get_down_block(
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
return CrossAttnDownBlock2D(
num_layers=num_layers,
num_transformer_blocks=num_transformer_blocks,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
@@ -227,6 +229,7 @@ def get_up_block(
add_upsample,
resnet_eps,
resnet_act_fn,
num_transformer_blocks=1,
num_attention_heads=None,
resnet_groups=None,
cross_attention_dim=None,
@@ -281,6 +284,7 @@ def get_up_block(
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
return CrossAttnUpBlock2D(
num_layers=num_layers,
num_transformer_blocks=num_transformer_blocks,
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
@@ -506,6 +510,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
num_transformer_blocks: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
@@ -548,7 +553,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
num_attention_heads,
in_channels // num_attention_heads,
in_channels=in_channels,
num_layers=1,
num_layers=num_transformer_blocks,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
@@ -829,6 +834,7 @@ class CrossAttnDownBlock2D(nn.Module):
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
num_transformer_blocks: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
@@ -873,7 +879,7 @@ class CrossAttnDownBlock2D(nn.Module):
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=1,
num_layers=num_transformer_blocks,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
@@ -1939,6 +1945,7 @@ class CrossAttnUpBlock2D(nn.Module):
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
num_transformer_blocks: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
@@ -1984,7 +1991,7 @@ class CrossAttnUpBlock2D(nn.Module):
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=1,
num_layers=num_transformer_blocks,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,

View File

@@ -96,6 +96,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
The dimension of the cross attention features.
num_transformer_blocks (`int` or `Tuple[int]`, *optional*, defaults to 1):
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
encoder_hid_dim (`int`, *optional*, defaults to None):
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
dimension to `cross_attention_dim`.
@@ -168,6 +170,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
cross_attention_dim: Union[int, Tuple[int]] = 1280,
num_transformer_blocks: Union[int, Tuple[int]] = 1,
encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None,
attention_head_dim: Union[int, Tuple[int]] = 8,
@@ -381,6 +384,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
if isinstance(layers_per_block, int):
layers_per_block = [layers_per_block] * len(down_block_types)
if isinstance(num_transformer_blocks, int):
num_transformer_blocks = [num_transformer_blocks] * len(down_block_types)
if class_embeddings_concat:
# The time embeddings are concatenated with the class embeddings. The dimension of the
# time embeddings passed to the down, middle, and up blocks is twice the dimension of the
@@ -399,6 +405,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
down_block = get_down_block(
down_block_type,
num_layers=layers_per_block[i],
num_transformer_blocks=num_transformer_blocks[i],
in_channels=input_channel,
out_channels=output_channel,
temb_channels=blocks_time_embed_dim,
@@ -424,6 +431,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
# mid
if mid_block_type == "UNetMidBlock2DCrossAttn":
self.mid_block = UNetMidBlock2DCrossAttn(
num_transformer_blocks=num_transformer_blocks[-1],
in_channels=block_out_channels[-1],
temb_channels=blocks_time_embed_dim,
resnet_eps=norm_eps,
@@ -465,6 +473,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
reversed_num_attention_heads = list(reversed(num_attention_heads))
reversed_layers_per_block = list(reversed(layers_per_block))
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
reversed_num_transformer_blocks = list(reversed(num_transformer_blocks))
only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0]
@@ -485,6 +494,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
up_block = get_up_block(
up_block_type,
num_layers=reversed_layers_per_block[i] + 1,
num_transformer_blocks=reversed_num_transformer_blocks[i],
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,

View File

@@ -233,7 +233,10 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
if controlnet:
unet_params = original_config.model.params.control_stage_config.params
else:
unet_params = original_config.model.params.unet_config.params
if original_config.model.params.unet_config is not None:
unet_params = original_config.model.params.unet_config.params
else:
unet_params = original_config.model.params.network_config.params
vae_params = original_config.model.params.first_stage_config.params.ddconfig
@@ -253,6 +256,11 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
up_block_types.append(block_type)
resolution //= 2
if unet_params.transformer_depth is not None:
num_transformer_blocks = unet_params.transformer_depth if isinstance(unet_params.transformer_depth, int) else list(unet_params.transformer_depth)
else:
num_transformer_blocks = 1
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
head_dim = unet_params.num_heads if "num_heads" in unet_params else None
@@ -262,7 +270,7 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
if use_linear_projection:
# stable diffusion 2-base-512 and 2-768
if head_dim is None:
head_dim = [5, 10, 20, 20]
head_dim = [5 * c for c in list(unet_params.channel_mult)]
class_embed_type = None
projection_class_embeddings_input_dim = None
@@ -286,6 +294,7 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
"use_linear_projection": use_linear_projection,
"class_embed_type": class_embed_type,
"projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
"num_transformer_blocks": num_transformer_blocks,
}
if controlnet:
@@ -1172,9 +1181,9 @@ def download_from_original_stable_diffusion_ckpt(
checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
)
num_train_timesteps = original_config.model.params.timesteps
beta_start = original_config.model.params.linear_start
beta_end = original_config.model.params.linear_end
num_train_timesteps = original_config.model.params.timesteps or 1000
beta_start = original_config.model.params.linear_start or 0.02
beta_end = original_config.model.params.linear_end or 0.085
scheduler = DDIMScheduler(
beta_end=beta_end,
@@ -1216,8 +1225,9 @@ def download_from_original_stable_diffusion_ckpt(
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema
)
unet.load_state_dict(converted_unet_checkpoint)
# Works!
import ipdb; ipdb.set_trace()
# Convert the VAE model.
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)