1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Chore] remove class assignments for linear and conv. (#7553)

* remove class assignments for linear and conv.

* fix: self.nn
This commit is contained in:
Sayak Paul
2024-04-02 13:01:04 +05:30
committed by GitHub
parent 5d83f50c23
commit 000fa82a1e
10 changed files with 38 additions and 61 deletions

View File

@@ -634,7 +634,6 @@ class FeedForward(nn.Module):
if inner_dim is None:
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
linear_cls = nn.Linear
if activation_fn == "gelu":
act_fn = GELU(dim, inner_dim, bias=bias)
@@ -651,7 +650,7 @@ class FeedForward(nn.Module):
# project dropout
self.net.append(nn.Dropout(dropout))
# project out
self.net.append(linear_cls(inner_dim, dim_out, bias=bias))
self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
if final_dropout:
self.net.append(nn.Dropout(dropout))

View File

@@ -181,25 +181,22 @@ class Attention(nn.Module):
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
)
linear_cls = nn.Linear
self.linear_cls = linear_cls
self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
if not self.only_cross_attention:
# only relevant for the `AddedKVProcessor` classes
self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
self.to_k = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
self.to_v = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
else:
self.to_k = None
self.to_v = None
if self.added_kv_proj_dim is not None:
self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
self.to_out = nn.ModuleList([])
self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(nn.Dropout(dropout))
# set attention processor
@@ -706,7 +703,7 @@ class Attention(nn.Module):
out_features = concatenated_weights.shape[0]
# create a new single projection layer and copy over the weights.
self.to_qkv = self.linear_cls(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
self.to_qkv.weight.copy_(concatenated_weights)
if self.use_bias:
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
@@ -717,7 +714,7 @@ class Attention(nn.Module):
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]
self.to_kv = self.linear_cls(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
self.to_kv.weight.copy_(concatenated_weights)
if self.use_bias:
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])

View File

@@ -102,7 +102,6 @@ class Downsample2D(nn.Module):
self.padding = padding
stride = 2
self.name = name
conv_cls = nn.Conv2d
if norm_type == "ln_norm":
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
@@ -114,7 +113,7 @@ class Downsample2D(nn.Module):
raise ValueError(f"unknown norm_type: {norm_type}")
if use_conv:
conv = conv_cls(
conv = nn.Conv2d(
self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias
)
else:

View File

@@ -199,9 +199,8 @@ class TimestepEmbedding(nn.Module):
sample_proj_bias=True,
):
super().__init__()
linear_cls = nn.Linear
self.linear_1 = linear_cls(in_channels, time_embed_dim, sample_proj_bias)
self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
if cond_proj_dim is not None:
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
@@ -214,7 +213,7 @@ class TimestepEmbedding(nn.Module):
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = linear_cls(time_embed_dim, time_embed_dim_out, sample_proj_bias)
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
if post_act_fn is None:
self.post_act = None

View File

