mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Implement framewise encoding/decoding in LTX Video VAE (#10488)
* add framewise decode * add framewise encode, refactor tiled encode/decode * add sanity test tiling for ltx * run make style * Update src/diffusers/models/autoencoders/autoencoder_kl_ltx.py Co-authored-by: Aryan <contact.aryanvs@gmail.com> --------- Co-authored-by: Pham Hong Vinh <vinhph3@vng.com.vn> Co-authored-by: Aryan <contact.aryanvs@gmail.com>
This commit is contained in:
@@ -1010,10 +1010,12 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
# The minimal tile height and width for spatial tiling to be used
|
||||
self.tile_sample_min_height = 512
|
||||
self.tile_sample_min_width = 512
|
||||
self.tile_sample_min_num_frames = 16
|
||||
|
||||
# The minimal distance between two spatial tiles
|
||||
self.tile_sample_stride_height = 448
|
||||
self.tile_sample_stride_width = 448
|
||||
self.tile_sample_stride_num_frames = 8
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, (LTXVideoEncoder3d, LTXVideoDecoder3d)):
|
||||
@@ -1023,8 +1025,10 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
self,
|
||||
tile_sample_min_height: Optional[int] = None,
|
||||
tile_sample_min_width: Optional[int] = None,
|
||||
tile_sample_min_num_frames: Optional[int] = None,
|
||||
tile_sample_stride_height: Optional[float] = None,
|
||||
tile_sample_stride_width: Optional[float] = None,
|
||||
tile_sample_stride_num_frames: Optional[float] = None,
|
||||
) -> None:
|
||||
r"""
|
||||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||
@@ -1046,8 +1050,10 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
self.use_tiling = True
|
||||
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
|
||||
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
|
||||
self.tile_sample_min_num_frames = tile_sample_min_num_frames or self.tile_sample_min_num_frames
|
||||
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
|
||||
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
|
||||
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
|
||||
|
||||
def disable_tiling(self) -> None:
|
||||
r"""
|
||||
@@ -1073,18 +1079,13 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = x.shape
|
||||
|
||||
if self.use_framewise_decoding and num_frames > self.tile_sample_min_num_frames:
|
||||
return self._temporal_tiled_encode(x)
|
||||
|
||||
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
|
||||
return self.tiled_encode(x)
|
||||
|
||||
if self.use_framewise_encoding:
|
||||
# TODO(aryan): requires investigation
|
||||
raise NotImplementedError(
|
||||
"Frame-wise encoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
|
||||
"quality issues caused by splitting inference across frame dimension. If you believe this "
|
||||
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
|
||||
)
|
||||
else:
|
||||
enc = self.encoder(x)
|
||||
enc = self.encoder(x)
|
||||
|
||||
return enc
|
||||
|
||||
@@ -1121,19 +1122,15 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
batch_size, num_channels, num_frames, height, width = z.shape
|
||||
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
||||
tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio
|
||||
tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
|
||||
|
||||
if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames:
|
||||
return self._temporal_tiled_decode(z, temb, return_dict=return_dict)
|
||||
|
||||
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
|
||||
return self.tiled_decode(z, temb, return_dict=return_dict)
|
||||
|
||||
if self.use_framewise_decoding:
|
||||
# TODO(aryan): requires investigation
|
||||
raise NotImplementedError(
|
||||
"Frame-wise decoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
|
||||
"quality issues caused by splitting inference across frame dimension. If you believe this "
|
||||
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
|
||||
)
|
||||
else:
|
||||
dec = self.decoder(z, temb)
|
||||
dec = self.decoder(z, temb)
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
@@ -1189,6 +1186,14 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
)
|
||||
return b
|
||||
|
||||
def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
|
||||
for x in range(blend_extent):
|
||||
b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (
|
||||
x / blend_extent
|
||||
)
|
||||
return b
|
||||
|
||||
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
r"""Encode a batch of images using a tiled encoder.
|
||||
|
||||
@@ -1217,17 +1222,9 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
for i in range(0, height, self.tile_sample_stride_height):
|
||||
row = []
|
||||
for j in range(0, width, self.tile_sample_stride_width):
|
||||
if self.use_framewise_encoding:
|
||||
# TODO(aryan): requires investigation
|
||||
raise NotImplementedError(
|
||||
"Frame-wise encoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
|
||||
"quality issues caused by splitting inference across frame dimension. If you believe this "
|
||||
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
|
||||
)
|
||||
else:
|
||||
time = self.encoder(
|
||||
x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
|
||||
)
|
||||
time = self.encoder(
|
||||
x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
|
||||
)
|
||||
|
||||
row.append(time)
|
||||
rows.append(row)
|
||||
@@ -1283,17 +1280,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
for i in range(0, height, tile_latent_stride_height):
|
||||
row = []
|
||||
for j in range(0, width, tile_latent_stride_width):
|
||||
if self.use_framewise_decoding:
|
||||
# TODO(aryan): requires investigation
|
||||
raise NotImplementedError(
|
||||
"Frame-wise decoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
|
||||
"quality issues caused by splitting inference across frame dimension. If you believe this "
|
||||
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
|
||||
)
|
||||
else:
|
||||
time = self.decoder(
|
||||
z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb
|
||||
)
|
||||
time = self.decoder(z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb)
|
||||
|
||||
row.append(time)
|
||||
rows.append(row)
|
||||
@@ -1318,6 +1305,74 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
def _temporal_tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
|
||||
batch_size, num_channels, num_frames, height, width = x.shape
|
||||
latent_num_frames = (num_frames - 1) // self.temporal_compression_ratio + 1
|
||||
|
||||
tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
|
||||
tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio
|
||||
blend_num_frames = tile_latent_min_num_frames - tile_latent_stride_num_frames
|
||||
|
||||
row = []
|
||||
for i in range(0, num_frames, self.tile_sample_stride_num_frames):
|
||||
tile = x[:, :, i : i + self.tile_sample_min_num_frames + 1, :, :]
|
||||
if self.use_tiling and (height > self.tile_sample_min_height or width > self.tile_sample_min_width):
|
||||
tile = self.tiled_encode(tile)
|
||||
else:
|
||||
tile = self.encoder(tile)
|
||||
if i > 0:
|
||||
tile = tile[:, :, 1:, :, :]
|
||||
row.append(tile)
|
||||
|
||||
result_row = []
|
||||
for i, tile in enumerate(row):
|
||||
if i > 0:
|
||||
tile = self.blend_t(row[i - 1], tile, blend_num_frames)
|
||||
result_row.append(tile[:, :, :tile_latent_stride_num_frames, :, :])
|
||||
else:
|
||||
result_row.append(tile[:, :, : tile_latent_stride_num_frames + 1, :, :])
|
||||
|
||||
enc = torch.cat(result_row, dim=2)[:, :, :latent_num_frames]
|
||||
return enc
|
||||
|
||||
def _temporal_tiled_decode(
|
||||
self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
batch_size, num_channels, num_frames, height, width = z.shape
|
||||
num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
|
||||
|
||||
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
||||
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
||||
tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
|
||||
tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio
|
||||
blend_num_frames = self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames
|
||||
|
||||
row = []
|
||||
for i in range(0, num_frames, tile_latent_stride_num_frames):
|
||||
tile = z[:, :, i : i + tile_latent_min_num_frames + 1, :, :]
|
||||
if self.use_tiling and (tile.shape[-1] > tile_latent_min_width or tile.shape[-2] > tile_latent_min_height):
|
||||
decoded = self.tiled_decode(tile, temb, return_dict=True).sample
|
||||
else:
|
||||
decoded = self.decoder(tile, temb)
|
||||
if i > 0:
|
||||
decoded = decoded[:, :, :-1, :, :]
|
||||
row.append(decoded)
|
||||
|
||||
result_row = []
|
||||
for i, tile in enumerate(row):
|
||||
if i > 0:
|
||||
tile = self.blend_t(row[i - 1], tile, blend_num_frames)
|
||||
tile = tile[:, :, : self.tile_sample_stride_num_frames, :, :]
|
||||
result_row.append(tile)
|
||||
else:
|
||||
result_row.append(tile[:, :, : self.tile_sample_stride_num_frames + 1, :, :])
|
||||
|
||||
dec = torch.cat(result_row, dim=2)[:, :, :num_sample_frames]
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
@@ -1334,5 +1389,5 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z, temb)
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
return (dec.sample,)
|
||||
return dec
|
||||
|
||||
@@ -167,3 +167,34 @@ class AutoencoderKLLTXVideo091Tests(ModelTesterMixin, UNetTesterMixin, unittest.
|
||||
@unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
|
||||
def test_enable_disable_tiling(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
inputs_dict.update({"return_dict": False})
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.enable_tiling()
|
||||
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertLess(
|
||||
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
|
||||
0.5,
|
||||
"VAE tiling should not affect the inference results",
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.disable_tiling()
|
||||
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertEqual(
|
||||
output_without_tiling.detach().cpu().numpy().all(),
|
||||
output_without_tiling_2.detach().cpu().numpy().all(),
|
||||
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user