From 91fd181245d7b287c735f2f479e4c498d5458462 Mon Sep 17 00:00:00 2001 From: Aryan V S Date: Wed, 11 Oct 2023 16:34:59 +0530 Subject: [PATCH] Improve typehints and docs in `diffusers/models` (#5312) * improvement: add missing typehints and docs to diffusers/models/attention.py * chore: convert doc strings to raw python strings add missing typehints * improvement: add missing typehints and docs to diffusers/models/adapter.py * improvement: add missing typehints and docs to diffusers/models/lora.py * docs: include suggestion by @sayakpaul in src/diffusers/models/adapter.py Co-authored-by: Sayak Paul * docs: include suggestion by @sayakpaul in src/diffusers/models/adapter.py Co-authored-by: Sayak Paul * docs: include suggestion by @sayakpaul in src/diffusers/models/adapter.py Co-authored-by: Sayak Paul * docs: include suggestion by @sayakpaul in src/diffusers/models/adapter.py Co-authored-by: Sayak Paul * Update src/diffusers/models/lora.py --------- Co-authored-by: Sayak Paul Co-authored-by: Patrick von Platen --- src/diffusers/models/adapter.py | 83 +++++++++++++++++++++++---- src/diffusers/models/attention.py | 95 +++++++++++++++++++++++-------- src/diffusers/models/lora.py | 75 +++++++++++++++++++++--- 3 files changed, 211 insertions(+), 42 deletions(-) diff --git a/src/diffusers/models/adapter.py b/src/diffusers/models/adapter.py index bf6803c565..64d64d07bf 100644 --- a/src/diffusers/models/adapter.py +++ b/src/diffusers/models/adapter.py @@ -231,7 +231,11 @@ class T2IAdapter(ModelMixin, ConfigMixin): The number of channel of each downsample block's output hidden state. The `len(block_out_channels)` will also determine the number of downsample blocks in the Adapter. num_res_blocks (`int`, *optional*, defaults to 2): - Number of ResNet blocks in each downsample block + Number of ResNet blocks in each downsample block. + downscale_factor (`int`, *optional*, defaults to 8): + A factor that determines the total downscale factor of the Adapter. + adapter_type (`str`, *optional*, defaults to `full_adapter`): + The type of Adapter to use. Choose either `full_adapter` or `full_adapter_xl` or `light_adapter`. """ @register_to_config @@ -275,6 +279,10 @@ class T2IAdapter(ModelMixin, ConfigMixin): class FullAdapter(nn.Module): + r""" + See [`T2IAdapter`] for more information. + """ + def __init__( self, in_channels: int = 3, @@ -321,6 +329,10 @@ class FullAdapter(nn.Module): class FullAdapterXL(nn.Module): + r""" + See [`T2IAdapter`] for more information. + """ + def __init__( self, in_channels: int = 3, @@ -367,7 +379,22 @@ class FullAdapterXL(nn.Module): class AdapterBlock(nn.Module): - def __init__(self, in_channels, out_channels, num_res_blocks, down=False): + r""" + An AdapterBlock is a helper model that contains multiple ResNet-like blocks. It is used in the `FullAdapter` and + `FullAdapterXL` models. + + Parameters: + in_channels (`int`): + Number of channels of AdapterBlock's input. + out_channels (`int`): + Number of channels of AdapterBlock's output. + num_res_blocks (`int`): + Number of ResNet blocks in the AdapterBlock. + down (`bool`, *optional*, defaults to `False`): + Whether to perform downsampling on AdapterBlock's input. + """ + + def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, down: bool = False): super().__init__() self.downsample = None @@ -382,7 +409,7 @@ class AdapterBlock(nn.Module): *[AdapterResnetBlock(out_channels) for _ in range(num_res_blocks)], ) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: r""" This method takes tensor x as input and performs operations downsampling and convolutional layers if the self.downsample and self.in_conv properties of AdapterBlock model are specified. Then it applies a series of @@ -400,13 +427,21 @@ class AdapterBlock(nn.Module): class AdapterResnetBlock(nn.Module): - def __init__(self, channels): + r""" + An `AdapterResnetBlock` is a helper model that implements a ResNet-like block. + + Parameters: + channels (`int`): + Number of channels of AdapterResnetBlock's input and output. + """ + + def __init__(self, channels: int): super().__init__() self.block1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) self.act = nn.ReLU() self.block2 = nn.Conv2d(channels, channels, kernel_size=1) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: r""" This method takes input tensor x and applies a convolutional layer, ReLU activation, and another convolutional layer on the input tensor. It returns addition with the input tensor. @@ -423,6 +458,10 @@ class AdapterResnetBlock(nn.Module): class LightAdapter(nn.Module): + r""" + See [`T2IAdapter`] for more information. + """ + def __init__( self, in_channels: int = 3, @@ -449,7 +488,7 @@ class LightAdapter(nn.Module): self.total_downscale_factor = downscale_factor * (2 ** len(channels)) - def forward(self, x): + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: r""" This method takes the input tensor x and performs downscaling and appends it in list of feature tensors. Each feature tensor corresponds to a different level of processing within the LightAdapter. @@ -466,7 +505,22 @@ class LightAdapter(nn.Module): class LightAdapterBlock(nn.Module): - def __init__(self, in_channels, out_channels, num_res_blocks, down=False): + r""" + A `LightAdapterBlock` is a helper model that contains multiple `LightAdapterResnetBlocks`. It is used in the + `LightAdapter` model. + + Parameters: + in_channels (`int`): + Number of channels of LightAdapterBlock's input. + out_channels (`int`): + Number of channels of LightAdapterBlock's output. + num_res_blocks (`int`): + Number of LightAdapterResnetBlocks in the LightAdapterBlock. + down (`bool`, *optional*, defaults to `False`): + Whether to perform downsampling on LightAdapterBlock's input. + """ + + def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, down: bool = False): super().__init__() mid_channels = out_channels // 4 @@ -478,7 +532,7 @@ class LightAdapterBlock(nn.Module): self.resnets = nn.Sequential(*[LightAdapterResnetBlock(mid_channels) for _ in range(num_res_blocks)]) self.out_conv = nn.Conv2d(mid_channels, out_channels, kernel_size=1) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: r""" This method takes tensor x as input and performs downsampling if required. Then it applies in convolution layer, a sequence of residual blocks, and out convolutional layer. @@ -494,13 +548,22 @@ class LightAdapterBlock(nn.Module): class LightAdapterResnetBlock(nn.Module): - def __init__(self, channels): + """ + A `LightAdapterResnetBlock` is a helper model that implements a ResNet-like block with a slightly different + architecture than `AdapterResnetBlock`. + + Parameters: + channels (`int`): + Number of channels of LightAdapterResnetBlock's input and output. + """ + + def __init__(self, channels: int): super().__init__() self.block1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) self.act = nn.ReLU() self.block2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: r""" This function takes input tensor x and processes it through one convolutional layer, ReLU activation, and another convolutional layer and adds it to input tensor. diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 892d44a031..6f5d1da6c6 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Tuple import torch import torch.nn.functional as F @@ -26,7 +26,17 @@ from .lora import LoRACompatibleLinear @maybe_allow_in_graph class GatedSelfAttentionDense(nn.Module): - def __init__(self, query_dim, context_dim, n_heads, d_head): + r""" + A gated self-attention dense layer that combines visual features and object features. + + Parameters: + query_dim (`int`): The number of channels in the query. + context_dim (`int`): The number of channels in the context. + n_heads (`int`): The number of heads to use for attention. + d_head (`int`): The number of channels in each head. + """ + + def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int): super().__init__() # we need a linear projection since we need cat visual feature and obj feature @@ -43,7 +53,7 @@ class GatedSelfAttentionDense(nn.Module): self.enabled = True - def forward(self, x, objs): + def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor: if not self.enabled: return x @@ -67,15 +77,25 @@ class BasicTransformerBlock(nn.Module): attention_head_dim (`int`): The number of channels in each head. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. - only_cross_attention (`bool`, *optional*): - Whether to use only cross-attention layers. In this case two cross attention layers are used. - double_self_attention (`bool`, *optional*): - Whether to use two self-attention layers. In this case no cross attention layers are used. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. num_embeds_ada_norm (: obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. attention_bias (: obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, *optional*, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. """ def __init__( @@ -175,7 +195,7 @@ class BasicTransformerBlock(nn.Module): timestep: Optional[torch.LongTensor] = None, cross_attention_kwargs: Dict[str, Any] = None, class_labels: Optional[torch.LongTensor] = None, - ): + ) -> torch.FloatTensor: # Notice that normalization is always applied before the real computation in the following blocks. # 0. Self-Attention if self.use_ada_layer_norm: @@ -301,7 +321,7 @@ class FeedForward(nn.Module): if final_dropout: self.net.append(nn.Dropout(dropout)) - def forward(self, hidden_states, scale: float = 1.0): + def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: for module in self.net: if isinstance(module, (LoRACompatibleLinear, GEGLU)): hidden_states = module(hidden_states, scale) @@ -313,6 +333,11 @@ class FeedForward(nn.Module): class GELU(nn.Module): r""" GELU activation function with tanh approximation support with `approximate="tanh"`. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation. """ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"): @@ -320,7 +345,7 @@ class GELU(nn.Module): self.proj = nn.Linear(dim_in, dim_out) self.approximate = approximate - def gelu(self, gate): + def gelu(self, gate: torch.Tensor) -> torch.Tensor: if gate.device.type != "mps": return F.gelu(gate, approximate=self.approximate) # mps: gelu is not implemented for float16 @@ -345,7 +370,7 @@ class GEGLU(nn.Module): super().__init__() self.proj = LoRACompatibleLinear(dim_in, dim_out * 2) - def gelu(self, gate): + def gelu(self, gate: torch.Tensor) -> torch.Tensor: if gate.device.type != "mps": return F.gelu(gate) # mps: gelu is not implemented for float16 @@ -357,34 +382,41 @@ class GEGLU(nn.Module): class ApproximateGELU(nn.Module): - """ - The approximate form of Gaussian Error Linear Unit (GELU) + r""" + The approximate form of Gaussian Error Linear Unit (GELU). For more details, see section 2: + https://arxiv.org/abs/1606.08415. - For more details, see section 2: https://arxiv.org/abs/1606.08415 + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. """ def __init__(self, dim_in: int, dim_out: int): super().__init__() self.proj = nn.Linear(dim_in, dim_out) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.proj(x) return x * torch.sigmoid(1.702 * x) class AdaLayerNorm(nn.Module): - """ + r""" Norm layer modified to incorporate timestep embeddings. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`): The size of the dictionary of embeddings. """ - def __init__(self, embedding_dim, num_embeddings): + def __init__(self, embedding_dim: int, num_embeddings: int): super().__init__() self.emb = nn.Embedding(num_embeddings, embedding_dim) self.silu = nn.SiLU() self.linear = nn.Linear(embedding_dim, embedding_dim * 2) self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False) - def forward(self, x, timestep): + def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor: emb = self.linear(self.silu(self.emb(timestep))) scale, shift = torch.chunk(emb, 2) x = self.norm(x) * (1 + scale) + shift @@ -392,11 +424,15 @@ class AdaLayerNorm(nn.Module): class AdaLayerNormZero(nn.Module): - """ + r""" Norm layer adaptive layer norm zero (adaLN-Zero). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`): The size of the dictionary of embeddings. """ - def __init__(self, embedding_dim, num_embeddings): + def __init__(self, embedding_dim: int, num_embeddings: int): super().__init__() self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim) @@ -405,7 +441,13 @@ class AdaLayerNormZero(nn.Module): self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) - def forward(self, x, timestep, class_labels, hidden_dtype=None): + def forward( + self, + x: torch.Tensor, + timestep: torch.Tensor, + class_labels: torch.LongTensor, + hidden_dtype: Optional[torch.dtype] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype))) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1) x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] @@ -413,8 +455,15 @@ class AdaLayerNormZero(nn.Module): class AdaGroupNorm(nn.Module): - """ + r""" GroupNorm layer modified to incorporate timestep embeddings. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`): The size of the dictionary of embeddings. + num_groups (`int`): The number of groups to separate the channels into. + act_fn (`str`, *optional*, defaults to `None`): The activation function to use. + eps (`float`, *optional*, defaults to `1e-5`): The epsilon value to use for numerical stability. """ def __init__( @@ -431,7 +480,7 @@ class AdaGroupNorm(nn.Module): self.linear = nn.Linear(embedding_dim, out_dim * 2) - def forward(self, x, emb): + def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: if self.act: emb = self.act(emb) emb = self.linear(emb) diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index aec7200afd..a143c17458 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Optional, Tuple, Union import torch import torch.nn.functional as F @@ -40,7 +40,35 @@ def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0): class LoRALinearLayer(nn.Module): - def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None): + r""" + A linear layer that is used with LoRA. + + Parameters: + in_features (`int`): + Number of input features. + out_features (`int`): + Number of output features. + rank (`int`, `optional`, defaults to 4): + The rank of the LoRA layer. + network_alpha (`float`, `optional`, defaults to `None`): + The value of the network alpha used for stable learning and preventing underflow. This value has the same + meaning as the `--network_alpha` option in the kohya-ss trainer script. See + https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning + device (`torch.device`, `optional`, defaults to `None`): + The device to use for the layer's weights. + dtype (`torch.dtype`, `optional`, defaults to `None`): + The dtype to use for the layer's weights. + """ + + def __init__( + self, + in_features: int, + out_features: int, + rank: int = 4, + network_alpha: Optional[float] = None, + device: Optional[Union[torch.device, str]] = None, + dtype: Optional[torch.dtype] = None, + ): super().__init__() self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype) @@ -55,7 +83,7 @@ class LoRALinearLayer(nn.Module): nn.init.normal_(self.down.weight, std=1 / rank) nn.init.zeros_(self.up.weight) - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: orig_dtype = hidden_states.dtype dtype = self.down.weight.dtype @@ -69,8 +97,37 @@ class LoRALinearLayer(nn.Module): class LoRAConv2dLayer(nn.Module): + r""" + A convolutional layer that is used with LoRA. + + Parameters: + in_features (`int`): + Number of input features. + out_features (`int`): + Number of output features. + rank (`int`, `optional`, defaults to 4): + The rank of the LoRA layer. + kernel_size (`int` or `tuple` of two `int`, `optional`, defaults to 1): + The kernel size of the convolution. + stride (`int` or `tuple` of two `int`, `optional`, defaults to 1): + The stride of the convolution. + padding (`int` or `tuple` of two `int` or `str`, `optional`, defaults to 0): + The padding of the convolution. + network_alpha (`float`, `optional`, defaults to `None`): + The value of the network alpha used for stable learning and preventing underflow. This value has the same + meaning as the `--network_alpha` option in the kohya-ss trainer script. See + https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning + """ + def __init__( - self, in_features, out_features, rank=4, kernel_size=(1, 1), stride=(1, 1), padding=0, network_alpha=None + self, + in_features: int, + out_features: int, + rank: int = 4, + kernel_size: Union[int, Tuple[int, int]] = (1, 1), + stride: Union[int, Tuple[int, int]] = (1, 1), + padding: Union[int, Tuple[int, int], str] = 0, + network_alpha: Optional[float] = None, ): super().__init__() @@ -87,7 +144,7 @@ class LoRAConv2dLayer(nn.Module): nn.init.normal_(self.down.weight, std=1 / rank) nn.init.zeros_(self.up.weight) - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: orig_dtype = hidden_states.dtype dtype = self.down.weight.dtype @@ -112,7 +169,7 @@ class LoRACompatibleConv(nn.Conv2d): def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]): self.lora_layer = lora_layer - def _fuse_lora(self, lora_scale=1.0, safe_fusing=False): + def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False): if self.lora_layer is None: return @@ -164,7 +221,7 @@ class LoRACompatibleConv(nn.Conv2d): self.w_up = None self.w_down = None - def forward(self, hidden_states, scale: float = 1.0): + def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: if self.lora_layer is None: # make sure to the functional Conv2D function as otherwise torch.compile's graph will break # see: https://github.com/huggingface/diffusers/pull/4315 @@ -190,7 +247,7 @@ class LoRACompatibleLinear(nn.Linear): def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]): self.lora_layer = lora_layer - def _fuse_lora(self, lora_scale=1.0, safe_fusing=False): + def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False): if self.lora_layer is None: return @@ -238,7 +295,7 @@ class LoRACompatibleLinear(nn.Linear): self.w_up = None self.w_down = None - def forward(self, hidden_states, scale: float = 1.0): + def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: if self.lora_layer is None: out = super().forward(hidden_states) return out