mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
feat: rename single-letter vars in resnet.py (#3868)
feat: rename single-letter vars
This commit is contained in:
@@ -95,9 +95,9 @@ class Downsample1D(nn.Module):
|
||||
assert self.channels == self.out_channels
|
||||
self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
return self.conv(x)
|
||||
def forward(self, inputs):
|
||||
assert inputs.shape[1] == self.channels
|
||||
return self.conv(inputs)
|
||||
|
||||
|
||||
class Upsample2D(nn.Module):
|
||||
@@ -431,13 +431,13 @@ class KDownsample2D(nn.Module):
|
||||
self.pad = kernel_1d.shape[1] // 2 - 1
|
||||
self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.pad(x, (self.pad,) * 4, self.pad_mode)
|
||||
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
|
||||
indices = torch.arange(x.shape[1], device=x.device)
|
||||
kernel = self.kernel.to(weight)[None, :].expand(x.shape[1], -1, -1)
|
||||
def forward(self, inputs):
|
||||
inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode)
|
||||
weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
|
||||
indices = torch.arange(inputs.shape[1], device=inputs.device)
|
||||
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
|
||||
weight[indices, indices] = kernel
|
||||
return F.conv2d(x, weight, stride=2)
|
||||
return F.conv2d(inputs, weight, stride=2)
|
||||
|
||||
|
||||
class KUpsample2D(nn.Module):
|
||||
@@ -448,13 +448,13 @@ class KUpsample2D(nn.Module):
|
||||
self.pad = kernel_1d.shape[1] // 2 - 1
|
||||
self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode)
|
||||
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
|
||||
indices = torch.arange(x.shape[1], device=x.device)
|
||||
kernel = self.kernel.to(weight)[None, :].expand(x.shape[1], -1, -1)
|
||||
def forward(self, inputs):
|
||||
inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode)
|
||||
weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
|
||||
indices = torch.arange(inputs.shape[1], device=inputs.device)
|
||||
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
|
||||
weight[indices, indices] = kernel
|
||||
return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1)
|
||||
return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1)
|
||||
|
||||
|
||||
class ResnetBlock2D(nn.Module):
|
||||
@@ -664,13 +664,13 @@ class Conv1dBlock(nn.Module):
|
||||
self.group_norm = nn.GroupNorm(n_groups, out_channels)
|
||||
self.mish = nn.Mish()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1d(x)
|
||||
x = rearrange_dims(x)
|
||||
x = self.group_norm(x)
|
||||
x = rearrange_dims(x)
|
||||
x = self.mish(x)
|
||||
return x
|
||||
def forward(self, inputs):
|
||||
intermediate_repr = self.conv1d(inputs)
|
||||
intermediate_repr = rearrange_dims(intermediate_repr)
|
||||
intermediate_repr = self.group_norm(intermediate_repr)
|
||||
intermediate_repr = rearrange_dims(intermediate_repr)
|
||||
output = self.mish(intermediate_repr)
|
||||
return output
|
||||
|
||||
|
||||
# unet_rl.py
|
||||
@@ -687,10 +687,10 @@ class ResidualTemporalBlock1D(nn.Module):
|
||||
nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(self, x, t):
|
||||
def forward(self, inputs, t):
|
||||
"""
|
||||
Args:
|
||||
x : [ batch_size x inp_channels x horizon ]
|
||||
inputs : [ batch_size x inp_channels x horizon ]
|
||||
t : [ batch_size x embed_dim ]
|
||||
|
||||
returns:
|
||||
@@ -698,9 +698,9 @@ class ResidualTemporalBlock1D(nn.Module):
|
||||
"""
|
||||
t = self.time_emb_act(t)
|
||||
t = self.time_emb(t)
|
||||
out = self.conv_in(x) + rearrange_dims(t)
|
||||
out = self.conv_in(inputs) + rearrange_dims(t)
|
||||
out = self.conv_out(out)
|
||||
return out + self.residual_conv(x)
|
||||
return out + self.residual_conv(inputs)
|
||||
|
||||
|
||||
def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
|
||||
|
||||
Reference in New Issue
Block a user