1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

revert change to make less breaking

This commit is contained in:
Nathan Lambert
2022-11-14 14:12:47 -08:00
parent 8b7f2e301d
commit 64c5688284
2 changed files with 5 additions and 4 deletions

View File

@@ -57,7 +57,7 @@ class UNet1DModel(ModelMixin, ConfigMixin):
obj:`("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`): Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to :
obj:`(32, 32, 64)`): Tuple of block output channels.
up_down_block_layers (`int`, defaults to 2):
layers_per_block (`int`, defaults to 2):
number of resnet, attention, or other layers in the up and down blocks.
mid_block_layers (`int`, defaults to 5): number of resnet, attention, or other layers in the mid block.
mid_block_type (`str`, *optional*, defaults to "UNetMidBlock1D"): block type for middle of UNet.
@@ -85,7 +85,7 @@ class UNet1DModel(ModelMixin, ConfigMixin):
mid_block_type: Tuple[str] = "UNetMidBlock1D",
out_block_type: str = None,
block_out_channels: Tuple[int] = (32, 32, 64),
up_down_block_layers: int = 2,
layers_per_block: int = 2,
mid_block_layers: int = 5,
act_fn: str = None,
norm_num_groups: int = 8,
@@ -135,7 +135,7 @@ class UNet1DModel(ModelMixin, ConfigMixin):
down_block_type,
in_channels=input_channel,
out_channels=output_channel,
num_layers=up_down_block_layers,
num_layers=layers_per_block,
temb_channels=block_out_channels[0],
add_downsample=not is_final_block or downsample_each_block,
)
@@ -172,7 +172,7 @@ class UNet1DModel(ModelMixin, ConfigMixin):
up_block_type,
in_channels=prev_output_channel,
out_channels=output_channel,
num_layers=up_down_block_layers,
num_layers=layers_per_block,
temb_channels=block_out_channels[0],
add_upsample=not is_final_block,
)

View File

@@ -224,6 +224,7 @@ class UNetRLModelTests(ModelTesterMixin, unittest.TestCase):
"mid_block_type": "ValueFunctionMidBlock1D",
"block_out_channels": [32, 64, 128, 256],
"layers_per_block": 1,
"mid_block_layers": 1,
"downsample_each_block": True,
"use_timestep_embedding": True,
"freq_shift": 1.0,