From 1d213def63ffc3ff708b534768e48d7db2a8f10c Mon Sep 17 00:00:00 2001 From: Will Rice Date: Fri, 5 May 2023 14:50:41 -0400 Subject: [PATCH] Add upsample_size to AttnUpBlock2D, AttnDownBlock2D (#3275) The argument `upsample_size` needs to be added to these modules to allow compatibility with other blocks that require this argument. --- src/diffusers/models/unet_2d_blocks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 57153fa398..2f7b19b732 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -734,7 +734,7 @@ class AttnDownBlock2D(nn.Module): else: self.downsamplers = None - def forward(self, hidden_states, temb=None): + def forward(self, hidden_states, temb=None, upsample_size=None): output_states = () for resnet, attn in zip(self.resnets, self.attentions): @@ -1720,7 +1720,7 @@ class AttnUpBlock2D(nn.Module): else: self.upsamplers = None - def forward(self, hidden_states, res_hidden_states_tuple, temb=None): + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1]