mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
add transformers depth
This commit is contained in:
@@ -38,6 +38,7 @@ def get_down_block(
|
||||
add_downsample,
|
||||
resnet_eps,
|
||||
resnet_act_fn,
|
||||
num_transformer_blocks=1,
|
||||
num_attention_heads=None,
|
||||
resnet_groups=None,
|
||||
cross_attention_dim=None,
|
||||
@@ -106,6 +107,7 @@ def get_down_block(
|
||||
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
|
||||
return CrossAttnDownBlock2D(
|
||||
num_layers=num_layers,
|
||||
num_transformer_blocks=num_transformer_blocks,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
@@ -227,6 +229,7 @@ def get_up_block(
|
||||
add_upsample,
|
||||
resnet_eps,
|
||||
resnet_act_fn,
|
||||
num_transformer_blocks=1,
|
||||
num_attention_heads=None,
|
||||
resnet_groups=None,
|
||||
cross_attention_dim=None,
|
||||
@@ -281,6 +284,7 @@ def get_up_block(
|
||||
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
|
||||
return CrossAttnUpBlock2D(
|
||||
num_layers=num_layers,
|
||||
num_transformer_blocks=num_transformer_blocks,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
prev_output_channel=prev_output_channel,
|
||||
@@ -506,6 +510,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
temb_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
num_transformer_blocks: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
@@ -548,7 +553,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
num_attention_heads,
|
||||
in_channels // num_attention_heads,
|
||||
in_channels=in_channels,
|
||||
num_layers=1,
|
||||
num_layers=num_transformer_blocks,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
@@ -829,6 +834,7 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
temb_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
num_transformer_blocks: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
@@ -873,7 +879,7 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
num_attention_heads,
|
||||
out_channels // num_attention_heads,
|
||||
in_channels=out_channels,
|
||||
num_layers=1,
|
||||
num_layers=num_transformer_blocks,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
@@ -1939,6 +1945,7 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
temb_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
num_transformer_blocks: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
@@ -1984,7 +1991,7 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
num_attention_heads,
|
||||
out_channels // num_attention_heads,
|
||||
in_channels=out_channels,
|
||||
num_layers=1,
|
||||
num_layers=num_transformer_blocks,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
|
||||
@@ -96,6 +96,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
|
||||
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
|
||||
The dimension of the cross attention features.
|
||||
num_transformer_blocks (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
||||
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
||||
encoder_hid_dim (`int`, *optional*, defaults to None):
|
||||
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
||||
dimension to `cross_attention_dim`.
|
||||
@@ -168,6 +170,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: Union[int, Tuple[int]] = 1280,
|
||||
num_transformer_blocks: Union[int, Tuple[int]] = 1,
|
||||
encoder_hid_dim: Optional[int] = None,
|
||||
encoder_hid_dim_type: Optional[str] = None,
|
||||
attention_head_dim: Union[int, Tuple[int]] = 8,
|
||||
@@ -381,6 +384,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
if isinstance(layers_per_block, int):
|
||||
layers_per_block = [layers_per_block] * len(down_block_types)
|
||||
|
||||
if isinstance(num_transformer_blocks, int):
|
||||
num_transformer_blocks = [num_transformer_blocks] * len(down_block_types)
|
||||
|
||||
if class_embeddings_concat:
|
||||
# The time embeddings are concatenated with the class embeddings. The dimension of the
|
||||
# time embeddings passed to the down, middle, and up blocks is twice the dimension of the
|
||||
@@ -399,6 +405,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
down_block = get_down_block(
|
||||
down_block_type,
|
||||
num_layers=layers_per_block[i],
|
||||
num_transformer_blocks=num_transformer_blocks[i],
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=blocks_time_embed_dim,
|
||||
@@ -424,6 +431,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
# mid
|
||||
if mid_block_type == "UNetMidBlock2DCrossAttn":
|
||||
self.mid_block = UNetMidBlock2DCrossAttn(
|
||||
num_transformer_blocks=num_transformer_blocks[-1],
|
||||
in_channels=block_out_channels[-1],
|
||||
temb_channels=blocks_time_embed_dim,
|
||||
resnet_eps=norm_eps,
|
||||
@@ -465,6 +473,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
||||
reversed_layers_per_block = list(reversed(layers_per_block))
|
||||
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
|
||||
reversed_num_transformer_blocks = list(reversed(num_transformer_blocks))
|
||||
only_cross_attention = list(reversed(only_cross_attention))
|
||||
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
@@ -485,6 +494,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
up_block = get_up_block(
|
||||
up_block_type,
|
||||
num_layers=reversed_layers_per_block[i] + 1,
|
||||
num_transformer_blocks=reversed_num_transformer_blocks[i],
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
prev_output_channel=prev_output_channel,
|
||||
|
||||
@@ -233,7 +233,10 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
|
||||
if controlnet:
|
||||
unet_params = original_config.model.params.control_stage_config.params
|
||||
else:
|
||||
unet_params = original_config.model.params.unet_config.params
|
||||
if original_config.model.params.unet_config is not None:
|
||||
unet_params = original_config.model.params.unet_config.params
|
||||
else:
|
||||
unet_params = original_config.model.params.network_config.params
|
||||
|
||||
vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
||||
|
||||
@@ -253,6 +256,11 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
|
||||
up_block_types.append(block_type)
|
||||
resolution //= 2
|
||||
|
||||
if unet_params.transformer_depth is not None:
|
||||
num_transformer_blocks = unet_params.transformer_depth if isinstance(unet_params.transformer_depth, int) else list(unet_params.transformer_depth)
|
||||
else:
|
||||
num_transformer_blocks = 1
|
||||
|
||||
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
|
||||
|
||||
head_dim = unet_params.num_heads if "num_heads" in unet_params else None
|
||||
@@ -262,7 +270,7 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
|
||||
if use_linear_projection:
|
||||
# stable diffusion 2-base-512 and 2-768
|
||||
if head_dim is None:
|
||||
head_dim = [5, 10, 20, 20]
|
||||
head_dim = [5 * c for c in list(unet_params.channel_mult)]
|
||||
|
||||
class_embed_type = None
|
||||
projection_class_embeddings_input_dim = None
|
||||
@@ -286,6 +294,7 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
|
||||
"use_linear_projection": use_linear_projection,
|
||||
"class_embed_type": class_embed_type,
|
||||
"projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
|
||||
"num_transformer_blocks": num_transformer_blocks,
|
||||
}
|
||||
|
||||
if controlnet:
|
||||
@@ -1172,9 +1181,9 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
|
||||
)
|
||||
|
||||
num_train_timesteps = original_config.model.params.timesteps
|
||||
beta_start = original_config.model.params.linear_start
|
||||
beta_end = original_config.model.params.linear_end
|
||||
num_train_timesteps = original_config.model.params.timesteps or 1000
|
||||
beta_start = original_config.model.params.linear_start or 0.02
|
||||
beta_end = original_config.model.params.linear_end or 0.085
|
||||
|
||||
scheduler = DDIMScheduler(
|
||||
beta_end=beta_end,
|
||||
@@ -1216,8 +1225,9 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
|
||||
checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema
|
||||
)
|
||||
|
||||
unet.load_state_dict(converted_unet_checkpoint)
|
||||
# Works!
|
||||
import ipdb; ipdb.set_trace()
|
||||
|
||||
# Convert the VAE model.
|
||||
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
|
||||
|
||||
Reference in New Issue
Block a user