1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

feat: rename single-letter vars in resnet.py (#3868)

feat: rename single-letter vars
This commit is contained in:
Saurav Maheshkar
2023-06-28 17:01:32 +05:30
committed by GitHub
parent 9a45d7fb76
commit 0bf6aeb885

View File

@@ -95,9 +95,9 @@ class Downsample1D(nn.Module):
assert self.channels == self.out_channels
self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
def forward(self, x):
assert x.shape[1] == self.channels
return self.conv(x)
def forward(self, inputs):
assert inputs.shape[1] == self.channels
return self.conv(inputs)
class Upsample2D(nn.Module):
@@ -431,13 +431,13 @@ class KDownsample2D(nn.Module):
self.pad = kernel_1d.shape[1] // 2 - 1
self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
def forward(self, x):
x = F.pad(x, (self.pad,) * 4, self.pad_mode)
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
indices = torch.arange(x.shape[1], device=x.device)
kernel = self.kernel.to(weight)[None, :].expand(x.shape[1], -1, -1)
def forward(self, inputs):
inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode)
weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
indices = torch.arange(inputs.shape[1], device=inputs.device)
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
weight[indices, indices] = kernel
return F.conv2d(x, weight, stride=2)
return F.conv2d(inputs, weight, stride=2)
class KUpsample2D(nn.Module):
@@ -448,13 +448,13 @@ class KUpsample2D(nn.Module):
self.pad = kernel_1d.shape[1] // 2 - 1
self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
def forward(self, x):
x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode)
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
indices = torch.arange(x.shape[1], device=x.device)
kernel = self.kernel.to(weight)[None, :].expand(x.shape[1], -1, -1)
def forward(self, inputs):
inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode)
weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
indices = torch.arange(inputs.shape[1], device=inputs.device)
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
weight[indices, indices] = kernel
return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1)
return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1)
class ResnetBlock2D(nn.Module):
@@ -664,13 +664,13 @@ class Conv1dBlock(nn.Module):
self.group_norm = nn.GroupNorm(n_groups, out_channels)
self.mish = nn.Mish()
def forward(self, x):
x = self.conv1d(x)
x = rearrange_dims(x)
x = self.group_norm(x)
x = rearrange_dims(x)
x = self.mish(x)
return x
def forward(self, inputs):
intermediate_repr = self.conv1d(inputs)
intermediate_repr = rearrange_dims(intermediate_repr)
intermediate_repr = self.group_norm(intermediate_repr)
intermediate_repr = rearrange_dims(intermediate_repr)
output = self.mish(intermediate_repr)
return output
# unet_rl.py
@@ -687,10 +687,10 @@ class ResidualTemporalBlock1D(nn.Module):
nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
)
def forward(self, x, t):
def forward(self, inputs, t):
"""
Args:
x : [ batch_size x inp_channels x horizon ]
inputs : [ batch_size x inp_channels x horizon ]
t : [ batch_size x embed_dim ]
returns:
@@ -698,9 +698,9 @@ class ResidualTemporalBlock1D(nn.Module):
"""
t = self.time_emb_act(t)
t = self.time_emb(t)
out = self.conv_in(x) + rearrange_dims(t)
out = self.conv_in(inputs) + rearrange_dims(t)
out = self.conv_out(out)
return out + self.residual_conv(x)
return out + self.residual_conv(inputs)
def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):