diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py index f74f608457..c3bb7a00a4 100644 --- a/src/diffusers/models/transformers/transformer_ltx2.py +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -246,7 +246,6 @@ class LTX2Attention(torch.nn.Module, AttentionModuleMixin): return hidden_states -@maybe_allow_in_graph class LTX2VideoTransformerBlock(nn.Module): r""" Transformer block used in [LTX-2.0](https://huggingface.co/Lightricks/LTX-Video). @@ -802,7 +801,6 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module): return cos_freqs, sin_freqs -@maybe_allow_in_graph class LTX2VideoTransformer3DModel( ModelMixin, ConfigMixin, AttentionMixin, FromOriginalModelMixin, PeftAdapterMixin, CacheMixin ): @@ -834,7 +832,7 @@ class LTX2VideoTransformer3DModel( _supports_gradient_checkpointing = True _skip_layerwise_casting_patterns = ["norm"] - _repeated_blocks = ["LTXVideoTransformerBlock"] + _repeated_blocks = ["LTX2VideoTransformerBlock"] _cp_plan = { "": { "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),