diff --git a/src/diffusers/models/autoencoder_asym_kl.py b/src/diffusers/models/autoencoder_asym_kl.py index 656683b43f..818e181fcd 100644 --- a/src/diffusers/models/autoencoder_asym_kl.py +++ b/src/diffusers/models/autoencoder_asym_kl.py @@ -108,6 +108,9 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin): self.use_slicing = False self.use_tiling = False + self.register_to_config(block_out_channels=up_block_out_channels) + self.register_to_config(force_upcast=False) + @apply_forward_hook def encode( self, x: torch.FloatTensor, return_dict: bool = True diff --git a/src/diffusers/models/autoencoder_tiny.py b/src/diffusers/models/autoencoder_tiny.py index d2d2f6f940..56ccf30e04 100644 --- a/src/diffusers/models/autoencoder_tiny.py +++ b/src/diffusers/models/autoencoder_tiny.py @@ -148,6 +148,9 @@ class AutoencoderTiny(ModelMixin, ConfigMixin): self.tile_sample_min_size = 512 self.tile_latent_min_size = self.tile_sample_min_size // self.spatial_scale_factor + self.register_to_config(block_out_channels=decoder_block_out_channels) + self.register_to_config(force_upcast=False) + def _set_gradient_checkpointing(self, module, value: bool = False) -> None: if isinstance(module, (EncoderTiny, DecoderTiny)): module.gradient_checkpointing = value diff --git a/src/diffusers/models/consistency_decoder_vae.py b/src/diffusers/models/consistency_decoder_vae.py index a2d82e2565..34176a35e8 100644 --- a/src/diffusers/models/consistency_decoder_vae.py +++ b/src/diffusers/models/consistency_decoder_vae.py @@ -138,6 +138,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): ) self.decoder_scheduler = ConsistencyDecoderScheduler() self.register_to_config(block_out_channels=encoder_block_out_channels) + self.register_to_config(force_upcast=False) self.register_buffer( "means", torch.tensor([0.38862467, 0.02253063, 0.07381133, -0.0171294])[None, :, None, None], diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index 9b5eb1b4c6..4272fa1247 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -76,9 +76,13 @@ EXAMPLE_DOC_STRING = """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents -def retrieve_latents(encoder_output, generator): - if hasattr(encoder_output, "latent_dist"): +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py index 8945bd3d9c..fa489941c9 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py @@ -92,9 +92,13 @@ EXAMPLE_DOC_STRING = """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents -def retrieve_latents(encoder_output, generator): - if hasattr(encoder_output, "latent_dist"): +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index 9e2e428eaf..7bbc4889e7 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -104,9 +104,13 @@ EXAMPLE_DOC_STRING = """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents -def retrieve_latents(encoder_output, generator): - if hasattr(encoder_output, "latent_dist"): +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index 3e5cba79f5..0f51ad58a5 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -54,6 +54,20 @@ if is_invisible_watermark_available(): logger = logging.get_logger(__name__) # pylint: disable=invalid-name +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -824,12 +838,12 @@ class StableDiffusionXLControlNetInpaintPipeline( if isinstance(generator, list): image_latents = [ - self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0]) ] image_latents = torch.cat(image_latents, dim=0) else: - image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) if self.vae.config.force_upcast: self.vae.to(dtype) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index 4fccd6a91b..ba18567b60 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -133,9 +133,13 @@ EXAMPLE_DOC_STRING = """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents -def retrieve_latents(encoder_output, generator): - if hasattr(encoder_output, "latent_dist"): +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py index 0e7bd6e722..ed29a93938 100644 --- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py @@ -44,9 +44,13 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents -def retrieve_latents(encoder_output, generator): - if hasattr(encoder_output, "latent_dist"): +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: diff --git a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py index 38b90b10ad..0a20981beb 100644 --- a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +++ b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py @@ -35,9 +35,13 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents -def retrieve_latents(encoder_output, generator): - if hasattr(encoder_output, "latent_dist"): +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py index 10adefcff0..e5c2c78720 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py @@ -61,6 +61,20 @@ def preprocess(image): return image +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + def posterior_sample(scheduler, latents, timestep, clean_latents, generator, eta): # 1. get previous step value (=t-1) prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps @@ -567,11 +581,12 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor if isinstance(generator, list): init_latents = [ - self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) ] init_latents = torch.cat(init_latents, dim=0) else: - init_latents = self.vae.encode(image).latent_dist.sample(generator) + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) init_latents = self.vae.config.scaling_factor * init_latents diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py index 6a712692ac..e431fee7bd 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py @@ -37,9 +37,13 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents -def retrieve_latents(encoder_output, generator): - if hasattr(encoder_output, "latent_dist"): +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 1bec0807a2..e3a1a0ed36 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -73,9 +73,13 @@ EXAMPLE_DOC_STRING = """ """ -def retrieve_latents(encoder_output, generator): - if hasattr(encoder_output, "latent_dist"): +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 251dfb5676..3570eaa6fd 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -160,9 +160,13 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents -def retrieve_latents(encoder_output, generator): - if hasattr(encoder_output, "latent_dist"): +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py index 49da65bfbe..d922803858 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py @@ -58,6 +58,20 @@ def preprocess(image): return image +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" Pipeline for pixel-level image editing by following text instructions (based on Stable Diffusion). @@ -320,7 +334,6 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion prompt_embeds.dtype, device, self.do_classifier_free_guidance, - generator, ) height, width = image_latents.shape[-2:] @@ -716,17 +729,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion if image.shape[1] == 4: image_latents = image else: - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if isinstance(generator, list): - image_latents = [self.vae.encode(image[i : i + 1]).latent_dist.mode() for i in range(batch_size)] - image_latents = torch.cat(image_latents, dim=0) - else: - image_latents = self.vae.encode(image).latent_dist.mode() + image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax") if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: # expand image_latents for batch_size diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index dc8b95bf99..436d816e5e 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -105,9 +105,13 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents -def retrieve_latents(encoder_output, generator): - if hasattr(encoder_output, "latent_dist"): +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index e49ec0d607..f54b680dfd 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -250,9 +250,13 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents -def retrieve_latents(encoder_output, generator): - if hasattr(encoder_output, "latent_dist"): +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index d639bee39a..b14c746f99 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -88,6 +88,20 @@ EXAMPLE_DOC_STRING = """ """ +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and @@ -533,17 +547,7 @@ class StableDiffusionXLInstructPix2PixPipeline( self.upcast_vae() image = image.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if isinstance(generator, list): - image_latents = [self.vae.encode(image[i : i + 1]).latent_dist.mode() for i in range(batch_size)] - image_latents = torch.cat(image_latents, dim=0) - else: - image_latents = self.vae.encode(image).latent_dist.mode() + image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax") # cast back to fp16 if needed if needs_upcasting: @@ -866,7 +870,6 @@ class StableDiffusionXLInstructPix2PixPipeline( prompt_embeds.dtype, device, do_classifier_free_guidance, - generator, ) # 7. Prepare latent variables diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py index dae7127c22..6779a7b820 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py @@ -79,6 +79,20 @@ EXAMPLE_DOC_STRING = """ """ +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]: # This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78 # reshape to ncfhw @@ -547,14 +561,14 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - elif isinstance(generator, list): init_latents = [ - self.vae.encode(video[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + retrieve_latents(self.vae.encode(video[i : i + 1]), generator=generator[i]) + for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: - init_latents = self.vae.encode(video).latent_dist.sample(generator) + init_latents = retrieve_latents(self.vae.encode(video), generator=generator) init_latents = self.vae.config.scaling_factor * init_latents diff --git a/tests/models/test_models_vae.py b/tests/models/test_models_vae.py index 842a08c90b..83788b836a 100644 --- a/tests/models/test_models_vae.py +++ b/tests/models/test_models_vae.py @@ -46,6 +46,82 @@ from .test_modeling_common import ModelTesterMixin, UNetTesterMixin enable_full_determinism() +def get_autoencoder_kl_config(block_out_channels=None, norm_num_groups=None): + block_out_channels = block_out_channels or [32, 64] + norm_num_groups = norm_num_groups or 32 + init_dict = { + "block_out_channels": block_out_channels, + "in_channels": 3, + "out_channels": 3, + "down_block_types": ["DownEncoderBlock2D"] * len(block_out_channels), + "up_block_types": ["UpDecoderBlock2D"] * len(block_out_channels), + "latent_channels": 4, + "norm_num_groups": norm_num_groups, + } + return init_dict + + +def get_asym_autoencoder_kl_config(block_out_channels=None, norm_num_groups=None): + block_out_channels = block_out_channels or [32, 64] + norm_num_groups = norm_num_groups or 32 + init_dict = { + "in_channels": 3, + "out_channels": 3, + "down_block_types": ["DownEncoderBlock2D"] * len(block_out_channels), + "down_block_out_channels": block_out_channels, + "layers_per_down_block": 1, + "up_block_types": ["UpDecoderBlock2D"] * len(block_out_channels), + "up_block_out_channels": block_out_channels, + "layers_per_up_block": 1, + "act_fn": "silu", + "latent_channels": 4, + "norm_num_groups": norm_num_groups, + "sample_size": 32, + "scaling_factor": 0.18215, + } + return init_dict + + +def get_autoencoder_tiny_config(block_out_channels=None): + block_out_channels = (len(block_out_channels) * [32]) if block_out_channels is not None else [32, 32] + init_dict = { + "in_channels": 3, + "out_channels": 3, + "encoder_block_out_channels": block_out_channels, + "decoder_block_out_channels": block_out_channels, + "num_encoder_blocks": [b // min(block_out_channels) for b in block_out_channels], + "num_decoder_blocks": [b // min(block_out_channels) for b in reversed(block_out_channels)], + } + return init_dict + + +def get_consistency_vae_config(block_out_channels=None, norm_num_groups=None): + block_out_channels = block_out_channels or [32, 64] + norm_num_groups = norm_num_groups or 32 + return { + "encoder_block_out_channels": block_out_channels, + "encoder_in_channels": 3, + "encoder_out_channels": 4, + "encoder_down_block_types": ["DownEncoderBlock2D"] * len(block_out_channels), + "decoder_add_attention": False, + "decoder_block_out_channels": block_out_channels, + "decoder_down_block_types": ["ResnetDownsampleBlock2D"] * len(block_out_channels), + "decoder_downsample_padding": 1, + "decoder_in_channels": 7, + "decoder_layers_per_block": 1, + "decoder_norm_eps": 1e-05, + "decoder_norm_num_groups": norm_num_groups, + "encoder_norm_num_groups": norm_num_groups, + "decoder_num_train_timesteps": 1024, + "decoder_out_channels": 6, + "decoder_resnet_time_scale_shift": "scale_shift", + "decoder_time_embedding_type": "learned", + "decoder_up_block_types": ["ResnetUpsampleBlock2D"] * len(block_out_channels), + "scaling_factor": 1, + "latent_channels": 4, + } + + class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): model_class = AutoencoderKL main_input_name = "sample" @@ -70,14 +146,7 @@ class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): return (3, 32, 32) def prepare_init_args_and_inputs_for_common(self): - init_dict = { - "block_out_channels": [32, 64], - "in_channels": 3, - "out_channels": 3, - "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"], - "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"], - "latent_channels": 4, - } + init_dict = get_autoencoder_kl_config() inputs_dict = self.dummy_input return init_dict, inputs_dict @@ -214,21 +283,7 @@ class AsymmetricAutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.T return (3, 32, 32) def prepare_init_args_and_inputs_for_common(self): - init_dict = { - "in_channels": 3, - "out_channels": 3, - "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"], - "down_block_out_channels": [32, 64], - "layers_per_down_block": 1, - "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"], - "up_block_out_channels": [32, 64], - "layers_per_up_block": 1, - "act_fn": "silu", - "latent_channels": 4, - "norm_num_groups": 32, - "sample_size": 32, - "scaling_factor": 0.18215, - } + init_dict = get_asym_autoencoder_kl_config() inputs_dict = self.dummy_input return init_dict, inputs_dict @@ -263,14 +318,7 @@ class AutoencoderTinyTests(ModelTesterMixin, unittest.TestCase): return (3, 32, 32) def prepare_init_args_and_inputs_for_common(self): - init_dict = { - "in_channels": 3, - "out_channels": 3, - "encoder_block_out_channels": (32, 32), - "decoder_block_out_channels": (32, 32), - "num_encoder_blocks": (1, 2), - "num_decoder_blocks": (2, 1), - } + init_dict = get_autoencoder_tiny_config() inputs_dict = self.dummy_input return init_dict, inputs_dict @@ -302,33 +350,7 @@ class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase): @property def init_dict(self): - return { - "encoder_block_out_channels": [32, 64], - "encoder_in_channels": 3, - "encoder_out_channels": 4, - "encoder_down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"], - "decoder_add_attention": False, - "decoder_block_out_channels": [32, 64], - "decoder_down_block_types": [ - "ResnetDownsampleBlock2D", - "ResnetDownsampleBlock2D", - ], - "decoder_downsample_padding": 1, - "decoder_in_channels": 7, - "decoder_layers_per_block": 1, - "decoder_norm_eps": 1e-05, - "decoder_norm_num_groups": 32, - "decoder_num_train_timesteps": 1024, - "decoder_out_channels": 6, - "decoder_resnet_time_scale_shift": "scale_shift", - "decoder_time_embedding_type": "learned", - "decoder_up_block_types": [ - "ResnetUpsampleBlock2D", - "ResnetUpsampleBlock2D", - ], - "scaling_factor": 1, - "latent_channels": 4, - } + return get_consistency_vae_config() def prepare_init_args_and_inputs_for_common(self): return self.init_dict, self.inputs_dict() diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index dfe523cda9..e111759211 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -17,7 +17,16 @@ from huggingface_hub import delete_repo from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer import diffusers -from diffusers import AutoencoderKL, DDIMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel +from diffusers import ( + AsymmetricAutoencoderKL, + AutoencoderKL, + AutoencoderTiny, + ConsistencyDecoderVAE, + DDIMScheduler, + DiffusionPipeline, + StableDiffusionPipeline, + UNet2DConditionModel, +) from diffusers.image_processor import VaeImageProcessor from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.utils import logging @@ -28,6 +37,12 @@ from diffusers.utils.testing_utils import ( torch_device, ) +from ..models.test_models_vae import ( + get_asym_autoencoder_kl_config, + get_autoencoder_kl_config, + get_autoencoder_tiny_config, + get_consistency_vae_config, +) from ..others.test_utils import TOKEN, USER, is_staging_test @@ -171,6 +186,34 @@ class PipelineLatentTesterMixin: max_diff = np.abs(out - out_latents_inputs).max() self.assertLess(max_diff, 1e-4, "passing latents as image input generate different result from passing image") + def test_multi_vae(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + block_out_channels = pipe.vae.config.block_out_channels + norm_num_groups = pipe.vae.config.norm_num_groups + + vae_classes = [AutoencoderKL, AsymmetricAutoencoderKL, ConsistencyDecoderVAE, AutoencoderTiny] + configs = [ + get_autoencoder_kl_config(block_out_channels, norm_num_groups), + get_asym_autoencoder_kl_config(block_out_channels, norm_num_groups), + get_consistency_vae_config(block_out_channels, norm_num_groups), + get_autoencoder_tiny_config(block_out_channels), + ] + + out_np = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0] + + for vae_cls, config in zip(vae_classes, configs): + vae = vae_cls(**config) + vae = vae.to(torch_device) + components["vae"] = vae + vae_pipe = self.pipeline_class(**components) + out_vae_np = vae_pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0] + + assert out_vae_np.shape == out_np.shape + @require_torch class PipelineKarrasSchedulerTesterMixin: