1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Improve typehints and docs in diffusers/models (#5299)

* improvement: add typehints and docs to diffusers/models/activations.py

* improvement: add typehints and docs to diffusers/models/resnet.py
This commit is contained in:
Aryan V S
2023-10-09 19:59:23 +05:30
committed by GitHub
parent c4d66200b7
commit bd72927c63
2 changed files with 159 additions and 46 deletions

View File

@@ -1,7 +1,15 @@
from torch import nn
def get_activation(act_fn):
def get_activation(act_fn: str) -> nn.Module:
"""Helper function to get activation function from string.
Args:
act_fn (str): Name of activation function.
Returns:
nn.Module: Activation function.
"""
if act_fn in ["swish", "silu"]:
return nn.SiLU()
elif act_fn == "mish":

View File

@@ -14,7 +14,7 @@
# limitations under the License.
from functools import partial
from typing import Optional
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -38,9 +38,18 @@ class Upsample1D(nn.Module):
option to use a convolution transpose.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
name (`str`, default `conv`):
name of the upsampling 1D layer.
"""
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
def __init__(
self,
channels: int,
use_conv: bool = False,
use_conv_transpose: bool = False,
out_channels: Optional[int] = None,
name: str = "conv",
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
@@ -54,7 +63,7 @@ class Upsample1D(nn.Module):
elif use_conv:
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
def forward(self, inputs):
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
assert inputs.shape[1] == self.channels
if self.use_conv_transpose:
return self.conv(inputs)
@@ -79,9 +88,18 @@ class Downsample1D(nn.Module):
number of output channels. Defaults to `channels`.
padding (`int`, default `1`):
padding for the convolution.
name (`str`, default `conv`):
name of the downsampling 1D layer.
"""
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
def __init__(
self,
channels: int,
use_conv: bool = False,
out_channels: Optional[int] = None,
padding: int = 1,
name: str = "conv",
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
@@ -96,7 +114,7 @@ class Downsample1D(nn.Module):
assert self.channels == self.out_channels
self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
def forward(self, inputs):
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
assert inputs.shape[1] == self.channels
return self.conv(inputs)
@@ -113,9 +131,18 @@ class Upsample2D(nn.Module):
option to use a convolution transpose.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
name (`str`, default `conv`):
name of the upsampling 2D layer.
"""
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
def __init__(
self,
channels: int,
use_conv: bool = False,
use_conv_transpose: bool = False,
out_channels: Optional[int] = None,
name: str = "conv",
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
@@ -135,7 +162,7 @@ class Upsample2D(nn.Module):
else:
self.Conv2d_0 = conv
def forward(self, hidden_states, output_size=None, scale: float = 1.0):
def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None, scale: float = 1.0):
assert hidden_states.shape[1] == self.channels
if self.use_conv_transpose:
@@ -191,9 +218,18 @@ class Downsample2D(nn.Module):
number of output channels. Defaults to `channels`.
padding (`int`, default `1`):
padding for the convolution.
name (`str`, default `conv`):
name of the downsampling 2D layer.
"""
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
def __init__(
self,
channels: int,
use_conv: bool = False,
out_channels: Optional[int] = None,
padding: int = 1,
name: str = "conv",
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
@@ -246,7 +282,13 @@ class FirUpsample2D(nn.Module):
kernel for the FIR filter.
"""
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
def __init__(
self,
channels: int = None,
out_channels: Optional[int] = None,
use_conv: bool = False,
fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
):
super().__init__()
out_channels = out_channels if out_channels else channels
if use_conv:
@@ -255,7 +297,14 @@ class FirUpsample2D(nn.Module):
self.fir_kernel = fir_kernel
self.out_channels = out_channels
def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
def _upsample_2d(
self,
hidden_states: torch.Tensor,
weight: Optional[torch.Tensor] = None,
kernel: Optional[torch.FloatTensor] = None,
factor: int = 2,
gain: float = 1,
) -> torch.Tensor:
"""Fused `upsample_2d()` followed by `Conv2d()`.
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
@@ -335,7 +384,7 @@ class FirUpsample2D(nn.Module):
return output
def forward(self, hidden_states):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.use_conv:
height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
@@ -359,7 +408,13 @@ class FirDownsample2D(nn.Module):
kernel for the FIR filter.
"""
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
def __init__(
self,
channels: int = None,
out_channels: Optional[int] = None,
use_conv: bool = False,
fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
):
super().__init__()
out_channels = out_channels if out_channels else channels
if use_conv:
@@ -368,7 +423,14 @@ class FirDownsample2D(nn.Module):
self.use_conv = use_conv
self.out_channels = out_channels
def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
def _downsample_2d(
self,
hidden_states: torch.Tensor,
weight: Optional[torch.Tensor] = None,
kernel: Optional[torch.FloatTensor] = None,
factor: int = 2,
gain: float = 1,
) -> torch.Tensor:
"""Fused `Conv2d()` followed by `downsample_2d()`.
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
@@ -422,7 +484,7 @@ class FirDownsample2D(nn.Module):
return output
def forward(self, hidden_states):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.use_conv:
downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
@@ -434,14 +496,20 @@ class FirDownsample2D(nn.Module):
# downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead
class KDownsample2D(nn.Module):
def __init__(self, pad_mode="reflect"):
r"""A 2D K-downsampling layer.
Parameters:
pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use.
"""
def __init__(self, pad_mode: str = "reflect"):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]])
self.pad = kernel_1d.shape[1] // 2 - 1
self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
def forward(self, inputs):
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode)
weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
indices = torch.arange(inputs.shape[1], device=inputs.device)
@@ -451,14 +519,20 @@ class KDownsample2D(nn.Module):
class KUpsample2D(nn.Module):
def __init__(self, pad_mode="reflect"):
r"""A 2D K-upsampling layer.
Parameters:
pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use.
"""
def __init__(self, pad_mode: str = "reflect"):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2
self.pad = kernel_1d.shape[1] // 2 - 1
self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
def forward(self, inputs):
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode)
weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
indices = torch.arange(inputs.shape[1], device=inputs.device)
@@ -501,23 +575,23 @@ class ResnetBlock2D(nn.Module):
def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout=0.0,
temb_channels=512,
groups=32,
groups_out=None,
pre_norm=True,
eps=1e-6,
non_linearity="swish",
skip_time_act=False,
time_embedding_norm="default", # default, scale_shift, ada_group, spatial
kernel=None,
output_scale_factor=1.0,
use_in_shortcut=None,
up=False,
down=False,
in_channels: int,
out_channels: Optional[int] = None,
conv_shortcut: bool = False,
dropout: float = 0.0,
temb_channels: int = 512,
groups: int = 32,
groups_out: Optional[int] = None,
pre_norm: bool = True,
eps: float = 1e-6,
non_linearity: str = "swish",
skip_time_act: bool = False,
time_embedding_norm: str = "default", # default, scale_shift, ada_group, spatial
kernel: Optional[torch.FloatTensor] = None,
output_scale_factor: float = 1.0,
use_in_shortcut: Optional[bool] = None,
up: bool = False,
down: bool = False,
conv_shortcut_bias: bool = True,
conv_2d_out_channels: Optional[int] = None,
):
@@ -667,7 +741,7 @@ class ResnetBlock2D(nn.Module):
# unet_rl.py
def rearrange_dims(tensor):
def rearrange_dims(tensor: torch.Tensor) -> torch.Tensor:
if len(tensor.shape) == 2:
return tensor[:, :, None]
if len(tensor.shape) == 3:
@@ -681,16 +755,24 @@ def rearrange_dims(tensor):
class Conv1dBlock(nn.Module):
"""
Conv1d --> GroupNorm --> Mish
Parameters:
inp_channels (`int`): Number of input channels.
out_channels (`int`): Number of output channels.
kernel_size (`int` or `tuple`): Size of the convolving kernel.
n_groups (`int`, default `8`): Number of groups to separate the channels into.
"""
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
def __init__(
self, inp_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]], n_groups: int = 8
):
super().__init__()
self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
self.group_norm = nn.GroupNorm(n_groups, out_channels)
self.mish = nn.Mish()
def forward(self, inputs):
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
intermediate_repr = self.conv1d(inputs)
intermediate_repr = rearrange_dims(intermediate_repr)
intermediate_repr = self.group_norm(intermediate_repr)
@@ -701,7 +783,19 @@ class Conv1dBlock(nn.Module):
# unet_rl.py
class ResidualTemporalBlock1D(nn.Module):
def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5):
"""
Residual 1D block with temporal convolutions.
Parameters:
inp_channels (`int`): Number of input channels.
out_channels (`int`): Number of output channels.
embed_dim (`int`): Embedding dimension.
kernel_size (`int` or `tuple`): Size of the convolving kernel.
"""
def __init__(
self, inp_channels: int, out_channels: int, embed_dim: int, kernel_size: Union[int, Tuple[int, int]] = 5
):
super().__init__()
self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)
@@ -713,7 +807,7 @@ class ResidualTemporalBlock1D(nn.Module):
nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
)
def forward(self, inputs, t):
def forward(self, inputs: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""
Args:
inputs : [ batch_size x inp_channels x horizon ]
@@ -729,7 +823,9 @@ class ResidualTemporalBlock1D(nn.Module):
return out + self.residual_conv(inputs)
def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
def upsample_2d(
hidden_states: torch.Tensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1
) -> torch.Tensor:
r"""Upsample2D a batch of 2D images with the given filter.
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
@@ -766,7 +862,9 @@ def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
return output
def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
def downsample_2d(
hidden_states: torch.Tensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1
) -> torch.Tensor:
r"""Downsample2D a batch of 2D images with the given filter.
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
@@ -801,7 +899,9 @@ def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
return output
def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)):
def upfirdn2d_native(
tensor: torch.Tensor, kernel: torch.Tensor, up: int = 1, down: int = 1, pad: Tuple[int, int] = (0, 0)
) -> torch.Tensor:
up_x = up_y = up
down_x = down_y = down
pad_x0 = pad_y0 = pad[0]
@@ -849,9 +949,14 @@ class TemporalConvLayer(nn.Module):
"""
Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from:
https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
Parameters:
in_dim (`int`): Number of input channels.
out_dim (`int`): Number of output channels.
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
"""
def __init__(self, in_dim, out_dim=None, dropout=0.0):
def __init__(self, in_dim: int, out_dim: Optional[int] = None, dropout: float = 0.0):
super().__init__()
out_dim = out_dim or in_dim
self.in_dim = in_dim
@@ -884,7 +989,7 @@ class TemporalConvLayer(nn.Module):
nn.init.zeros_(self.conv4[-1].weight)
nn.init.zeros_(self.conv4[-1].bias)
def forward(self, hidden_states, num_frames=1):
def forward(self, hidden_states: torch.Tensor, num_frames: int = 1) -> torch.Tensor:
hidden_states = (
hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4)
)