1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

[Vae] Make sure all vae's work with latent diffusion models (#5880)

* add comments to explain the code better

* add comments to explain the code better

* add comments to explain the code better

* add comments to explain the code better

* add comments to explain the code better

* fix more

* fix more

* fix more

* fix more

* fix more

* fix more
This commit is contained in:
Patrick von Platen
2023-11-27 14:17:47 +01:00
committed by GitHub
parent 20f0cbc88f
commit e550163b9f
21 changed files with 277 additions and 112 deletions

View File

@@ -108,6 +108,9 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
self.use_slicing = False
self.use_tiling = False
self.register_to_config(block_out_channels=up_block_out_channels)
self.register_to_config(force_upcast=False)
@apply_forward_hook
def encode(
self, x: torch.FloatTensor, return_dict: bool = True

View File

@@ -148,6 +148,9 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
self.tile_sample_min_size = 512
self.tile_latent_min_size = self.tile_sample_min_size // self.spatial_scale_factor
self.register_to_config(block_out_channels=decoder_block_out_channels)
self.register_to_config(force_upcast=False)
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
if isinstance(module, (EncoderTiny, DecoderTiny)):
module.gradient_checkpointing = value

View File

@@ -138,6 +138,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
)
self.decoder_scheduler = ConsistencyDecoderScheduler()
self.register_to_config(block_out_channels=encoder_block_out_channels)
self.register_to_config(force_upcast=False)
self.register_buffer(
"means",
torch.tensor([0.38862467, 0.02253063, 0.07381133, -0.0171294])[None, :, None, None],

View File

@@ -76,9 +76,13 @@ EXAMPLE_DOC_STRING = """
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output, generator):
if hasattr(encoder_output, "latent_dist"):
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:

View File

@@ -92,9 +92,13 @@ EXAMPLE_DOC_STRING = """
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output, generator):
if hasattr(encoder_output, "latent_dist"):
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:

View File

@@ -104,9 +104,13 @@ EXAMPLE_DOC_STRING = """
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output, generator):
if hasattr(encoder_output, "latent_dist"):
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:

View File

@@ -54,6 +54,20 @@ if is_invisible_watermark_available():
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -824,12 +838,12 @@ class StableDiffusionXLControlNetInpaintPipeline(
if isinstance(generator, list):
image_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
if self.vae.config.force_upcast:
self.vae.to(dtype)

View File

@@ -133,9 +133,13 @@ EXAMPLE_DOC_STRING = """
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output, generator):
if hasattr(encoder_output, "latent_dist"):
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:

View File

@@ -44,9 +44,13 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output, generator):
if hasattr(encoder_output, "latent_dist"):
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:

View File

@@ -35,9 +35,13 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output, generator):
if hasattr(encoder_output, "latent_dist"):
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:

View File

@@ -61,6 +61,20 @@ def preprocess(image):
return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
def posterior_sample(scheduler, latents, timestep, clean_latents, generator, eta):
# 1. get previous step value (=t-1)
prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps
@@ -567,11 +581,12 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
if isinstance(generator, list):
init_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0])
]
init_latents = torch.cat(init_latents, dim=0)
else:
init_latents = self.vae.encode(image).latent_dist.sample(generator)
init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
init_latents = self.vae.config.scaling_factor * init_latents

View File

@@ -37,9 +37,13 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output, generator):
if hasattr(encoder_output, "latent_dist"):
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:

View File

@@ -73,9 +73,13 @@ EXAMPLE_DOC_STRING = """
"""
def retrieve_latents(encoder_output, generator):
if hasattr(encoder_output, "latent_dist"):
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:

View File

@@ -160,9 +160,13 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output, generator):
if hasattr(encoder_output, "latent_dist"):
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:

View File

