1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Merge branch 'main' into pinned-context

This commit is contained in:
Dhruv Nair
2025-03-18 11:52:43 +01:00
19 changed files with 1987 additions and 137 deletions

View File

@@ -28,7 +28,51 @@ env:
PIPELINE_USAGE_CUTOFF: 1000000000 # set high cutoff so that only always-test pipelines run
jobs:
check_code_quality:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.8"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[quality]
- name: Check quality
run: make quality
- name: Check if failure
if: ${{ failure() }}
run: |
echo "Quality check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make style && make quality'" >> $GITHUB_STEP_SUMMARY
check_repository_consistency:
needs: check_code_quality
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.8"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[quality]
- name: Check repo consistency
run: |
python utils/check_copies.py
python utils/check_dummies.py
python utils/check_support_list.py
make deps_table_check_updated
- name: Check if failure
if: ${{ failure() }}
run: |
echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY
setup_torch_cuda_pipeline_matrix:
needs: [check_code_quality, check_repository_consistency]
name: Setup Torch Pipelines CUDA Slow Tests Matrix
runs-on:
group: aws-general-8-plus

View File

@@ -196,6 +196,12 @@ export_to_video(video, "ship.mp4", fps=24)
- all
- __call__
## LTXConditionPipeline
[[autodoc]] LTXConditionPipeline
- all
- __call__
## LTXPipelineOutput
[[autodoc]] pipelines.ltx.pipeline_output.LTXPipelineOutput

View File

