mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Chore] remove class assignments for linear and conv. (#7553)
* remove class assignments for linear and conv. * fix: self.nn
This commit is contained in:
@@ -634,7 +634,6 @@ class FeedForward(nn.Module):
|
||||
if inner_dim is None:
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = dim_out if dim_out is not None else dim
|
||||
linear_cls = nn.Linear
|
||||
|
||||
if activation_fn == "gelu":
|
||||
act_fn = GELU(dim, inner_dim, bias=bias)
|
||||
@@ -651,7 +650,7 @@ class FeedForward(nn.Module):
|
||||
# project dropout
|
||||
self.net.append(nn.Dropout(dropout))
|
||||
# project out
|
||||
self.net.append(linear_cls(inner_dim, dim_out, bias=bias))
|
||||
self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
|
||||
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
|
||||
if final_dropout:
|
||||
self.net.append(nn.Dropout(dropout))
|
||||
|
||||
@@ -181,25 +181,22 @@ class Attention(nn.Module):
|
||||
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
|
||||
)
|
||||
|
||||
linear_cls = nn.Linear
|
||||
|
||||
self.linear_cls = linear_cls
|
||||
self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
|
||||
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
|
||||
if not self.only_cross_attention:
|
||||
# only relevant for the `AddedKVProcessor` classes
|
||||
self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
|
||||
self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
|
||||
self.to_k = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
|
||||
self.to_v = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
|
||||
else:
|
||||
self.to_k = None
|
||||
self.to_v = None
|
||||
|
||||
if self.added_kv_proj_dim is not None:
|
||||
self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
|
||||
self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
|
||||
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
|
||||
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
|
||||
|
||||
self.to_out = nn.ModuleList([])
|
||||
self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias))
|
||||
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
||||
self.to_out.append(nn.Dropout(dropout))
|
||||
|
||||
# set attention processor
|
||||
@@ -706,7 +703,7 @@ class Attention(nn.Module):
|
||||
out_features = concatenated_weights.shape[0]
|
||||
|
||||
# create a new single projection layer and copy over the weights.
|
||||
self.to_qkv = self.linear_cls(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
|
||||
self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
|
||||
self.to_qkv.weight.copy_(concatenated_weights)
|
||||
if self.use_bias:
|
||||
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
|
||||
@@ -717,7 +714,7 @@ class Attention(nn.Module):
|
||||
in_features = concatenated_weights.shape[1]
|
||||
out_features = concatenated_weights.shape[0]
|
||||
|
||||
self.to_kv = self.linear_cls(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
|
||||
self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
|
||||
self.to_kv.weight.copy_(concatenated_weights)
|
||||
if self.use_bias:
|
||||
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
|
||||
|
||||
@@ -102,7 +102,6 @@ class Downsample2D(nn.Module):
|
||||
self.padding = padding
|
||||
stride = 2
|
||||
self.name = name
|
||||
conv_cls = nn.Conv2d
|
||||
|
||||
if norm_type == "ln_norm":
|
||||
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
|
||||
@@ -114,7 +113,7 @@ class Downsample2D(nn.Module):
|
||||
raise ValueError(f"unknown norm_type: {norm_type}")
|
||||
|
||||
if use_conv:
|
||||
conv = conv_cls(
|
||||
conv = nn.Conv2d(
|
||||
self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -199,9 +199,8 @@ class TimestepEmbedding(nn.Module):
|
||||
sample_proj_bias=True,
|
||||
):
|
||||
super().__init__()
|
||||
linear_cls = nn.Linear
|
||||
|
||||
self.linear_1 = linear_cls(in_channels, time_embed_dim, sample_proj_bias)
|
||||
self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
|
||||
|
||||
if cond_proj_dim is not None:
|
||||
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
|
||||
@@ -214,7 +213,7 @@ class TimestepEmbedding(nn.Module):
|
||||
time_embed_dim_out = out_dim
|
||||
else:
|
||||
time_embed_dim_out = time_embed_dim
|
||||
self.linear_2 = linear_cls(time_embed_dim, time_embed_dim_out, sample_proj_bias)
|
||||
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
|
||||
|
||||
if post_act_fn is None:
|
||||
self.post_act = None
|
||||
|
||||
@@ -101,8 +101,6 @@ class ResnetBlockCondNorm2D(nn.Module):
|
||||
self.output_scale_factor = output_scale_factor
|
||||
self.time_embedding_norm = time_embedding_norm
|
||||
|
||||
conv_cls = nn.Conv2d
|
||||
|
||||
if groups_out is None:
|
||||
groups_out = groups
|
||||
|
||||
@@ -113,7 +111,7 @@ class ResnetBlockCondNorm2D(nn.Module):
|
||||
else:
|
||||
raise ValueError(f" unsupported time_embedding_norm: {self.time_embedding_norm}")
|
||||
|
||||
self.conv1 = conv_cls(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
if self.time_embedding_norm == "ada_group": # ada_group
|
||||
self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
|
||||
@@ -125,7 +123,7 @@ class ResnetBlockCondNorm2D(nn.Module):
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
|
||||
conv_2d_out_channels = conv_2d_out_channels or out_channels
|
||||
self.conv2 = conv_cls(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.conv2 = nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
|
||||
@@ -139,7 +137,7 @@ class ResnetBlockCondNorm2D(nn.Module):
|
||||
|
||||
self.conv_shortcut = None
|
||||
if self.use_in_shortcut:
|
||||
self.conv_shortcut = conv_cls(
|
||||
self.conv_shortcut = nn.Conv2d(
|
||||
in_channels,
|
||||
conv_2d_out_channels,
|
||||
kernel_size=1,
|
||||
@@ -263,21 +261,18 @@ class ResnetBlock2D(nn.Module):
|
||||
self.time_embedding_norm = time_embedding_norm
|
||||
self.skip_time_act = skip_time_act
|
||||
|
||||
linear_cls = nn.Linear
|
||||
conv_cls = nn.Conv2d
|
||||
|
||||
if groups_out is None:
|
||||
groups_out = groups
|
||||
|
||||
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
||||
|
||||
self.conv1 = conv_cls(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
if temb_channels is not None:
|
||||
if self.time_embedding_norm == "default":
|
||||
self.time_emb_proj = linear_cls(temb_channels, out_channels)
|
||||
self.time_emb_proj = nn.Linear(temb_channels, out_channels)
|
||||
elif self.time_embedding_norm == "scale_shift":
|
||||
self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels)
|
||||
self.time_emb_proj = nn.Linear(temb_channels, 2 * out_channels)
|
||||
else:
|
||||
raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
|
||||
else:
|
||||
@@ -287,7 +282,7 @@ class ResnetBlock2D(nn.Module):
|
||||
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
conv_2d_out_channels = conv_2d_out_channels or out_channels
|
||||
self.conv2 = conv_cls(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.conv2 = nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
|
||||
@@ -313,7 +308,7 @@ class ResnetBlock2D(nn.Module):
|
||||
|
||||
self.conv_shortcut = None
|
||||
if self.use_in_shortcut:
|
||||
self.conv_shortcut = conv_cls(
|
||||
self.conv_shortcut = nn.Conv2d(
|
||||
in_channels,
|
||||
conv_2d_out_channels,
|
||||
kernel_size=1,
|
||||
|
||||
@@ -117,9 +117,6 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
self.attention_head_dim = attention_head_dim
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
conv_cls = nn.Conv2d
|
||||
linear_cls = nn.Linear
|
||||
|
||||
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
|
||||
# Define whether input is continuous or discrete depending on configuration
|
||||
self.is_input_continuous = (in_channels is not None) and (patch_size is None)
|
||||
@@ -159,9 +156,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
if use_linear_projection:
|
||||
self.proj_in = linear_cls(in_channels, inner_dim)
|
||||
self.proj_in = nn.Linear(in_channels, inner_dim)
|
||||
else:
|
||||
self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
||||
elif self.is_input_vectorized:
|
||||
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
|
||||
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
|
||||
@@ -222,9 +219,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
if self.is_input_continuous:
|
||||
# TODO: should use out_channels for continuous projections
|
||||
if use_linear_projection:
|
||||
self.proj_out = linear_cls(inner_dim, in_channels)
|
||||
self.proj_out = nn.Linear(inner_dim, in_channels)
|
||||
else:
|
||||
self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
elif self.is_input_vectorized:
|
||||
self.norm_out = nn.LayerNorm(inner_dim)
|
||||
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
|
||||
|
||||
@@ -41,11 +41,11 @@ class SDCascadeLayerNorm(nn.LayerNorm):
|
||||
class SDCascadeTimestepBlock(nn.Module):
|
||||
def __init__(self, c, c_timestep, conds=[]):
|
||||
super().__init__()
|
||||
linear_cls = nn.Linear
|
||||
self.mapper = linear_cls(c_timestep, c * 2)
|
||||
|
||||
self.mapper = nn.Linear(c_timestep, c * 2)
|
||||
self.conds = conds
|
||||
for cname in conds:
|
||||
setattr(self, f"mapper_{cname}", linear_cls(c_timestep, c * 2))
|
||||
setattr(self, f"mapper_{cname}", nn.Linear(c_timestep, c * 2))
|
||||
|
||||
def forward(self, x, t):
|
||||
t = t.chunk(len(self.conds) + 1, dim=1)
|
||||
@@ -94,12 +94,11 @@ class GlobalResponseNorm(nn.Module):
|
||||
class SDCascadeAttnBlock(nn.Module):
|
||||
def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
|
||||
super().__init__()
|
||||
linear_cls = nn.Linear
|
||||
|
||||
self.self_attn = self_attn
|
||||
self.norm = SDCascadeLayerNorm(c, elementwise_affine=False, eps=1e-6)
|
||||
self.attention = Attention(query_dim=c, heads=nhead, dim_head=c // nhead, dropout=dropout, bias=True)
|
||||
self.kv_mapper = nn.Sequential(nn.SiLU(), linear_cls(c_cond, c))
|
||||
self.kv_mapper = nn.Sequential(nn.SiLU(), nn.Linear(c_cond, c))
|
||||
|
||||
def forward(self, x, kv):
|
||||
kv = self.kv_mapper(kv)
|
||||
|
||||
@@ -110,7 +110,6 @@ class Upsample2D(nn.Module):
|
||||
self.use_conv_transpose = use_conv_transpose
|
||||
self.name = name
|
||||
self.interpolate = interpolate
|
||||
conv_cls = nn.Conv2d
|
||||
|
||||
if norm_type == "ln_norm":
|
||||
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
|
||||
@@ -131,7 +130,7 @@ class Upsample2D(nn.Module):
|
||||
elif use_conv:
|
||||
if kernel_size is None:
|
||||
kernel_size = 3
|
||||
conv = conv_cls(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias)
|
||||
conv = nn.Conv2d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias)
|
||||
|
||||
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
||||
if name == "conv":
|
||||
|
||||
@@ -17,8 +17,8 @@ class WuerstchenLayerNorm(nn.LayerNorm):
|
||||
class TimestepBlock(nn.Module):
|
||||
def __init__(self, c, c_timestep):
|
||||
super().__init__()
|
||||
linear_cls = nn.Linear
|
||||
self.mapper = linear_cls(c_timestep, c * 2)
|
||||
|
||||
self.mapper = nn.Linear(c_timestep, c * 2)
|
||||
|
||||
def forward(self, x, t):
|
||||
a, b = self.mapper(t)[:, :, None, None].chunk(2, dim=1)
|
||||
@@ -29,13 +29,10 @@ class ResBlock(nn.Module):
|
||||
def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0):
|
||||
super().__init__()
|
||||
|
||||
conv_cls = nn.Conv2d
|
||||
linear_cls = nn.Linear
|
||||
|
||||
self.depthwise = conv_cls(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
|
||||
self.depthwise = nn.Conv2d(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
|
||||
self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
|
||||
self.channelwise = nn.Sequential(
|
||||
linear_cls(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), linear_cls(c * 4, c)
|
||||
nn.Linear(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), nn.Linear(c * 4, c)
|
||||
)
|
||||
|
||||
def forward(self, x, x_skip=None):
|
||||
@@ -64,12 +61,10 @@ class AttnBlock(nn.Module):
|
||||
def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
|
||||
super().__init__()
|
||||
|
||||
linear_cls = nn.Linear
|
||||
|
||||
self.self_attn = self_attn
|
||||
self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
|
||||
self.attention = Attention(query_dim=c, heads=nhead, dim_head=c // nhead, dropout=dropout, bias=True)
|
||||
self.kv_mapper = nn.Sequential(nn.SiLU(), linear_cls(c_cond, c))
|
||||
self.kv_mapper = nn.Sequential(nn.SiLU(), nn.Linear(c_cond, c))
|
||||
|
||||
def forward(self, x, kv):
|
||||
kv = self.kv_mapper(kv)
|
||||
|
||||
@@ -40,15 +40,13 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
|
||||
@register_to_config
|
||||
def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dropout=0.1):
|
||||
super().__init__()
|
||||
conv_cls = nn.Conv2d
|
||||
linear_cls = nn.Linear
|
||||
|
||||
self.c_r = c_r
|
||||
self.projection = conv_cls(c_in, c, kernel_size=1)
|
||||
self.projection = nn.Conv2d(c_in, c, kernel_size=1)
|
||||
self.cond_mapper = nn.Sequential(
|
||||
linear_cls(c_cond, c),
|
||||
nn.Linear(c_cond, c),
|
||||
nn.LeakyReLU(0.2),
|
||||
linear_cls(c, c),
|
||||
nn.Linear(c, c),
|
||||
)
|
||||
|
||||
self.blocks = nn.ModuleList()
|
||||
@@ -58,7 +56,7 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
|
||||
self.blocks.append(AttnBlock(c, c, nhead, self_attn=True, dropout=dropout))
|
||||
self.out = nn.Sequential(
|
||||
WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6),
|
||||
conv_cls(c, c_in * 2, kernel_size=1),
|
||||
nn.Conv2d(c, c_in * 2, kernel_size=1),
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
Reference in New Issue
Block a user