mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
fake context parallel cache, vae encode tiling
This commit is contained in:
@@ -999,6 +999,7 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
# setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different
|
||||
# number of temporal frames.
|
||||
self.num_latent_frames_batch_size = 2
|
||||
self.num_sample_frames_batch_size = 8
|
||||
|
||||
# We make the minimum height and width of sample for tiling half that of the generally supported
|
||||
self.tile_sample_min_height = sample_height // 2
|
||||
@@ -1081,6 +1082,29 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
"""
|
||||
self.use_slicing = False
|
||||
|
||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = x.shape
|
||||
|
||||
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
|
||||
return self.tiled_encode(x)
|
||||
|
||||
frame_batch_size = self.num_sample_frames_batch_size
|
||||
enc = []
|
||||
for i in range(num_frames // frame_batch_size):
|
||||
remaining_frames = num_frames % frame_batch_size
|
||||
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
|
||||
end_frame = frame_batch_size * (i + 1) + remaining_frames
|
||||
x_intermediate = x[:, :, start_frame:end_frame]
|
||||
x_intermediate = self.encoder(x_intermediate)
|
||||
if self.quant_conv is not None:
|
||||
x_intermediate = self.quant_conv(x_intermediate)
|
||||
enc.append(x_intermediate)
|
||||
|
||||
self._clear_fake_context_parallel_cache()
|
||||
enc = torch.cat(enc, dim=2)
|
||||
|
||||
return enc
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
@@ -1094,13 +1118,17 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
The latent representations of the encoded images. If `return_dict` is True, a
|
||||
The latent representations of the encoded videos. If `return_dict` is True, a
|
||||
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
||||
"""
|
||||
h = self.encoder(x)
|
||||
if self.quant_conv is not None:
|
||||
h = self.quant_conv(h)
|
||||
if self.use_slicing and x.shape[0] > 1:
|
||||
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
|
||||
h = torch.cat(encoded_slices)
|
||||
else:
|
||||
h = self._encode(x)
|
||||
|
||||
posterior = DiagonalGaussianDistribution(h)
|
||||
|
||||
if not return_dict:
|
||||
return (posterior,)
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
@@ -1172,6 +1200,75 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
)
|
||||
return b
|
||||
|
||||
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
r"""Encode a batch of images using a tiled encoder.
|
||||
|
||||
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
||||
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
|
||||
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
|
||||
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
|
||||
output, but they should be much less noticeable.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`): Input batch of videos.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The latent representation of the encoded videos.
|
||||
"""
|
||||
# For a rough memory estimate, take a look at the `tiled_decode` method.
|
||||
batch_size, num_channels, num_frames, height, width = x.shape
|
||||
|
||||
overlap_height = int(self.tile_sample_min_height * (1 - self.tile_overlap_factor_height))
|
||||
overlap_width = int(self.tile_sample_min_width * (1 - self.tile_overlap_factor_width))
|
||||
blend_extent_height = int(self.tile_latent_min_height * self.tile_overlap_factor_height)
|
||||
blend_extent_width = int(self.tile_latent_min_width * self.tile_overlap_factor_width)
|
||||
row_limit_height = self.tile_latent_min_height - blend_extent_height
|
||||
row_limit_width = self.tile_latent_min_width - blend_extent_width
|
||||
frame_batch_size = self.num_sample_frames_batch_size
|
||||
|
||||
# Split x into overlapping tiles and encode them separately.
|
||||
# The tiles have an overlap to avoid seams between tiles.
|
||||
rows = []
|
||||
for i in range(0, height, overlap_height):
|
||||
row = []
|
||||
for j in range(0, width, overlap_width):
|
||||
time = []
|
||||
for k in range(num_frames // frame_batch_size):
|
||||
remaining_frames = num_frames % frame_batch_size
|
||||
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
|
||||
end_frame = frame_batch_size * (k + 1) + remaining_frames
|
||||
tile = x[
|
||||
:,
|
||||
:,
|
||||
start_frame:end_frame,
|
||||
i : i + self.tile_sample_min_height,
|
||||
j : j + self.tile_sample_min_width,
|
||||
]
|
||||
tile = self.encoder(tile)
|
||||
if self.quant_conv is not None:
|
||||
tile = self.quant_conv(tile)
|
||||
time.append(tile)
|
||||
self._clear_fake_context_parallel_cache()
|
||||
row.append(torch.cat(time, dim=2))
|
||||
rows.append(row)
|
||||
|
||||
result_rows = []
|
||||
for i, row in enumerate(rows):
|
||||
result_row = []
|
||||
for j, tile in enumerate(row):
|
||||
# blend the above tile and the left tile
|
||||
# to the current tile and add the current tile to the result row
|
||||
if i > 0:
|
||||
tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
|
||||
if j > 0:
|
||||
tile = self.blend_h(row[j - 1], tile, blend_extent_width)
|
||||
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
|
||||
result_rows.append(torch.cat(result_row, dim=4))
|
||||
|
||||
enc = torch.cat(result_rows, dim=3)
|
||||
return enc
|
||||
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
r"""
|
||||
Decode a batch of images using a tiled decoder.
|
||||
|
||||
@@ -341,7 +341,6 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline):
|
||||
video: Optional[torch.Tensor] = None,
|
||||
batch_size: int = 1,
|
||||
num_channels_latents: int = 16,
|
||||
num_frames: int = 13,
|
||||
height: int = 60,
|
||||
width: int = 90,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
@@ -350,13 +349,16 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline):
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
timestep: Optional[torch.Tensor] = None,
|
||||
):
|
||||
num_frames = (video.size(2) - 1) // self.vae_scale_factor_temporal + 1 if latents is None else latents.size(1)
|
||||
|
||||
shape = (
|
||||
batch_size,
|
||||
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
|
||||
num_frames,
|
||||
num_channels_latents,
|
||||
height // self.vae_scale_factor_spatial,
|
||||
width // self.vae_scale_factor_spatial,
|
||||
)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
@@ -432,6 +434,8 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline):
|
||||
strength,
|
||||
negative_prompt,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
video=None,
|
||||
latents=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
):
|
||||
@@ -479,6 +483,9 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline):
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
if video is not None and latents is not None:
|
||||
raise ValueError("Only one of `video` or `latents` should be provided")
|
||||
|
||||
def fuse_qkv_projections(self) -> None:
|
||||
r"""Enables fused QKV projections."""
|
||||
self.fusing_transformer = True
|
||||
@@ -539,7 +546,6 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline):
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
height: int = 480,
|
||||
width: int = 720,
|
||||
num_frames: int = 49,
|
||||
num_inference_steps: int = 50,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
strength: float = 0.8,
|
||||
@@ -576,11 +582,6 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline):
|
||||
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
num_frames (`int`, defaults to `48`):
|
||||
Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
|
||||
contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
|
||||
num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
|
||||
needs to be satisfied is that of divisibility mentioned above.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
@@ -639,11 +640,6 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline):
|
||||
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
if num_frames > 49:
|
||||
raise ValueError(
|
||||
"The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
|
||||
)
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
@@ -700,16 +696,16 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline):
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 5. Prepare latents.
|
||||
# 5. Prepare latents
|
||||
if latents is None:
|
||||
video = self.video_processor.preprocess_video(video, height=height, width=width)
|
||||
video = video.to(device=device, dtype=prompt_embeds.dtype)
|
||||
|
||||
latent_channels = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
video,
|
||||
batch_size * num_videos_per_prompt,
|
||||
latent_channels,
|
||||
num_frames,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
|
||||
Reference in New Issue
Block a user