mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add upsample_size to AttnUpBlock2D, AttnDownBlock2D (#3275)
The argument `upsample_size` needs to be added to these modules to allow compatibility with other blocks that require this argument.
This commit is contained in:
@@ -734,7 +734,7 @@ class AttnDownBlock2D(nn.Module):
|
||||
else:
|
||||
self.downsamplers = None
|
||||
|
||||
def forward(self, hidden_states, temb=None):
|
||||
def forward(self, hidden_states, temb=None, upsample_size=None):
|
||||
output_states = ()
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
@@ -1720,7 +1720,7 @@ class AttnUpBlock2D(nn.Module):
|
||||
else:
|
||||
self.upsamplers = None
|
||||
|
||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
|
||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
|
||||
Reference in New Issue
Block a user