@@ -1,8 +1,6 @@
# Generating images using Flux and PyTorch/XLA
The `flux_inference` script shows how to do image generation using Flux on TPU devices using PyTorch/XLA. It uses the pallas kernel for flash attention for faster generation.
It has been tested on [Trillium](https://cloud.google.com/blog/products/compute/introducing-trillium-6th-gen-tpus) TPU versions. No other TPU types have been tested.
The `flux_inference` script shows how to do image generation using Flux on TPU devices using PyTorch/XLA. It uses the pallas kernel for flash attention for faster generation using custom flash block sizes for better performance on [Trillium](https://cloud.google.com/blog/products/compute/introducing-trillium-6th-gen-tpus) TPU versions. No other TPU types have been tested.
## Create TPU
@@ -23,20 +21,23 @@ Verify that PyTorch and PyTorch/XLA were installed correctly:
python3 -c "import torch; import torch_xla;"
```
Install dependencies
Clone the diffusers repo and install dependencies
```bash
git clone https://github.com/huggingface/diffusers.git
cd diffusers
pip install transformers accelerate sentencepiece structlog
pushd ../../..
pip install .
popd
cd examples/research_projects/pytorch_xla/inference/flux/
```
## Run the inference job
### Authenticate
Run the following command to authenticate your token in order to download Flux weights.
**Gated Model**
As the model is gated, before using it with diffusers you first need to go to the [FLUX.1 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.1-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows youve accepted the gate. Use the command below to log in:
```bash
huggingface-cli login

View File

@@ -74,6 +74,32 @@ VAE_091_RENAME_DICT = {
"last_scale_shift_table": "scale_shift_table",
}
VAE_095_RENAME_DICT = {
# decoder
"up_blocks.0": "mid_block",
"up_blocks.1": "up_blocks.0.upsamplers.0",
"up_blocks.2": "up_blocks.0",
"up_blocks.3": "up_blocks.1.upsamplers.0",
"up_blocks.4": "up_blocks.1",
"up_blocks.5": "up_blocks.2.upsamplers.0",
"up_blocks.6": "up_blocks.2",
"up_blocks.7": "up_blocks.3.upsamplers.0",
"up_blocks.8": "up_blocks.3",
# encoder
"down_blocks.0": "down_blocks.0",
"down_blocks.1": "down_blocks.0.downsamplers.0",
"down_blocks.2": "down_blocks.1",
"down_blocks.3": "down_blocks.1.downsamplers.0",
"down_blocks.4": "down_blocks.2",
"down_blocks.5": "down_blocks.2.downsamplers.0",
"down_blocks.6": "down_blocks.3",
"down_blocks.7": "down_blocks.3.downsamplers.0",
"down_blocks.8": "mid_block",
# common
"last_time_embedder": "time_embedder",
"last_scale_shift_table": "scale_shift_table",
}
VAE_SPECIAL_KEYS_REMAP = {
"per_channel_statistics.channel": remove_keys_,
"per_channel_statistics.mean-of-means": remove_keys_,
@@ -81,10 +107,6 @@ VAE_SPECIAL_KEYS_REMAP = {
"model.diffusion_model": remove_keys_,
}
VAE_091_SPECIAL_KEYS_REMAP = {
"timestep_scale_multiplier": remove_keys_,
}
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
state_dict = saved_dict
@@ -104,12 +126,16 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key:
def convert_transformer(
ckpt_path: str,
dtype: torch.dtype,
version: str = "0.9.0",
):
PREFIX_KEY = "model.diffusion_model."
original_state_dict = get_state_dict(load_file(ckpt_path))
config = {}
if version == "0.9.5":
config["_use_causal_rope_fix"] = True
with init_empty_weights():
transformer = LTXVideoTransformer3DModel()
transformer = LTXVideoTransformer3DModel(**config)
for key in list(original_state_dict.keys()):
new_key = key[:]
@@ -161,12 +187,19 @@ def get_vae_config(version: str) -> Dict[str, Any]:
"out_channels": 3,
"latent_channels": 128,
"block_out_channels": (128, 256, 512, 512),
"down_block_types": (
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
),
"decoder_block_out_channels": (128, 256, 512, 512),
"layers_per_block": (4, 3, 3, 3, 4),
"decoder_layers_per_block": (4, 3, 3, 3, 4),
"spatio_temporal_scaling": (True, True, True, False),
"decoder_spatio_temporal_scaling": (True, True, True, False),
"decoder_inject_noise": (False, False, False, False, False),
"downsample_type": ("conv", "conv", "conv", "conv"),
"upsample_residual": (False, False, False, False),
"upsample_factor": (1, 1, 1, 1),
"patch_size": 4,
@@ -183,12 +216,19 @@ def get_vae_config(version: str) -> Dict[str, Any]:
"out_channels": 3,
"latent_channels": 128,
"block_out_channels": (128, 256, 512, 512),
"down_block_types": (
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
),
"decoder_block_out_channels": (256, 512, 1024),
"layers_per_block": (4, 3, 3, 3, 4),
"decoder_layers_per_block": (5, 6, 7, 8),
"spatio_temporal_scaling": (True, True, True, False),
"decoder_spatio_temporal_scaling": (True, True, True),
"decoder_inject_noise": (True, True, True, False),
"downsample_type": ("conv", "conv", "conv", "conv"),
"upsample_residual": (True, True, True),
"upsample_factor": (2, 2, 2),
"timestep_conditioning": True,
@@ -200,7 +240,38 @@ def get_vae_config(version: str) -> Dict[str, Any]:
"decoder_causal": False,
}
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
VAE_SPECIAL_KEYS_REMAP.update(VAE_091_SPECIAL_KEYS_REMAP)
elif version == "0.9.5":
config = {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 128,
"block_out_channels": (128, 256, 512, 1024, 2048),
"down_block_types": (
"LTXVideo095DownBlock3D",
"LTXVideo095DownBlock3D",
"LTXVideo095DownBlock3D",
"LTXVideo095DownBlock3D",
),
"decoder_block_out_channels": (256, 512, 1024),
"layers_per_block": (4, 6, 6, 2, 2),
"decoder_layers_per_block": (5, 5, 5, 5),
"spatio_temporal_scaling": (True, True, True, True),
"decoder_spatio_temporal_scaling": (True, True, True),
"decoder_inject_noise": (False, False, False, False),
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
"upsample_residual": (True, True, True),
"upsample_factor": (2, 2, 2),
"timestep_conditioning": True,
"patch_size": 4,
"patch_size_t": 1,
"resnet_norm_eps": 1e-6,
"scaling_factor": 1.0,
"encoder_causal": True,
"decoder_causal": False,
"spatial_compression_ratio": 32,
"temporal_compression_ratio": 8,
}
VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT)
return config
@@ -223,7 +294,7 @@ def get_args():
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
parser.add_argument(
"--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1"], help="Version of the LTX model"
"--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1", "0.9.5"], help="Version of the LTX model"
)
return parser.parse_args()
@@ -277,14 +348,17 @@ if __name__ == "__main__":
for param in text_encoder.parameters():
param.data = param.data.contiguous()
scheduler = FlowMatchEulerDiscreteScheduler(
use_dynamic_shifting=True,
base_shift=0.95,
max_shift=2.05,
base_image_seq_len=1024,
max_image_seq_len=4096,
shift_terminal=0.1,
)
if args.version == "0.9.5":
scheduler = FlowMatchEulerDiscreteScheduler(use_dynamic_shifting=False)
else:
scheduler = FlowMatchEulerDiscreteScheduler(
use_dynamic_shifting=True,
base_shift=0.95,
max_shift=2.05,
base_image_seq_len=1024,
max_image_seq_len=4096,
shift_terminal=0.1,
)
pipe = LTXPipeline(
scheduler=scheduler,

View File

@@ -402,6 +402,7 @@ else:
"LDMTextToImagePipeline",
"LEditsPPPipelineStableDiffusion",
"LEditsPPPipelineStableDiffusionXL",
"LTXConditionPipeline",
"LTXImageToVideoPipeline",
"LTXPipeline",
"Lumina2Pipeline",
@@ -947,6 +948,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LDMTextToImagePipeline,
LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL,
LTXConditionPipeline,
LTXImageToVideoPipeline,
LTXPipeline,
Lumina2Pipeline,

View File

@@ -73,14 +73,10 @@ class ModuleGroup:
self.cpu_param_dict = {}
for module in self.modules:
for param in module.parameters():
if self.low_cpu_mem_usage:
self.cpu_param_dict[param] = param.data.cpu()
else:
self.cpu_param_dict[param] = param.data.cpu().pin_memory()
self.cpu_param_dict.update(_get_cpu_param_dict(module, self.low_cpu_mem_usage))
@contextmanager
def _pin_memory(self):
def _pinned_memory_tensors(self):
pinned_dict = {}
try:
for param, tensor in self.cpu_param_dict.items():
@@ -103,52 +99,37 @@ class ModuleGroup:
with context:
if self.stream is not None:
with self._pin_memory() as pinned_dict:
with self._pinned_memory_tensors() as pinned_memory:
for module in self.modules:
for param in module.parameters():
if param in pinned_dict:
param.data = pinned_dict[param].to(
self.onload_device, non_blocking=True
)
if self.parameters:
for param in self.parameters:
param.data = param.data.to(
self.onload_device, non_blocking=True
)
if self.buffers:
for buffer in self.buffers:
buffer.data = buffer.data.to(self.onload_device, non_blocking=True)
param.data = pinned_memory[param].to(
self.onload_device, non_blocking=self.non_blocking
)
else:
# Standard transfer for non-stream case
for module in self.modules:
module.to(self.onload_device, non_blocking=self.non_blocking)
if self.parameters:
for param in self.parameters:
for group_module in self.modules:
for param in group_module.parameters():
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
if self.buffers:
for buffer in self.buffers:
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
if self.parameters is not None:
for param in self.parameters:
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
if self.buffers is not None:
for buffer in self.buffers:
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
def offload_(self):
r"""Offloads the group of modules to the offload_device."""
if self.stream is not None:
torch.cuda.current_stream().synchronize()
for module in self.modules:
for param in module.parameters():
if param in self.cpu_param_dict:
param.data = self.cpu_param_dict[param]
if self.parameters:
for group_module in self.modules:
for param in group_module.parameters():
param.data = self.cpu_param_dict[param]
if self.parameters is not None:
for param in self.parameters:
if param in self.cpu_param_dict:
param.data = self.cpu_param_dict[param]
if self.buffers:
param.data = self.cpu_param_dict[param]
if self.buffers is not None:
for buffer in self.buffers:
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
buffer.data = self.cpu_param_dict[buffer]
else:
for module in self.modules:
module.to(self.offload_device, non_blocking=self.non_blocking)
@@ -223,6 +204,13 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
self._layer_execution_tracker_module_names = set()
def initialize_hook(self, module):
def make_execution_order_update_callback(current_name, current_submodule):
def callback():
logger.debug(f"Adding {current_name} to the execution order")
self.execution_order.append((current_name, current_submodule))
return callback
# To every submodule that contains a group offloading hook (at this point, no prefetching is enabled for any
# of the groups), we add a layer execution tracker hook that will be used to determine the order in which the
# layers are executed during the forward pass.
@@ -234,14 +222,8 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
group_offloading_hook = registry.get_hook(_GROUP_OFFLOADING)
if group_offloading_hook is not None:
def make_execution_order_update_callback(current_name, current_submodule):
def callback():
logger.debug(f"Adding {current_name} to the execution order")
self.execution_order.append((current_name, current_submodule))
return callback
# For the first forward pass, we have to load in a blocking manner
group_offloading_hook.group.non_blocking = False
layer_tracker_hook = LayerExecutionTrackerHook(make_execution_order_update_callback(name, submodule))
registry.register_hook(layer_tracker_hook, _LAYER_EXECUTION_TRACKER)
self._layer_execution_tracker_module_names.add(name)
@@ -271,6 +253,7 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
# Remove the layer execution tracker hooks from the submodules
base_module_registry = module._diffusers_hook
registries = [submodule._diffusers_hook for _, submodule in self.execution_order]
group_offloading_hooks = [registry.get_hook(_GROUP_OFFLOADING) for registry in registries]
for i in range(num_executed):
registries[i].remove_hook(_LAYER_EXECUTION_TRACKER, recurse=False)
@@ -278,8 +261,13 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
# Remove the current lazy prefetch group offloading hook so that it doesn't interfere with the next forward pass
base_module_registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=False)
# Apply lazy prefetching by setting required attributes
group_offloading_hooks = [registry.get_hook(_GROUP_OFFLOADING) for registry in registries]
# LazyPrefetchGroupOffloadingHook is only used with streams, so we know that non_blocking should be True.
# We disable non_blocking for the first forward pass, but need to enable it for the subsequent passes to
# see the benefits of prefetching.
for hook in group_offloading_hooks:
hook.group.non_blocking = True
# Set required attributes for prefetching
if num_executed > 0:
base_module_group_offloading_hook = base_module_registry.get_hook(_GROUP_OFFLOADING)
base_module_group_offloading_hook.next_group = group_offloading_hooks[0].group
@@ -437,6 +425,11 @@ def _apply_group_offloading_block_level(
for overlapping computation and data transfer.
"""
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
cpu_param_dict = None
if stream is not None:
cpu_param_dict = _get_pinned_cpu_param_dict(module)
# Create module groups for ModuleList and Sequential blocks
modules_with_group_offloading = set()
unmatched_modules = []
@@ -532,11 +525,7 @@ def _apply_group_offloading_leaf_level(
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
cpu_param_dict = None
if stream is not None:
"""
for param in module.parameters():
param.data = param.data.cpu().pin_memory()
cpu_param_dict = {param: param.data for param in module.parameters()}
"""
cpu_param_dict = _get_pinned_cpu_param_dict(module)
# Create module groups for leaf modules and apply group offloading hooks
modules_with_group_offloading = set()
@@ -652,6 +641,23 @@ def _apply_lazy_group_offloading_hook(
registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING)
def _get_cpu_param_dict(module: torch.nn.Module, low_cpu_mem_usage: bool = False) -> Dict[torch.nn.Parameter, torch.Tensor]:
cpu_param_dict = {}
for param in module.parameters():
if low_cpu_mem_usage:
cpu_param_dict[param] = param.data.cpu()
else:
cpu_param_dict[param] = param.data.cpu().pin_memory()
for buffer in module.buffers():
if low_cpu_mem_usage:
cpu_param_dict[buffer] = buffer.data.cpu()
else:
cpu_param_dict[buffer] = buffer.data.cpu().pin_memory()
return cpu_param_dict
def _gather_parameters_with_no_group_offloading_parent(
module: torch.nn.Module, modules_with_group_offloading: Set[str]
) -> List[torch.nn.Parameter]:

View File

@@ -196,6 +196,55 @@ class LTXVideoResnetBlock3d(nn.Module):
return hidden_states
class LTXVideoDownsampler3d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
stride: Union[int, Tuple[int, int, int]] = 1,
is_causal: bool = True,
padding_mode: str = "zeros",
) -> None:
super().__init__()
self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
self.group_size = (in_channels * stride[0] * stride[1] * stride[2]) // out_channels
out_channels = out_channels // (self.stride[0] * self.stride[1] * self.stride[2])
self.conv = LTXVideoCausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=1,
is_causal=is_causal,
padding_mode=padding_mode,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = torch.cat([hidden_states[:, :, : self.stride[0] - 1], hidden_states], dim=2)
residual = (
hidden_states.unflatten(4, (-1, self.stride[2]))
.unflatten(3, (-1, self.stride[1]))
.unflatten(2, (-1, self.stride[0]))
)
residual = residual.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4)
residual = residual.unflatten(1, (-1, self.group_size))
residual = residual.mean(dim=2)
hidden_states = self.conv(hidden_states)
hidden_states = (
hidden_states.unflatten(4, (-1, self.stride[2]))
.unflatten(3, (-1, self.stride[1]))
.unflatten(2, (-1, self.stride[0]))
)
hidden_states = hidden_states.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4)
hidden_states = hidden_states + residual
return hidden_states
class LTXVideoUpsampler3d(nn.Module):
def __init__(
self,
@@ -204,6 +253,7 @@ class LTXVideoUpsampler3d(nn.Module):
is_causal: bool = True,
residual: bool = False,
upscale_factor: int = 1,
padding_mode: str = "zeros",
) -> None:
super().__init__()
@@ -219,6 +269,7 @@ class LTXVideoUpsampler3d(nn.Module):
kernel_size=3,
stride=1,
is_causal=is_causal,
padding_mode=padding_mode,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -352,6 +403,118 @@ class LTXVideoDownBlock3D(nn.Module):
return hidden_states
class LTXVideo095DownBlock3D(nn.Module):
r"""
Down block used in the LTXVideo model.
Args:
in_channels (`int`):
Number of input channels.
out_channels (`int`, *optional*):
Number of output channels. If None, defaults to `in_channels`.
num_layers (`int`, defaults to `1`):
Number of resnet layers.
dropout (`float`, defaults to `0.0`):
Dropout rate.
resnet_eps (`float`, defaults to `1e-6`):
Epsilon value for normalization layers.
resnet_act_fn (`str`, defaults to `"swish"`):
Activation function to use.
spatio_temporal_scale (`bool`, defaults to `True`):
Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
Whether or not to downsample across temporal dimension.
is_causal (`bool`, defaults to `True`):
Whether this layer behaves causally (future frames depend only on past frames) or not.
"""
_supports_gradient_checkpointing = True
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
resnet_eps: float = 1e-6,
resnet_act_fn: str = "swish",
spatio_temporal_scale: bool = True,
is_causal: bool = True,
downsample_type: str = "conv",
):
super().__init__()
out_channels = out_channels or in_channels
resnets = []
for _ in range(num_layers):
resnets.append(
LTXVideoResnetBlock3d(
in_channels=in_channels,
out_channels=in_channels,
dropout=dropout,
eps=resnet_eps,
non_linearity=resnet_act_fn,
is_causal=is_causal,
)
)
self.resnets = nn.ModuleList(resnets)
self.downsamplers = None
if spatio_temporal_scale:
self.downsamplers = nn.ModuleList()
if downsample_type == "conv":
self.downsamplers.append(
LTXVideoCausalConv3d(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=3,
stride=(2, 2, 2),
is_causal=is_causal,
)
)
elif downsample_type == "spatial":
self.downsamplers.append(
LTXVideoDownsampler3d(
in_channels=in_channels, out_channels=out_channels, stride=(1, 2, 2), is_causal=is_causal
)
)
elif downsample_type == "temporal":
self.downsamplers.append(
LTXVideoDownsampler3d(
in_channels=in_channels, out_channels=out_channels, stride=(2, 1, 1), is_causal=is_causal
)
)
elif downsample_type == "spatiotemporal":
self.downsamplers.append(
LTXVideoDownsampler3d(
in_channels=in_channels, out_channels=out_channels, stride=(2, 2, 2), is_causal=is_causal
)
)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
generator: Optional[torch.Generator] = None,
) -> torch.Tensor:
r"""Forward method of the `LTXDownBlock3D` class."""
for i, resnet in enumerate(self.resnets):
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator)
else:
hidden_states = resnet(hidden_states, temb, generator)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
return hidden_states
# Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoMidBlock3d
class LTXVideoMidBlock3d(nn.Module):
r"""
@@ -593,8 +756,15 @@ class LTXVideoEncoder3d(nn.Module):
in_channels: int = 3,
out_channels: int = 128,
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
down_block_types: Tuple[str, ...] = (
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
),
spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
downsample_type: Tuple[str, ...] = ("conv", "conv", "conv", "conv"),
patch_size: int = 4,
patch_size_t: int = 1,
resnet_norm_eps: float = 1e-6,
@@ -617,20 +787,37 @@ class LTXVideoEncoder3d(nn.Module):
)
# down blocks
num_block_out_channels = len(block_out_channels)
is_ltx_095 = down_block_types[-1] == "LTXVideo095DownBlock3D"
num_block_out_channels = len(block_out_channels) - (1 if is_ltx_095 else 0)
self.down_blocks = nn.ModuleList([])
for i in range(num_block_out_channels):
input_channel = output_channel
output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i]
if not is_ltx_095:
output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i]
else:
output_channel = block_out_channels[i + 1]
down_block = LTXVideoDownBlock3D(
in_channels=input_channel,
out_channels=output_channel,
num_layers=layers_per_block[i],
resnet_eps=resnet_norm_eps,
spatio_temporal_scale=spatio_temporal_scaling[i],
is_causal=is_causal,
)
if down_block_types[i] == "LTXVideoDownBlock3D":
down_block = LTXVideoDownBlock3D(
in_channels=input_channel,
out_channels=output_channel,
num_layers=layers_per_block[i],
resnet_eps=resnet_norm_eps,
spatio_temporal_scale=spatio_temporal_scaling[i],
is_causal=is_causal,
)
elif down_block_types[i] == "LTXVideo095DownBlock3D":
down_block = LTXVideo095DownBlock3D(
in_channels=input_channel,
out_channels=output_channel,
num_layers=layers_per_block[i],
resnet_eps=resnet_norm_eps,
spatio_temporal_scale=spatio_temporal_scaling[i],
is_causal=is_causal,
downsample_type=downsample_type[i],
)
else:
raise ValueError(f"Unknown down block type: {down_block_types[i]}")
self.down_blocks.append(down_block)
@@ -794,7 +981,9 @@ class LTXVideoDecoder3d(nn.Module):
# timestep embedding
self.time_embedder = None
self.scale_shift_table = None
self.timestep_scale_multiplier = None
if timestep_conditioning:
self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0, dtype=torch.float32))
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0)
self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5)
@@ -803,6 +992,9 @@ class LTXVideoDecoder3d(nn.Module):
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
hidden_states = self.conv_in(hidden_states)
if self.timestep_scale_multiplier is not None:
temb = temb * self.timestep_scale_multiplier
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, temb)
@@ -891,12 +1083,19 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
out_channels: int = 3,
latent_channels: int = 128,
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
down_block_types: Tuple[str, ...] = (
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
),
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
decoder_layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
decoder_spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
decoder_inject_noise: Tuple[bool, ...] = (False, False, False, False, False),
downsample_type: Tuple[str, ...] = ("conv", "conv", "conv", "conv"),
upsample_residual: Tuple[bool, ...] = (False, False, False, False),
upsample_factor: Tuple[int, ...] = (1, 1, 1, 1),
timestep_conditioning: bool = False,
@@ -906,6 +1105,8 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
scaling_factor: float = 1.0,
encoder_causal: bool = True,
decoder_causal: bool = False,
spatial_compression_ratio: int = None,
temporal_compression_ratio: int = None,
) -> None:
super().__init__()
@@ -913,8 +1114,10 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
in_channels=in_channels,
out_channels=latent_channels,
block_out_channels=block_out_channels,
down_block_types=down_block_types,
spatio_temporal_scaling=spatio_temporal_scaling,
layers_per_block=layers_per_block,
downsample_type=downsample_type,
patch_size=patch_size,
patch_size_t=patch_size_t,
resnet_norm_eps=resnet_norm_eps,
@@ -941,8 +1144,16 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.register_buffer("latents_mean", latents_mean, persistent=True)
self.register_buffer("latents_std", latents_std, persistent=True)
self.spatial_compression_ratio = patch_size * 2 ** sum(spatio_temporal_scaling)
self.temporal_compression_ratio = patch_size_t * 2 ** sum(spatio_temporal_scaling)
self.spatial_compression_ratio = (
patch_size * 2 ** sum(spatio_temporal_scaling)
if spatial_compression_ratio is None
else spatial_compression_ratio
)
self.temporal_compression_ratio = (
patch_size_t * 2 ** sum(spatio_temporal_scaling)
if temporal_compression_ratio is None
else temporal_compression_ratio
)
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
# to perform decoding of a single video latent at a time.

