mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
new test case
This commit is contained in:
@@ -46,24 +46,34 @@ class AdapterTests:
|
||||
|
||||
def get_dummy_components(self, adapter_type):
|
||||
torch.manual_seed(0)
|
||||
if adapter_type == 'light_adapter':
|
||||
channels = [32, 32, 32]
|
||||
else:
|
||||
channels = [32, 32, 32, 32]
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(32, 64),
|
||||
block_out_channels=[32, 32, 32, 32],
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("CrossAttnDownBlock2D", "DownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
down_block_types=(
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
up_block_types= ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
||||
cross_attention_dim=32,
|
||||
)
|
||||
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=[32, 64],
|
||||
block_out_channels=[32, 32, 32, 32],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
@@ -84,7 +94,7 @@ class AdapterTests:
|
||||
torch.manual_seed(0)
|
||||
adapter = T2IAdapter(
|
||||
in_channels=3,
|
||||
channels=[320, 640, 1280, 1280],
|
||||
channels=channels,
|
||||
num_res_blocks=2,
|
||||
downscale_factor=8,
|
||||
adapter_type=adapter_type,
|
||||
|
||||
Reference in New Issue
Block a user