mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
[Vae] Make sure all vae's work with latent diffusion models (#5880)
* add comments to explain the code better * add comments to explain the code better * add comments to explain the code better * add comments to explain the code better * add comments to explain the code better * fix more * fix more * fix more * fix more * fix more * fix more
This commit is contained in:
committed by
GitHub
parent
20f0cbc88f
commit
e550163b9f
@@ -108,6 +108,9 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
|
||||
self.use_slicing = False
|
||||
self.use_tiling = False
|
||||
|
||||
self.register_to_config(block_out_channels=up_block_out_channels)
|
||||
self.register_to_config(force_upcast=False)
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.FloatTensor, return_dict: bool = True
|
||||
|
||||
@@ -148,6 +148,9 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
|
||||
self.tile_sample_min_size = 512
|
||||
self.tile_latent_min_size = self.tile_sample_min_size // self.spatial_scale_factor
|
||||
|
||||
self.register_to_config(block_out_channels=decoder_block_out_channels)
|
||||
self.register_to_config(force_upcast=False)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
|
||||
if isinstance(module, (EncoderTiny, DecoderTiny)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
@@ -138,6 +138,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
||||
)
|
||||
self.decoder_scheduler = ConsistencyDecoderScheduler()
|
||||
self.register_to_config(block_out_channels=encoder_block_out_channels)
|
||||
self.register_to_config(force_upcast=False)
|
||||
self.register_buffer(
|
||||
"means",
|
||||
torch.tensor([0.38862467, 0.02253063, 0.07381133, -0.0171294])[None, :, None, None],
|
||||
|
||||
@@ -76,9 +76,13 @@ EXAMPLE_DOC_STRING = """
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(encoder_output, generator):
|
||||
if hasattr(encoder_output, "latent_dist"):
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
|
||||
@@ -92,9 +92,13 @@ EXAMPLE_DOC_STRING = """
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(encoder_output, generator):
|
||||
if hasattr(encoder_output, "latent_dist"):
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
|
||||
@@ -104,9 +104,13 @@ EXAMPLE_DOC_STRING = """
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(encoder_output, generator):
|
||||
if hasattr(encoder_output, "latent_dist"):
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
|
||||
@@ -54,6 +54,20 @@ if is_invisible_watermark_available():
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
@@ -824,12 +838,12 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
|
||||
if isinstance(generator, list):
|
||||
image_latents = [
|
||||
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
|
||||
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
|
||||
for i in range(image.shape[0])
|
||||
]
|
||||
image_latents = torch.cat(image_latents, dim=0)
|
||||
else:
|
||||
image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)
|
||||
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
|
||||
|
||||
if self.vae.config.force_upcast:
|
||||
self.vae.to(dtype)
|
||||
|
||||
@@ -133,9 +133,13 @@ EXAMPLE_DOC_STRING = """
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(encoder_output, generator):
|
||||
if hasattr(encoder_output, "latent_dist"):
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
|
||||
@@ -44,9 +44,13 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(encoder_output, generator):
|
||||
if hasattr(encoder_output, "latent_dist"):
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
|
||||
@@ -35,9 +35,13 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(encoder_output, generator):
|
||||
if hasattr(encoder_output, "latent_dist"):
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
|
||||
@@ -61,6 +61,20 @@ def preprocess(image):
|
||||
return image
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
def posterior_sample(scheduler, latents, timestep, clean_latents, generator, eta):
|
||||
# 1. get previous step value (=t-1)
|
||||
prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps
|
||||
@@ -567,11 +581,12 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
|
||||
|
||||
if isinstance(generator, list):
|
||||
init_latents = [
|
||||
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
|
||||
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
|
||||
for i in range(image.shape[0])
|
||||
]
|
||||
init_latents = torch.cat(init_latents, dim=0)
|
||||
else:
|
||||
init_latents = self.vae.encode(image).latent_dist.sample(generator)
|
||||
init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
|
||||
|
||||
init_latents = self.vae.config.scaling_factor * init_latents
|
||||
|
||||
|
||||
@@ -37,9 +37,13 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(encoder_output, generator):
|
||||
if hasattr(encoder_output, "latent_dist"):
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
|
||||
@@ -73,9 +73,13 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
def retrieve_latents(encoder_output, generator):
|
||||
if hasattr(encoder_output, "latent_dist"):
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
|
||||
@@ -160,9 +160,13 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(encoder_output, generator):
|
||||
if hasattr(encoder_output, "latent_dist"):
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
|
||||
@@ -58,6 +58,20 @@ def preprocess(image):
|
||||
return image
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for pixel-level image editing by following text instructions (based on Stable Diffusion).
|
||||
@@ -320,7 +334,6 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
self.do_classifier_free_guidance,
|
||||
generator,
|
||||
)
|
||||
|
||||
height, width = image_latents.shape[-2:]
|
||||
@@ -716,17 +729,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
|
||||
if image.shape[1] == 4:
|
||||
image_latents = image
|
||||
else:
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if isinstance(generator, list):
|
||||
image_latents = [self.vae.encode(image[i : i + 1]).latent_dist.mode() for i in range(batch_size)]
|
||||
image_latents = torch.cat(image_latents, dim=0)
|
||||
else:
|
||||
image_latents = self.vae.encode(image).latent_dist.mode()
|
||||
image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax")
|
||||
|
||||
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
|
||||
# expand image_latents for batch_size
|
||||
|
||||
@@ -105,9 +105,13 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(encoder_output, generator):
|
||||
if hasattr(encoder_output, "latent_dist"):
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
|
||||
@@ -250,9 +250,13 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(encoder_output, generator):
|
||||
if hasattr(encoder_output, "latent_dist"):
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
|
||||
@@ -88,6 +88,20 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
"""
|
||||
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
||||
@@ -533,17 +547,7 @@ class StableDiffusionXLInstructPix2PixPipeline(
|
||||
self.upcast_vae()
|
||||
image = image.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if isinstance(generator, list):
|
||||
image_latents = [self.vae.encode(image[i : i + 1]).latent_dist.mode() for i in range(batch_size)]
|
||||
image_latents = torch.cat(image_latents, dim=0)
|
||||
else:
|
||||
image_latents = self.vae.encode(image).latent_dist.mode()
|
||||
image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax")
|
||||
|
||||
# cast back to fp16 if needed
|
||||
if needs_upcasting:
|
||||
@@ -866,7 +870,6 @@ class StableDiffusionXLInstructPix2PixPipeline(
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
do_classifier_free_guidance,
|
||||
generator,
|
||||
)
|
||||
|
||||
# 7. Prepare latent variables
|
||||
|
||||
@@ -79,6 +79,20 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]:
|
||||
# This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
|
||||
# reshape to ncfhw
|
||||
@@ -547,14 +561,14 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
elif isinstance(generator, list):
|
||||
init_latents = [
|
||||
self.vae.encode(video[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
|
||||
retrieve_latents(self.vae.encode(video[i : i + 1]), generator=generator[i])
|
||||
for i in range(batch_size)
|
||||
]
|
||||
init_latents = torch.cat(init_latents, dim=0)
|
||||
else:
|
||||
init_latents = self.vae.encode(video).latent_dist.sample(generator)
|
||||
init_latents = retrieve_latents(self.vae.encode(video), generator=generator)
|
||||
|
||||
init_latents = self.vae.config.scaling_factor * init_latents
|
||||
|
||||
|
||||
@@ -46,6 +46,82 @@ from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
def get_autoencoder_kl_config(block_out_channels=None, norm_num_groups=None):
|
||||
block_out_channels = block_out_channels or [32, 64]
|
||||
norm_num_groups = norm_num_groups or 32
|
||||
init_dict = {
|
||||
"block_out_channels": block_out_channels,
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"down_block_types": ["DownEncoderBlock2D"] * len(block_out_channels),
|
||||
"up_block_types": ["UpDecoderBlock2D"] * len(block_out_channels),
|
||||
"latent_channels": 4,
|
||||
"norm_num_groups": norm_num_groups,
|
||||
}
|
||||
return init_dict
|
||||
|
||||
|
||||
def get_asym_autoencoder_kl_config(block_out_channels=None, norm_num_groups=None):
|
||||
block_out_channels = block_out_channels or [32, 64]
|
||||
norm_num_groups = norm_num_groups or 32
|
||||
init_dict = {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"down_block_types": ["DownEncoderBlock2D"] * len(block_out_channels),
|
||||
"down_block_out_channels": block_out_channels,
|
||||
"layers_per_down_block": 1,
|
||||
"up_block_types": ["UpDecoderBlock2D"] * len(block_out_channels),
|
||||
"up_block_out_channels": block_out_channels,
|
||||
"layers_per_up_block": 1,
|
||||
"act_fn": "silu",
|
||||
"latent_channels": 4,
|
||||
"norm_num_groups": norm_num_groups,
|
||||
"sample_size": 32,
|
||||
"scaling_factor": 0.18215,
|
||||
}
|
||||
return init_dict
|
||||
|
||||
|
||||
def get_autoencoder_tiny_config(block_out_channels=None):
|
||||
block_out_channels = (len(block_out_channels) * [32]) if block_out_channels is not None else [32, 32]
|
||||
init_dict = {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"encoder_block_out_channels": block_out_channels,
|
||||
"decoder_block_out_channels": block_out_channels,
|
||||
"num_encoder_blocks": [b // min(block_out_channels) for b in block_out_channels],
|
||||
"num_decoder_blocks": [b // min(block_out_channels) for b in reversed(block_out_channels)],
|
||||
}
|
||||
return init_dict
|
||||
|
||||
|
||||
def get_consistency_vae_config(block_out_channels=None, norm_num_groups=None):
|
||||
block_out_channels = block_out_channels or [32, 64]
|
||||
norm_num_groups = norm_num_groups or 32
|
||||
return {
|
||||
"encoder_block_out_channels": block_out_channels,
|
||||
"encoder_in_channels": 3,
|
||||
"encoder_out_channels": 4,
|
||||
"encoder_down_block_types": ["DownEncoderBlock2D"] * len(block_out_channels),
|
||||
"decoder_add_attention": False,
|
||||
"decoder_block_out_channels": block_out_channels,
|
||||
"decoder_down_block_types": ["ResnetDownsampleBlock2D"] * len(block_out_channels),
|
||||
"decoder_downsample_padding": 1,
|
||||
"decoder_in_channels": 7,
|
||||
"decoder_layers_per_block": 1,
|
||||
"decoder_norm_eps": 1e-05,
|
||||
"decoder_norm_num_groups": norm_num_groups,
|
||||
"encoder_norm_num_groups": norm_num_groups,
|
||||
"decoder_num_train_timesteps": 1024,
|
||||
"decoder_out_channels": 6,
|
||||
"decoder_resnet_time_scale_shift": "scale_shift",
|
||||
"decoder_time_embedding_type": "learned",
|
||||
"decoder_up_block_types": ["ResnetUpsampleBlock2D"] * len(block_out_channels),
|
||||
"scaling_factor": 1,
|
||||
"latent_channels": 4,
|
||||
}
|
||||
|
||||
|
||||
class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKL
|
||||
main_input_name = "sample"
|
||||
@@ -70,14 +146,7 @@ class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
return (3, 32, 32)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"block_out_channels": [32, 64],
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
"up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
"latent_channels": 4,
|
||||
}
|
||||
init_dict = get_autoencoder_kl_config()
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
@@ -214,21 +283,7 @@ class AsymmetricAutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.T
|
||||
return (3, 32, 32)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
"down_block_out_channels": [32, 64],
|
||||
"layers_per_down_block": 1,
|
||||
"up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
"up_block_out_channels": [32, 64],
|
||||
"layers_per_up_block": 1,
|
||||
"act_fn": "silu",
|
||||
"latent_channels": 4,
|
||||
"norm_num_groups": 32,
|
||||
"sample_size": 32,
|
||||
"scaling_factor": 0.18215,
|
||||
}
|
||||
init_dict = get_asym_autoencoder_kl_config()
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
@@ -263,14 +318,7 @@ class AutoencoderTinyTests(ModelTesterMixin, unittest.TestCase):
|
||||
return (3, 32, 32)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"encoder_block_out_channels": (32, 32),
|
||||
"decoder_block_out_channels": (32, 32),
|
||||
"num_encoder_blocks": (1, 2),
|
||||
"num_decoder_blocks": (2, 1),
|
||||
}
|
||||
init_dict = get_autoencoder_tiny_config()
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
@@ -302,33 +350,7 @@ class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
@property
|
||||
def init_dict(self):
|
||||
return {
|
||||
"encoder_block_out_channels": [32, 64],
|
||||
"encoder_in_channels": 3,
|
||||
"encoder_out_channels": 4,
|
||||
"encoder_down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
"decoder_add_attention": False,
|
||||
"decoder_block_out_channels": [32, 64],
|
||||
"decoder_down_block_types": [
|
||||
"ResnetDownsampleBlock2D",
|
||||
"ResnetDownsampleBlock2D",
|
||||
],
|
||||
"decoder_downsample_padding": 1,
|
||||
"decoder_in_channels": 7,
|
||||
"decoder_layers_per_block": 1,
|
||||
"decoder_norm_eps": 1e-05,
|
||||
"decoder_norm_num_groups": 32,
|
||||
"decoder_num_train_timesteps": 1024,
|
||||
"decoder_out_channels": 6,
|
||||
"decoder_resnet_time_scale_shift": "scale_shift",
|
||||
"decoder_time_embedding_type": "learned",
|
||||
"decoder_up_block_types": [
|
||||
"ResnetUpsampleBlock2D",
|
||||
"ResnetUpsampleBlock2D",
|
||||
],
|
||||
"scaling_factor": 1,
|
||||
"latent_channels": 4,
|
||||
}
|
||||
return get_consistency_vae_config()
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return self.init_dict, self.inputs_dict()
|
||||
|
||||
@@ -17,7 +17,16 @@ from huggingface_hub import delete_repo
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
import diffusers
|
||||
from diffusers import AutoencoderKL, DDIMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel
|
||||
from diffusers import (
|
||||
AsymmetricAutoencoderKL,
|
||||
AutoencoderKL,
|
||||
AutoencoderTiny,
|
||||
ConsistencyDecoderVAE,
|
||||
DDIMScheduler,
|
||||
DiffusionPipeline,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.utils import logging
|
||||
@@ -28,6 +37,12 @@ from diffusers.utils.testing_utils import (
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..models.test_models_vae import (
|
||||
get_asym_autoencoder_kl_config,
|
||||
get_autoencoder_kl_config,
|
||||
get_autoencoder_tiny_config,
|
||||
get_consistency_vae_config,
|
||||
)
|
||||
from ..others.test_utils import TOKEN, USER, is_staging_test
|
||||
|
||||
|
||||
@@ -171,6 +186,34 @@ class PipelineLatentTesterMixin:
|
||||
max_diff = np.abs(out - out_latents_inputs).max()
|
||||
self.assertLess(max_diff, 1e-4, "passing latents as image input generate different result from passing image")
|
||||
|
||||
def test_multi_vae(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
block_out_channels = pipe.vae.config.block_out_channels
|
||||
norm_num_groups = pipe.vae.config.norm_num_groups
|
||||
|
||||
vae_classes = [AutoencoderKL, AsymmetricAutoencoderKL, ConsistencyDecoderVAE, AutoencoderTiny]
|
||||
configs = [
|
||||
get_autoencoder_kl_config(block_out_channels, norm_num_groups),
|
||||
get_asym_autoencoder_kl_config(block_out_channels, norm_num_groups),
|
||||
get_consistency_vae_config(block_out_channels, norm_num_groups),
|
||||
get_autoencoder_tiny_config(block_out_channels),
|
||||
]
|
||||
|
||||
out_np = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0]
|
||||
|
||||
for vae_cls, config in zip(vae_classes, configs):
|
||||
vae = vae_cls(**config)
|
||||
vae = vae.to(torch_device)
|
||||
components["vae"] = vae
|
||||
vae_pipe = self.pipeline_class(**components)
|
||||
out_vae_np = vae_pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0]
|
||||
|
||||
assert out_vae_np.shape == out_np.shape
|
||||
|
||||
|
||||
@require_torch
|
||||
class PipelineKarrasSchedulerTesterMixin:
|
||||
|
||||
Reference in New Issue
Block a user