View File

@@ -366,7 +366,7 @@ class ResnetBlock2D(nn.Module):
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
input_tensor = self.conv_shortcut(input_tensor.contiguous())
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor

View File

@@ -14,7 +14,7 @@
# limitations under the License.
import math
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -113,20 +113,19 @@ class LTXVideoRotaryPosEmbed(nn.Module):
self.patch_size_t = patch_size_t
self.theta = theta
def forward(
def _prepare_video_coords(
self,
hidden_states: torch.Tensor,
batch_size: int,
num_frames: int,
height: int,
width: int,
rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = hidden_states.size(0)
rope_interpolation_scale: Tuple[torch.Tensor, float, float],
device: torch.device,
) -> torch.Tensor:
# Always compute rope in fp32
grid_h = torch.arange(height, dtype=torch.float32, device=hidden_states.device)
grid_w = torch.arange(width, dtype=torch.float32, device=hidden_states.device)
grid_f = torch.arange(num_frames, dtype=torch.float32, device=hidden_states.device)
grid_h = torch.arange(height, dtype=torch.float32, device=device)
grid_w = torch.arange(width, dtype=torch.float32, device=device)
grid_f = torch.arange(num_frames, dtype=torch.float32, device=device)
grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij")
grid = torch.stack(grid, dim=0)
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
@@ -138,6 +137,38 @@ class LTXVideoRotaryPosEmbed(nn.Module):
grid = grid.flatten(2, 4).transpose(1, 2)
return grid
def forward(
self,
hidden_states: torch.Tensor,
num_frames: Optional[int] = None,
height: Optional[int] = None,
width: Optional[int] = None,
rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None,
video_coords: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = hidden_states.size(0)
if video_coords is None:
grid = self._prepare_video_coords(
batch_size,
num_frames,
height,
width,
rope_interpolation_scale=rope_interpolation_scale,
device=hidden_states.device,
)
else:
grid = torch.stack(
[
video_coords[:, 0] / self.base_num_frames,
video_coords[:, 1] / self.base_height,
video_coords[:, 2] / self.base_width,
],
dim=-1,
)
start = 1.0
end = self.theta
freqs = self.theta ** torch.linspace(
@@ -367,10 +398,11 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
encoder_hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_attention_mask: torch.Tensor,
num_frames: int,
height: int,
width: int,
rope_interpolation_scale: Optional[Tuple[float, float, float]] = None,
num_frames: Optional[int] = None,
height: Optional[int] = None,
width: Optional[int] = None,
rope_interpolation_scale: Optional[Union[Tuple[float, float, float], torch.Tensor]] = None,
video_coords: Optional[torch.Tensor] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> torch.Tensor:
@@ -389,7 +421,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale)
image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale, video_coords)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:

View File

@@ -264,7 +264,7 @@ else:
]
)
_import_structure["latte"] = ["LattePipeline"]
_import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline"]
_import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline", "LTXConditionPipeline"]
_import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"]
_import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"]
_import_structure["marigold"].extend(
@@ -618,7 +618,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL,
)
from .ltx import LTXImageToVideoPipeline, LTXPipeline
from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXPipeline
from .lumina import LuminaPipeline, LuminaText2ImgPipeline
from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline
from .marigold import (

View File

@@ -63,6 +63,7 @@ EXAMPLE_DOC_STRING = """
>>> from diffusers import FluxControlNetPipeline
>>> from diffusers import FluxControlNetModel
>>> base_model = "black-forest-labs/FLUX.1-dev"
>>> controlnet_model = "InstantX/FLUX.1-dev-controlnet-canny"
>>> controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
>>> pipe = FluxControlNetPipeline.from_pretrained(

View File

@@ -23,6 +23,7 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_ltx"] = ["LTXPipeline"]
_import_structure["pipeline_ltx_condition"] = ["LTXConditionPipeline"]
_import_structure["pipeline_ltx_image2video"] = ["LTXImageToVideoPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
@@ -34,6 +35,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_ltx import LTXPipeline
from .pipeline_ltx_condition import LTXConditionPipeline
from .pipeline_ltx_image2video import LTXImageToVideoPipeline
else:

View File

@@ -694,9 +694,8 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
self._num_timesteps = len(timesteps)
# 6. Prepare micro-conditions
latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio
rope_interpolation_scale = (
1 / latent_frame_rate,
self.vae_temporal_compression_ratio / frame_rate,
self.vae_spatial_compression_ratio,
self.vae_spatial_compression_ratio,
)

File diff suppressed because it is too large Load Diff

View File

@@ -764,9 +764,8 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
self._num_timesteps = len(timesteps)
# 6. Prepare micro-conditions
latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio
rope_interpolation_scale = (
1 / latent_frame_rate,
self.vae_temporal_compression_ratio / frame_rate,
self.vae_spatial_compression_ratio,
self.vae_spatial_compression_ratio,
)

View File

@@ -108,31 +108,16 @@ def prompt_clean(text):
return text
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor,
latents_mean: torch.Tensor,
latents_std: torch.Tensor,
generator: Optional[torch.Generator] = None,
sample_mode: str = "sample",
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
encoder_output.latent_dist.mean = (encoder_output.latent_dist.mean - latents_mean) * latents_std
encoder_output.latent_dist.logvar = torch.clamp(
(encoder_output.latent_dist.logvar - latents_mean) * latents_std, -30.0, 20.0
)
encoder_output.latent_dist.std = torch.exp(0.5 * encoder_output.latent_dist.logvar)
encoder_output.latent_dist.var = torch.exp(encoder_output.latent_dist.logvar)
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
encoder_output.latent_dist.mean = (encoder_output.latent_dist.mean - latents_mean) * latents_std
encoder_output.latent_dist.logvar = torch.clamp(
(encoder_output.latent_dist.logvar - latents_mean) * latents_std, -30.0, 20.0
)
encoder_output.latent_dist.std = torch.exp(0.5 * encoder_output.latent_dist.logvar)
encoder_output.latent_dist.var = torch.exp(encoder_output.latent_dist.logvar)
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return (encoder_output.latents - latents_mean) * latents_std
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
@@ -412,13 +397,15 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
if isinstance(generator, list):
latent_condition = [
retrieve_latents(self.vae.encode(video_condition), latents_mean, latents_std, g) for g in generator
retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator
]
latent_condition = torch.cat(latent_condition)
else:
latent_condition = retrieve_latents(self.vae.encode(video_condition), latents_mean, latents_std, generator)
latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax")
latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)
latent_condition = (latent_condition - latents_mean) * latents_std
mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
mask_lat_size[:, :, list(range(1, num_frames))] = 0
first_frame_mask = mask_lat_size[:, :, 0:1]

View File

@@ -377,6 +377,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
s_tmax: float = float("inf"),
s_noise: float = 1.0,
generator: Optional[torch.Generator] = None,
per_token_timesteps: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
"""
@@ -397,6 +398,8 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
Scaling factor for noise added to the sample.
generator (`torch.Generator`, *optional*):
A random number generator.
per_token_timesteps (`torch.Tensor`, *optional*):
The timesteps for each token in the sample.
return_dict (`bool`):
Whether or not to return a
[`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or tuple.
@@ -427,16 +430,26 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
# Upcast to avoid precision issues when computing prev_sample
sample = sample.to(torch.float32)
sigma = self.sigmas[self.step_index]
sigma_next = self.sigmas[self.step_index + 1]
if per_token_timesteps is not None:
per_token_sigmas = per_token_timesteps / self.config.num_train_timesteps
prev_sample = sample + (sigma_next - sigma) * model_output
sigmas = self.sigmas[:, None, None]
lower_mask = sigmas < per_token_sigmas[None] - 1e-6
lower_sigmas = lower_mask * sigmas
lower_sigmas, _ = lower_sigmas.max(dim=0)
dt = (per_token_sigmas - lower_sigmas)[..., None]
else:
sigma = self.sigmas[self.step_index]
sigma_next = self.sigmas[self.step_index + 1]
dt = sigma_next - sigma
# Cast sample back to model compatible dtype
prev_sample = prev_sample.to(model_output.dtype)
prev_sample = sample + dt * model_output
# upon completion increase step index by one
self._step_index += 1
if per_token_timesteps is None:
# Cast sample back to model compatible dtype
prev_sample = prev_sample.to(model_output.dtype)
if not return_dict:
return (prev_sample,)

View File

@@ -1217,6 +1217,21 @@ class LEditsPPPipelineStableDiffusionXL(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class LTXConditionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class LTXImageToVideoPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -0,0 +1,284 @@
# Copyright 2024 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import unittest
import numpy as np
import torch
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import (
AutoencoderKLLTXVideo,
FlowMatchEulerDiscreteScheduler,
LTXConditionPipeline,
LTXVideoTransformer3DModel,
)
from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXVideoCondition
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
enable_full_determinism()
class LTXConditionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = LTXConditionPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"image"})
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"return_dict",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
]
)
test_xformers_attention = False
def get_dummy_components(self):
torch.manual_seed(0)
transformer = LTXVideoTransformer3DModel(
in_channels=8,
out_channels=8,
patch_size=1,
patch_size_t=1,
num_attention_heads=4,
attention_head_dim=8,
cross_attention_dim=32,
num_layers=1,
caption_channels=32,
)
torch.manual_seed(0)
vae = AutoencoderKLLTXVideo(
in_channels=3,
out_channels=3,
latent_channels=8,
block_out_channels=(8, 8, 8, 8),
decoder_block_out_channels=(8, 8, 8, 8),
layers_per_block=(1, 1, 1, 1, 1),
decoder_layers_per_block=(1, 1, 1, 1, 1),
spatio_temporal_scaling=(True, True, False, False),
decoder_spatio_temporal_scaling=(True, True, False, False),
decoder_inject_noise=(False, False, False, False, False),
upsample_residual=(False, False, False, False),
upsample_factor=(1, 1, 1, 1),
timestep_conditioning=False,
patch_size=1,
patch_size_t=1,
encoder_causal=True,
decoder_causal=False,
)
vae.use_framewise_encoding = False
vae.use_framewise_decoding = False
torch.manual_seed(0)
scheduler = FlowMatchEulerDiscreteScheduler()
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
components = {
"transformer": transformer,
"vae": vae,
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
}
return components
def get_dummy_inputs(self, device, seed=0, use_conditions=False):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
image = torch.randn((1, 3, 32, 32), generator=generator, device=device)
if use_conditions:
conditions = LTXVideoCondition(
image=image,
)
else:
conditions = None
inputs = {
"conditions": conditions,
"image": None if use_conditions else image,
"prompt": "dance monkey",
"negative_prompt": "",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 3.0,
"height": 32,
"width": 32,
# 8 * k + 1 is the recommendation
"num_frames": 9,
"max_sequence_length": 16,
"output_type": "pt",
}
return inputs
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs2 = self.get_dummy_inputs(device, use_conditions=True)
video = pipe(**inputs).frames
generated_video = video[0]
video2 = pipe(**inputs2).frames
generated_video2 = video2[0]
self.assertEqual(generated_video.shape, (9, 3, 32, 32))
max_diff = np.abs(generated_video - generated_video2).max()
self.assertLessEqual(max_diff, 1e-3)
def test_callback_inputs(self):
sig = inspect.signature(self.pipeline_class.__call__)
has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
has_callback_step_end = "callback_on_step_end" in sig.parameters
if not (has_callback_tensor_inputs and has_callback_step_end):
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
self.assertTrue(
hasattr(pipe, "_callback_tensor_inputs"),
f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
)
def callback_inputs_subset(pipe, i, t, callback_kwargs):
# iterate over callback args
for tensor_name, tensor_value in callback_kwargs.items():
# check that we're only passing in allowed tensor inputs
assert tensor_name in pipe._callback_tensor_inputs
return callback_kwargs
def callback_inputs_all(pipe, i, t, callback_kwargs):
for tensor_name in pipe._callback_tensor_inputs:
assert tensor_name in callback_kwargs
# iterate over callback args
for tensor_name, tensor_value in callback_kwargs.items():
# check that we're only passing in allowed tensor inputs
assert tensor_name in pipe._callback_tensor_inputs
return callback_kwargs
inputs = self.get_dummy_inputs(torch_device)
# Test passing in a subset
inputs["callback_on_step_end"] = callback_inputs_subset
inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
output = pipe(**inputs)[0]
# Test passing in a everything
inputs["callback_on_step_end"] = callback_inputs_all
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
output = pipe(**inputs)[0]
def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
is_last = i == (pipe.num_timesteps - 1)
if is_last:
callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
return callback_kwargs
inputs["callback_on_step_end"] = callback_inputs_change_tensor
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
output = pipe(**inputs)[0]
assert output.abs().sum() < 1e10
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3)
def test_attention_slicing_forward_pass(
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
):
if not self.test_attention_slicing:
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
output_without_slicing = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=1)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing1 = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=2)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing2 = pipe(**inputs)[0]
if test_max_difference:
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
self.assertLess(
max(max_diff1, max_diff2),
expected_max_diff,
"Attention slicing should not affect the inference results",
)
def test_vae_tiling(self, expected_diff_max: float = 0.2):
generator_device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to("cpu")
pipe.set_progress_bar_config(disable=None)
# Without tiling
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
output_without_tiling = pipe(**inputs)[0]
# With tiling
pipe.vae.enable_tiling(
tile_sample_min_height=96,
tile_sample_min_width=96,
tile_sample_stride_height=64,
tile_sample_stride_width=64,
)
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
output_with_tiling = pipe(**inputs)[0]
self.assertLess(
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
expected_diff_max,
"VAE tiling should not affect the inference results",
)