@@ -101,8 +101,6 @@ class ResnetBlockCondNorm2D(nn.Module):
self.output_scale_factor = output_scale_factor
self.time_embedding_norm = time_embedding_norm
conv_cls = nn.Conv2d
if groups_out is None:
groups_out = groups
@@ -113,7 +111,7 @@ class ResnetBlockCondNorm2D(nn.Module):
else:
raise ValueError(f" unsupported time_embedding_norm: {self.time_embedding_norm}")
self.conv1 = conv_cls(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.time_embedding_norm == "ada_group": # ada_group
self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
@@ -125,7 +123,7 @@ class ResnetBlockCondNorm2D(nn.Module):
self.dropout = torch.nn.Dropout(dropout)
conv_2d_out_channels = conv_2d_out_channels or out_channels
self.conv2 = conv_cls(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
self.nonlinearity = get_activation(non_linearity)
@@ -139,7 +137,7 @@ class ResnetBlockCondNorm2D(nn.Module):
self.conv_shortcut = None
if self.use_in_shortcut:
self.conv_shortcut = conv_cls(
self.conv_shortcut = nn.Conv2d(
in_channels,
conv_2d_out_channels,
kernel_size=1,
@@ -263,21 +261,18 @@ class ResnetBlock2D(nn.Module):
self.time_embedding_norm = time_embedding_norm
self.skip_time_act = skip_time_act
linear_cls = nn.Linear
conv_cls = nn.Conv2d
if groups_out is None:
groups_out = groups
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
self.conv1 = conv_cls(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if temb_channels is not None:
if self.time_embedding_norm == "default":
self.time_emb_proj = linear_cls(temb_channels, out_channels)
self.time_emb_proj = nn.Linear(temb_channels, out_channels)
elif self.time_embedding_norm == "scale_shift":
self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels)
self.time_emb_proj = nn.Linear(temb_channels, 2 * out_channels)
else:
raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
else:
@@ -287,7 +282,7 @@ class ResnetBlock2D(nn.Module):
self.dropout = torch.nn.Dropout(dropout)
conv_2d_out_channels = conv_2d_out_channels or out_channels
self.conv2 = conv_cls(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
self.nonlinearity = get_activation(non_linearity)
@@ -313,7 +308,7 @@ class ResnetBlock2D(nn.Module):
self.conv_shortcut = None
if self.use_in_shortcut:
self.conv_shortcut = conv_cls(
self.conv_shortcut = nn.Conv2d(
in_channels,
conv_2d_out_channels,
kernel_size=1,

View File

@@ -117,9 +117,6 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
conv_cls = nn.Conv2d
linear_cls = nn.Linear
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
# Define whether input is continuous or discrete depending on configuration
self.is_input_continuous = (in_channels is not None) and (patch_size is None)
@@ -159,9 +156,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
if use_linear_projection:
self.proj_in = linear_cls(in_channels, inner_dim)
self.proj_in = nn.Linear(in_channels, inner_dim)
else:
self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized:
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
@@ -222,9 +219,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
if self.is_input_continuous:
# TODO: should use out_channels for continuous projections
if use_linear_projection:
self.proj_out = linear_cls(inner_dim, in_channels)
self.proj_out = nn.Linear(inner_dim, in_channels)
else:
self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized:
self.norm_out = nn.LayerNorm(inner_dim)
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)

View File

@@ -41,11 +41,11 @@ class SDCascadeLayerNorm(nn.LayerNorm):
class SDCascadeTimestepBlock(nn.Module):
def __init__(self, c, c_timestep, conds=[]):
super().__init__()
linear_cls = nn.Linear
self.mapper = linear_cls(c_timestep, c * 2)
self.mapper = nn.Linear(c_timestep, c * 2)
self.conds = conds
for cname in conds:
setattr(self, f"mapper_{cname}", linear_cls(c_timestep, c * 2))
setattr(self, f"mapper_{cname}", nn.Linear(c_timestep, c * 2))
def forward(self, x, t):
t = t.chunk(len(self.conds) + 1, dim=1)
@@ -94,12 +94,11 @@ class GlobalResponseNorm(nn.Module):
class SDCascadeAttnBlock(nn.Module):
def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
super().__init__()
linear_cls = nn.Linear
self.self_attn = self_attn
self.norm = SDCascadeLayerNorm(c, elementwise_affine=False, eps=1e-6)
self.attention = Attention(query_dim=c, heads=nhead, dim_head=c // nhead, dropout=dropout, bias=True)
self.kv_mapper = nn.Sequential(nn.SiLU(), linear_cls(c_cond, c))
self.kv_mapper = nn.Sequential(nn.SiLU(), nn.Linear(c_cond, c))
def forward(self, x, kv):
kv = self.kv_mapper(kv)

View File

@@ -110,7 +110,6 @@ class Upsample2D(nn.Module):
self.use_conv_transpose = use_conv_transpose
self.name = name
self.interpolate = interpolate
conv_cls = nn.Conv2d
if norm_type == "ln_norm":
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
@@ -131,7 +130,7 @@ class Upsample2D(nn.Module):
elif use_conv:
if kernel_size is None:
kernel_size = 3
conv = conv_cls(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias)
conv = nn.Conv2d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias)
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if name == "conv":

View File

@@ -17,8 +17,8 @@ class WuerstchenLayerNorm(nn.LayerNorm):
class TimestepBlock(nn.Module):
def __init__(self, c, c_timestep):
super().__init__()
linear_cls = nn.Linear
self.mapper = linear_cls(c_timestep, c * 2)
self.mapper = nn.Linear(c_timestep, c * 2)
def forward(self, x, t):
a, b = self.mapper(t)[:, :, None, None].chunk(2, dim=1)
@@ -29,13 +29,10 @@ class ResBlock(nn.Module):
def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0):
super().__init__()
conv_cls = nn.Conv2d
linear_cls = nn.Linear
self.depthwise = conv_cls(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
self.depthwise = nn.Conv2d(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
self.channelwise = nn.Sequential(
linear_cls(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), linear_cls(c * 4, c)
nn.Linear(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), nn.Linear(c * 4, c)
)
def forward(self, x, x_skip=None):
@@ -64,12 +61,10 @@ class AttnBlock(nn.Module):
def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
super().__init__()
linear_cls = nn.Linear
self.self_attn = self_attn
self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
self.attention = Attention(query_dim=c, heads=nhead, dim_head=c // nhead, dropout=dropout, bias=True)
self.kv_mapper = nn.Sequential(nn.SiLU(), linear_cls(c_cond, c))
self.kv_mapper = nn.Sequential(nn.SiLU(), nn.Linear(c_cond, c))
def forward(self, x, kv):
kv = self.kv_mapper(kv)

View File

@@ -40,15 +40,13 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
@register_to_config
def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dropout=0.1):
super().__init__()
conv_cls = nn.Conv2d
linear_cls = nn.Linear
self.c_r = c_r
self.projection = conv_cls(c_in, c, kernel_size=1)
self.projection = nn.Conv2d(c_in, c, kernel_size=1)
self.cond_mapper = nn.Sequential(
linear_cls(c_cond, c),
nn.Linear(c_cond, c),
nn.LeakyReLU(0.2),
linear_cls(c, c),
nn.Linear(c, c),
)
self.blocks = nn.ModuleList()
@@ -58,7 +56,7 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
self.blocks.append(AttnBlock(c, c, nhead, self_attn=True, dropout=dropout))
self.out = nn.Sequential(
WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6),
conv_cls(c, c_in * 2, kernel_size=1),
nn.Conv2d(c, c_in * 2, kernel_size=1),
)
self.gradient_checkpointing = False