mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Improve typehints and docs in diffusers/models (#5299)
* improvement: add typehints and docs to diffusers/models/activations.py * improvement: add typehints and docs to diffusers/models/resnet.py
This commit is contained in:
@@ -1,7 +1,15 @@
|
||||
from torch import nn
|
||||
|
||||
|
||||
def get_activation(act_fn):
|
||||
def get_activation(act_fn: str) -> nn.Module:
|
||||
"""Helper function to get activation function from string.
|
||||
|
||||
Args:
|
||||
act_fn (str): Name of activation function.
|
||||
|
||||
Returns:
|
||||
nn.Module: Activation function.
|
||||
"""
|
||||
if act_fn in ["swish", "silu"]:
|
||||
return nn.SiLU()
|
||||
elif act_fn == "mish":
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -38,9 +38,18 @@ class Upsample1D(nn.Module):
|
||||
option to use a convolution transpose.
|
||||
out_channels (`int`, optional):
|
||||
number of output channels. Defaults to `channels`.
|
||||
name (`str`, default `conv`):
|
||||
name of the upsampling 1D layer.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
use_conv: bool = False,
|
||||
use_conv_transpose: bool = False,
|
||||
out_channels: Optional[int] = None,
|
||||
name: str = "conv",
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
@@ -54,7 +63,7 @@ class Upsample1D(nn.Module):
|
||||
elif use_conv:
|
||||
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
|
||||
|
||||
def forward(self, inputs):
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
assert inputs.shape[1] == self.channels
|
||||
if self.use_conv_transpose:
|
||||
return self.conv(inputs)
|
||||
@@ -79,9 +88,18 @@ class Downsample1D(nn.Module):
|
||||
number of output channels. Defaults to `channels`.
|
||||
padding (`int`, default `1`):
|
||||
padding for the convolution.
|
||||
name (`str`, default `conv`):
|
||||
name of the downsampling 1D layer.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
use_conv: bool = False,
|
||||
out_channels: Optional[int] = None,
|
||||
padding: int = 1,
|
||||
name: str = "conv",
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
@@ -96,7 +114,7 @@ class Downsample1D(nn.Module):
|
||||
assert self.channels == self.out_channels
|
||||
self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
|
||||
|
||||
def forward(self, inputs):
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
assert inputs.shape[1] == self.channels
|
||||
return self.conv(inputs)
|
||||
|
||||
@@ -113,9 +131,18 @@ class Upsample2D(nn.Module):
|
||||
option to use a convolution transpose.
|
||||
out_channels (`int`, optional):
|
||||
number of output channels. Defaults to `channels`.
|
||||
name (`str`, default `conv`):
|
||||
name of the upsampling 2D layer.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
use_conv: bool = False,
|
||||
use_conv_transpose: bool = False,
|
||||
out_channels: Optional[int] = None,
|
||||
name: str = "conv",
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
@@ -135,7 +162,7 @@ class Upsample2D(nn.Module):
|
||||
else:
|
||||
self.Conv2d_0 = conv
|
||||
|
||||
def forward(self, hidden_states, output_size=None, scale: float = 1.0):
|
||||
def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None, scale: float = 1.0):
|
||||
assert hidden_states.shape[1] == self.channels
|
||||
|
||||
if self.use_conv_transpose:
|
||||
@@ -191,9 +218,18 @@ class Downsample2D(nn.Module):
|
||||
number of output channels. Defaults to `channels`.
|
||||
padding (`int`, default `1`):
|
||||
padding for the convolution.
|
||||
name (`str`, default `conv`):
|
||||
name of the downsampling 2D layer.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
use_conv: bool = False,
|
||||
out_channels: Optional[int] = None,
|
||||
padding: int = 1,
|
||||
name: str = "conv",
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
@@ -246,7 +282,13 @@ class FirUpsample2D(nn.Module):
|
||||
kernel for the FIR filter.
|
||||
"""
|
||||
|
||||
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int = None,
|
||||
out_channels: Optional[int] = None,
|
||||
use_conv: bool = False,
|
||||
fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
|
||||
):
|
||||
super().__init__()
|
||||
out_channels = out_channels if out_channels else channels
|
||||
if use_conv:
|
||||
@@ -255,7 +297,14 @@ class FirUpsample2D(nn.Module):
|
||||
self.fir_kernel = fir_kernel
|
||||
self.out_channels = out_channels
|
||||
|
||||
def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
|
||||
def _upsample_2d(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
weight: Optional[torch.Tensor] = None,
|
||||
kernel: Optional[torch.FloatTensor] = None,
|
||||
factor: int = 2,
|
||||
gain: float = 1,
|
||||
) -> torch.Tensor:
|
||||
"""Fused `upsample_2d()` followed by `Conv2d()`.
|
||||
|
||||
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
||||
@@ -335,7 +384,7 @@ class FirUpsample2D(nn.Module):
|
||||
|
||||
return output
|
||||
|
||||
def forward(self, hidden_states):
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
if self.use_conv:
|
||||
height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
|
||||
height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
||||
@@ -359,7 +408,13 @@ class FirDownsample2D(nn.Module):
|
||||
kernel for the FIR filter.
|
||||
"""
|
||||
|
||||
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int = None,
|
||||
out_channels: Optional[int] = None,
|
||||
use_conv: bool = False,
|
||||
fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
|
||||
):
|
||||
super().__init__()
|
||||
out_channels = out_channels if out_channels else channels
|
||||
if use_conv:
|
||||
@@ -368,7 +423,14 @@ class FirDownsample2D(nn.Module):
|
||||
self.use_conv = use_conv
|
||||
self.out_channels = out_channels
|
||||
|
||||
def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
|
||||
def _downsample_2d(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
weight: Optional[torch.Tensor] = None,
|
||||
kernel: Optional[torch.FloatTensor] = None,
|
||||
factor: int = 2,
|
||||
gain: float = 1,
|
||||
) -> torch.Tensor:
|
||||
"""Fused `Conv2d()` followed by `downsample_2d()`.
|
||||
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
||||
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
|
||||
@@ -422,7 +484,7 @@ class FirDownsample2D(nn.Module):
|
||||
|
||||
return output
|
||||
|
||||
def forward(self, hidden_states):
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
if self.use_conv:
|
||||
downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
|
||||
hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
||||
@@ -434,14 +496,20 @@ class FirDownsample2D(nn.Module):
|
||||
|
||||
# downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead
|
||||
class KDownsample2D(nn.Module):
|
||||
def __init__(self, pad_mode="reflect"):
|
||||
r"""A 2D K-downsampling layer.
|
||||
|
||||
Parameters:
|
||||
pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use.
|
||||
"""
|
||||
|
||||
def __init__(self, pad_mode: str = "reflect"):
|
||||
super().__init__()
|
||||
self.pad_mode = pad_mode
|
||||
kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]])
|
||||
self.pad = kernel_1d.shape[1] // 2 - 1
|
||||
self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
|
||||
|
||||
def forward(self, inputs):
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode)
|
||||
weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
|
||||
indices = torch.arange(inputs.shape[1], device=inputs.device)
|
||||
@@ -451,14 +519,20 @@ class KDownsample2D(nn.Module):
|
||||
|
||||
|
||||
class KUpsample2D(nn.Module):
|
||||
def __init__(self, pad_mode="reflect"):
|
||||
r"""A 2D K-upsampling layer.
|
||||
|
||||
Parameters:
|
||||
pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use.
|
||||
"""
|
||||
|
||||
def __init__(self, pad_mode: str = "reflect"):
|
||||
super().__init__()
|
||||
self.pad_mode = pad_mode
|
||||
kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2
|
||||
self.pad = kernel_1d.shape[1] // 2 - 1
|
||||
self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
|
||||
|
||||
def forward(self, inputs):
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode)
|
||||
weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
|
||||
indices = torch.arange(inputs.shape[1], device=inputs.device)
|
||||
@@ -501,23 +575,23 @@ class ResnetBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
in_channels,
|
||||
out_channels=None,
|
||||
conv_shortcut=False,
|
||||
dropout=0.0,
|
||||
temb_channels=512,
|
||||
groups=32,
|
||||
groups_out=None,
|
||||
pre_norm=True,
|
||||
eps=1e-6,
|
||||
non_linearity="swish",
|
||||
skip_time_act=False,
|
||||
time_embedding_norm="default", # default, scale_shift, ada_group, spatial
|
||||
kernel=None,
|
||||
output_scale_factor=1.0,
|
||||
use_in_shortcut=None,
|
||||
up=False,
|
||||
down=False,
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
conv_shortcut: bool = False,
|
||||
dropout: float = 0.0,
|
||||
temb_channels: int = 512,
|
||||
groups: int = 32,
|
||||
groups_out: Optional[int] = None,
|
||||
pre_norm: bool = True,
|
||||
eps: float = 1e-6,
|
||||
non_linearity: str = "swish",
|
||||
skip_time_act: bool = False,
|
||||
time_embedding_norm: str = "default", # default, scale_shift, ada_group, spatial
|
||||
kernel: Optional[torch.FloatTensor] = None,
|
||||
output_scale_factor: float = 1.0,
|
||||
use_in_shortcut: Optional[bool] = None,
|
||||
up: bool = False,
|
||||
down: bool = False,
|
||||
conv_shortcut_bias: bool = True,
|
||||
conv_2d_out_channels: Optional[int] = None,
|
||||
):
|
||||
@@ -667,7 +741,7 @@ class ResnetBlock2D(nn.Module):
|
||||
|
||||
|
||||
# unet_rl.py
|
||||
def rearrange_dims(tensor):
|
||||
def rearrange_dims(tensor: torch.Tensor) -> torch.Tensor:
|
||||
if len(tensor.shape) == 2:
|
||||
return tensor[:, :, None]
|
||||
if len(tensor.shape) == 3:
|
||||
@@ -681,16 +755,24 @@ def rearrange_dims(tensor):
|
||||
class Conv1dBlock(nn.Module):
|
||||
"""
|
||||
Conv1d --> GroupNorm --> Mish
|
||||
|
||||
Parameters:
|
||||
inp_channels (`int`): Number of input channels.
|
||||
out_channels (`int`): Number of output channels.
|
||||
kernel_size (`int` or `tuple`): Size of the convolving kernel.
|
||||
n_groups (`int`, default `8`): Number of groups to separate the channels into.
|
||||
"""
|
||||
|
||||
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
|
||||
def __init__(
|
||||
self, inp_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]], n_groups: int = 8
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
|
||||
self.group_norm = nn.GroupNorm(n_groups, out_channels)
|
||||
self.mish = nn.Mish()
|
||||
|
||||
def forward(self, inputs):
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
intermediate_repr = self.conv1d(inputs)
|
||||
intermediate_repr = rearrange_dims(intermediate_repr)
|
||||
intermediate_repr = self.group_norm(intermediate_repr)
|
||||
@@ -701,7 +783,19 @@ class Conv1dBlock(nn.Module):
|
||||
|
||||
# unet_rl.py
|
||||
class ResidualTemporalBlock1D(nn.Module):
|
||||
def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5):
|
||||
"""
|
||||
Residual 1D block with temporal convolutions.
|
||||
|
||||
Parameters:
|
||||
inp_channels (`int`): Number of input channels.
|
||||
out_channels (`int`): Number of output channels.
|
||||
embed_dim (`int`): Embedding dimension.
|
||||
kernel_size (`int` or `tuple`): Size of the convolving kernel.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, inp_channels: int, out_channels: int, embed_dim: int, kernel_size: Union[int, Tuple[int, int]] = 5
|
||||
):
|
||||
super().__init__()
|
||||
self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
|
||||
self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)
|
||||
@@ -713,7 +807,7 @@ class ResidualTemporalBlock1D(nn.Module):
|
||||
nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(self, inputs, t):
|
||||
def forward(self, inputs: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
inputs : [ batch_size x inp_channels x horizon ]
|
||||
@@ -729,7 +823,9 @@ class ResidualTemporalBlock1D(nn.Module):
|
||||
return out + self.residual_conv(inputs)
|
||||
|
||||
|
||||
def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
|
||||
def upsample_2d(
|
||||
hidden_states: torch.Tensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1
|
||||
) -> torch.Tensor:
|
||||
r"""Upsample2D a batch of 2D images with the given filter.
|
||||
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
|
||||
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
|
||||
@@ -766,7 +862,9 @@ def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
|
||||
return output
|
||||
|
||||
|
||||
def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
|
||||
def downsample_2d(
|
||||
hidden_states: torch.Tensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1
|
||||
) -> torch.Tensor:
|
||||
r"""Downsample2D a batch of 2D images with the given filter.
|
||||
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
|
||||
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
|
||||
@@ -801,7 +899,9 @@ def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
|
||||
return output
|
||||
|
||||
|
||||
def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)):
|
||||
def upfirdn2d_native(
|
||||
tensor: torch.Tensor, kernel: torch.Tensor, up: int = 1, down: int = 1, pad: Tuple[int, int] = (0, 0)
|
||||
) -> torch.Tensor:
|
||||
up_x = up_y = up
|
||||
down_x = down_y = down
|
||||
pad_x0 = pad_y0 = pad[0]
|
||||
@@ -849,9 +949,14 @@ class TemporalConvLayer(nn.Module):
|
||||
"""
|
||||
Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from:
|
||||
https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
|
||||
|
||||
Parameters:
|
||||
in_dim (`int`): Number of input channels.
|
||||
out_dim (`int`): Number of output channels.
|
||||
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
|
||||
"""
|
||||
|
||||
def __init__(self, in_dim, out_dim=None, dropout=0.0):
|
||||
def __init__(self, in_dim: int, out_dim: Optional[int] = None, dropout: float = 0.0):
|
||||
super().__init__()
|
||||
out_dim = out_dim or in_dim
|
||||
self.in_dim = in_dim
|
||||
@@ -884,7 +989,7 @@ class TemporalConvLayer(nn.Module):
|
||||
nn.init.zeros_(self.conv4[-1].weight)
|
||||
nn.init.zeros_(self.conv4[-1].bias)
|
||||
|
||||
def forward(self, hidden_states, num_frames=1):
|
||||
def forward(self, hidden_states: torch.Tensor, num_frames: int = 1) -> torch.Tensor:
|
||||
hidden_states = (
|
||||
hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user