mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[VAE] fix the downsample block in Encoder. (#156)
* pass downsample_padding in encoder * update tests
This commit is contained in:
@@ -40,6 +40,7 @@ class Encoder(nn.Module):
|
||||
out_channels=output_channel,
|
||||
add_downsample=not is_final_block,
|
||||
resnet_eps=1e-6,
|
||||
downsample_padding=0,
|
||||
resnet_act_fn=act_fn,
|
||||
attn_num_head_channels=None,
|
||||
temb_channels=None,
|
||||
|
||||
@@ -555,11 +555,11 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"block_out_channels": [64],
|
||||
"block_out_channels": [32, 64],
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"down_block_types": ["DownEncoderBlock2D"],
|
||||
"up_block_types": ["UpDecoderBlock2D"],
|
||||
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
"up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
"latent_channels": 3,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
@@ -595,7 +595,7 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
output_slice = output[0, -1, -3:, -3:].flatten()
|
||||
# fmt: off
|
||||
expected_output_slice = torch.tensor([-1.1321, 0.1056, 0.3505, -0.6461, -0.2014, 0.0419, -0.5763, -0.8462, -0.4218])
|
||||
expected_output_slice = torch.tensor([-0.0153, -0.4044, -0.1880, -0.5161, -0.2418, -0.4072, -0.1612, -0.0633, -0.0143])
|
||||
# fmt: on
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
|
||||
|
||||
@@ -623,22 +623,11 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"ch": 64,
|
||||
"ch_mult": (1,),
|
||||
"embed_dim": 4,
|
||||
"in_channels": 3,
|
||||
"attn_resolutions": [],
|
||||
"num_res_blocks": 1,
|
||||
"out_ch": 3,
|
||||
"resolution": 32,
|
||||
"z_channels": 4,
|
||||
}
|
||||
init_dict = {
|
||||
"block_out_channels": [64],
|
||||
"block_out_channels": [32, 64],
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"down_block_types": ["DownEncoderBlock2D"],
|
||||
"up_block_types": ["UpDecoderBlock2D"],
|
||||
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
"up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
"latent_channels": 4,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
@@ -674,7 +663,7 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
output_slice = output[0, -1, -3:, -3:].flatten()
|
||||
# fmt: off
|
||||
expected_output_slice = torch.tensor([-0.3900, -0.2800, 0.1281, -0.4449, -0.4890, -0.0207, 0.0784, -0.1258, -0.0409])
|
||||
expected_output_slice = torch.tensor([-4.0078e-01, -3.8304e-04, -1.2681e-01, -1.1462e-01, 2.0095e-01, 1.0893e-01, -8.8248e-02, -3.0361e-01, -9.8646e-03])
|
||||
# fmt: on
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user