diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index eee7e6023e..a57a469caa 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -38,6 +38,7 @@ def get_down_block( add_downsample, resnet_eps, resnet_act_fn, + num_transformer_blocks=1, num_attention_heads=None, resnet_groups=None, cross_attention_dim=None, @@ -106,6 +107,7 @@ def get_down_block( raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D") return CrossAttnDownBlock2D( num_layers=num_layers, + num_transformer_blocks=num_transformer_blocks, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -227,6 +229,7 @@ def get_up_block( add_upsample, resnet_eps, resnet_act_fn, + num_transformer_blocks=1, num_attention_heads=None, resnet_groups=None, cross_attention_dim=None, @@ -281,6 +284,7 @@ def get_up_block( raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D") return CrossAttnUpBlock2D( num_layers=num_layers, + num_transformer_blocks=num_transformer_blocks, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, @@ -506,6 +510,7 @@ class UNetMidBlock2DCrossAttn(nn.Module): temb_channels: int, dropout: float = 0.0, num_layers: int = 1, + num_transformer_blocks: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", @@ -548,7 +553,7 @@ class UNetMidBlock2DCrossAttn(nn.Module): num_attention_heads, in_channels // num_attention_heads, in_channels=in_channels, - num_layers=1, + num_layers=num_transformer_blocks, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, @@ -829,6 +834,7 @@ class CrossAttnDownBlock2D(nn.Module): temb_channels: int, dropout: float = 0.0, num_layers: int = 1, + num_transformer_blocks: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", @@ -873,7 +879,7 @@ class CrossAttnDownBlock2D(nn.Module): num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, - num_layers=1, + num_layers=num_transformer_blocks, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, @@ -1939,6 +1945,7 @@ class CrossAttnUpBlock2D(nn.Module): temb_channels: int, dropout: float = 0.0, num_layers: int = 1, + num_transformer_blocks: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", @@ -1984,7 +1991,7 @@ class CrossAttnUpBlock2D(nn.Module): num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, - num_layers=1, + num_layers=num_transformer_blocks, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 7bca5c336c..0cc9618d91 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -96,6 +96,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): The dimension of the cross attention features. + num_transformer_blocks (`int` or `Tuple[int]`, *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. encoder_hid_dim (`int`, *optional*, defaults to None): If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` dimension to `cross_attention_dim`. @@ -168,6 +170,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: Union[int, Tuple[int]] = 1280, + num_transformer_blocks: Union[int, Tuple[int]] = 1, encoder_hid_dim: Optional[int] = None, encoder_hid_dim_type: Optional[str] = None, attention_head_dim: Union[int, Tuple[int]] = 8, @@ -381,6 +384,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) if isinstance(layers_per_block, int): layers_per_block = [layers_per_block] * len(down_block_types) + if isinstance(num_transformer_blocks, int): + num_transformer_blocks = [num_transformer_blocks] * len(down_block_types) + if class_embeddings_concat: # The time embeddings are concatenated with the class embeddings. The dimension of the # time embeddings passed to the down, middle, and up blocks is twice the dimension of the @@ -399,6 +405,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) down_block = get_down_block( down_block_type, num_layers=layers_per_block[i], + num_transformer_blocks=num_transformer_blocks[i], in_channels=input_channel, out_channels=output_channel, temb_channels=blocks_time_embed_dim, @@ -424,6 +431,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) # mid if mid_block_type == "UNetMidBlock2DCrossAttn": self.mid_block = UNetMidBlock2DCrossAttn( + num_transformer_blocks=num_transformer_blocks[-1], in_channels=block_out_channels[-1], temb_channels=blocks_time_embed_dim, resnet_eps=norm_eps, @@ -465,6 +473,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) reversed_num_attention_heads = list(reversed(num_attention_heads)) reversed_layers_per_block = list(reversed(layers_per_block)) reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_num_transformer_blocks = list(reversed(num_transformer_blocks)) only_cross_attention = list(reversed(only_cross_attention)) output_channel = reversed_block_out_channels[0] @@ -485,6 +494,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) up_block = get_up_block( up_block_type, num_layers=reversed_layers_per_block[i] + 1, + num_transformer_blocks=reversed_num_transformer_blocks[i], in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index 3b3724f0d0..94cc1b36d2 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -233,7 +233,10 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa if controlnet: unet_params = original_config.model.params.control_stage_config.params else: - unet_params = original_config.model.params.unet_config.params + if original_config.model.params.unet_config is not None: + unet_params = original_config.model.params.unet_config.params + else: + unet_params = original_config.model.params.network_config.params vae_params = original_config.model.params.first_stage_config.params.ddconfig @@ -253,6 +256,11 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa up_block_types.append(block_type) resolution //= 2 + if unet_params.transformer_depth is not None: + num_transformer_blocks = unet_params.transformer_depth if isinstance(unet_params.transformer_depth, int) else list(unet_params.transformer_depth) + else: + num_transformer_blocks = 1 + vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1) head_dim = unet_params.num_heads if "num_heads" in unet_params else None @@ -262,7 +270,7 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa if use_linear_projection: # stable diffusion 2-base-512 and 2-768 if head_dim is None: - head_dim = [5, 10, 20, 20] + head_dim = [5 * c for c in list(unet_params.channel_mult)] class_embed_type = None projection_class_embeddings_input_dim = None @@ -286,6 +294,7 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa "use_linear_projection": use_linear_projection, "class_embed_type": class_embed_type, "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, + "num_transformer_blocks": num_transformer_blocks, } if controlnet: @@ -1172,9 +1181,9 @@ def download_from_original_stable_diffusion_ckpt( checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema ) - num_train_timesteps = original_config.model.params.timesteps - beta_start = original_config.model.params.linear_start - beta_end = original_config.model.params.linear_end + num_train_timesteps = original_config.model.params.timesteps or 1000 + beta_start = original_config.model.params.linear_start or 0.02 + beta_end = original_config.model.params.linear_end or 0.085 scheduler = DDIMScheduler( beta_end=beta_end, @@ -1216,8 +1225,9 @@ def download_from_original_stable_diffusion_ckpt( converted_unet_checkpoint = convert_ldm_unet_checkpoint( checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema ) - unet.load_state_dict(converted_unet_checkpoint) + # Works! + import ipdb; ipdb.set_trace() # Convert the VAE model. vae_config = create_vae_diffusers_config(original_config, image_size=image_size)