mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
LTX 0.9.5 (#10968)
* update --------- Co-authored-by: YiYi Xu <yixu310@gmail.com> Co-authored-by: hlky <hlky@hlky.ac>
This commit is contained in:
@@ -196,6 +196,12 @@ export_to_video(video, "ship.mp4", fps=24)
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## LTXConditionPipeline
|
||||
|
||||
[[autodoc]] LTXConditionPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## LTXPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.ltx.pipeline_output.LTXPipelineOutput
|
||||
|
||||
@@ -74,6 +74,32 @@ VAE_091_RENAME_DICT = {
|
||||
"last_scale_shift_table": "scale_shift_table",
|
||||
}
|
||||
|
||||
VAE_095_RENAME_DICT = {
|
||||
# decoder
|
||||
"up_blocks.0": "mid_block",
|
||||
"up_blocks.1": "up_blocks.0.upsamplers.0",
|
||||
"up_blocks.2": "up_blocks.0",
|
||||
"up_blocks.3": "up_blocks.1.upsamplers.0",
|
||||
"up_blocks.4": "up_blocks.1",
|
||||
"up_blocks.5": "up_blocks.2.upsamplers.0",
|
||||
"up_blocks.6": "up_blocks.2",
|
||||
"up_blocks.7": "up_blocks.3.upsamplers.0",
|
||||
"up_blocks.8": "up_blocks.3",
|
||||
# encoder
|
||||
"down_blocks.0": "down_blocks.0",
|
||||
"down_blocks.1": "down_blocks.0.downsamplers.0",
|
||||
"down_blocks.2": "down_blocks.1",
|
||||
"down_blocks.3": "down_blocks.1.downsamplers.0",
|
||||
"down_blocks.4": "down_blocks.2",
|
||||
"down_blocks.5": "down_blocks.2.downsamplers.0",
|
||||
"down_blocks.6": "down_blocks.3",
|
||||
"down_blocks.7": "down_blocks.3.downsamplers.0",
|
||||
"down_blocks.8": "mid_block",
|
||||
# common
|
||||
"last_time_embedder": "time_embedder",
|
||||
"last_scale_shift_table": "scale_shift_table",
|
||||
}
|
||||
|
||||
VAE_SPECIAL_KEYS_REMAP = {
|
||||
"per_channel_statistics.channel": remove_keys_,
|
||||
"per_channel_statistics.mean-of-means": remove_keys_,
|
||||
@@ -81,10 +107,6 @@ VAE_SPECIAL_KEYS_REMAP = {
|
||||
"model.diffusion_model": remove_keys_,
|
||||
}
|
||||
|
||||
VAE_091_SPECIAL_KEYS_REMAP = {
|
||||
"timestep_scale_multiplier": remove_keys_,
|
||||
}
|
||||
|
||||
|
||||
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
state_dict = saved_dict
|
||||
@@ -104,12 +126,16 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key:
|
||||
def convert_transformer(
|
||||
ckpt_path: str,
|
||||
dtype: torch.dtype,
|
||||
version: str = "0.9.0",
|
||||
):
|
||||
PREFIX_KEY = "model.diffusion_model."
|
||||
|
||||
original_state_dict = get_state_dict(load_file(ckpt_path))
|
||||
config = {}
|
||||
if version == "0.9.5":
|
||||
config["_use_causal_rope_fix"] = True
|
||||
with init_empty_weights():
|
||||
transformer = LTXVideoTransformer3DModel()
|
||||
transformer = LTXVideoTransformer3DModel(**config)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
@@ -161,12 +187,19 @@ def get_vae_config(version: str) -> Dict[str, Any]:
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"block_out_channels": (128, 256, 512, 512),
|
||||
"down_block_types": (
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
),
|
||||
"decoder_block_out_channels": (128, 256, 512, 512),
|
||||
"layers_per_block": (4, 3, 3, 3, 4),
|
||||
"decoder_layers_per_block": (4, 3, 3, 3, 4),
|
||||
"spatio_temporal_scaling": (True, True, True, False),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True, False),
|
||||
"decoder_inject_noise": (False, False, False, False, False),
|
||||
"downsample_type": ("conv", "conv", "conv", "conv"),
|
||||
"upsample_residual": (False, False, False, False),
|
||||
"upsample_factor": (1, 1, 1, 1),
|
||||
"patch_size": 4,
|
||||
@@ -183,12 +216,19 @@ def get_vae_config(version: str) -> Dict[str, Any]:
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"block_out_channels": (128, 256, 512, 512),
|
||||
"down_block_types": (
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
),
|
||||
"decoder_block_out_channels": (256, 512, 1024),
|
||||
"layers_per_block": (4, 3, 3, 3, 4),
|
||||
"decoder_layers_per_block": (5, 6, 7, 8),
|
||||
"spatio_temporal_scaling": (True, True, True, False),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True),
|
||||
"decoder_inject_noise": (True, True, True, False),
|
||||
"downsample_type": ("conv", "conv", "conv", "conv"),
|
||||
"upsample_residual": (True, True, True),
|
||||
"upsample_factor": (2, 2, 2),
|
||||
"timestep_conditioning": True,
|
||||
@@ -200,7 +240,38 @@ def get_vae_config(version: str) -> Dict[str, Any]:
|
||||
"decoder_causal": False,
|
||||
}
|
||||
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
|
||||
VAE_SPECIAL_KEYS_REMAP.update(VAE_091_SPECIAL_KEYS_REMAP)
|
||||
elif version == "0.9.5":
|
||||
config = {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"block_out_channels": (128, 256, 512, 1024, 2048),
|
||||
"down_block_types": (
|
||||
"LTXVideo095DownBlock3D",
|
||||
"LTXVideo095DownBlock3D",
|
||||
"LTXVideo095DownBlock3D",
|
||||
"LTXVideo095DownBlock3D",
|
||||
),
|
||||
"decoder_block_out_channels": (256, 512, 1024),
|
||||
"layers_per_block": (4, 6, 6, 2, 2),
|
||||
"decoder_layers_per_block": (5, 5, 5, 5),
|
||||
"spatio_temporal_scaling": (True, True, True, True),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True),
|
||||
"decoder_inject_noise": (False, False, False, False),
|
||||
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
|
||||
"upsample_residual": (True, True, True),
|
||||
"upsample_factor": (2, 2, 2),
|
||||
"timestep_conditioning": True,
|
||||
"patch_size": 4,
|
||||
"patch_size_t": 1,
|
||||
"resnet_norm_eps": 1e-6,
|
||||
"scaling_factor": 1.0,
|
||||
"encoder_causal": True,
|
||||
"decoder_causal": False,
|
||||
"spatial_compression_ratio": 32,
|
||||
"temporal_compression_ratio": 8,
|
||||
}
|
||||
VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT)
|
||||
return config
|
||||
|
||||
|
||||
@@ -223,7 +294,7 @@ def get_args():
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
||||
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
|
||||
parser.add_argument(
|
||||
"--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1"], help="Version of the LTX model"
|
||||
"--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1", "0.9.5"], help="Version of the LTX model"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
@@ -277,14 +348,17 @@ if __name__ == "__main__":
|
||||
for param in text_encoder.parameters():
|
||||
param.data = param.data.contiguous()
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(
|
||||
use_dynamic_shifting=True,
|
||||
base_shift=0.95,
|
||||
max_shift=2.05,
|
||||
base_image_seq_len=1024,
|
||||
max_image_seq_len=4096,
|
||||
shift_terminal=0.1,
|
||||
)
|
||||
if args.version == "0.9.5":
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(use_dynamic_shifting=False)
|
||||
else:
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(
|
||||
use_dynamic_shifting=True,
|
||||
base_shift=0.95,
|
||||
max_shift=2.05,
|
||||
base_image_seq_len=1024,
|
||||
max_image_seq_len=4096,
|
||||
shift_terminal=0.1,
|
||||
)
|
||||
|
||||
pipe = LTXPipeline(
|
||||
scheduler=scheduler,
|
||||
|
||||
@@ -402,6 +402,7 @@ else:
|
||||
"LDMTextToImagePipeline",
|
||||
"LEditsPPPipelineStableDiffusion",
|
||||
"LEditsPPPipelineStableDiffusionXL",
|
||||
"LTXConditionPipeline",
|
||||
"LTXImageToVideoPipeline",
|
||||
"LTXPipeline",
|
||||
"Lumina2Pipeline",
|
||||
@@ -947,6 +948,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LDMTextToImagePipeline,
|
||||
LEditsPPPipelineStableDiffusion,
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
LTXConditionPipeline,
|
||||
LTXImageToVideoPipeline,
|
||||
LTXPipeline,
|
||||
Lumina2Pipeline,
|
||||
|
||||
@@ -196,6 +196,55 @@ class LTXVideoResnetBlock3d(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTXVideoDownsampler3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
stride: Union[int, Tuple[int, int, int]] = 1,
|
||||
is_causal: bool = True,
|
||||
padding_mode: str = "zeros",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
|
||||
self.group_size = (in_channels * stride[0] * stride[1] * stride[2]) // out_channels
|
||||
|
||||
out_channels = out_channels // (self.stride[0] * self.stride[1] * self.stride[2])
|
||||
|
||||
self.conv = LTXVideoCausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
is_causal=is_causal,
|
||||
padding_mode=padding_mode,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = torch.cat([hidden_states[:, :, : self.stride[0] - 1], hidden_states], dim=2)
|
||||
|
||||
residual = (
|
||||
hidden_states.unflatten(4, (-1, self.stride[2]))
|
||||
.unflatten(3, (-1, self.stride[1]))
|
||||
.unflatten(2, (-1, self.stride[0]))
|
||||
)
|
||||
residual = residual.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4)
|
||||
residual = residual.unflatten(1, (-1, self.group_size))
|
||||
residual = residual.mean(dim=2)
|
||||
|
||||
hidden_states = self.conv(hidden_states)
|
||||
hidden_states = (
|
||||
hidden_states.unflatten(4, (-1, self.stride[2]))
|
||||
.unflatten(3, (-1, self.stride[1]))
|
||||
.unflatten(2, (-1, self.stride[0]))
|
||||
)
|
||||
hidden_states = hidden_states.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTXVideoUpsampler3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -204,6 +253,7 @@ class LTXVideoUpsampler3d(nn.Module):
|
||||
is_causal: bool = True,
|
||||
residual: bool = False,
|
||||
upscale_factor: int = 1,
|
||||
padding_mode: str = "zeros",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -219,6 +269,7 @@ class LTXVideoUpsampler3d(nn.Module):
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
is_causal=is_causal,
|
||||
padding_mode=padding_mode,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
@@ -352,6 +403,118 @@ class LTXVideoDownBlock3D(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTXVideo095DownBlock3D(nn.Module):
|
||||
r"""
|
||||
Down block used in the LTXVideo model.
|
||||
|
||||
Args:
|
||||
in_channels (`int`):
|
||||
Number of input channels.
|
||||
out_channels (`int`, *optional*):
|
||||
Number of output channels. If None, defaults to `in_channels`.
|
||||
num_layers (`int`, defaults to `1`):
|
||||
Number of resnet layers.
|
||||
dropout (`float`, defaults to `0.0`):
|
||||
Dropout rate.
|
||||
resnet_eps (`float`, defaults to `1e-6`):
|
||||
Epsilon value for normalization layers.
|
||||
resnet_act_fn (`str`, defaults to `"swish"`):
|
||||
Activation function to use.
|
||||
spatio_temporal_scale (`bool`, defaults to `True`):
|
||||
Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
|
||||
Whether or not to downsample across temporal dimension.
|
||||
is_causal (`bool`, defaults to `True`):
|
||||
Whether this layer behaves causally (future frames depend only on past frames) or not.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
dropout: float = 0.0,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_act_fn: str = "swish",
|
||||
spatio_temporal_scale: bool = True,
|
||||
is_causal: bool = True,
|
||||
downsample_type: str = "conv",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
out_channels = out_channels or in_channels
|
||||
|
||||
resnets = []
|
||||
for _ in range(num_layers):
|
||||
resnets.append(
|
||||
LTXVideoResnetBlock3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
dropout=dropout,
|
||||
eps=resnet_eps,
|
||||
non_linearity=resnet_act_fn,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
self.downsamplers = None
|
||||
if spatio_temporal_scale:
|
||||
self.downsamplers = nn.ModuleList()
|
||||
|
||||
if downsample_type == "conv":
|
||||
self.downsamplers.append(
|
||||
LTXVideoCausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
kernel_size=3,
|
||||
stride=(2, 2, 2),
|
||||
is_causal=is_causal,
|
||||
)
|
||||
)
|
||||
elif downsample_type == "spatial":
|
||||
self.downsamplers.append(
|
||||
LTXVideoDownsampler3d(
|
||||
in_channels=in_channels, out_channels=out_channels, stride=(1, 2, 2), is_causal=is_causal
|
||||
)
|
||||
)
|
||||
elif downsample_type == "temporal":
|
||||
self.downsamplers.append(
|
||||
LTXVideoDownsampler3d(
|
||||
in_channels=in_channels, out_channels=out_channels, stride=(2, 1, 1), is_causal=is_causal
|
||||
)
|
||||
)
|
||||
elif downsample_type == "spatiotemporal":
|
||||
self.downsamplers.append(
|
||||
LTXVideoDownsampler3d(
|
||||
in_channels=in_channels, out_channels=out_channels, stride=(2, 2, 2), is_causal=is_causal
|
||||
)
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""Forward method of the `LTXDownBlock3D` class."""
|
||||
|
||||
for i, resnet in enumerate(self.resnets):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, generator)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoMidBlock3d
|
||||
class LTXVideoMidBlock3d(nn.Module):
|
||||
r"""
|
||||
@@ -593,8 +756,15 @@ class LTXVideoEncoder3d(nn.Module):
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 128,
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
),
|
||||
spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
|
||||
layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
|
||||
downsample_type: Tuple[str, ...] = ("conv", "conv", "conv", "conv"),
|
||||
patch_size: int = 4,
|
||||
patch_size_t: int = 1,
|
||||
resnet_norm_eps: float = 1e-6,
|
||||
@@ -617,20 +787,37 @@ class LTXVideoEncoder3d(nn.Module):
|
||||
)
|
||||
|
||||
# down blocks
|
||||
num_block_out_channels = len(block_out_channels)
|
||||
is_ltx_095 = down_block_types[-1] == "LTXVideo095DownBlock3D"
|
||||
num_block_out_channels = len(block_out_channels) - (1 if is_ltx_095 else 0)
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
for i in range(num_block_out_channels):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i]
|
||||
if not is_ltx_095:
|
||||
output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i]
|
||||
else:
|
||||
output_channel = block_out_channels[i + 1]
|
||||
|
||||
down_block = LTXVideoDownBlock3D(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
num_layers=layers_per_block[i],
|
||||
resnet_eps=resnet_norm_eps,
|
||||
spatio_temporal_scale=spatio_temporal_scaling[i],
|
||||
is_causal=is_causal,
|
||||
)
|
||||
if down_block_types[i] == "LTXVideoDownBlock3D":
|
||||
down_block = LTXVideoDownBlock3D(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
num_layers=layers_per_block[i],
|
||||
resnet_eps=resnet_norm_eps,
|
||||
spatio_temporal_scale=spatio_temporal_scaling[i],
|
||||
is_causal=is_causal,
|
||||
)
|
||||
elif down_block_types[i] == "LTXVideo095DownBlock3D":
|
||||
down_block = LTXVideo095DownBlock3D(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
num_layers=layers_per_block[i],
|
||||
resnet_eps=resnet_norm_eps,
|
||||
spatio_temporal_scale=spatio_temporal_scaling[i],
|
||||
is_causal=is_causal,
|
||||
downsample_type=downsample_type[i],
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown down block type: {down_block_types[i]}")
|
||||
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
@@ -794,7 +981,9 @@ class LTXVideoDecoder3d(nn.Module):
|
||||
# timestep embedding
|
||||
self.time_embedder = None
|
||||
self.scale_shift_table = None
|
||||
self.timestep_scale_multiplier = None
|
||||
if timestep_conditioning:
|
||||
self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0, dtype=torch.float32))
|
||||
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0)
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5)
|
||||
|
||||
@@ -803,6 +992,9 @@ class LTXVideoDecoder3d(nn.Module):
|
||||
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
|
||||
if self.timestep_scale_multiplier is not None:
|
||||
temb = temb * self.timestep_scale_multiplier
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, temb)
|
||||
|
||||
@@ -891,12 +1083,19 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
out_channels: int = 3,
|
||||
latent_channels: int = 128,
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
),
|
||||
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
|
||||
decoder_layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
|
||||
spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
|
||||
decoder_spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
|
||||
decoder_inject_noise: Tuple[bool, ...] = (False, False, False, False, False),
|
||||
downsample_type: Tuple[str, ...] = ("conv", "conv", "conv", "conv"),
|
||||
upsample_residual: Tuple[bool, ...] = (False, False, False, False),
|
||||
upsample_factor: Tuple[int, ...] = (1, 1, 1, 1),
|
||||
timestep_conditioning: bool = False,
|
||||
@@ -906,6 +1105,8 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
scaling_factor: float = 1.0,
|
||||
encoder_causal: bool = True,
|
||||
decoder_causal: bool = False,
|
||||
spatial_compression_ratio: int = None,
|
||||
temporal_compression_ratio: int = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -913,8 +1114,10 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
in_channels=in_channels,
|
||||
out_channels=latent_channels,
|
||||
block_out_channels=block_out_channels,
|
||||
down_block_types=down_block_types,
|
||||
spatio_temporal_scaling=spatio_temporal_scaling,
|
||||
layers_per_block=layers_per_block,
|
||||
downsample_type=downsample_type,
|
||||
patch_size=patch_size,
|
||||
patch_size_t=patch_size_t,
|
||||
resnet_norm_eps=resnet_norm_eps,
|
||||
@@ -941,8 +1144,16 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
self.register_buffer("latents_mean", latents_mean, persistent=True)
|
||||
self.register_buffer("latents_std", latents_std, persistent=True)
|
||||
|
||||
self.spatial_compression_ratio = patch_size * 2 ** sum(spatio_temporal_scaling)
|
||||
self.temporal_compression_ratio = patch_size_t * 2 ** sum(spatio_temporal_scaling)
|
||||
self.spatial_compression_ratio = (
|
||||
patch_size * 2 ** sum(spatio_temporal_scaling)
|
||||
if spatial_compression_ratio is None
|
||||
else spatial_compression_ratio
|
||||
)
|
||||
self.temporal_compression_ratio = (
|
||||
patch_size_t * 2 ** sum(spatio_temporal_scaling)
|
||||
if temporal_compression_ratio is None
|
||||
else temporal_compression_ratio
|
||||
)
|
||||
|
||||
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
|
||||
# to perform decoding of a single video latent at a time.
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -113,20 +113,19 @@ class LTXVideoRotaryPosEmbed(nn.Module):
|
||||
self.patch_size_t = patch_size_t
|
||||
self.theta = theta
|
||||
|
||||
def forward(
|
||||
def _prepare_video_coords(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
batch_size: int,
|
||||
num_frames: int,
|
||||
height: int,
|
||||
width: int,
|
||||
rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size = hidden_states.size(0)
|
||||
|
||||
rope_interpolation_scale: Tuple[torch.Tensor, float, float],
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
# Always compute rope in fp32
|
||||
grid_h = torch.arange(height, dtype=torch.float32, device=hidden_states.device)
|
||||
grid_w = torch.arange(width, dtype=torch.float32, device=hidden_states.device)
|
||||
grid_f = torch.arange(num_frames, dtype=torch.float32, device=hidden_states.device)
|
||||
grid_h = torch.arange(height, dtype=torch.float32, device=device)
|
||||
grid_w = torch.arange(width, dtype=torch.float32, device=device)
|
||||
grid_f = torch.arange(num_frames, dtype=torch.float32, device=device)
|
||||
grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij")
|
||||
grid = torch.stack(grid, dim=0)
|
||||
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
||||
@@ -138,6 +137,38 @@ class LTXVideoRotaryPosEmbed(nn.Module):
|
||||
|
||||
grid = grid.flatten(2, 4).transpose(1, 2)
|
||||
|
||||
return grid
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
num_frames: Optional[int] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None,
|
||||
video_coords: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size = hidden_states.size(0)
|
||||
|
||||
if video_coords is None:
|
||||
grid = self._prepare_video_coords(
|
||||
batch_size,
|
||||
num_frames,
|
||||
height,
|
||||
width,
|
||||
rope_interpolation_scale=rope_interpolation_scale,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
else:
|
||||
grid = torch.stack(
|
||||
[
|
||||
video_coords[:, 0] / self.base_num_frames,
|
||||
video_coords[:, 1] / self.base_height,
|
||||
video_coords[:, 2] / self.base_width,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
start = 1.0
|
||||
end = self.theta
|
||||
freqs = self.theta ** torch.linspace(
|
||||
@@ -367,10 +398,11 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
encoder_attention_mask: torch.Tensor,
|
||||
num_frames: int,
|
||||
height: int,
|
||||
width: int,
|
||||
rope_interpolation_scale: Optional[Tuple[float, float, float]] = None,
|
||||
num_frames: Optional[int] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
rope_interpolation_scale: Optional[Union[Tuple[float, float, float], torch.Tensor]] = None,
|
||||
video_coords: Optional[torch.Tensor] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> torch.Tensor:
|
||||
@@ -389,7 +421,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale)
|
||||
image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale, video_coords)
|
||||
|
||||
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
||||
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
||||
|
||||
@@ -264,7 +264,7 @@ else:
|
||||
]
|
||||
)
|
||||
_import_structure["latte"] = ["LattePipeline"]
|
||||
_import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline"]
|
||||
_import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline", "LTXConditionPipeline"]
|
||||
_import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"]
|
||||
_import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"]
|
||||
_import_structure["marigold"].extend(
|
||||
@@ -618,7 +618,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LEditsPPPipelineStableDiffusion,
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
)
|
||||
from .ltx import LTXImageToVideoPipeline, LTXPipeline
|
||||
from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXPipeline
|
||||
from .lumina import LuminaPipeline, LuminaText2ImgPipeline
|
||||
from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline
|
||||
from .marigold import (
|
||||
|
||||
@@ -23,6 +23,7 @@ except OptionalDependencyNotAvailable:
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_ltx"] = ["LTXPipeline"]
|
||||
_import_structure["pipeline_ltx_condition"] = ["LTXConditionPipeline"]
|
||||
_import_structure["pipeline_ltx_image2video"] = ["LTXImageToVideoPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
@@ -34,6 +35,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_ltx import LTXPipeline
|
||||
from .pipeline_ltx_condition import LTXConditionPipeline
|
||||
from .pipeline_ltx_image2video import LTXImageToVideoPipeline
|
||||
|
||||
else:
|
||||
|
||||
@@ -694,9 +694,8 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 6. Prepare micro-conditions
|
||||
latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio
|
||||
rope_interpolation_scale = (
|
||||
1 / latent_frame_rate,
|
||||
self.vae_temporal_compression_ratio / frame_rate,
|
||||
self.vae_spatial_compression_ratio,
|
||||
self.vae_spatial_compression_ratio,
|
||||
)
|
||||
|
||||
1174
src/diffusers/pipelines/ltx/pipeline_ltx_condition.py
Normal file
1174
src/diffusers/pipelines/ltx/pipeline_ltx_condition.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -764,9 +764,8 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 6. Prepare micro-conditions
|
||||
latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio
|
||||
rope_interpolation_scale = (
|
||||
1 / latent_frame_rate,
|
||||
self.vae_temporal_compression_ratio / frame_rate,
|
||||
self.vae_spatial_compression_ratio,
|
||||
self.vae_spatial_compression_ratio,
|
||||
)
|
||||
|
||||
@@ -377,6 +377,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
s_tmax: float = float("inf"),
|
||||
s_noise: float = 1.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
per_token_timesteps: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
|
||||
"""
|
||||
@@ -397,6 +398,8 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
Scaling factor for noise added to the sample.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A random number generator.
|
||||
per_token_timesteps (`torch.Tensor`, *optional*):
|
||||
The timesteps for each token in the sample.
|
||||
return_dict (`bool`):
|
||||
Whether or not to return a
|
||||
[`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or tuple.
|
||||
@@ -427,16 +430,26 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
# Upcast to avoid precision issues when computing prev_sample
|
||||
sample = sample.to(torch.float32)
|
||||
|
||||
sigma = self.sigmas[self.step_index]
|
||||
sigma_next = self.sigmas[self.step_index + 1]
|
||||
if per_token_timesteps is not None:
|
||||
per_token_sigmas = per_token_timesteps / self.config.num_train_timesteps
|
||||
|
||||
prev_sample = sample + (sigma_next - sigma) * model_output
|
||||
sigmas = self.sigmas[:, None, None]
|
||||
lower_mask = sigmas < per_token_sigmas[None] - 1e-6
|
||||
lower_sigmas = lower_mask * sigmas
|
||||
lower_sigmas, _ = lower_sigmas.max(dim=0)
|
||||
dt = (per_token_sigmas - lower_sigmas)[..., None]
|
||||
else:
|
||||
sigma = self.sigmas[self.step_index]
|
||||
sigma_next = self.sigmas[self.step_index + 1]
|
||||
dt = sigma_next - sigma
|
||||
|
||||
# Cast sample back to model compatible dtype
|
||||
prev_sample = prev_sample.to(model_output.dtype)
|
||||
prev_sample = sample + dt * model_output
|
||||
|
||||
# upon completion increase step index by one
|
||||
self._step_index += 1
|
||||
if per_token_timesteps is None:
|
||||
# Cast sample back to model compatible dtype
|
||||
prev_sample = prev_sample.to(model_output.dtype)
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
@@ -1217,6 +1217,21 @@ class LEditsPPPipelineStableDiffusionXL(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LTXConditionPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LTXImageToVideoPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
284
tests/pipelines/ltx/test_ltx_condition.py
Normal file
284
tests/pipelines/ltx/test_ltx_condition.py
Normal file
@@ -0,0 +1,284 @@
|
||||
# Copyright 2024 The HuggingFace Team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLLTXVideo,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
LTXConditionPipeline,
|
||||
LTXVideoTransformer3DModel,
|
||||
)
|
||||
from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXVideoCondition
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin, to_np
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class LTXConditionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = LTXConditionPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"image"})
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"generator",
|
||||
"latents",
|
||||
"return_dict",
|
||||
"callback_on_step_end",
|
||||
"callback_on_step_end_tensor_inputs",
|
||||
]
|
||||
)
|
||||
test_xformers_attention = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
transformer = LTXVideoTransformer3DModel(
|
||||
in_channels=8,
|
||||
out_channels=8,
|
||||
patch_size=1,
|
||||
patch_size_t=1,
|
||||
num_attention_heads=4,
|
||||
attention_head_dim=8,
|
||||
cross_attention_dim=32,
|
||||
num_layers=1,
|
||||
caption_channels=32,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKLLTXVideo(
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
latent_channels=8,
|
||||
block_out_channels=(8, 8, 8, 8),
|
||||
decoder_block_out_channels=(8, 8, 8, 8),
|
||||
layers_per_block=(1, 1, 1, 1, 1),
|
||||
decoder_layers_per_block=(1, 1, 1, 1, 1),
|
||||
spatio_temporal_scaling=(True, True, False, False),
|
||||
decoder_spatio_temporal_scaling=(True, True, False, False),
|
||||
decoder_inject_noise=(False, False, False, False, False),
|
||||
upsample_residual=(False, False, False, False),
|
||||
upsample_factor=(1, 1, 1, 1),
|
||||
timestep_conditioning=False,
|
||||
patch_size=1,
|
||||
patch_size_t=1,
|
||||
encoder_causal=True,
|
||||
decoder_causal=False,
|
||||
)
|
||||
vae.use_framewise_encoding = False
|
||||
vae.use_framewise_decoding = False
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
components = {
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0, use_conditions=False):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
image = torch.randn((1, 3, 32, 32), generator=generator, device=device)
|
||||
if use_conditions:
|
||||
conditions = LTXVideoCondition(
|
||||
image=image,
|
||||
)
|
||||
else:
|
||||
conditions = None
|
||||
|
||||
inputs = {
|
||||
"conditions": conditions,
|
||||
"image": None if use_conditions else image,
|
||||
"prompt": "dance monkey",
|
||||
"negative_prompt": "",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 3.0,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
# 8 * k + 1 is the recommendation
|
||||
"num_frames": 9,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs2 = self.get_dummy_inputs(device, use_conditions=True)
|
||||
video = pipe(**inputs).frames
|
||||
generated_video = video[0]
|
||||
video2 = pipe(**inputs2).frames
|
||||
generated_video2 = video2[0]
|
||||
|
||||
self.assertEqual(generated_video.shape, (9, 3, 32, 32))
|
||||
|
||||
max_diff = np.abs(generated_video - generated_video2).max()
|
||||
self.assertLessEqual(max_diff, 1e-3)
|
||||
|
||||
def test_callback_inputs(self):
|
||||
sig = inspect.signature(self.pipeline_class.__call__)
|
||||
has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
|
||||
has_callback_step_end = "callback_on_step_end" in sig.parameters
|
||||
|
||||
if not (has_callback_tensor_inputs and has_callback_step_end):
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
self.assertTrue(
|
||||
hasattr(pipe, "_callback_tensor_inputs"),
|
||||
f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
|
||||
)
|
||||
|
||||
def callback_inputs_subset(pipe, i, t, callback_kwargs):
|
||||
# iterate over callback args
|
||||
for tensor_name, tensor_value in callback_kwargs.items():
|
||||
# check that we're only passing in allowed tensor inputs
|
||||
assert tensor_name in pipe._callback_tensor_inputs
|
||||
|
||||
return callback_kwargs
|
||||
|
||||
def callback_inputs_all(pipe, i, t, callback_kwargs):
|
||||
for tensor_name in pipe._callback_tensor_inputs:
|
||||
assert tensor_name in callback_kwargs
|
||||
|
||||
# iterate over callback args
|
||||
for tensor_name, tensor_value in callback_kwargs.items():
|
||||
# check that we're only passing in allowed tensor inputs
|
||||
assert tensor_name in pipe._callback_tensor_inputs
|
||||
|
||||
return callback_kwargs
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
# Test passing in a subset
|
||||
inputs["callback_on_step_end"] = callback_inputs_subset
|
||||
inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
|
||||
output = pipe(**inputs)[0]
|
||||
|
||||
# Test passing in a everything
|
||||
inputs["callback_on_step_end"] = callback_inputs_all
|
||||
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
|
||||
output = pipe(**inputs)[0]
|
||||
|
||||
def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
|
||||
is_last = i == (pipe.num_timesteps - 1)
|
||||
if is_last:
|
||||
callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
|
||||
return callback_kwargs
|
||||
|
||||
inputs["callback_on_step_end"] = callback_inputs_change_tensor
|
||||
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
|
||||
output = pipe(**inputs)[0]
|
||||
assert output.abs().sum() < 1e10
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3)
|
||||
|
||||
def test_attention_slicing_forward_pass(
|
||||
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
|
||||
):
|
||||
if not self.test_attention_slicing:
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
for component in pipe.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_without_slicing = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_attention_slicing(slice_size=1)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_slicing1 = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_attention_slicing(slice_size=2)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_slicing2 = pipe(**inputs)[0]
|
||||
|
||||
if test_max_difference:
|
||||
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
|
||||
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
|
||||
self.assertLess(
|
||||
max(max_diff1, max_diff2),
|
||||
expected_max_diff,
|
||||
"Attention slicing should not affect the inference results",
|
||||
)
|
||||
|
||||
def test_vae_tiling(self, expected_diff_max: float = 0.2):
|
||||
generator_device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to("cpu")
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# Without tiling
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
inputs["height"] = inputs["width"] = 128
|
||||
output_without_tiling = pipe(**inputs)[0]
|
||||
|
||||
# With tiling
|
||||
pipe.vae.enable_tiling(
|
||||
tile_sample_min_height=96,
|
||||
tile_sample_min_width=96,
|
||||
tile_sample_stride_height=64,
|
||||
tile_sample_stride_width=64,
|
||||
)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
inputs["height"] = inputs["width"] = 128
|
||||
output_with_tiling = pipe(**inputs)[0]
|
||||
|
||||
self.assertLess(
|
||||
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
|
||||
expected_diff_max,
|
||||
"VAE tiling should not affect the inference results",
|
||||
)
|
||||
Reference in New Issue
Block a user