1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Add upsample_size to AttnUpBlock2D, AttnDownBlock2D (#3275)

The argument `upsample_size` needs to be added to these modules to allow compatibility with other blocks that require this argument.
This commit is contained in:
Will Rice
2023-05-05 14:50:41 -04:00
committed by Daniel Gu
parent 652dbaad8d
commit 1d213def63

View File

@@ -734,7 +734,7 @@ class AttnDownBlock2D(nn.Module):
else:
self.downsamplers = None
def forward(self, hidden_states, temb=None):
def forward(self, hidden_states, temb=None, upsample_size=None):
output_states = ()
for resnet, attn in zip(self.resnets, self.attentions):
@@ -1720,7 +1720,7 @@ class AttnUpBlock2D(nn.Module):
else:
self.upsamplers = None
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]