From ba1bfac20b55efe0eee8b8b470110201941acf70 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 29 May 2024 06:30:47 +0530 Subject: [PATCH] [Core] Refactor `IPAdapterPlusImageProjection` a bit (#7994) * use IPAdapterPlusImageProjectionBlock in IPAdapterPlusImageProjection * reposition IPAdapterPlusImageProjection * refactor complete? * fix heads param retrieval. * update test dict creation method. --- src/diffusers/loaders/unet.py | 54 +++++-- src/diffusers/models/embeddings.py | 142 ++++++++---------- .../unets/test_models_unet_2d_condition.py | 80 ++++++---- 3 files changed, 153 insertions(+), 123 deletions(-) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 7db7bfeda6..cf67da1cae 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -847,7 +847,12 @@ class UNet2DConditionLoadersMixin: embed_dims = state_dict["proj_in.weight"].shape[1] output_dims = state_dict["proj_out.weight"].shape[0] hidden_dims = state_dict["latents"].shape[2] - heads = state_dict["layers.0.0.to_q.weight"].shape[0] // 64 + attn_key_present = any("attn" in k for k in state_dict) + heads = ( + state_dict["layers.0.attn.to_q.weight"].shape[0] // 64 + if attn_key_present + else state_dict["layers.0.0.to_q.weight"].shape[0] // 64 + ) with init_context(): image_projection = IPAdapterPlusImageProjection( @@ -860,26 +865,53 @@ class UNet2DConditionLoadersMixin: for key, value in state_dict.items(): diffusers_name = key.replace("0.to", "2.to") - diffusers_name = diffusers_name.replace("1.0.weight", "3.0.weight") - diffusers_name = diffusers_name.replace("1.0.bias", "3.0.bias") - diffusers_name = diffusers_name.replace("1.1.weight", "3.1.net.0.proj.weight") - diffusers_name = diffusers_name.replace("1.3.weight", "3.1.net.2.weight") - if "norm1" in diffusers_name: - updated_state_dict[diffusers_name.replace("0.norm1", "0")] = value - elif "norm2" in diffusers_name: - updated_state_dict[diffusers_name.replace("0.norm2", "1")] = value - elif "to_kv" in diffusers_name: + diffusers_name = diffusers_name.replace("0.0.norm1", "0.ln0") + diffusers_name = diffusers_name.replace("0.0.norm2", "0.ln1") + diffusers_name = diffusers_name.replace("1.0.norm1", "1.ln0") + diffusers_name = diffusers_name.replace("1.0.norm2", "1.ln1") + diffusers_name = diffusers_name.replace("2.0.norm1", "2.ln0") + diffusers_name = diffusers_name.replace("2.0.norm2", "2.ln1") + diffusers_name = diffusers_name.replace("3.0.norm1", "3.ln0") + diffusers_name = diffusers_name.replace("3.0.norm2", "3.ln1") + + if "to_kv" in diffusers_name: + parts = diffusers_name.split(".") + parts[2] = "attn" + diffusers_name = ".".join(parts) v_chunk = value.chunk(2, dim=0) updated_state_dict[diffusers_name.replace("to_kv", "to_k")] = v_chunk[0] updated_state_dict[diffusers_name.replace("to_kv", "to_v")] = v_chunk[1] + elif "to_q" in diffusers_name: + parts = diffusers_name.split(".") + parts[2] = "attn" + diffusers_name = ".".join(parts) + updated_state_dict[diffusers_name] = value elif "to_out" in diffusers_name: + parts = diffusers_name.split(".") + parts[2] = "attn" + diffusers_name = ".".join(parts) updated_state_dict[diffusers_name.replace("to_out", "to_out.0")] = value else: + diffusers_name = diffusers_name.replace("0.1.0", "0.ff.0") + diffusers_name = diffusers_name.replace("0.1.1", "0.ff.1.net.0.proj") + diffusers_name = diffusers_name.replace("0.1.3", "0.ff.1.net.2") + + diffusers_name = diffusers_name.replace("1.1.0", "1.ff.0") + diffusers_name = diffusers_name.replace("1.1.1", "1.ff.1.net.0.proj") + diffusers_name = diffusers_name.replace("1.1.3", "1.ff.1.net.2") + + diffusers_name = diffusers_name.replace("2.1.0", "2.ff.0") + diffusers_name = diffusers_name.replace("2.1.1", "2.ff.1.net.0.proj") + diffusers_name = diffusers_name.replace("2.1.3", "2.ff.1.net.2") + + diffusers_name = diffusers_name.replace("3.1.0", "3.ff.0") + diffusers_name = diffusers_name.replace("3.1.1", "3.ff.1.net.0.proj") + diffusers_name = diffusers_name.replace("3.1.3", "3.ff.1.net.2") updated_state_dict[diffusers_name] = value if not low_cpu_mem_usage: - image_projection.load_state_dict(updated_state_dict) + image_projection.load_state_dict(updated_state_dict, strict=True) else: load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index d13f8a06cf..d2940e861c 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -806,89 +806,6 @@ class PixArtAlphaTextProjection(nn.Module): return hidden_states -class IPAdapterPlusImageProjection(nn.Module): - """Resampler of IP-Adapter Plus. - - Args: - embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels, - that is the same - number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024. - hidden_dims (int): - The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults - to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads. - Defaults to 16. num_queries (int): - The number of queries. Defaults to 8. ffn_ratio (float): The expansion ratio - of feedforward network hidden - layer channels. Defaults to 4. - """ - - def __init__( - self, - embed_dims: int = 768, - output_dims: int = 1024, - hidden_dims: int = 1280, - depth: int = 4, - dim_head: int = 64, - heads: int = 16, - num_queries: int = 8, - ffn_ratio: float = 4, - ) -> None: - super().__init__() - from .attention import FeedForward # Lazy import to avoid circular import - - self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dims) / hidden_dims**0.5) - - self.proj_in = nn.Linear(embed_dims, hidden_dims) - - self.proj_out = nn.Linear(hidden_dims, output_dims) - self.norm_out = nn.LayerNorm(output_dims) - - self.layers = nn.ModuleList([]) - for _ in range(depth): - self.layers.append( - nn.ModuleList( - [ - nn.LayerNorm(hidden_dims), - nn.LayerNorm(hidden_dims), - Attention( - query_dim=hidden_dims, - dim_head=dim_head, - heads=heads, - out_bias=False, - ), - nn.Sequential( - nn.LayerNorm(hidden_dims), - FeedForward(hidden_dims, hidden_dims, activation_fn="gelu", mult=ffn_ratio, bias=False), - ), - ] - ) - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Forward pass. - - Args: - x (torch.Tensor): Input Tensor. - Returns: - torch.Tensor: Output Tensor. - """ - latents = self.latents.repeat(x.size(0), 1, 1) - - x = self.proj_in(x) - - for ln0, ln1, attn, ff in self.layers: - residual = latents - - encoder_hidden_states = ln0(x) - latents = ln1(latents) - encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2) - latents = attn(latents, encoder_hidden_states) + residual - latents = ff(latents) + latents - - latents = self.proj_out(latents) - return self.norm_out(latents) - - class IPAdapterPlusImageProjectionBlock(nn.Module): def __init__( self, @@ -922,6 +839,65 @@ class IPAdapterPlusImageProjectionBlock(nn.Module): return latents +class IPAdapterPlusImageProjection(nn.Module): + """Resampler of IP-Adapter Plus. + + Args: + embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels, + that is the same + number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024. + hidden_dims (int): + The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults + to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads. + Defaults to 16. num_queries (int): + The number of queries. Defaults to 8. ffn_ratio (float): The expansion ratio + of feedforward network hidden + layer channels. Defaults to 4. + """ + + def __init__( + self, + embed_dims: int = 768, + output_dims: int = 1024, + hidden_dims: int = 1280, + depth: int = 4, + dim_head: int = 64, + heads: int = 16, + num_queries: int = 8, + ffn_ratio: float = 4, + ) -> None: + super().__init__() + self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dims) / hidden_dims**0.5) + + self.proj_in = nn.Linear(embed_dims, hidden_dims) + + self.proj_out = nn.Linear(hidden_dims, output_dims) + self.norm_out = nn.LayerNorm(output_dims) + + self.layers = nn.ModuleList( + [IPAdapterPlusImageProjectionBlock(hidden_dims, dim_head, heads, ffn_ratio) for _ in range(depth)] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x (torch.Tensor): Input Tensor. + Returns: + torch.Tensor: Output Tensor. + """ + latents = self.latents.repeat(x.size(0), 1, 1) + + x = self.proj_in(x) + + for block in self.layers: + residual = latents + latents = block(x, latents, residual) + + latents = self.proj_out(latents) + return self.norm_out(latents) + + class IPAdapterFaceIDPlusImageProjection(nn.Module): """FacePerceiverResampler of IP-Adapter Plus. diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index 33aa6a1037..ad33df964d 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -146,42 +146,64 @@ def create_ip_adapter_plus_state_dict(model): ) ip_image_projection_state_dict = OrderedDict() + keys = [k for k in image_projection.state_dict() if "layers." in k] + print(keys) for k, v in image_projection.state_dict().items(): if "2.to" in k: k = k.replace("2.to", "0.to") - elif "3.0.weight" in k: - k = k.replace("3.0.weight", "1.0.weight") - elif "3.0.bias" in k: - k = k.replace("3.0.bias", "1.0.bias") - elif "3.0.weight" in k: - k = k.replace("3.0.weight", "1.0.weight") - elif "3.1.net.0.proj.weight" in k: - k = k.replace("3.1.net.0.proj.weight", "1.1.weight") - elif "3.net.2.weight" in k: - k = k.replace("3.net.2.weight", "1.3.weight") - elif "layers.0.0" in k: - k = k.replace("layers.0.0", "layers.0.0.norm1") - elif "layers.0.1" in k: - k = k.replace("layers.0.1", "layers.0.0.norm2") - elif "layers.1.0" in k: - k = k.replace("layers.1.0", "layers.1.0.norm1") - elif "layers.1.1" in k: - k = k.replace("layers.1.1", "layers.1.0.norm2") - elif "layers.2.0" in k: - k = k.replace("layers.2.0", "layers.2.0.norm1") - elif "layers.2.1" in k: - k = k.replace("layers.2.1", "layers.2.0.norm2") + elif "layers.0.ln0" in k: + k = k.replace("layers.0.ln0", "layers.0.0.norm1") + elif "layers.0.ln1" in k: + k = k.replace("layers.0.ln1", "layers.0.0.norm2") + elif "layers.1.ln0" in k: + k = k.replace("layers.1.ln0", "layers.1.0.norm1") + elif "layers.1.ln1" in k: + k = k.replace("layers.1.ln1", "layers.1.0.norm2") + elif "layers.2.ln0" in k: + k = k.replace("layers.2.ln0", "layers.2.0.norm1") + elif "layers.2.ln1" in k: + k = k.replace("layers.2.ln1", "layers.2.0.norm2") + elif "layers.3.ln0" in k: + k = k.replace("layers.3.ln0", "layers.3.0.norm1") + elif "layers.3.ln1" in k: + k = k.replace("layers.3.ln1", "layers.3.0.norm2") + elif "to_q" in k: + parts = k.split(".") + parts[2] = "attn" + k = ".".join(parts) + elif "to_out.0" in k: + parts = k.split(".") + parts[2] = "attn" + k = ".".join(parts) + k = k.replace("to_out.0", "to_out") + else: + k = k.replace("0.ff.0", "0.1.0") + k = k.replace("0.ff.1.net.0.proj", "0.1.1") + k = k.replace("0.ff.1.net.2", "0.1.3") - if "norm_cross" in k: - ip_image_projection_state_dict[k.replace("norm_cross", "norm1")] = v - elif "layer_norm" in k: - ip_image_projection_state_dict[k.replace("layer_norm", "norm2")] = v - elif "to_k" in k: + k = k.replace("1.ff.0", "1.1.0") + k = k.replace("1.ff.1.net.0.proj", "1.1.1") + k = k.replace("1.ff.1.net.2", "1.1.3") + + k = k.replace("2.ff.0", "2.1.0") + k = k.replace("2.ff.1.net.0.proj", "2.1.1") + k = k.replace("2.ff.1.net.2", "2.1.3") + + k = k.replace("3.ff.0", "3.1.0") + k = k.replace("3.ff.1.net.0.proj", "3.1.1") + k = k.replace("3.ff.1.net.2", "3.1.3") + + # if "norm_cross" in k: + # ip_image_projection_state_dict[k.replace("norm_cross", "norm1")] = v + # elif "layer_norm" in k: + # ip_image_projection_state_dict[k.replace("layer_norm", "norm2")] = v + if "to_k" in k: + parts = k.split(".") + parts[2] = "attn" + k = ".".join(parts) ip_image_projection_state_dict[k.replace("to_k", "to_kv")] = torch.cat([v, v], dim=0) elif "to_v" in k: continue - elif "to_out.0" in k: - ip_image_projection_state_dict[k.replace("to_out.0", "to_out")] = v else: ip_image_projection_state_dict[k] = v