From 7186bb45f00adb36a880bd30d41cfddb12faae11 Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 16 Dec 2024 19:31:02 +0000 Subject: [PATCH] Add enable_vae_tiling to AllegroPipeline, fix example (#10212) --- .../pipelines/allegro/pipeline_allegro.py | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/diffusers/pipelines/allegro/pipeline_allegro.py b/src/diffusers/pipelines/allegro/pipeline_allegro.py index 2be596cf8e..b3650dc6ce 100644 --- a/src/diffusers/pipelines/allegro/pipeline_allegro.py +++ b/src/diffusers/pipelines/allegro/pipeline_allegro.py @@ -59,6 +59,7 @@ EXAMPLE_DOC_STRING = """ >>> vae = AutoencoderKLAllegro.from_pretrained("rhymes-ai/Allegro", subfolder="vae", torch_dtype=torch.float32) >>> pipe = AllegroPipeline.from_pretrained("rhymes-ai/Allegro", vae=vae, torch_dtype=torch.bfloat16).to("cuda") + >>> pipe.enable_vae_tiling() >>> prompt = ( ... "A seaside harbor with bright sunlight and sparkling seawater, with many boats in the water. From an aerial view, " @@ -636,6 +637,35 @@ class AllegroPipeline(DiffusionPipeline): return (freqs_t, freqs_h, freqs_w), (grid_t, grid_h, grid_w) + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + @property def guidance_scale(self): return self._guidance_scale