1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

clean Linear

This commit is contained in:
patil-suraj
2022-06-30 13:31:47 +02:00
parent 3e2cff4da2
commit f35387b33f

View File

@@ -157,6 +157,13 @@ def Conv2d(in_planes, out_planes, kernel_size=3, stride=1, bias=True, init_scale
return conv
def Linear(dim_in, dim_out):
linear = nn.Linear(dim_in, dim_out)
linear.weight.data = _variance_scaling()(linear.weight.shape)
nn.init.zeros_(linear.bias)
return linear
class Combine(nn.Module):
"""Combine information from skip connections."""
@@ -296,13 +303,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
else:
raise ValueError(f"embedding type {embedding_type} unknown.")
if conditional:
modules.append(nn.Linear(embed_dim, nf * 4))
modules[-1].weight.data = _variance_scaling()(modules[-1].weight.shape)
nn.init.zeros_(modules[-1].bias)
modules.append(nn.Linear(nf * 4, nf * 4))
modules[-1].weight.data = _variance_scaling()(modules[-1].weight.shape)
nn.init.zeros_(modules[-1].bias)
modules.append(Linear(embed_dim, nf * 4))
modules.append(Linear(nf * 4, nf * 4))
AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0))
Up_sample = functools.partial(FirUpsample, with_conv=resamp_with_conv, fir_kernel=fir_kernel)