mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Merge branch 'ltx-2-transformer' of github.com:huggingface/diffusers into ltx-2-transformer
This commit is contained in:
22
.github/workflows/codeql.yml
vendored
Normal file
22
.github/workflows/codeql.yml
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
---
|
||||
name: CodeQL Security Analysis For Github Actions
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
workflow_dispatch:
|
||||
# pull_request:
|
||||
|
||||
jobs:
|
||||
codeql:
|
||||
name: CodeQL Analysis
|
||||
uses: huggingface/security-workflows/.github/workflows/codeql-reusable.yml@v1
|
||||
permissions:
|
||||
security-events: write
|
||||
packages: read
|
||||
actions: read
|
||||
contents: read
|
||||
with:
|
||||
languages: '["actions","python"]'
|
||||
queries: 'security-extended,security-and-quality'
|
||||
runner: 'ubuntu-latest' #optional if need custom runner
|
||||
@@ -136,7 +136,7 @@ export_to_video(video, "output.mp4", fps=24)
|
||||
- The recommended dtype for the transformer, VAE, and text encoder is `torch.bfloat16`. The VAE and text encoder can also be `torch.float32` or `torch.float16`.
|
||||
- For guidance-distilled variants of LTX-Video, set `guidance_scale` to `1.0`. The `guidance_scale` for any other model should be set higher, like `5.0`, for good generation quality.
|
||||
- For timestep-aware VAE variants (LTX-Video 0.9.1 and above), set `decode_timestep` to `0.05` and `image_cond_noise_scale` to `0.025`.
|
||||
- For variants that support interpolation between multiple conditioning images and videos (LTX-Video 0.9.5 and above), use similar images and videos for the best results. Divergence from the conditioning inputs may lead to abrupt transitionts in the generated video.
|
||||
- For variants that support interpolation between multiple conditioning images and videos (LTX-Video 0.9.5 and above), use similar images and videos for the best results. Divergence from the conditioning inputs may lead to abrupt transitions in the generated video.
|
||||
|
||||
- LTX-Video 0.9.7 includes a spatial latent upscaler and a 13B parameter transformer. During inference, a low resolution video is quickly generated first and then upscaled and refined.
|
||||
|
||||
@@ -329,7 +329,7 @@ export_to_video(video, "output.mp4", fps=24)
|
||||
|
||||
<details>
|
||||
<summary>Show example code</summary>
|
||||
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import LTXConditionPipeline, LTXLatentUpsamplePipeline
|
||||
@@ -474,6 +474,12 @@ export_to_video(video, "output.mp4", fps=24)
|
||||
|
||||
</details>
|
||||
|
||||
## LTXI2VLongMultiPromptPipeline
|
||||
|
||||
[[autodoc]] LTXI2VLongMultiPromptPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## LTXPipeline
|
||||
|
||||
[[autodoc]] LTXPipeline
|
||||
|
||||
@@ -356,6 +356,7 @@ else:
|
||||
"KDPM2AncestralDiscreteScheduler",
|
||||
"KDPM2DiscreteScheduler",
|
||||
"LCMScheduler",
|
||||
"LTXEulerAncestralRFScheduler",
|
||||
"PNDMScheduler",
|
||||
"RePaintScheduler",
|
||||
"SASolverScheduler",
|
||||
@@ -543,6 +544,7 @@ else:
|
||||
"LTX2ImageToVideoPipeline",
|
||||
"LTX2Pipeline",
|
||||
"LTXConditionPipeline",
|
||||
"LTXI2VLongMultiPromptPipeline",
|
||||
"LTXImageToVideoPipeline",
|
||||
"LTXLatentUpsamplePipeline",
|
||||
"LTXPipeline",
|
||||
@@ -1096,6 +1098,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
LCMScheduler,
|
||||
LTXEulerAncestralRFScheduler,
|
||||
PNDMScheduler,
|
||||
RePaintScheduler,
|
||||
SASolverScheduler,
|
||||
@@ -1262,6 +1265,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LTX2ImageToVideoPipeline,
|
||||
LTX2Pipeline,
|
||||
LTXConditionPipeline,
|
||||
LTXI2VLongMultiPromptPipeline,
|
||||
LTXImageToVideoPipeline,
|
||||
LTXLatentUpsamplePipeline,
|
||||
LTXPipeline,
|
||||
|
||||
@@ -1420,6 +1420,7 @@ def _flash_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
@@ -1427,6 +1428,9 @@ def _flash_attention(
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
lse = None
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for flash-attn 2.")
|
||||
|
||||
if _parallel_config is None:
|
||||
out = flash_attn_func(
|
||||
q=query,
|
||||
@@ -1469,6 +1473,7 @@ def _flash_attention_hub(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
@@ -1476,6 +1481,9 @@ def _flash_attention_hub(
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
lse = None
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for flash-attn 2.")
|
||||
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].kernel_fn
|
||||
out = func(
|
||||
q=query,
|
||||
@@ -1612,11 +1620,15 @@ def _flash_attention_3(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
scale: Optional[float] = None,
|
||||
is_causal: bool = False,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for flash-attn 3.")
|
||||
|
||||
out, lse = _wrapped_flash_attn_3(
|
||||
q=query,
|
||||
k=key,
|
||||
@@ -1636,6 +1648,7 @@ def _flash_attention_3_hub(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
scale: Optional[float] = None,
|
||||
is_causal: bool = False,
|
||||
window_size: Tuple[int, int] = (-1, -1),
|
||||
@@ -1646,6 +1659,8 @@ def _flash_attention_3_hub(
|
||||
) -> torch.Tensor:
|
||||
if _parallel_config:
|
||||
raise NotImplementedError(f"{AttentionBackendName._FLASH_3_HUB.value} is not implemented for parallelism yet.")
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for flash-attn 3.")
|
||||
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
|
||||
out = func(
|
||||
@@ -1785,12 +1800,16 @@ def _aiter_flash_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for aiter attention")
|
||||
|
||||
if not return_lse and torch.is_grad_enabled():
|
||||
# aiter requires return_lse=True by assertion when gradients are enabled.
|
||||
out, lse, *_ = aiter_flash_attn_func(
|
||||
@@ -2028,6 +2047,7 @@ def _native_flash_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
@@ -2035,6 +2055,9 @@ def _native_flash_attention(
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for aiter attention")
|
||||
|
||||
lse = None
|
||||
if _parallel_config is None and not return_lse:
|
||||
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
@@ -2113,11 +2136,14 @@ def _native_npu_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for NPU attention")
|
||||
if return_lse:
|
||||
raise ValueError("NPU attention backend does not support setting `return_lse=True`.")
|
||||
query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
|
||||
@@ -2148,10 +2174,13 @@ def _native_xla_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
is_causal: bool = False,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for XLA attention")
|
||||
if return_lse:
|
||||
raise ValueError("XLA attention backend does not support setting `return_lse=True`.")
|
||||
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
@@ -2175,11 +2204,14 @@ def _sage_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for sage attention")
|
||||
lse = None
|
||||
if _parallel_config is None:
|
||||
out = sageattn(
|
||||
@@ -2223,11 +2255,14 @@ def _sage_attention_hub(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for sage attention")
|
||||
lse = None
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn
|
||||
if _parallel_config is None:
|
||||
@@ -2309,11 +2344,14 @@ def _sage_qk_int8_pv_fp8_cuda_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for sage attention")
|
||||
return sageattn_qk_int8_pv_fp8_cuda(
|
||||
q=query,
|
||||
k=key,
|
||||
@@ -2333,11 +2371,14 @@ def _sage_qk_int8_pv_fp8_cuda_sm90_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for sage attention")
|
||||
return sageattn_qk_int8_pv_fp8_cuda_sm90(
|
||||
q=query,
|
||||
k=key,
|
||||
@@ -2357,11 +2398,14 @@ def _sage_qk_int8_pv_fp16_cuda_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for sage attention")
|
||||
return sageattn_qk_int8_pv_fp16_cuda(
|
||||
q=query,
|
||||
k=key,
|
||||
@@ -2381,11 +2425,14 @@ def _sage_qk_int8_pv_fp16_triton_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for sage attention")
|
||||
return sageattn_qk_int8_pv_fp16_triton(
|
||||
q=query,
|
||||
k=key,
|
||||
|
||||
@@ -22,7 +22,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, is_torch_npu_available, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
@@ -717,11 +717,7 @@ class FluxTransformer2DModel(
|
||||
img_ids = img_ids[0]
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=0)
|
||||
if is_torch_npu_available():
|
||||
freqs_cos, freqs_sin = self.pos_embed(ids.cpu())
|
||||
image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu())
|
||||
else:
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
|
||||
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
|
||||
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, is_torch_npu_available, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
@@ -835,14 +835,8 @@ class Flux2Transformer2DModel(
|
||||
if txt_ids.ndim == 3:
|
||||
txt_ids = txt_ids[0]
|
||||
|
||||
if is_torch_npu_available():
|
||||
freqs_cos_image, freqs_sin_image = self.pos_embed(img_ids.cpu())
|
||||
image_rotary_emb = (freqs_cos_image.npu(), freqs_sin_image.npu())
|
||||
freqs_cos_text, freqs_sin_text = self.pos_embed(txt_ids.cpu())
|
||||
text_rotary_emb = (freqs_cos_text.npu(), freqs_sin_text.npu())
|
||||
else:
|
||||
image_rotary_emb = self.pos_embed(img_ids)
|
||||
text_rotary_emb = self.pos_embed(txt_ids)
|
||||
image_rotary_emb = self.pos_embed(img_ids)
|
||||
text_rotary_emb = self.pos_embed(txt_ids)
|
||||
concat_rotary_emb = (
|
||||
torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0),
|
||||
torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import is_torch_npu_available, logging
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import AttentionModuleMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
@@ -499,11 +499,7 @@ class LongCatImageTransformer2DModel(
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=0)
|
||||
if is_torch_npu_available():
|
||||
freqs_cos, freqs_sin = self.pos_embed(ids.cpu())
|
||||
image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu())
|
||||
else:
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing and self.use_checkpoint[index_block]:
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import is_torch_npu_available, logging
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import AttentionModuleMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
@@ -530,11 +530,7 @@ class OvisImageTransformer2DModel(
|
||||
img_ids = img_ids[0]
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=0)
|
||||
if is_torch_npu_available():
|
||||
freqs_cos, freqs_sin = self.pos_embed(ids.cpu())
|
||||
image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu())
|
||||
else:
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
@@ -288,6 +288,7 @@ else:
|
||||
"LTXImageToVideoPipeline",
|
||||
"LTXConditionPipeline",
|
||||
"LTXLatentUpsamplePipeline",
|
||||
"LTXI2VLongMultiPromptPipeline",
|
||||
]
|
||||
_import_structure["ltx2"] = ["LTX2Pipeline", "LTX2ImageToVideoPipeline"]
|
||||
_import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"]
|
||||
@@ -730,7 +731,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
)
|
||||
from .longcat_image import LongCatImageEditPipeline, LongCatImagePipeline
|
||||
from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline
|
||||
from .ltx import (
|
||||
LTXConditionPipeline,
|
||||
LTXI2VLongMultiPromptPipeline,
|
||||
LTXImageToVideoPipeline,
|
||||
LTXLatentUpsamplePipeline,
|
||||
LTXPipeline,
|
||||
)
|
||||
from .ltx2 import LTX2ImageToVideoPipeline, LTX2Pipeline
|
||||
from .lucy import LucyEditPipeline
|
||||
from .lumina import LuminaPipeline, LuminaText2ImgPipeline
|
||||
|
||||
@@ -25,6 +25,7 @@ else:
|
||||
_import_structure["modeling_latent_upsampler"] = ["LTXLatentUpsamplerModel"]
|
||||
_import_structure["pipeline_ltx"] = ["LTXPipeline"]
|
||||
_import_structure["pipeline_ltx_condition"] = ["LTXConditionPipeline"]
|
||||
_import_structure["pipeline_ltx_i2v_long_multi_prompt"] = ["LTXI2VLongMultiPromptPipeline"]
|
||||
_import_structure["pipeline_ltx_image2video"] = ["LTXImageToVideoPipeline"]
|
||||
_import_structure["pipeline_ltx_latent_upsample"] = ["LTXLatentUpsamplePipeline"]
|
||||
|
||||
@@ -39,6 +40,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .modeling_latent_upsampler import LTXLatentUpsamplerModel
|
||||
from .pipeline_ltx import LTXPipeline
|
||||
from .pipeline_ltx_condition import LTXConditionPipeline
|
||||
from .pipeline_ltx_i2v_long_multi_prompt import LTXI2VLongMultiPromptPipeline
|
||||
from .pipeline_ltx_image2video import LTXImageToVideoPipeline
|
||||
from .pipeline_ltx_latent_upsample import LTXLatentUpsamplePipeline
|
||||
|
||||
|
||||
1408
src/diffusers/pipelines/ltx/pipeline_ltx_i2v_long_multi_prompt.py
Normal file
1408
src/diffusers/pipelines/ltx/pipeline_ltx_i2v_long_multi_prompt.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -66,6 +66,7 @@ else:
|
||||
_import_structure["scheduling_k_dpm_2_ancestral_discrete"] = ["KDPM2AncestralDiscreteScheduler"]
|
||||
_import_structure["scheduling_k_dpm_2_discrete"] = ["KDPM2DiscreteScheduler"]
|
||||
_import_structure["scheduling_lcm"] = ["LCMScheduler"]
|
||||
_import_structure["scheduling_ltx_euler_ancestral_rf"] = ["LTXEulerAncestralRFScheduler"]
|
||||
_import_structure["scheduling_pndm"] = ["PNDMScheduler"]
|
||||
_import_structure["scheduling_repaint"] = ["RePaintScheduler"]
|
||||
_import_structure["scheduling_sasolver"] = ["SASolverScheduler"]
|
||||
@@ -168,6 +169,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .scheduling_k_dpm_2_ancestral_discrete import KDPM2AncestralDiscreteScheduler
|
||||
from .scheduling_k_dpm_2_discrete import KDPM2DiscreteScheduler
|
||||
from .scheduling_lcm import LCMScheduler
|
||||
from .scheduling_ltx_euler_ancestral_rf import LTXEulerAncestralRFScheduler
|
||||
from .scheduling_pndm import PNDMScheduler
|
||||
from .scheduling_repaint import RePaintScheduler
|
||||
from .scheduling_sasolver import SASolverScheduler
|
||||
|
||||
386
src/diffusers/schedulers/scheduling_ltx_euler_ancestral_rf.py
Normal file
386
src/diffusers/schedulers/scheduling_ltx_euler_ancestral_rf.py
Normal file
@@ -0,0 +1,386 @@
|
||||
# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
LTXEulerAncestralRFScheduler
|
||||
|
||||
This scheduler implements a K-diffusion style Euler-Ancestral sampler specialized for flow / CONST parameterization,
|
||||
closely mirroring ComfyUI's `sample_euler_ancestral_RF` implementation used for LTX-Video.
|
||||
|
||||
Reference implementation (ComfyUI):
|
||||
comfy.k_diffusion.sampling.sample_euler_ancestral_RF
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput, logging
|
||||
from ..utils.torch_utils import randn_tensor
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclass
|
||||
class LTXEulerAncestralRFSchedulerOutput(BaseOutput):
|
||||
"""
|
||||
Output class for the scheduler's `step` function output.
|
||||
|
||||
Args:
|
||||
prev_sample (`torch.FloatTensor`):
|
||||
Updated sample for the next step in the denoising process.
|
||||
"""
|
||||
|
||||
prev_sample: torch.FloatTensor
|
||||
|
||||
|
||||
class LTXEulerAncestralRFScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
Euler-Ancestral scheduler for LTX-Video (RF / CONST parametrization).
|
||||
|
||||
This scheduler is intended for models where the network is trained with a CONST-like parameterization (as in LTXV /
|
||||
FLUX). It approximates ComfyUI's `sample_euler_ancestral_RF` sampler and is useful when reproducing ComfyUI
|
||||
workflows inside diffusers.
|
||||
|
||||
The scheduler can either:
|
||||
- reuse the [`FlowMatchEulerDiscreteScheduler`] sigma / timestep logic when only `num_inference_steps` is provided
|
||||
(default diffusers-style usage), or
|
||||
- follow an explicit ComfyUI-style sigma schedule when `sigmas` (or `timesteps`) are passed to [`set_timesteps`].
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`, defaults to 1000):
|
||||
Included for config compatibility; not used to build the schedule.
|
||||
eta (`float`, defaults to 1.0):
|
||||
Stochasticity parameter. `eta=0.0` yields deterministic DDIM-like sampling; `eta=1.0` matches ComfyUI's
|
||||
default RF behavior.
|
||||
s_noise (`float`, defaults to 1.0):
|
||||
Global scaling factor for the stochastic noise term.
|
||||
"""
|
||||
|
||||
# Allow config migration from the flow-match scheduler and back.
|
||||
_compatibles = ["FlowMatchEulerDiscreteScheduler"]
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
eta: float = 1.0,
|
||||
s_noise: float = 1.0,
|
||||
):
|
||||
# Note: num_train_timesteps is kept only for config compatibility.
|
||||
self.num_inference_steps: Optional[int] = None
|
||||
self.sigmas: Optional[torch.Tensor] = None
|
||||
self.timesteps: Optional[torch.Tensor] = None
|
||||
self._step_index: Optional[int] = None
|
||||
self._begin_index: Optional[int] = None
|
||||
|
||||
@property
|
||||
def step_index(self) -> Optional[int]:
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self) -> Optional[int]:
|
||||
"""
|
||||
The index for the first timestep. It can be set from a pipeline with `set_begin_index` to support
|
||||
image-to-image like workflows that start denoising part-way through the schedule.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""
|
||||
Included for API compatibility; not strictly needed here but kept to allow pipelines that call
|
||||
`set_begin_index`.
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def index_for_timestep(
|
||||
self, timestep: Union[float, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
||||
) -> int:
|
||||
"""
|
||||
Map a (continuous) `timestep` value to an index into `self.timesteps`.
|
||||
|
||||
This follows the convention used in other discrete schedulers: if the same timestep value appears multiple
|
||||
times in the schedule (which can happen when starting in the middle of the schedule), the *second* occurrence
|
||||
is used for the first `step` call so that no sigma is accidentally skipped.
|
||||
"""
|
||||
if schedule_timesteps is None:
|
||||
if self.timesteps is None:
|
||||
raise ValueError("Timesteps have not been set. Call `set_timesteps` first.")
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(schedule_timesteps.device)
|
||||
|
||||
indices = (schedule_timesteps == timestep).nonzero()
|
||||
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
pos = 1 if len(indices) > 1 else 0
|
||||
|
||||
if len(indices) == 0:
|
||||
raise ValueError(
|
||||
"Passed `timestep` is not in `self.timesteps`. Make sure to use values from `scheduler.timesteps`."
|
||||
)
|
||||
|
||||
return indices[pos].item()
|
||||
|
||||
def _init_step_index(self, timestep: Union[float, torch.Tensor]):
|
||||
"""
|
||||
Initialize the internal step index based on a given timestep.
|
||||
"""
|
||||
if self.timesteps is None:
|
||||
raise ValueError("Timesteps have not been set. Call `set_timesteps` first.")
|
||||
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Union[str, torch.device, None] = None,
|
||||
sigmas: Optional[Union[List[float], torch.Tensor]] = None,
|
||||
timesteps: Optional[Union[List[float], torch.Tensor]] = None,
|
||||
mu: Optional[float] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Set the sigma / timestep schedule for sampling.
|
||||
|
||||
When `sigmas` or `timesteps` are provided explicitly, they are used as the RF sigma schedule (ComfyUI-style)
|
||||
and are expected to include the terminal 0.0. When both are `None`, the scheduler reuses the
|
||||
[`FlowMatchEulerDiscreteScheduler`] logic to generate sigmas from `num_inference_steps` and the stored config
|
||||
(including any resolution-dependent shifting, Karras/beta schedules, etc.).
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`, *optional*):
|
||||
Number of denoising steps. If provided together with explicit `sigmas`/`timesteps`, they are expected
|
||||
to be consistent and are otherwise ignored with a warning.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
Device to move the internal tensors to.
|
||||
sigmas (`List[float]` or `torch.Tensor`, *optional*):
|
||||
Explicit sigma schedule, e.g. `[1.0, 0.99, ..., 0.0]`.
|
||||
timesteps (`List[float]` or `torch.Tensor`, *optional*):
|
||||
Optional alias for `sigmas`. If `sigmas` is None and `timesteps` is provided, timesteps are treated as
|
||||
sigmas.
|
||||
mu (`float`, *optional*):
|
||||
Optional shift parameter used when delegating to [`FlowMatchEulerDiscreteScheduler.set_timesteps`] and
|
||||
`config.use_dynamic_shifting` is `True`.
|
||||
"""
|
||||
# 1. Auto-generate schedule (FlowMatch-style) when no explicit sigmas/timesteps are given
|
||||
if sigmas is None and timesteps is None:
|
||||
if num_inference_steps is None:
|
||||
raise ValueError(
|
||||
"LTXEulerAncestralRFScheduler.set_timesteps requires either explicit `sigmas`/`timesteps` "
|
||||
"or a `num_inference_steps` value."
|
||||
)
|
||||
|
||||
# We reuse FlowMatchEulerDiscreteScheduler to construct a sigma schedule that is
|
||||
# consistent with the original LTX training setup (including optional time shifting,
|
||||
# Karras / exponential / beta schedules, etc.).
|
||||
from .scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
|
||||
|
||||
base_scheduler = FlowMatchEulerDiscreteScheduler.from_config(self.config)
|
||||
base_scheduler.set_timesteps(
|
||||
num_inference_steps=num_inference_steps,
|
||||
device=device,
|
||||
sigmas=None,
|
||||
mu=mu,
|
||||
timesteps=None,
|
||||
)
|
||||
|
||||
self.num_inference_steps = base_scheduler.num_inference_steps
|
||||
# Keep sigmas / timesteps on the requested device so step() can operate on-device without
|
||||
# extra transfers.
|
||||
self.sigmas = base_scheduler.sigmas.to(device=device)
|
||||
self.timesteps = base_scheduler.timesteps.to(device=device)
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
return
|
||||
|
||||
# 2. Explicit sigma schedule (ComfyUI-style path)
|
||||
if sigmas is None:
|
||||
# `timesteps` is treated as sigmas in RF / flow-matching setups.
|
||||
sigmas = timesteps
|
||||
|
||||
if isinstance(sigmas, list):
|
||||
sigmas_tensor = torch.tensor(sigmas, dtype=torch.float32)
|
||||
elif isinstance(sigmas, torch.Tensor):
|
||||
sigmas_tensor = sigmas.to(dtype=torch.float32)
|
||||
else:
|
||||
raise TypeError(f"`sigmas` must be a list or torch.Tensor, got {type(sigmas)}.")
|
||||
|
||||
if sigmas_tensor.ndim != 1:
|
||||
raise ValueError(f"`sigmas` must be a 1D tensor, got shape {tuple(sigmas_tensor.shape)}.")
|
||||
|
||||
if sigmas_tensor[-1].abs().item() > 1e-6:
|
||||
logger.warning(
|
||||
"The last sigma in the schedule is not zero (%.6f). "
|
||||
"For best compatibility with ComfyUI's RF sampler, the terminal sigma "
|
||||
"should be 0.0.",
|
||||
sigmas_tensor[-1].item(),
|
||||
)
|
||||
|
||||
# Move to device once, then derive timesteps.
|
||||
if device is not None:
|
||||
sigmas_tensor = sigmas_tensor.to(device)
|
||||
|
||||
# Internal sigma schedule stays in [0, 1] (as provided).
|
||||
self.sigmas = sigmas_tensor
|
||||
# Timesteps are scaled to match the training setup of LTX (FlowMatch-style),
|
||||
# where the network expects timesteps on [0, num_train_timesteps].
|
||||
# This keeps the transformer conditioning in the expected range while the RF
|
||||
# scheduler still operates on the raw sigma values.
|
||||
num_train = float(getattr(self.config, "num_train_timesteps", 1000))
|
||||
self.timesteps = sigmas_tensor * num_train
|
||||
|
||||
if num_inference_steps is not None and num_inference_steps != len(sigmas) - 1:
|
||||
logger.warning(
|
||||
"Provided `num_inference_steps=%d` does not match `len(sigmas)-1=%d`. "
|
||||
"Overriding `num_inference_steps` with `len(sigmas)-1`.",
|
||||
num_inference_steps,
|
||||
len(sigmas) - 1,
|
||||
)
|
||||
|
||||
self.num_inference_steps = len(sigmas) - 1
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
def _sigma_broadcast(self, sigma: torch.Tensor, sample: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Helper to broadcast a scalar sigma to the shape of `sample`.
|
||||
"""
|
||||
while sigma.ndim < sample.ndim:
|
||||
sigma = sigma.view(*sigma.shape, 1)
|
||||
return sigma
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: Union[float, torch.Tensor],
|
||||
sample: torch.FloatTensor,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[LTXEulerAncestralRFSchedulerOutput, Tuple[torch.FloatTensor]]:
|
||||
"""
|
||||
Perform a single Euler-Ancestral RF update step.
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`):
|
||||
Raw model output at the current step. Interpreted under the CONST parametrization as `v_t`, with
|
||||
denoised state reconstructed as `x0 = x_t - sigma_t * v_t`.
|
||||
timestep (`float` or `torch.Tensor`):
|
||||
The current sigma value (must match one entry in `self.timesteps`).
|
||||
sample (`torch.FloatTensor`):
|
||||
Current latent sample `x_t`.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
Optional generator for reproducible noise.
|
||||
return_dict (`bool`):
|
||||
If `True`, return a `LTXEulerAncestralRFSchedulerOutput`; otherwise return a tuple where the first
|
||||
element is the updated sample.
|
||||
"""
|
||||
|
||||
if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
|
||||
raise ValueError(
|
||||
(
|
||||
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
||||
" `LTXEulerAncestralRFScheduler.step()` is not supported. Make sure to pass"
|
||||
" one of the `scheduler.timesteps` values as `timestep`."
|
||||
),
|
||||
)
|
||||
|
||||
if self.sigmas is None or self.timesteps is None:
|
||||
raise ValueError("Scheduler has not been initialized. Call `set_timesteps` before `step`.")
|
||||
|
||||
if self._step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
i = self._step_index
|
||||
if i >= len(self.sigmas) - 1:
|
||||
# Already at the end; simply return the current sample.
|
||||
prev_sample = sample
|
||||
else:
|
||||
# Work in float32 for numerical stability
|
||||
sample_f = sample.to(torch.float32)
|
||||
model_output_f = model_output.to(torch.float32)
|
||||
|
||||
sigma = self.sigmas[i]
|
||||
sigma_next = self.sigmas[i + 1]
|
||||
|
||||
sigma_b = self._sigma_broadcast(sigma.view(1), sample_f)
|
||||
sigma_next_b = self._sigma_broadcast(sigma_next.view(1), sample_f)
|
||||
|
||||
# Approximate denoised x0 under CONST parametrization:
|
||||
# x0 = x_t - sigma_t * v_t
|
||||
denoised = sample_f - sigma_b * model_output_f
|
||||
|
||||
if sigma_next.abs().item() < 1e-8:
|
||||
# Final denoising step
|
||||
x = denoised
|
||||
else:
|
||||
eta = float(self.config.eta)
|
||||
s_noise = float(self.config.s_noise)
|
||||
|
||||
# Downstep computation (ComfyUI RF variant)
|
||||
downstep_ratio = 1.0 + (sigma_next / sigma - 1.0) * eta
|
||||
sigma_down = sigma_next * downstep_ratio
|
||||
|
||||
alpha_ip1 = 1.0 - sigma_next
|
||||
alpha_down = 1.0 - sigma_down
|
||||
|
||||
# Deterministic part (Euler step in (x, x0)-space)
|
||||
sigma_down_b = self._sigma_broadcast(sigma_down.view(1), sample_f)
|
||||
alpha_ip1_b = self._sigma_broadcast(alpha_ip1.view(1), sample_f)
|
||||
alpha_down_b = self._sigma_broadcast(alpha_down.view(1), sample_f)
|
||||
|
||||
sigma_ratio = sigma_down_b / sigma_b
|
||||
x = sigma_ratio * sample_f + (1.0 - sigma_ratio) * denoised
|
||||
|
||||
# Stochastic ancestral noise
|
||||
if eta > 0.0 and s_noise > 0.0:
|
||||
renoise_coeff = (
|
||||
(sigma_next_b**2 - sigma_down_b**2 * alpha_ip1_b**2 / (alpha_down_b**2 + 1e-12))
|
||||
.clamp(min=0.0)
|
||||
.sqrt()
|
||||
)
|
||||
|
||||
noise = randn_tensor(
|
||||
sample_f.shape, generator=generator, device=sample_f.device, dtype=sample_f.dtype
|
||||
)
|
||||
x = (alpha_ip1_b / (alpha_down_b + 1e-12)) * x + noise * renoise_coeff * s_noise
|
||||
|
||||
prev_sample = x.to(sample.dtype)
|
||||
|
||||
# Advance internal step index
|
||||
self._step_index = min(self._step_index + 1, len(self.sigmas) - 1)
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
return LTXEulerAncestralRFSchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
def __len__(self) -> int:
|
||||
# For compatibility with other schedulers; used e.g. in some training
|
||||
# utilities to infer the maximum number of training timesteps.
|
||||
return int(getattr(self.config, "num_train_timesteps", 1000))
|
||||
@@ -2679,6 +2679,21 @@ class LCMScheduler(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class LTXEulerAncestralRFScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class PNDMScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -1922,6 +1922,21 @@ class LTXConditionPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LTXI2VLongMultiPromptPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LTXImageToVideoPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user