mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Merge branch 'main' into pinned-context
This commit is contained in:
44
.github/workflows/pr_tests_gpu.yml
vendored
44
.github/workflows/pr_tests_gpu.yml
vendored
@@ -28,7 +28,51 @@ env:
|
||||
PIPELINE_USAGE_CUTOFF: 1000000000 # set high cutoff so that only always-test pipelines run
|
||||
|
||||
jobs:
|
||||
check_code_quality:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.8"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install .[quality]
|
||||
- name: Check quality
|
||||
run: make quality
|
||||
- name: Check if failure
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
echo "Quality check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make style && make quality'" >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
check_repository_consistency:
|
||||
needs: check_code_quality
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.8"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install .[quality]
|
||||
- name: Check repo consistency
|
||||
run: |
|
||||
python utils/check_copies.py
|
||||
python utils/check_dummies.py
|
||||
python utils/check_support_list.py
|
||||
make deps_table_check_updated
|
||||
- name: Check if failure
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
setup_torch_cuda_pipeline_matrix:
|
||||
needs: [check_code_quality, check_repository_consistency]
|
||||
name: Setup Torch Pipelines CUDA Slow Tests Matrix
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
|
||||
@@ -196,6 +196,12 @@ export_to_video(video, "ship.mp4", fps=24)
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## LTXConditionPipeline
|
||||
|
||||
[[autodoc]] LTXConditionPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## LTXPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.ltx.pipeline_output.LTXPipelineOutput
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
# Generating images using Flux and PyTorch/XLA
|
||||
|
||||
The `flux_inference` script shows how to do image generation using Flux on TPU devices using PyTorch/XLA. It uses the pallas kernel for flash attention for faster generation.
|
||||
|
||||
It has been tested on [Trillium](https://cloud.google.com/blog/products/compute/introducing-trillium-6th-gen-tpus) TPU versions. No other TPU types have been tested.
|
||||
The `flux_inference` script shows how to do image generation using Flux on TPU devices using PyTorch/XLA. It uses the pallas kernel for flash attention for faster generation using custom flash block sizes for better performance on [Trillium](https://cloud.google.com/blog/products/compute/introducing-trillium-6th-gen-tpus) TPU versions. No other TPU types have been tested.
|
||||
|
||||
## Create TPU
|
||||
|
||||
@@ -23,20 +21,23 @@ Verify that PyTorch and PyTorch/XLA were installed correctly:
|
||||
python3 -c "import torch; import torch_xla;"
|
||||
```
|
||||
|
||||
Install dependencies
|
||||
Clone the diffusers repo and install dependencies
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers.git
|
||||
cd diffusers
|
||||
pip install transformers accelerate sentencepiece structlog
|
||||
pushd ../../..
|
||||
pip install .
|
||||
popd
|
||||
cd examples/research_projects/pytorch_xla/inference/flux/
|
||||
```
|
||||
|
||||
## Run the inference job
|
||||
|
||||
### Authenticate
|
||||
|
||||
Run the following command to authenticate your token in order to download Flux weights.
|
||||
**Gated Model**
|
||||
|
||||
As the model is gated, before using it with diffusers you first need to go to the [FLUX.1 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.1-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:
|
||||
|
||||
```bash
|
||||
huggingface-cli login
|
||||
|
||||
@@ -74,6 +74,32 @@ VAE_091_RENAME_DICT = {
|
||||
"last_scale_shift_table": "scale_shift_table",
|
||||
}
|
||||
|
||||
VAE_095_RENAME_DICT = {
|
||||
# decoder
|
||||
"up_blocks.0": "mid_block",
|
||||
"up_blocks.1": "up_blocks.0.upsamplers.0",
|
||||
"up_blocks.2": "up_blocks.0",
|
||||
"up_blocks.3": "up_blocks.1.upsamplers.0",
|
||||
"up_blocks.4": "up_blocks.1",
|
||||
"up_blocks.5": "up_blocks.2.upsamplers.0",
|
||||
"up_blocks.6": "up_blocks.2",
|
||||
"up_blocks.7": "up_blocks.3.upsamplers.0",
|
||||
"up_blocks.8": "up_blocks.3",
|
||||
# encoder
|
||||
"down_blocks.0": "down_blocks.0",
|
||||
"down_blocks.1": "down_blocks.0.downsamplers.0",
|
||||
"down_blocks.2": "down_blocks.1",
|
||||
"down_blocks.3": "down_blocks.1.downsamplers.0",
|
||||
"down_blocks.4": "down_blocks.2",
|
||||
"down_blocks.5": "down_blocks.2.downsamplers.0",
|
||||
"down_blocks.6": "down_blocks.3",
|
||||
"down_blocks.7": "down_blocks.3.downsamplers.0",
|
||||
"down_blocks.8": "mid_block",
|
||||
# common
|
||||
"last_time_embedder": "time_embedder",
|
||||
"last_scale_shift_table": "scale_shift_table",
|
||||
}
|
||||
|
||||
VAE_SPECIAL_KEYS_REMAP = {
|
||||
"per_channel_statistics.channel": remove_keys_,
|
||||
"per_channel_statistics.mean-of-means": remove_keys_,
|
||||
@@ -81,10 +107,6 @@ VAE_SPECIAL_KEYS_REMAP = {
|
||||
"model.diffusion_model": remove_keys_,
|
||||
}
|
||||
|
||||
VAE_091_SPECIAL_KEYS_REMAP = {
|
||||
"timestep_scale_multiplier": remove_keys_,
|
||||
}
|
||||
|
||||
|
||||
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
state_dict = saved_dict
|
||||
@@ -104,12 +126,16 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key:
|
||||
def convert_transformer(
|
||||
ckpt_path: str,
|
||||
dtype: torch.dtype,
|
||||
version: str = "0.9.0",
|
||||
):
|
||||
PREFIX_KEY = "model.diffusion_model."
|
||||
|
||||
original_state_dict = get_state_dict(load_file(ckpt_path))
|
||||
config = {}
|
||||
if version == "0.9.5":
|
||||
config["_use_causal_rope_fix"] = True
|
||||
with init_empty_weights():
|
||||
transformer = LTXVideoTransformer3DModel()
|
||||
transformer = LTXVideoTransformer3DModel(**config)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
@@ -161,12 +187,19 @@ def get_vae_config(version: str) -> Dict[str, Any]:
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"block_out_channels": (128, 256, 512, 512),
|
||||
"down_block_types": (
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
),
|
||||
"decoder_block_out_channels": (128, 256, 512, 512),
|
||||
"layers_per_block": (4, 3, 3, 3, 4),
|
||||
"decoder_layers_per_block": (4, 3, 3, 3, 4),
|
||||
"spatio_temporal_scaling": (True, True, True, False),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True, False),
|
||||
"decoder_inject_noise": (False, False, False, False, False),
|
||||
"downsample_type": ("conv", "conv", "conv", "conv"),
|
||||
"upsample_residual": (False, False, False, False),
|
||||
"upsample_factor": (1, 1, 1, 1),
|
||||
"patch_size": 4,
|
||||
@@ -183,12 +216,19 @@ def get_vae_config(version: str) -> Dict[str, Any]:
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"block_out_channels": (128, 256, 512, 512),
|
||||
"down_block_types": (
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
),
|
||||
"decoder_block_out_channels": (256, 512, 1024),
|
||||
"layers_per_block": (4, 3, 3, 3, 4),
|
||||
"decoder_layers_per_block": (5, 6, 7, 8),
|
||||
"spatio_temporal_scaling": (True, True, True, False),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True),
|
||||
"decoder_inject_noise": (True, True, True, False),
|
||||
"downsample_type": ("conv", "conv", "conv", "conv"),
|
||||
"upsample_residual": (True, True, True),
|
||||
"upsample_factor": (2, 2, 2),
|
||||
"timestep_conditioning": True,
|
||||
@@ -200,7 +240,38 @@ def get_vae_config(version: str) -> Dict[str, Any]:
|
||||
"decoder_causal": False,
|
||||
}
|
||||
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
|
||||
VAE_SPECIAL_KEYS_REMAP.update(VAE_091_SPECIAL_KEYS_REMAP)
|
||||
elif version == "0.9.5":
|
||||
config = {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"block_out_channels": (128, 256, 512, 1024, 2048),
|
||||
"down_block_types": (
|
||||
"LTXVideo095DownBlock3D",
|
||||
"LTXVideo095DownBlock3D",
|
||||
"LTXVideo095DownBlock3D",
|
||||
"LTXVideo095DownBlock3D",
|
||||
),
|
||||
"decoder_block_out_channels": (256, 512, 1024),
|
||||
"layers_per_block": (4, 6, 6, 2, 2),
|
||||
"decoder_layers_per_block": (5, 5, 5, 5),
|
||||
"spatio_temporal_scaling": (True, True, True, True),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True),
|
||||
"decoder_inject_noise": (False, False, False, False),
|
||||
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
|
||||
"upsample_residual": (True, True, True),
|
||||
"upsample_factor": (2, 2, 2),
|
||||
"timestep_conditioning": True,
|
||||
"patch_size": 4,
|
||||
"patch_size_t": 1,
|
||||
"resnet_norm_eps": 1e-6,
|
||||
"scaling_factor": 1.0,
|
||||
"encoder_causal": True,
|
||||
"decoder_causal": False,
|
||||
"spatial_compression_ratio": 32,
|
||||
"temporal_compression_ratio": 8,
|
||||
}
|
||||
VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT)
|
||||
return config
|
||||
|
||||
|
||||
@@ -223,7 +294,7 @@ def get_args():
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
||||
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
|
||||
parser.add_argument(
|
||||
"--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1"], help="Version of the LTX model"
|
||||
"--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1", "0.9.5"], help="Version of the LTX model"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
@@ -277,14 +348,17 @@ if __name__ == "__main__":
|
||||
for param in text_encoder.parameters():
|
||||
param.data = param.data.contiguous()
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(
|
||||
use_dynamic_shifting=True,
|
||||
base_shift=0.95,
|
||||
max_shift=2.05,
|
||||
base_image_seq_len=1024,
|
||||
max_image_seq_len=4096,
|
||||
shift_terminal=0.1,
|
||||
)
|
||||
if args.version == "0.9.5":
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(use_dynamic_shifting=False)
|
||||
else:
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(
|
||||
use_dynamic_shifting=True,
|
||||
base_shift=0.95,
|
||||
max_shift=2.05,
|
||||
base_image_seq_len=1024,
|
||||
max_image_seq_len=4096,
|
||||
shift_terminal=0.1,
|
||||
)
|
||||
|
||||
pipe = LTXPipeline(
|
||||
scheduler=scheduler,
|
||||
|
||||
@@ -402,6 +402,7 @@ else:
|
||||
"LDMTextToImagePipeline",
|
||||
"LEditsPPPipelineStableDiffusion",
|
||||
"LEditsPPPipelineStableDiffusionXL",
|
||||
"LTXConditionPipeline",
|
||||
"LTXImageToVideoPipeline",
|
||||
"LTXPipeline",
|
||||
"Lumina2Pipeline",
|
||||
@@ -947,6 +948,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LDMTextToImagePipeline,
|
||||
LEditsPPPipelineStableDiffusion,
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
LTXConditionPipeline,
|
||||
LTXImageToVideoPipeline,
|
||||
LTXPipeline,
|
||||
Lumina2Pipeline,
|
||||
|
||||
@@ -73,14 +73,10 @@ class ModuleGroup:
|
||||
|
||||
self.cpu_param_dict = {}
|
||||
for module in self.modules:
|
||||
for param in module.parameters():
|
||||
if self.low_cpu_mem_usage:
|
||||
self.cpu_param_dict[param] = param.data.cpu()
|
||||
else:
|
||||
self.cpu_param_dict[param] = param.data.cpu().pin_memory()
|
||||
self.cpu_param_dict.update(_get_cpu_param_dict(module, self.low_cpu_mem_usage))
|
||||
|
||||
@contextmanager
|
||||
def _pin_memory(self):
|
||||
def _pinned_memory_tensors(self):
|
||||
pinned_dict = {}
|
||||
try:
|
||||
for param, tensor in self.cpu_param_dict.items():
|
||||
@@ -103,52 +99,37 @@ class ModuleGroup:
|
||||
|
||||
with context:
|
||||
if self.stream is not None:
|
||||
with self._pin_memory() as pinned_dict:
|
||||
with self._pinned_memory_tensors() as pinned_memory:
|
||||
for module in self.modules:
|
||||
for param in module.parameters():
|
||||
if param in pinned_dict:
|
||||
param.data = pinned_dict[param].to(
|
||||
self.onload_device, non_blocking=True
|
||||
)
|
||||
|
||||
if self.parameters:
|
||||
for param in self.parameters:
|
||||
param.data = param.data.to(
|
||||
self.onload_device, non_blocking=True
|
||||
)
|
||||
|
||||
if self.buffers:
|
||||
for buffer in self.buffers:
|
||||
buffer.data = buffer.data.to(self.onload_device, non_blocking=True)
|
||||
param.data = pinned_memory[param].to(
|
||||
self.onload_device, non_blocking=self.non_blocking
|
||||
)
|
||||
else:
|
||||
# Standard transfer for non-stream case
|
||||
for module in self.modules:
|
||||
module.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if self.parameters:
|
||||
for param in self.parameters:
|
||||
for group_module in self.modules:
|
||||
for param in group_module.parameters():
|
||||
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if self.buffers:
|
||||
for buffer in self.buffers:
|
||||
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
|
||||
if self.parameters is not None:
|
||||
for param in self.parameters:
|
||||
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if self.buffers is not None:
|
||||
for buffer in self.buffers:
|
||||
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
|
||||
def offload_(self):
|
||||
r"""Offloads the group of modules to the offload_device."""
|
||||
if self.stream is not None:
|
||||
torch.cuda.current_stream().synchronize()
|
||||
|
||||
for module in self.modules:
|
||||
for param in module.parameters():
|
||||
if param in self.cpu_param_dict:
|
||||
param.data = self.cpu_param_dict[param]
|
||||
|
||||
if self.parameters:
|
||||
for group_module in self.modules:
|
||||
for param in group_module.parameters():
|
||||
param.data = self.cpu_param_dict[param]
|
||||
if self.parameters is not None:
|
||||
for param in self.parameters:
|
||||
if param in self.cpu_param_dict:
|
||||
param.data = self.cpu_param_dict[param]
|
||||
|
||||
if self.buffers:
|
||||
param.data = self.cpu_param_dict[param]
|
||||
if self.buffers is not None:
|
||||
for buffer in self.buffers:
|
||||
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
|
||||
buffer.data = self.cpu_param_dict[buffer]
|
||||
else:
|
||||
for module in self.modules:
|
||||
module.to(self.offload_device, non_blocking=self.non_blocking)
|
||||
@@ -223,6 +204,13 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
|
||||
self._layer_execution_tracker_module_names = set()
|
||||
|
||||
def initialize_hook(self, module):
|
||||
def make_execution_order_update_callback(current_name, current_submodule):
|
||||
def callback():
|
||||
logger.debug(f"Adding {current_name} to the execution order")
|
||||
self.execution_order.append((current_name, current_submodule))
|
||||
|
||||
return callback
|
||||
|
||||
# To every submodule that contains a group offloading hook (at this point, no prefetching is enabled for any
|
||||
# of the groups), we add a layer execution tracker hook that will be used to determine the order in which the
|
||||
# layers are executed during the forward pass.
|
||||
@@ -234,14 +222,8 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
|
||||
group_offloading_hook = registry.get_hook(_GROUP_OFFLOADING)
|
||||
|
||||
if group_offloading_hook is not None:
|
||||
|
||||
def make_execution_order_update_callback(current_name, current_submodule):
|
||||
def callback():
|
||||
logger.debug(f"Adding {current_name} to the execution order")
|
||||
self.execution_order.append((current_name, current_submodule))
|
||||
|
||||
return callback
|
||||
|
||||
# For the first forward pass, we have to load in a blocking manner
|
||||
group_offloading_hook.group.non_blocking = False
|
||||
layer_tracker_hook = LayerExecutionTrackerHook(make_execution_order_update_callback(name, submodule))
|
||||
registry.register_hook(layer_tracker_hook, _LAYER_EXECUTION_TRACKER)
|
||||
self._layer_execution_tracker_module_names.add(name)
|
||||
@@ -271,6 +253,7 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
|
||||
# Remove the layer execution tracker hooks from the submodules
|
||||
base_module_registry = module._diffusers_hook
|
||||
registries = [submodule._diffusers_hook for _, submodule in self.execution_order]
|
||||
group_offloading_hooks = [registry.get_hook(_GROUP_OFFLOADING) for registry in registries]
|
||||
|
||||
for i in range(num_executed):
|
||||
registries[i].remove_hook(_LAYER_EXECUTION_TRACKER, recurse=False)
|
||||
@@ -278,8 +261,13 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
|
||||
# Remove the current lazy prefetch group offloading hook so that it doesn't interfere with the next forward pass
|
||||
base_module_registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=False)
|
||||
|
||||
# Apply lazy prefetching by setting required attributes
|
||||
group_offloading_hooks = [registry.get_hook(_GROUP_OFFLOADING) for registry in registries]
|
||||
# LazyPrefetchGroupOffloadingHook is only used with streams, so we know that non_blocking should be True.
|
||||
# We disable non_blocking for the first forward pass, but need to enable it for the subsequent passes to
|
||||
# see the benefits of prefetching.
|
||||
for hook in group_offloading_hooks:
|
||||
hook.group.non_blocking = True
|
||||
|
||||
# Set required attributes for prefetching
|
||||
if num_executed > 0:
|
||||
base_module_group_offloading_hook = base_module_registry.get_hook(_GROUP_OFFLOADING)
|
||||
base_module_group_offloading_hook.next_group = group_offloading_hooks[0].group
|
||||
@@ -437,6 +425,11 @@ def _apply_group_offloading_block_level(
|
||||
for overlapping computation and data transfer.
|
||||
"""
|
||||
|
||||
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
|
||||
cpu_param_dict = None
|
||||
if stream is not None:
|
||||
cpu_param_dict = _get_pinned_cpu_param_dict(module)
|
||||
|
||||
# Create module groups for ModuleList and Sequential blocks
|
||||
modules_with_group_offloading = set()
|
||||
unmatched_modules = []
|
||||
@@ -532,11 +525,7 @@ def _apply_group_offloading_leaf_level(
|
||||
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
|
||||
cpu_param_dict = None
|
||||
if stream is not None:
|
||||
"""
|
||||
for param in module.parameters():
|
||||
param.data = param.data.cpu().pin_memory()
|
||||
cpu_param_dict = {param: param.data for param in module.parameters()}
|
||||
"""
|
||||
cpu_param_dict = _get_pinned_cpu_param_dict(module)
|
||||
|
||||
# Create module groups for leaf modules and apply group offloading hooks
|
||||
modules_with_group_offloading = set()
|
||||
@@ -652,6 +641,23 @@ def _apply_lazy_group_offloading_hook(
|
||||
registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING)
|
||||
|
||||
|
||||
def _get_cpu_param_dict(module: torch.nn.Module, low_cpu_mem_usage: bool = False) -> Dict[torch.nn.Parameter, torch.Tensor]:
|
||||
cpu_param_dict = {}
|
||||
for param in module.parameters():
|
||||
if low_cpu_mem_usage:
|
||||
cpu_param_dict[param] = param.data.cpu()
|
||||
else:
|
||||
cpu_param_dict[param] = param.data.cpu().pin_memory()
|
||||
|
||||
for buffer in module.buffers():
|
||||
if low_cpu_mem_usage:
|
||||
cpu_param_dict[buffer] = buffer.data.cpu()
|
||||
else:
|
||||
cpu_param_dict[buffer] = buffer.data.cpu().pin_memory()
|
||||
|
||||
return cpu_param_dict
|
||||
|
||||
|
||||
def _gather_parameters_with_no_group_offloading_parent(
|
||||
module: torch.nn.Module, modules_with_group_offloading: Set[str]
|
||||
) -> List[torch.nn.Parameter]:
|
||||
|
||||
@@ -196,6 +196,55 @@ class LTXVideoResnetBlock3d(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTXVideoDownsampler3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
stride: Union[int, Tuple[int, int, int]] = 1,
|
||||
is_causal: bool = True,
|
||||
padding_mode: str = "zeros",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
|
||||
self.group_size = (in_channels * stride[0] * stride[1] * stride[2]) // out_channels
|
||||
|
||||
out_channels = out_channels // (self.stride[0] * self.stride[1] * self.stride[2])
|
||||
|
||||
self.conv = LTXVideoCausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
is_causal=is_causal,
|
||||
padding_mode=padding_mode,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = torch.cat([hidden_states[:, :, : self.stride[0] - 1], hidden_states], dim=2)
|
||||
|
||||
residual = (
|
||||
hidden_states.unflatten(4, (-1, self.stride[2]))
|
||||
.unflatten(3, (-1, self.stride[1]))
|
||||
.unflatten(2, (-1, self.stride[0]))
|
||||
)
|
||||
residual = residual.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4)
|
||||
residual = residual.unflatten(1, (-1, self.group_size))
|
||||
residual = residual.mean(dim=2)
|
||||
|
||||
hidden_states = self.conv(hidden_states)
|
||||
hidden_states = (
|
||||
hidden_states.unflatten(4, (-1, self.stride[2]))
|
||||
.unflatten(3, (-1, self.stride[1]))
|
||||
.unflatten(2, (-1, self.stride[0]))
|
||||
)
|
||||
hidden_states = hidden_states.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTXVideoUpsampler3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -204,6 +253,7 @@ class LTXVideoUpsampler3d(nn.Module):
|
||||
is_causal: bool = True,
|
||||
residual: bool = False,
|
||||
upscale_factor: int = 1,
|
||||
padding_mode: str = "zeros",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -219,6 +269,7 @@ class LTXVideoUpsampler3d(nn.Module):
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
is_causal=is_causal,
|
||||
padding_mode=padding_mode,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
@@ -352,6 +403,118 @@ class LTXVideoDownBlock3D(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTXVideo095DownBlock3D(nn.Module):
|
||||
r"""
|
||||
Down block used in the LTXVideo model.
|
||||
|
||||
Args:
|
||||
in_channels (`int`):
|
||||
Number of input channels.
|
||||
out_channels (`int`, *optional*):
|
||||
Number of output channels. If None, defaults to `in_channels`.
|
||||
num_layers (`int`, defaults to `1`):
|
||||
Number of resnet layers.
|
||||
dropout (`float`, defaults to `0.0`):
|
||||
Dropout rate.
|
||||
resnet_eps (`float`, defaults to `1e-6`):
|
||||
Epsilon value for normalization layers.
|
||||
resnet_act_fn (`str`, defaults to `"swish"`):
|
||||
Activation function to use.
|
||||
spatio_temporal_scale (`bool`, defaults to `True`):
|
||||
Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
|
||||
Whether or not to downsample across temporal dimension.
|
||||
is_causal (`bool`, defaults to `True`):
|
||||
Whether this layer behaves causally (future frames depend only on past frames) or not.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
dropout: float = 0.0,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_act_fn: str = "swish",
|
||||
spatio_temporal_scale: bool = True,
|
||||
is_causal: bool = True,
|
||||
downsample_type: str = "conv",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
out_channels = out_channels or in_channels
|
||||
|
||||
resnets = []
|
||||
for _ in range(num_layers):
|
||||
resnets.append(
|
||||
LTXVideoResnetBlock3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
dropout=dropout,
|
||||
eps=resnet_eps,
|
||||
non_linearity=resnet_act_fn,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
self.downsamplers = None
|
||||
if spatio_temporal_scale:
|
||||
self.downsamplers = nn.ModuleList()
|
||||
|
||||
if downsample_type == "conv":
|
||||
self.downsamplers.append(
|
||||
LTXVideoCausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
kernel_size=3,
|
||||
stride=(2, 2, 2),
|
||||
is_causal=is_causal,
|
||||
)
|
||||
)
|
||||
elif downsample_type == "spatial":
|
||||
self.downsamplers.append(
|
||||
LTXVideoDownsampler3d(
|
||||
in_channels=in_channels, out_channels=out_channels, stride=(1, 2, 2), is_causal=is_causal
|
||||
)
|
||||
)
|
||||
elif downsample_type == "temporal":
|
||||
self.downsamplers.append(
|
||||
LTXVideoDownsampler3d(
|
||||
in_channels=in_channels, out_channels=out_channels, stride=(2, 1, 1), is_causal=is_causal
|
||||
)
|
||||
)
|
||||
elif downsample_type == "spatiotemporal":
|
||||
self.downsamplers.append(
|
||||
LTXVideoDownsampler3d(
|
||||
in_channels=in_channels, out_channels=out_channels, stride=(2, 2, 2), is_causal=is_causal
|
||||
)
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""Forward method of the `LTXDownBlock3D` class."""
|
||||
|
||||
for i, resnet in enumerate(self.resnets):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, generator)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoMidBlock3d
|
||||
class LTXVideoMidBlock3d(nn.Module):
|
||||
r"""
|
||||
@@ -593,8 +756,15 @@ class LTXVideoEncoder3d(nn.Module):
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 128,
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
),
|
||||
spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
|
||||
layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
|
||||
downsample_type: Tuple[str, ...] = ("conv", "conv", "conv", "conv"),
|
||||
patch_size: int = 4,
|
||||
patch_size_t: int = 1,
|
||||
resnet_norm_eps: float = 1e-6,
|
||||
@@ -617,20 +787,37 @@ class LTXVideoEncoder3d(nn.Module):
|
||||
)
|
||||
|
||||
# down blocks
|
||||
num_block_out_channels = len(block_out_channels)
|
||||
is_ltx_095 = down_block_types[-1] == "LTXVideo095DownBlock3D"
|
||||
num_block_out_channels = len(block_out_channels) - (1 if is_ltx_095 else 0)
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
for i in range(num_block_out_channels):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i]
|
||||
if not is_ltx_095:
|
||||
output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i]
|
||||
else:
|
||||
output_channel = block_out_channels[i + 1]
|
||||
|
||||
down_block = LTXVideoDownBlock3D(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
num_layers=layers_per_block[i],
|
||||
resnet_eps=resnet_norm_eps,
|
||||
spatio_temporal_scale=spatio_temporal_scaling[i],
|
||||
is_causal=is_causal,
|
||||
)
|
||||
if down_block_types[i] == "LTXVideoDownBlock3D":
|
||||
down_block = LTXVideoDownBlock3D(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
num_layers=layers_per_block[i],
|
||||
resnet_eps=resnet_norm_eps,
|
||||
spatio_temporal_scale=spatio_temporal_scaling[i],
|
||||
is_causal=is_causal,
|
||||
)
|
||||
elif down_block_types[i] == "LTXVideo095DownBlock3D":
|
||||
down_block = LTXVideo095DownBlock3D(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
num_layers=layers_per_block[i],
|
||||
resnet_eps=resnet_norm_eps,
|
||||
spatio_temporal_scale=spatio_temporal_scaling[i],
|
||||
is_causal=is_causal,
|
||||
downsample_type=downsample_type[i],
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown down block type: {down_block_types[i]}")
|
||||
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
@@ -794,7 +981,9 @@ class LTXVideoDecoder3d(nn.Module):
|
||||
# timestep embedding
|
||||
self.time_embedder = None
|
||||
self.scale_shift_table = None
|
||||
self.timestep_scale_multiplier = None
|
||||
if timestep_conditioning:
|
||||
self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0, dtype=torch.float32))
|
||||
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0)
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5)
|
||||
|
||||
@@ -803,6 +992,9 @@ class LTXVideoDecoder3d(nn.Module):
|
||||
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
|
||||
if self.timestep_scale_multiplier is not None:
|
||||
temb = temb * self.timestep_scale_multiplier
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, temb)
|
||||
|
||||
@@ -891,12 +1083,19 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
out_channels: int = 3,
|
||||
latent_channels: int = 128,
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
),
|
||||
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
|
||||
decoder_layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
|
||||
spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
|
||||
decoder_spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
|
||||
decoder_inject_noise: Tuple[bool, ...] = (False, False, False, False, False),
|
||||
downsample_type: Tuple[str, ...] = ("conv", "conv", "conv", "conv"),
|
||||
upsample_residual: Tuple[bool, ...] = (False, False, False, False),
|
||||
upsample_factor: Tuple[int, ...] = (1, 1, 1, 1),
|
||||
timestep_conditioning: bool = False,
|
||||
@@ -906,6 +1105,8 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
scaling_factor: float = 1.0,
|
||||
encoder_causal: bool = True,
|
||||
decoder_causal: bool = False,
|
||||
spatial_compression_ratio: int = None,
|
||||
temporal_compression_ratio: int = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -913,8 +1114,10 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
in_channels=in_channels,
|
||||
out_channels=latent_channels,
|
||||
block_out_channels=block_out_channels,
|
||||
down_block_types=down_block_types,
|
||||
spatio_temporal_scaling=spatio_temporal_scaling,
|
||||
layers_per_block=layers_per_block,
|
||||
downsample_type=downsample_type,
|
||||
patch_size=patch_size,
|
||||
patch_size_t=patch_size_t,
|
||||
resnet_norm_eps=resnet_norm_eps,
|
||||
@@ -941,8 +1144,16 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
self.register_buffer("latents_mean", latents_mean, persistent=True)
|
||||
self.register_buffer("latents_std", latents_std, persistent=True)
|
||||
|
||||
self.spatial_compression_ratio = patch_size * 2 ** sum(spatio_temporal_scaling)
|
||||
self.temporal_compression_ratio = patch_size_t * 2 ** sum(spatio_temporal_scaling)
|
||||
self.spatial_compression_ratio = (
|
||||
patch_size * 2 ** sum(spatio_temporal_scaling)
|
||||
if spatial_compression_ratio is None
|
||||
else spatial_compression_ratio
|
||||
)
|
||||
self.temporal_compression_ratio = (
|
||||
patch_size_t * 2 ** sum(spatio_temporal_scaling)
|
||||
if temporal_compression_ratio is None
|
||||
else temporal_compression_ratio
|
||||
)
|
||||
|
||||
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
|
||||
# to perform decoding of a single video latent at a time.
|
||||
|
||||
@@ -366,7 +366,7 @@ class ResnetBlock2D(nn.Module):
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
|
||||
if self.conv_shortcut is not None:
|
||||
input_tensor = self.conv_shortcut(input_tensor)
|
||||
input_tensor = self.conv_shortcut(input_tensor.contiguous())
|
||||
|
||||
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -113,20 +113,19 @@ class LTXVideoRotaryPosEmbed(nn.Module):
|
||||
self.patch_size_t = patch_size_t
|
||||
self.theta = theta
|
||||
|
||||
def forward(
|
||||
def _prepare_video_coords(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
batch_size: int,
|
||||
num_frames: int,
|
||||
height: int,
|
||||
width: int,
|
||||
rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size = hidden_states.size(0)
|
||||
|
||||
rope_interpolation_scale: Tuple[torch.Tensor, float, float],
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
# Always compute rope in fp32
|
||||
grid_h = torch.arange(height, dtype=torch.float32, device=hidden_states.device)
|
||||
grid_w = torch.arange(width, dtype=torch.float32, device=hidden_states.device)
|
||||
grid_f = torch.arange(num_frames, dtype=torch.float32, device=hidden_states.device)
|
||||
grid_h = torch.arange(height, dtype=torch.float32, device=device)
|
||||
grid_w = torch.arange(width, dtype=torch.float32, device=device)
|
||||
grid_f = torch.arange(num_frames, dtype=torch.float32, device=device)
|
||||
grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij")
|
||||
grid = torch.stack(grid, dim=0)
|
||||
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
||||
@@ -138,6 +137,38 @@ class LTXVideoRotaryPosEmbed(nn.Module):
|
||||
|
||||
grid = grid.flatten(2, 4).transpose(1, 2)
|
||||
|
||||
return grid
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
num_frames: Optional[int] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None,
|
||||
video_coords: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size = hidden_states.size(0)
|
||||
|
||||
if video_coords is None:
|
||||
grid = self._prepare_video_coords(
|
||||
batch_size,
|
||||
num_frames,
|
||||
height,
|
||||
width,
|
||||
rope_interpolation_scale=rope_interpolation_scale,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
else:
|
||||
grid = torch.stack(
|
||||
[
|
||||
video_coords[:, 0] / self.base_num_frames,
|
||||
video_coords[:, 1] / self.base_height,
|
||||
video_coords[:, 2] / self.base_width,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
start = 1.0
|
||||
end = self.theta
|
||||
freqs = self.theta ** torch.linspace(
|
||||
@@ -367,10 +398,11 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
encoder_attention_mask: torch.Tensor,
|
||||
num_frames: int,
|
||||
height: int,
|
||||
width: int,
|
||||
rope_interpolation_scale: Optional[Tuple[float, float, float]] = None,
|
||||
num_frames: Optional[int] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
rope_interpolation_scale: Optional[Union[Tuple[float, float, float], torch.Tensor]] = None,
|
||||
video_coords: Optional[torch.Tensor] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> torch.Tensor:
|
||||
@@ -389,7 +421,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale)
|
||||
image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale, video_coords)
|
||||
|
||||
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
||||
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
||||
|
||||
@@ -264,7 +264,7 @@ else:
|
||||
]
|
||||
)
|
||||
_import_structure["latte"] = ["LattePipeline"]
|
||||
_import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline"]
|
||||
_import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline", "LTXConditionPipeline"]
|
||||
_import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"]
|
||||
_import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"]
|
||||
_import_structure["marigold"].extend(
|
||||
@@ -618,7 +618,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LEditsPPPipelineStableDiffusion,
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
)
|
||||
from .ltx import LTXImageToVideoPipeline, LTXPipeline
|
||||
from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXPipeline
|
||||
from .lumina import LuminaPipeline, LuminaText2ImgPipeline
|
||||
from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline
|
||||
from .marigold import (
|
||||
|
||||
@@ -63,6 +63,7 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> from diffusers import FluxControlNetPipeline
|
||||
>>> from diffusers import FluxControlNetModel
|
||||
|
||||
>>> base_model = "black-forest-labs/FLUX.1-dev"
|
||||
>>> controlnet_model = "InstantX/FLUX.1-dev-controlnet-canny"
|
||||
>>> controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
|
||||
>>> pipe = FluxControlNetPipeline.from_pretrained(
|
||||
|
||||
@@ -23,6 +23,7 @@ except OptionalDependencyNotAvailable:
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_ltx"] = ["LTXPipeline"]
|
||||
_import_structure["pipeline_ltx_condition"] = ["LTXConditionPipeline"]
|
||||
_import_structure["pipeline_ltx_image2video"] = ["LTXImageToVideoPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
@@ -34,6 +35,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_ltx import LTXPipeline
|
||||
from .pipeline_ltx_condition import LTXConditionPipeline
|
||||
from .pipeline_ltx_image2video import LTXImageToVideoPipeline
|
||||
|
||||
else:
|
||||
|
||||
@@ -694,9 +694,8 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 6. Prepare micro-conditions
|
||||
latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio
|
||||
rope_interpolation_scale = (
|
||||
1 / latent_frame_rate,
|
||||
self.vae_temporal_compression_ratio / frame_rate,
|
||||
self.vae_spatial_compression_ratio,
|
||||
self.vae_spatial_compression_ratio,
|
||||
)
|
||||
|
||||
1174
src/diffusers/pipelines/ltx/pipeline_ltx_condition.py
Normal file
1174
src/diffusers/pipelines/ltx/pipeline_ltx_condition.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -764,9 +764,8 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 6. Prepare micro-conditions
|
||||
latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio
|
||||
rope_interpolation_scale = (
|
||||
1 / latent_frame_rate,
|
||||
self.vae_temporal_compression_ratio / frame_rate,
|
||||
self.vae_spatial_compression_ratio,
|
||||
self.vae_spatial_compression_ratio,
|
||||
)
|
||||
|
||||
@@ -108,31 +108,16 @@ def prompt_clean(text):
|
||||
return text
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor,
|
||||
latents_mean: torch.Tensor,
|
||||
latents_std: torch.Tensor,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
sample_mode: str = "sample",
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
encoder_output.latent_dist.mean = (encoder_output.latent_dist.mean - latents_mean) * latents_std
|
||||
encoder_output.latent_dist.logvar = torch.clamp(
|
||||
(encoder_output.latent_dist.logvar - latents_mean) * latents_std, -30.0, 20.0
|
||||
)
|
||||
encoder_output.latent_dist.std = torch.exp(0.5 * encoder_output.latent_dist.logvar)
|
||||
encoder_output.latent_dist.var = torch.exp(encoder_output.latent_dist.logvar)
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
encoder_output.latent_dist.mean = (encoder_output.latent_dist.mean - latents_mean) * latents_std
|
||||
encoder_output.latent_dist.logvar = torch.clamp(
|
||||
(encoder_output.latent_dist.logvar - latents_mean) * latents_std, -30.0, 20.0
|
||||
)
|
||||
encoder_output.latent_dist.std = torch.exp(0.5 * encoder_output.latent_dist.logvar)
|
||||
encoder_output.latent_dist.var = torch.exp(encoder_output.latent_dist.logvar)
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return (encoder_output.latents - latents_mean) * latents_std
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
@@ -412,13 +397,15 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
|
||||
if isinstance(generator, list):
|
||||
latent_condition = [
|
||||
retrieve_latents(self.vae.encode(video_condition), latents_mean, latents_std, g) for g in generator
|
||||
retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator
|
||||
]
|
||||
latent_condition = torch.cat(latent_condition)
|
||||
else:
|
||||
latent_condition = retrieve_latents(self.vae.encode(video_condition), latents_mean, latents_std, generator)
|
||||
latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax")
|
||||
latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)
|
||||
|
||||
latent_condition = (latent_condition - latents_mean) * latents_std
|
||||
|
||||
mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
|
||||
mask_lat_size[:, :, list(range(1, num_frames))] = 0
|
||||
first_frame_mask = mask_lat_size[:, :, 0:1]
|
||||
|
||||
@@ -377,6 +377,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
s_tmax: float = float("inf"),
|
||||
s_noise: float = 1.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
per_token_timesteps: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
|
||||
"""
|
||||
@@ -397,6 +398,8 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
Scaling factor for noise added to the sample.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A random number generator.
|
||||
per_token_timesteps (`torch.Tensor`, *optional*):
|
||||
The timesteps for each token in the sample.
|
||||
return_dict (`bool`):
|
||||
Whether or not to return a
|
||||
[`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or tuple.
|
||||
@@ -427,16 +430,26 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
# Upcast to avoid precision issues when computing prev_sample
|
||||
sample = sample.to(torch.float32)
|
||||
|
||||
sigma = self.sigmas[self.step_index]
|
||||
sigma_next = self.sigmas[self.step_index + 1]
|
||||
if per_token_timesteps is not None:
|
||||
per_token_sigmas = per_token_timesteps / self.config.num_train_timesteps
|
||||
|
||||
prev_sample = sample + (sigma_next - sigma) * model_output
|
||||
sigmas = self.sigmas[:, None, None]
|
||||
lower_mask = sigmas < per_token_sigmas[None] - 1e-6
|
||||
lower_sigmas = lower_mask * sigmas
|
||||
lower_sigmas, _ = lower_sigmas.max(dim=0)
|
||||
dt = (per_token_sigmas - lower_sigmas)[..., None]
|
||||
else:
|
||||
sigma = self.sigmas[self.step_index]
|
||||
sigma_next = self.sigmas[self.step_index + 1]
|
||||
dt = sigma_next - sigma
|
||||
|
||||
# Cast sample back to model compatible dtype
|
||||
prev_sample = prev_sample.to(model_output.dtype)
|
||||
prev_sample = sample + dt * model_output
|
||||
|
||||
# upon completion increase step index by one
|
||||
self._step_index += 1
|
||||
if per_token_timesteps is None:
|
||||
# Cast sample back to model compatible dtype
|
||||
prev_sample = prev_sample.to(model_output.dtype)
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
@@ -1217,6 +1217,21 @@ class LEditsPPPipelineStableDiffusionXL(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LTXConditionPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LTXImageToVideoPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
284
tests/pipelines/ltx/test_ltx_condition.py
Normal file
284
tests/pipelines/ltx/test_ltx_condition.py
Normal file
@@ -0,0 +1,284 @@
|
||||
# Copyright 2024 The HuggingFace Team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLLTXVideo,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
LTXConditionPipeline,
|
||||
LTXVideoTransformer3DModel,
|
||||
)
|
||||
from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXVideoCondition
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin, to_np
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class LTXConditionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = LTXConditionPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"image"})
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"generator",
|
||||
"latents",
|
||||
"return_dict",
|
||||
"callback_on_step_end",
|
||||
"callback_on_step_end_tensor_inputs",
|
||||
]
|
||||
)
|
||||
test_xformers_attention = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
transformer = LTXVideoTransformer3DModel(
|
||||
in_channels=8,
|
||||
out_channels=8,
|
||||
patch_size=1,
|
||||
patch_size_t=1,
|
||||
num_attention_heads=4,
|
||||
attention_head_dim=8,
|
||||
cross_attention_dim=32,
|
||||
num_layers=1,
|
||||
caption_channels=32,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKLLTXVideo(
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
latent_channels=8,
|
||||
block_out_channels=(8, 8, 8, 8),
|
||||
decoder_block_out_channels=(8, 8, 8, 8),
|
||||
layers_per_block=(1, 1, 1, 1, 1),
|
||||
decoder_layers_per_block=(1, 1, 1, 1, 1),
|
||||
spatio_temporal_scaling=(True, True, False, False),
|
||||
decoder_spatio_temporal_scaling=(True, True, False, False),
|
||||
decoder_inject_noise=(False, False, False, False, False),
|
||||
upsample_residual=(False, False, False, False),
|
||||
upsample_factor=(1, 1, 1, 1),
|
||||
timestep_conditioning=False,
|
||||
patch_size=1,
|
||||
patch_size_t=1,
|
||||
encoder_causal=True,
|
||||
decoder_causal=False,
|
||||
)
|
||||
vae.use_framewise_encoding = False
|
||||
vae.use_framewise_decoding = False
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
components = {
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0, use_conditions=False):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
image = torch.randn((1, 3, 32, 32), generator=generator, device=device)
|
||||
if use_conditions:
|
||||
conditions = LTXVideoCondition(
|
||||
image=image,
|
||||
)
|
||||
else:
|
||||
conditions = None
|
||||
|
||||
inputs = {
|
||||
"conditions": conditions,
|
||||
"image": None if use_conditions else image,
|
||||
"prompt": "dance monkey",
|
||||
"negative_prompt": "",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 3.0,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
# 8 * k + 1 is the recommendation
|
||||
"num_frames": 9,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs2 = self.get_dummy_inputs(device, use_conditions=True)
|
||||
video = pipe(**inputs).frames
|
||||
generated_video = video[0]
|
||||
video2 = pipe(**inputs2).frames
|
||||
generated_video2 = video2[0]
|
||||
|
||||
self.assertEqual(generated_video.shape, (9, 3, 32, 32))
|
||||
|
||||
max_diff = np.abs(generated_video - generated_video2).max()
|
||||
self.assertLessEqual(max_diff, 1e-3)
|
||||
|
||||
def test_callback_inputs(self):
|
||||
sig = inspect.signature(self.pipeline_class.__call__)
|
||||
has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
|
||||
has_callback_step_end = "callback_on_step_end" in sig.parameters
|
||||
|
||||
if not (has_callback_tensor_inputs and has_callback_step_end):
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
self.assertTrue(
|
||||
hasattr(pipe, "_callback_tensor_inputs"),
|
||||
f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
|
||||
)
|
||||
|
||||
def callback_inputs_subset(pipe, i, t, callback_kwargs):
|
||||
# iterate over callback args
|
||||
for tensor_name, tensor_value in callback_kwargs.items():
|
||||
# check that we're only passing in allowed tensor inputs
|
||||
assert tensor_name in pipe._callback_tensor_inputs
|
||||
|
||||
return callback_kwargs
|
||||
|
||||
def callback_inputs_all(pipe, i, t, callback_kwargs):
|
||||
for tensor_name in pipe._callback_tensor_inputs:
|
||||
assert tensor_name in callback_kwargs
|
||||
|
||||
# iterate over callback args
|
||||
for tensor_name, tensor_value in callback_kwargs.items():
|
||||
# check that we're only passing in allowed tensor inputs
|
||||
assert tensor_name in pipe._callback_tensor_inputs
|
||||
|
||||
return callback_kwargs
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
# Test passing in a subset
|
||||
inputs["callback_on_step_end"] = callback_inputs_subset
|
||||
inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
|
||||
output = pipe(**inputs)[0]
|
||||
|
||||
# Test passing in a everything
|
||||
inputs["callback_on_step_end"] = callback_inputs_all
|
||||
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
|
||||
output = pipe(**inputs)[0]
|
||||
|
||||
def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
|
||||
is_last = i == (pipe.num_timesteps - 1)
|
||||
if is_last:
|
||||
callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
|
||||
return callback_kwargs
|
||||
|
||||
inputs["callback_on_step_end"] = callback_inputs_change_tensor
|
||||
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
|
||||
output = pipe(**inputs)[0]
|
||||
assert output.abs().sum() < 1e10
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3)
|
||||
|
||||
def test_attention_slicing_forward_pass(
|
||||
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
|
||||
):
|
||||
if not self.test_attention_slicing:
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
for component in pipe.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_without_slicing = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_attention_slicing(slice_size=1)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_slicing1 = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_attention_slicing(slice_size=2)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_slicing2 = pipe(**inputs)[0]
|
||||
|
||||
if test_max_difference:
|
||||
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
|
||||
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
|
||||
self.assertLess(
|
||||
max(max_diff1, max_diff2),
|
||||
expected_max_diff,
|
||||
"Attention slicing should not affect the inference results",
|
||||
)
|
||||
|
||||
def test_vae_tiling(self, expected_diff_max: float = 0.2):
|
||||
generator_device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to("cpu")
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# Without tiling
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
inputs["height"] = inputs["width"] = 128
|
||||
output_without_tiling = pipe(**inputs)[0]
|
||||
|
||||
# With tiling
|
||||
pipe.vae.enable_tiling(
|
||||
tile_sample_min_height=96,
|
||||
tile_sample_min_width=96,
|
||||
tile_sample_stride_height=64,
|
||||
tile_sample_stride_width=64,
|
||||
)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
inputs["height"] = inputs["width"] = 128
|
||||
output_with_tiling = pipe(**inputs)[0]
|
||||
|
||||
self.assertLess(
|
||||
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
|
||||
expected_diff_max,
|
||||
"VAE tiling should not affect the inference results",
|
||||
)
|
||||
Reference in New Issue
Block a user