From bf890bca0e8aed875d6a207f9b826ce894901522 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 1 Sep 2024 09:07:13 +0200 Subject: [PATCH] fake context parallel cache, vae encode tiling --- .../autoencoders/autoencoder_kl_cogvideox.py | 105 +++++++++++++++++- .../pipeline_cogvideox_video2video.py | 26 ++--- 2 files changed, 112 insertions(+), 19 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index 17fa2bbf40..fe887b7db0 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -999,6 +999,7 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): # setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different # number of temporal frames. self.num_latent_frames_batch_size = 2 + self.num_sample_frames_batch_size = 8 # We make the minimum height and width of sample for tiling half that of the generally supported self.tile_sample_min_height = sample_height // 2 @@ -1081,6 +1082,29 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): """ self.use_slicing = False + def _encode(self, x: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = x.shape + + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x) + + frame_batch_size = self.num_sample_frames_batch_size + enc = [] + for i in range(num_frames // frame_batch_size): + remaining_frames = num_frames % frame_batch_size + start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames) + end_frame = frame_batch_size * (i + 1) + remaining_frames + x_intermediate = x[:, :, start_frame:end_frame] + x_intermediate = self.encoder(x_intermediate) + if self.quant_conv is not None: + x_intermediate = self.quant_conv(x_intermediate) + enc.append(x_intermediate) + + self._clear_fake_context_parallel_cache() + enc = torch.cat(enc, dim=2) + + return enc + @apply_forward_hook def encode( self, x: torch.Tensor, return_dict: bool = True @@ -1094,13 +1118,17 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. Returns: - The latent representations of the encoded images. If `return_dict` is True, a + The latent representations of the encoded videos. If `return_dict` is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. """ - h = self.encoder(x) - if self.quant_conv is not None: - h = self.quant_conv(h) + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + posterior = DiagonalGaussianDistribution(h) + if not return_dict: return (posterior,) return AutoencoderKLOutput(latent_dist=posterior) @@ -1172,6 +1200,75 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): ) return b + def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is + different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + output, but they should be much less noticeable. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + # For a rough memory estimate, take a look at the `tiled_decode` method. + batch_size, num_channels, num_frames, height, width = x.shape + + overlap_height = int(self.tile_sample_min_height * (1 - self.tile_overlap_factor_height)) + overlap_width = int(self.tile_sample_min_width * (1 - self.tile_overlap_factor_width)) + blend_extent_height = int(self.tile_latent_min_height * self.tile_overlap_factor_height) + blend_extent_width = int(self.tile_latent_min_width * self.tile_overlap_factor_width) + row_limit_height = self.tile_latent_min_height - blend_extent_height + row_limit_width = self.tile_latent_min_width - blend_extent_width + frame_batch_size = self.num_sample_frames_batch_size + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, overlap_height): + row = [] + for j in range(0, width, overlap_width): + time = [] + for k in range(num_frames // frame_batch_size): + remaining_frames = num_frames % frame_batch_size + start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames) + end_frame = frame_batch_size * (k + 1) + remaining_frames + tile = x[ + :, + :, + start_frame:end_frame, + i : i + self.tile_sample_min_height, + j : j + self.tile_sample_min_width, + ] + tile = self.encoder(tile) + if self.quant_conv is not None: + tile = self.quant_conv(tile) + time.append(tile) + self._clear_fake_context_parallel_cache() + row.append(torch.cat(time, dim=2)) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent_width) + result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width]) + result_rows.append(torch.cat(result_row, dim=4)) + + enc = torch.cat(result_rows, dim=3) + return enc + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: r""" Decode a batch of images using a tiled decoder. diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py index bc96a4ef12..fe31dd43b7 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py @@ -341,7 +341,6 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline): video: Optional[torch.Tensor] = None, batch_size: int = 1, num_channels_latents: int = 16, - num_frames: int = 13, height: int = 60, width: int = 90, dtype: Optional[torch.dtype] = None, @@ -350,13 +349,16 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline): latents: Optional[torch.Tensor] = None, timestep: Optional[torch.Tensor] = None, ): + num_frames = (video.size(2) - 1) // self.vae_scale_factor_temporal + 1 if latents is None else latents.size(1) + shape = ( batch_size, - (num_frames - 1) // self.vae_scale_factor_temporal + 1, + num_frames, num_channels_latents, height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial, ) + if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" @@ -432,6 +434,8 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline): strength, negative_prompt, callback_on_step_end_tensor_inputs, + video=None, + latents=None, prompt_embeds=None, negative_prompt_embeds=None, ): @@ -479,6 +483,9 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline): f" {negative_prompt_embeds.shape}." ) + if video is not None and latents is not None: + raise ValueError("Only one of `video` or `latents` should be provided") + def fuse_qkv_projections(self) -> None: r"""Enables fused QKV projections.""" self.fusing_transformer = True @@ -539,7 +546,6 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline): negative_prompt: Optional[Union[str, List[str]]] = None, height: int = 480, width: int = 720, - num_frames: int = 49, num_inference_steps: int = 50, timesteps: Optional[List[int]] = None, strength: float = 0.8, @@ -576,11 +582,6 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline): The height in pixels of the generated image. This is set to 1024 by default for the best results. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. This is set to 1024 by default for the best results. - num_frames (`int`, defaults to `48`): - Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will - contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where - num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that - needs to be satisfied is that of divisibility mentioned above. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. @@ -639,11 +640,6 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline): `tuple`. When returning a tuple, the first element is a list with the generated images. """ - if num_frames > 49: - raise ValueError( - "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation." - ) - if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs @@ -700,16 +696,16 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline): latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) self._num_timesteps = len(timesteps) - # 5. Prepare latents. + # 5. Prepare latents if latents is None: video = self.video_processor.preprocess_video(video, height=height, width=width) video = video.to(device=device, dtype=prompt_embeds.dtype) + latent_channels = self.transformer.config.in_channels latents = self.prepare_latents( video, batch_size * num_videos_per_prompt, latent_channels, - num_frames, height, width, prompt_embeds.dtype,