mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
clean Linear
This commit is contained in:
@@ -157,6 +157,13 @@ def Conv2d(in_planes, out_planes, kernel_size=3, stride=1, bias=True, init_scale
|
||||
return conv
|
||||
|
||||
|
||||
def Linear(dim_in, dim_out):
|
||||
linear = nn.Linear(dim_in, dim_out)
|
||||
linear.weight.data = _variance_scaling()(linear.weight.shape)
|
||||
nn.init.zeros_(linear.bias)
|
||||
return linear
|
||||
|
||||
|
||||
class Combine(nn.Module):
|
||||
"""Combine information from skip connections."""
|
||||
|
||||
@@ -296,13 +303,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||
else:
|
||||
raise ValueError(f"embedding type {embedding_type} unknown.")
|
||||
|
||||
if conditional:
|
||||
modules.append(nn.Linear(embed_dim, nf * 4))
|
||||
modules[-1].weight.data = _variance_scaling()(modules[-1].weight.shape)
|
||||
nn.init.zeros_(modules[-1].bias)
|
||||
modules.append(nn.Linear(nf * 4, nf * 4))
|
||||
modules[-1].weight.data = _variance_scaling()(modules[-1].weight.shape)
|
||||
nn.init.zeros_(modules[-1].bias)
|
||||
modules.append(Linear(embed_dim, nf * 4))
|
||||
modules.append(Linear(nf * 4, nf * 4))
|
||||
|
||||
AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0))
|
||||
Up_sample = functools.partial(FirUpsample, with_conv=resamp_with_conv, fir_kernel=fir_kernel)
|
||||
|
||||
Reference in New Issue
Block a user