diff --git a/src/diffusers/models/activations.py b/src/diffusers/models/activations.py index 04c978403f..46da899096 100644 --- a/src/diffusers/models/activations.py +++ b/src/diffusers/models/activations.py @@ -1,7 +1,15 @@ from torch import nn -def get_activation(act_fn): +def get_activation(act_fn: str) -> nn.Module: + """Helper function to get activation function from string. + + Args: + act_fn (str): Name of activation function. + + Returns: + nn.Module: Activation function. + """ if act_fn in ["swish", "silu"]: return nn.SiLU() elif act_fn == "mish": diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index ac66e2271c..3972b438b0 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -14,7 +14,7 @@ # limitations under the License. from functools import partial -from typing import Optional +from typing import Optional, Tuple, Union import torch import torch.nn as nn @@ -38,9 +38,18 @@ class Upsample1D(nn.Module): option to use a convolution transpose. out_channels (`int`, optional): number of output channels. Defaults to `channels`. + name (`str`, default `conv`): + name of the upsampling 1D layer. """ - def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): + def __init__( + self, + channels: int, + use_conv: bool = False, + use_conv_transpose: bool = False, + out_channels: Optional[int] = None, + name: str = "conv", + ): super().__init__() self.channels = channels self.out_channels = out_channels or channels @@ -54,7 +63,7 @@ class Upsample1D(nn.Module): elif use_conv: self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1) - def forward(self, inputs): + def forward(self, inputs: torch.Tensor) -> torch.Tensor: assert inputs.shape[1] == self.channels if self.use_conv_transpose: return self.conv(inputs) @@ -79,9 +88,18 @@ class Downsample1D(nn.Module): number of output channels. Defaults to `channels`. padding (`int`, default `1`): padding for the convolution. + name (`str`, default `conv`): + name of the downsampling 1D layer. """ - def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): + def __init__( + self, + channels: int, + use_conv: bool = False, + out_channels: Optional[int] = None, + padding: int = 1, + name: str = "conv", + ): super().__init__() self.channels = channels self.out_channels = out_channels or channels @@ -96,7 +114,7 @@ class Downsample1D(nn.Module): assert self.channels == self.out_channels self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride) - def forward(self, inputs): + def forward(self, inputs: torch.Tensor) -> torch.Tensor: assert inputs.shape[1] == self.channels return self.conv(inputs) @@ -113,9 +131,18 @@ class Upsample2D(nn.Module): option to use a convolution transpose. out_channels (`int`, optional): number of output channels. Defaults to `channels`. + name (`str`, default `conv`): + name of the upsampling 2D layer. """ - def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): + def __init__( + self, + channels: int, + use_conv: bool = False, + use_conv_transpose: bool = False, + out_channels: Optional[int] = None, + name: str = "conv", + ): super().__init__() self.channels = channels self.out_channels = out_channels or channels @@ -135,7 +162,7 @@ class Upsample2D(nn.Module): else: self.Conv2d_0 = conv - def forward(self, hidden_states, output_size=None, scale: float = 1.0): + def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None, scale: float = 1.0): assert hidden_states.shape[1] == self.channels if self.use_conv_transpose: @@ -191,9 +218,18 @@ class Downsample2D(nn.Module): number of output channels. Defaults to `channels`. padding (`int`, default `1`): padding for the convolution. + name (`str`, default `conv`): + name of the downsampling 2D layer. """ - def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): + def __init__( + self, + channels: int, + use_conv: bool = False, + out_channels: Optional[int] = None, + padding: int = 1, + name: str = "conv", + ): super().__init__() self.channels = channels self.out_channels = out_channels or channels @@ -246,7 +282,13 @@ class FirUpsample2D(nn.Module): kernel for the FIR filter. """ - def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)): + def __init__( + self, + channels: int = None, + out_channels: Optional[int] = None, + use_conv: bool = False, + fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1), + ): super().__init__() out_channels = out_channels if out_channels else channels if use_conv: @@ -255,7 +297,14 @@ class FirUpsample2D(nn.Module): self.fir_kernel = fir_kernel self.out_channels = out_channels - def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1): + def _upsample_2d( + self, + hidden_states: torch.Tensor, + weight: Optional[torch.Tensor] = None, + kernel: Optional[torch.FloatTensor] = None, + factor: int = 2, + gain: float = 1, + ) -> torch.Tensor: """Fused `upsample_2d()` followed by `Conv2d()`. Padding is performed only once at the beginning, not between the operations. The fused op is considerably more @@ -335,7 +384,7 @@ class FirUpsample2D(nn.Module): return output - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.use_conv: height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel) height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1) @@ -359,7 +408,13 @@ class FirDownsample2D(nn.Module): kernel for the FIR filter. """ - def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)): + def __init__( + self, + channels: int = None, + out_channels: Optional[int] = None, + use_conv: bool = False, + fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1), + ): super().__init__() out_channels = out_channels if out_channels else channels if use_conv: @@ -368,7 +423,14 @@ class FirDownsample2D(nn.Module): self.use_conv = use_conv self.out_channels = out_channels - def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1): + def _downsample_2d( + self, + hidden_states: torch.Tensor, + weight: Optional[torch.Tensor] = None, + kernel: Optional[torch.FloatTensor] = None, + factor: int = 2, + gain: float = 1, + ) -> torch.Tensor: """Fused `Conv2d()` followed by `downsample_2d()`. Padding is performed only once at the beginning, not between the operations. The fused op is considerably more efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of @@ -422,7 +484,7 @@ class FirDownsample2D(nn.Module): return output - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.use_conv: downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel) hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1) @@ -434,14 +496,20 @@ class FirDownsample2D(nn.Module): # downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead class KDownsample2D(nn.Module): - def __init__(self, pad_mode="reflect"): + r"""A 2D K-downsampling layer. + + Parameters: + pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use. + """ + + def __init__(self, pad_mode: str = "reflect"): super().__init__() self.pad_mode = pad_mode kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) self.pad = kernel_1d.shape[1] // 2 - 1 self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False) - def forward(self, inputs): + def forward(self, inputs: torch.Tensor) -> torch.Tensor: inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode) weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) indices = torch.arange(inputs.shape[1], device=inputs.device) @@ -451,14 +519,20 @@ class KDownsample2D(nn.Module): class KUpsample2D(nn.Module): - def __init__(self, pad_mode="reflect"): + r"""A 2D K-upsampling layer. + + Parameters: + pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use. + """ + + def __init__(self, pad_mode: str = "reflect"): super().__init__() self.pad_mode = pad_mode kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2 self.pad = kernel_1d.shape[1] // 2 - 1 self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False) - def forward(self, inputs): + def forward(self, inputs: torch.Tensor) -> torch.Tensor: inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode) weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) indices = torch.arange(inputs.shape[1], device=inputs.device) @@ -501,23 +575,23 @@ class ResnetBlock2D(nn.Module): def __init__( self, *, - in_channels, - out_channels=None, - conv_shortcut=False, - dropout=0.0, - temb_channels=512, - groups=32, - groups_out=None, - pre_norm=True, - eps=1e-6, - non_linearity="swish", - skip_time_act=False, - time_embedding_norm="default", # default, scale_shift, ada_group, spatial - kernel=None, - output_scale_factor=1.0, - use_in_shortcut=None, - up=False, - down=False, + in_channels: int, + out_channels: Optional[int] = None, + conv_shortcut: bool = False, + dropout: float = 0.0, + temb_channels: int = 512, + groups: int = 32, + groups_out: Optional[int] = None, + pre_norm: bool = True, + eps: float = 1e-6, + non_linearity: str = "swish", + skip_time_act: bool = False, + time_embedding_norm: str = "default", # default, scale_shift, ada_group, spatial + kernel: Optional[torch.FloatTensor] = None, + output_scale_factor: float = 1.0, + use_in_shortcut: Optional[bool] = None, + up: bool = False, + down: bool = False, conv_shortcut_bias: bool = True, conv_2d_out_channels: Optional[int] = None, ): @@ -667,7 +741,7 @@ class ResnetBlock2D(nn.Module): # unet_rl.py -def rearrange_dims(tensor): +def rearrange_dims(tensor: torch.Tensor) -> torch.Tensor: if len(tensor.shape) == 2: return tensor[:, :, None] if len(tensor.shape) == 3: @@ -681,16 +755,24 @@ def rearrange_dims(tensor): class Conv1dBlock(nn.Module): """ Conv1d --> GroupNorm --> Mish + + Parameters: + inp_channels (`int`): Number of input channels. + out_channels (`int`): Number of output channels. + kernel_size (`int` or `tuple`): Size of the convolving kernel. + n_groups (`int`, default `8`): Number of groups to separate the channels into. """ - def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): + def __init__( + self, inp_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]], n_groups: int = 8 + ): super().__init__() self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2) self.group_norm = nn.GroupNorm(n_groups, out_channels) self.mish = nn.Mish() - def forward(self, inputs): + def forward(self, inputs: torch.Tensor) -> torch.Tensor: intermediate_repr = self.conv1d(inputs) intermediate_repr = rearrange_dims(intermediate_repr) intermediate_repr = self.group_norm(intermediate_repr) @@ -701,7 +783,19 @@ class Conv1dBlock(nn.Module): # unet_rl.py class ResidualTemporalBlock1D(nn.Module): - def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5): + """ + Residual 1D block with temporal convolutions. + + Parameters: + inp_channels (`int`): Number of input channels. + out_channels (`int`): Number of output channels. + embed_dim (`int`): Embedding dimension. + kernel_size (`int` or `tuple`): Size of the convolving kernel. + """ + + def __init__( + self, inp_channels: int, out_channels: int, embed_dim: int, kernel_size: Union[int, Tuple[int, int]] = 5 + ): super().__init__() self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size) self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size) @@ -713,7 +807,7 @@ class ResidualTemporalBlock1D(nn.Module): nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity() ) - def forward(self, inputs, t): + def forward(self, inputs: torch.Tensor, t: torch.Tensor) -> torch.Tensor: """ Args: inputs : [ batch_size x inp_channels x horizon ] @@ -729,7 +823,9 @@ class ResidualTemporalBlock1D(nn.Module): return out + self.residual_conv(inputs) -def upsample_2d(hidden_states, kernel=None, factor=2, gain=1): +def upsample_2d( + hidden_states: torch.Tensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1 +) -> torch.Tensor: r"""Upsample2D a batch of 2D images with the given filter. Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified @@ -766,7 +862,9 @@ def upsample_2d(hidden_states, kernel=None, factor=2, gain=1): return output -def downsample_2d(hidden_states, kernel=None, factor=2, gain=1): +def downsample_2d( + hidden_states: torch.Tensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1 +) -> torch.Tensor: r"""Downsample2D a batch of 2D images with the given filter. Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the @@ -801,7 +899,9 @@ def downsample_2d(hidden_states, kernel=None, factor=2, gain=1): return output -def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)): +def upfirdn2d_native( + tensor: torch.Tensor, kernel: torch.Tensor, up: int = 1, down: int = 1, pad: Tuple[int, int] = (0, 0) +) -> torch.Tensor: up_x = up_y = up down_x = down_y = down pad_x0 = pad_y0 = pad[0] @@ -849,9 +949,14 @@ class TemporalConvLayer(nn.Module): """ Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from: https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016 + + Parameters: + in_dim (`int`): Number of input channels. + out_dim (`int`): Number of output channels. + dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. """ - def __init__(self, in_dim, out_dim=None, dropout=0.0): + def __init__(self, in_dim: int, out_dim: Optional[int] = None, dropout: float = 0.0): super().__init__() out_dim = out_dim or in_dim self.in_dim = in_dim @@ -884,7 +989,7 @@ class TemporalConvLayer(nn.Module): nn.init.zeros_(self.conv4[-1].weight) nn.init.zeros_(self.conv4[-1].bias) - def forward(self, hidden_states, num_frames=1): + def forward(self, hidden_states: torch.Tensor, num_frames: int = 1) -> torch.Tensor: hidden_states = ( hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4) )