@@ -58,6 +58,20 @@ def preprocess(image):
return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
r"""
Pipeline for pixel-level image editing by following text instructions (based on Stable Diffusion).
@@ -320,7 +334,6 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
prompt_embeds.dtype,
device,
self.do_classifier_free_guidance,
generator,
)
height, width = image_latents.shape[-2:]
@@ -716,17 +729,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
if image.shape[1] == 4:
image_latents = image
else:
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if isinstance(generator, list):
image_latents = [self.vae.encode(image[i : i + 1]).latent_dist.mode() for i in range(batch_size)]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = self.vae.encode(image).latent_dist.mode()
image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax")
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
# expand image_latents for batch_size

View File

@@ -105,9 +105,13 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output, generator):
if hasattr(encoder_output, "latent_dist"):
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:

View File

@@ -250,9 +250,13 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output, generator):
if hasattr(encoder_output, "latent_dist"):
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:

View File

@@ -88,6 +88,20 @@ EXAMPLE_DOC_STRING = """
"""
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
@@ -533,17 +547,7 @@ class StableDiffusionXLInstructPix2PixPipeline(
self.upcast_vae()
image = image.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if isinstance(generator, list):
image_latents = [self.vae.encode(image[i : i + 1]).latent_dist.mode() for i in range(batch_size)]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = self.vae.encode(image).latent_dist.mode()
image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax")
# cast back to fp16 if needed
if needs_upcasting:
@@ -866,7 +870,6 @@ class StableDiffusionXLInstructPix2PixPipeline(
prompt_embeds.dtype,
device,
do_classifier_free_guidance,
generator,
)
# 7. Prepare latent variables

View File

@@ -79,6 +79,20 @@ EXAMPLE_DOC_STRING = """
"""
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]:
# This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
# reshape to ncfhw
@@ -547,14 +561,14 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
elif isinstance(generator, list):
init_latents = [
self.vae.encode(video[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
retrieve_latents(self.vae.encode(video[i : i + 1]), generator=generator[i])
for i in range(batch_size)
]
init_latents = torch.cat(init_latents, dim=0)
else:
init_latents = self.vae.encode(video).latent_dist.sample(generator)
init_latents = retrieve_latents(self.vae.encode(video), generator=generator)
init_latents = self.vae.config.scaling_factor * init_latents

View File

@@ -46,6 +46,82 @@ from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
enable_full_determinism()
def get_autoencoder_kl_config(block_out_channels=None, norm_num_groups=None):
block_out_channels = block_out_channels or [32, 64]
norm_num_groups = norm_num_groups or 32
init_dict = {
"block_out_channels": block_out_channels,
"in_channels": 3,
"out_channels": 3,
"down_block_types": ["DownEncoderBlock2D"] * len(block_out_channels),
"up_block_types": ["UpDecoderBlock2D"] * len(block_out_channels),
"latent_channels": 4,
"norm_num_groups": norm_num_groups,
}
return init_dict
def get_asym_autoencoder_kl_config(block_out_channels=None, norm_num_groups=None):
block_out_channels = block_out_channels or [32, 64]
norm_num_groups = norm_num_groups or 32
init_dict = {
"in_channels": 3,
"out_channels": 3,
"down_block_types": ["DownEncoderBlock2D"] * len(block_out_channels),
"down_block_out_channels": block_out_channels,
"layers_per_down_block": 1,
"up_block_types": ["UpDecoderBlock2D"] * len(block_out_channels),
"up_block_out_channels": block_out_channels,
"layers_per_up_block": 1,
"act_fn": "silu",
"latent_channels": 4,
"norm_num_groups": norm_num_groups,
"sample_size": 32,
"scaling_factor": 0.18215,
}
return init_dict
def get_autoencoder_tiny_config(block_out_channels=None):
block_out_channels = (len(block_out_channels) * [32]) if block_out_channels is not None else [32, 32]
init_dict = {
"in_channels": 3,
"out_channels": 3,
"encoder_block_out_channels": block_out_channels,
"decoder_block_out_channels": block_out_channels,
"num_encoder_blocks": [b // min(block_out_channels) for b in block_out_channels],
"num_decoder_blocks": [b // min(block_out_channels) for b in reversed(block_out_channels)],
}
return init_dict
def get_consistency_vae_config(block_out_channels=None, norm_num_groups=None):
block_out_channels = block_out_channels or [32, 64]
norm_num_groups = norm_num_groups or 32
return {
"encoder_block_out_channels": block_out_channels,
"encoder_in_channels": 3,
"encoder_out_channels": 4,
"encoder_down_block_types": ["DownEncoderBlock2D"] * len(block_out_channels),
"decoder_add_attention": False,
"decoder_block_out_channels": block_out_channels,
"decoder_down_block_types": ["ResnetDownsampleBlock2D"] * len(block_out_channels),
"decoder_downsample_padding": 1,
"decoder_in_channels": 7,
"decoder_layers_per_block": 1,
"decoder_norm_eps": 1e-05,
"decoder_norm_num_groups": norm_num_groups,
"encoder_norm_num_groups": norm_num_groups,
"decoder_num_train_timesteps": 1024,
"decoder_out_channels": 6,
"decoder_resnet_time_scale_shift": "scale_shift",
"decoder_time_embedding_type": "learned",
"decoder_up_block_types": ["ResnetUpsampleBlock2D"] * len(block_out_channels),
"scaling_factor": 1,
"latent_channels": 4,
}
class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = AutoencoderKL
main_input_name = "sample"
@@ -70,14 +146,7 @@ class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
return (3, 32, 32)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"block_out_channels": [32, 64],
"in_channels": 3,
"out_channels": 3,
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
"up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
"latent_channels": 4,
}
init_dict = get_autoencoder_kl_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@@ -214,21 +283,7 @@ class AsymmetricAutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.T
return (3, 32, 32)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"in_channels": 3,
"out_channels": 3,
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
"down_block_out_channels": [32, 64],
"layers_per_down_block": 1,
"up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
"up_block_out_channels": [32, 64],
"layers_per_up_block": 1,
"act_fn": "silu",
"latent_channels": 4,
"norm_num_groups": 32,
"sample_size": 32,
"scaling_factor": 0.18215,
}
init_dict = get_asym_autoencoder_kl_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@@ -263,14 +318,7 @@ class AutoencoderTinyTests(ModelTesterMixin, unittest.TestCase):
return (3, 32, 32)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"in_channels": 3,
"out_channels": 3,
"encoder_block_out_channels": (32, 32),
"decoder_block_out_channels": (32, 32),
"num_encoder_blocks": (1, 2),
"num_decoder_blocks": (2, 1),
}
init_dict = get_autoencoder_tiny_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@@ -302,33 +350,7 @@ class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase):
@property
def init_dict(self):
return {
"encoder_block_out_channels": [32, 64],
"encoder_in_channels": 3,
"encoder_out_channels": 4,
"encoder_down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
"decoder_add_attention": False,
"decoder_block_out_channels": [32, 64],
"decoder_down_block_types": [
"ResnetDownsampleBlock2D",
"ResnetDownsampleBlock2D",
],
"decoder_downsample_padding": 1,
"decoder_in_channels": 7,
"decoder_layers_per_block": 1,
"decoder_norm_eps": 1e-05,
"decoder_norm_num_groups": 32,
"decoder_num_train_timesteps": 1024,
"decoder_out_channels": 6,
"decoder_resnet_time_scale_shift": "scale_shift",
"decoder_time_embedding_type": "learned",
"decoder_up_block_types": [
"ResnetUpsampleBlock2D",
"ResnetUpsampleBlock2D",
],
"scaling_factor": 1,
"latent_channels": 4,
}
return get_consistency_vae_config()
def prepare_init_args_and_inputs_for_common(self):
return self.init_dict, self.inputs_dict()

View File

@@ -17,7 +17,16 @@ from huggingface_hub import delete_repo
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
import diffusers
from diffusers import AutoencoderKL, DDIMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel
from diffusers import (
AsymmetricAutoencoderKL,
AutoencoderKL,
AutoencoderTiny,
ConsistencyDecoderVAE,
DDIMScheduler,
DiffusionPipeline,
StableDiffusionPipeline,
UNet2DConditionModel,
)
from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import logging
@@ -28,6 +37,12 @@ from diffusers.utils.testing_utils import (
torch_device,
)
from ..models.test_models_vae import (
get_asym_autoencoder_kl_config,
get_autoencoder_kl_config,
get_autoencoder_tiny_config,
get_consistency_vae_config,
)
from ..others.test_utils import TOKEN, USER, is_staging_test
@@ -171,6 +186,34 @@ class PipelineLatentTesterMixin:
max_diff = np.abs(out - out_latents_inputs).max()
self.assertLess(max_diff, 1e-4, "passing latents as image input generate different result from passing image")
def test_multi_vae(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
block_out_channels = pipe.vae.config.block_out_channels
norm_num_groups = pipe.vae.config.norm_num_groups
vae_classes = [AutoencoderKL, AsymmetricAutoencoderKL, ConsistencyDecoderVAE, AutoencoderTiny]
configs = [
get_autoencoder_kl_config(block_out_channels, norm_num_groups),
get_asym_autoencoder_kl_config(block_out_channels, norm_num_groups),
get_consistency_vae_config(block_out_channels, norm_num_groups),
get_autoencoder_tiny_config(block_out_channels),
]
out_np = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0]
for vae_cls, config in zip(vae_classes, configs):
vae = vae_cls(**config)
vae = vae.to(torch_device)
components["vae"] = vae
vae_pipe = self.pipeline_class(**components)
out_vae_np = vae_pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0]
assert out_vae_np.shape == out_np.shape
@require_torch
class PipelineKarrasSchedulerTesterMixin: