1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[LoRA] add LoRA support to LTX-2 (#12933)

* up

* fixes

* tests

* docs.

* fix

* change loading info.

* up

* up
This commit is contained in:
Sayak Paul
2026-01-10 11:27:22 +05:30
committed by GitHub
parent d44b5f86e6
commit ed6e5ecf67
9 changed files with 587 additions and 3 deletions

View File

@@ -33,6 +33,7 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
- [`QwenImageLoraLoaderMixin`] provides similar functions for [Qwen Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/qwen).
- [`ZImageLoraLoaderMixin`] provides similar functions for [Z-Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/zimage).
- [`Flux2LoraLoaderMixin`] provides similar functions for [Flux2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux2).
- [`LTX2LoraLoaderMixin`] provides similar functions for [Flux2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/ltx2).
- [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more.
> [!TIP]
@@ -62,6 +63,10 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
[[autodoc]] loaders.lora_pipeline.Flux2LoraLoaderMixin
## LTX2LoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.LTX2LoraLoaderMixin
## CogVideoXLoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.CogVideoXLoraLoaderMixin

View File

@@ -14,6 +14,10 @@
# LTX-2
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
</div>
LTX-2 is a DiT-based audio-video foundation model designed to generate synchronized video and audio within a single model. It brings together the core building blocks of modern video generation, with open weights and a focus on practical, local execution.
You can find all the original LTX-Video checkpoints under the [Lightricks](https://huggingface.co/Lightricks) organization.

View File

@@ -67,6 +67,7 @@ if is_torch_available():
"SD3LoraLoaderMixin",
"AuraFlowLoraLoaderMixin",
"StableDiffusionXLLoraLoaderMixin",
"LTX2LoraLoaderMixin",
"LTXVideoLoraLoaderMixin",
"LoraLoaderMixin",
"FluxLoraLoaderMixin",
@@ -121,6 +122,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
HunyuanVideoLoraLoaderMixin,
KandinskyLoraLoaderMixin,
LoraLoaderMixin,
LTX2LoraLoaderMixin,
LTXVideoLoraLoaderMixin,
Lumina2LoraLoaderMixin,
Mochi1LoraLoaderMixin,

View File

@@ -2140,6 +2140,54 @@ def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_pref
return converted_state_dict
def _convert_non_diffusers_ltx2_lora_to_diffusers(state_dict, non_diffusers_prefix="diffusion_model"):
# Remove the prefix
state_dict = {k: v for k, v in state_dict.items() if k.startswith(f"{non_diffusers_prefix}.")}
converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()}
if non_diffusers_prefix == "diffusion_model":
rename_dict = {
"patchify_proj": "proj_in",
"audio_patchify_proj": "audio_proj_in",
"av_ca_video_scale_shift_adaln_single": "av_cross_attn_video_scale_shift",
"av_ca_a2v_gate_adaln_single": "av_cross_attn_video_a2v_gate",
"av_ca_audio_scale_shift_adaln_single": "av_cross_attn_audio_scale_shift",
"av_ca_v2a_gate_adaln_single": "av_cross_attn_audio_v2a_gate",
"scale_shift_table_a2v_ca_video": "video_a2v_cross_attn_scale_shift_table",
"scale_shift_table_a2v_ca_audio": "audio_a2v_cross_attn_scale_shift_table",
"q_norm": "norm_q",
"k_norm": "norm_k",
}
else:
rename_dict = {"aggregate_embed": "text_proj_in"}
# Apply renaming
renamed_state_dict = {}
for key, value in converted_state_dict.items():
new_key = key[:]
for old_pattern, new_pattern in rename_dict.items():
new_key = new_key.replace(old_pattern, new_pattern)
renamed_state_dict[new_key] = value
# Handle adaln_single -> time_embed and audio_adaln_single -> audio_time_embed
final_state_dict = {}
for key, value in renamed_state_dict.items():
if key.startswith("adaln_single."):
new_key = key.replace("adaln_single.", "time_embed.")
final_state_dict[new_key] = value
elif key.startswith("audio_adaln_single."):
new_key = key.replace("audio_adaln_single.", "audio_time_embed.")
final_state_dict[new_key] = value
else:
final_state_dict[key] = value
# Add transformer prefix
prefix = "transformer" if non_diffusers_prefix == "diffusion_model" else "connectors"
final_state_dict = {f"{prefix}.{k}": v for k, v in final_state_dict.items()}
return final_state_dict
def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
has_diffusion_model = any(k.startswith("diffusion_model.") for k in state_dict)
if has_diffusion_model:

View File

@@ -48,6 +48,7 @@ from .lora_conversion_utils import (
_convert_non_diffusers_flux2_lora_to_diffusers,
_convert_non_diffusers_hidream_lora_to_diffusers,
_convert_non_diffusers_lora_to_diffusers,
_convert_non_diffusers_ltx2_lora_to_diffusers,
_convert_non_diffusers_ltxv_lora_to_diffusers,
_convert_non_diffusers_lumina2_lora_to_diffusers,
_convert_non_diffusers_qwen_lora_to_diffusers,
@@ -74,6 +75,7 @@ logger = logging.get_logger(__name__)
TEXT_ENCODER_NAME = "text_encoder"
UNET_NAME = "unet"
TRANSFORMER_NAME = "transformer"
LTX2_CONNECTOR_NAME = "connectors"
_MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX = {"x_embedder": "in_channels"}
@@ -3011,6 +3013,233 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
super().unfuse_lora(components=components, **kwargs)
class LTX2LoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into [`LTX2VideoTransformer3DModel`]. Specific to [`LTX2Pipeline`].
"""
_lora_loadable_modules = ["transformer", "connectors"]
transformer_name = TRANSFORMER_NAME
connectors_name = LTX2_CONNECTOR_NAME
@classmethod
@validate_hf_hub_args
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
local_files_only=local_files_only,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
allow_pickle=allow_pickle,
)
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
if is_dora_scale_present:
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
final_state_dict = state_dict
is_non_diffusers_format = any(k.startswith("diffusion_model.") for k in state_dict)
has_connector = any(k.startswith("text_embedding_projection.") for k in state_dict)
if is_non_diffusers_format:
final_state_dict = _convert_non_diffusers_ltx2_lora_to_diffusers(state_dict)
if has_connector:
connectors_state_dict = _convert_non_diffusers_ltx2_lora_to_diffusers(
state_dict, "text_embedding_projection"
)
final_state_dict.update(connectors_state_dict)
out = (final_state_dict, metadata) if return_lora_metadata else final_state_dict
return out
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name: Optional[str] = None,
hotswap: bool = False,
**kwargs,
):
"""
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# if a dict is passed, copy it instead of modifying it inplace
if isinstance(pretrained_model_name_or_path_or_dict, dict):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
kwargs["return_lora_metadata"] = True
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
transformer_peft_state_dict = {
k: v for k, v in state_dict.items() if k.startswith(f"{self.transformer_name}.")
}
connectors_peft_state_dict = {k: v for k, v in state_dict.items() if k.startswith(f"{self.connectors_name}.")}
self.load_lora_into_transformer(
transformer_peft_state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
if connectors_peft_state_dict:
self.load_lora_into_transformer(
connectors_peft_state_dict,
transformer=getattr(self, self.connectors_name)
if not hasattr(self, "connectors")
else self.connectors,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
prefix=self.connectors_name,
)
@classmethod
def load_lora_into_transformer(
cls,
state_dict,
transformer,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
prefix: str = "transformer",
):
"""
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# Load the layers corresponding to transformer.
logger.info(f"Loading {prefix}.")
transformer.load_lora_adapter(
state_dict,
network_alphas=None,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
prefix=prefix,
)
@classmethod
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
def save_lora_weights(
cls,
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
"""
lora_layers = {}
lora_metadata = {}
if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
if not lora_layers:
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
cls._save_lora_weights(
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
def fuse_lora(
self,
components: List[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
"""
super().fuse_lora(
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
super().unfuse_lora(components=components, **kwargs)
class SanaLoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into [`SanaTransformer2DModel`]. Specific to [`SanaPipeline`].

View File

@@ -67,6 +67,8 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
"QwenImageTransformer2DModel": lambda model_cls, weights: weights,
"Flux2Transformer2DModel": lambda model_cls, weights: weights,
"ZImageTransformer2DModel": lambda model_cls, weights: weights,
"LTX2VideoTransformer3DModel": lambda model_cls, weights: weights,
"LTX2TextConnectors": lambda model_cls, weights: weights,
}

View File

@@ -5,6 +5,7 @@ import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...models.attention import FeedForward
from ...models.modeling_utils import ModelMixin
from ...models.transformers.transformer_ltx2 import LTX2Attention, LTX2AudioVideoAttnProcessor
@@ -252,7 +253,7 @@ class LTX2ConnectorTransformer1d(nn.Module):
return hidden_states, attention_mask
class LTX2TextConnectors(ModelMixin, ConfigMixin):
class LTX2TextConnectors(ModelMixin, PeftAdapterMixin, ConfigMixin):
"""
Text connector stack used by LTX 2.0 to process the packed text encoder hidden states for both the video and audio
streams.

View File

@@ -21,7 +21,7 @@ import torch
from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin
from ...loaders import FromSingleFileMixin, LTX2LoraLoaderMixin
from ...models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video
from ...models.transformers import LTX2VideoTransformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -184,7 +184,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
return noise_cfg
class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin):
class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
r"""
Pipeline for text-to-video generation.

View File

@@ -0,0 +1,293 @@
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import torch
from transformers import AutoTokenizer, Gemma3ForConditionalGeneration
from diffusers import (
AutoencoderKLLTX2Audio,
AutoencoderKLLTX2Video,
FlowMatchEulerDiscreteScheduler,
LTX2Pipeline,
LTX2VideoTransformer3DModel,
)
from diffusers.pipelines.ltx2 import LTX2TextConnectors
from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
from diffusers.utils.import_utils import is_peft_available
from ..testing_utils import floats_tensor, require_peft_backend
if is_peft_available():
from peft import LoraConfig
sys.path.append(".")
from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
class LTX2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = LTX2Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {}
transformer_kwargs = {
"in_channels": 4,
"out_channels": 4,
"patch_size": 1,
"patch_size_t": 1,
"num_attention_heads": 2,
"attention_head_dim": 8,
"cross_attention_dim": 16,
"audio_in_channels": 4,
"audio_out_channels": 4,
"audio_num_attention_heads": 2,
"audio_attention_head_dim": 4,
"audio_cross_attention_dim": 8,
"num_layers": 1,
"qk_norm": "rms_norm_across_heads",
"caption_channels": 32,
"rope_double_precision": False,
"rope_type": "split",
}
transformer_cls = LTX2VideoTransformer3DModel
vae_kwargs = {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 4,
"block_out_channels": (8,),
"decoder_block_out_channels": (8,),
"layers_per_block": (1,),
"decoder_layers_per_block": (1, 1),
"spatio_temporal_scaling": (True,),
"decoder_spatio_temporal_scaling": (True,),
"decoder_inject_noise": (False, False),
"downsample_type": ("spatial",),
"upsample_residual": (False,),
"upsample_factor": (1,),
"timestep_conditioning": False,
"patch_size": 1,
"patch_size_t": 1,
"encoder_causal": True,
"decoder_causal": False,
}
vae_cls = AutoencoderKLLTX2Video
audio_vae_kwargs = {
"base_channels": 4,
"output_channels": 2,
"ch_mult": (1,),
"num_res_blocks": 1,
"attn_resolutions": None,
"in_channels": 2,
"resolution": 32,
"latent_channels": 2,
"norm_type": "pixel",
"causality_axis": "height",
"dropout": 0.0,
"mid_block_add_attention": False,
"sample_rate": 16000,
"mel_hop_length": 160,
"is_causal": True,
"mel_bins": 8,
}
audio_vae_cls = AutoencoderKLLTX2Audio
vocoder_kwargs = {
"in_channels": 16, # output_channels * mel_bins = 2 * 8
"hidden_channels": 32,
"out_channels": 2,
"upsample_kernel_sizes": [4, 4],
"upsample_factors": [2, 2],
"resnet_kernel_sizes": [3],
"resnet_dilations": [[1, 3, 5]],
"leaky_relu_negative_slope": 0.1,
"output_sampling_rate": 16000,
}
vocoder_cls = LTX2Vocoder
connectors_kwargs = {
"caption_channels": 32, # Will be set dynamically from text_encoder
"text_proj_in_factor": 2, # Will be set dynamically from text_encoder
"video_connector_num_attention_heads": 4,
"video_connector_attention_head_dim": 8,
"video_connector_num_layers": 1,
"video_connector_num_learnable_registers": None,
"audio_connector_num_attention_heads": 4,
"audio_connector_attention_head_dim": 8,
"audio_connector_num_layers": 1,
"audio_connector_num_learnable_registers": None,
"connector_rope_base_seq_len": 32,
"rope_theta": 10000.0,
"rope_double_precision": False,
"causal_temporal_positioning": False,
"rope_type": "split",
}
connectors_cls = LTX2TextConnectors
tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-gemma3"
text_encoder_cls, text_encoder_id = (
Gemma3ForConditionalGeneration,
"hf-internal-testing/tiny-gemma3",
)
denoiser_target_modules = ["to_q", "to_k", "to_out.0"]
@property
def output_shape(self):
return (1, 5, 32, 32, 3)
def get_dummy_inputs(self, with_generator=True):
batch_size = 1
sequence_length = 16
num_channels = 4
num_frames = 5
num_latent_frames = 2
latent_height = 8
latent_width = 8
generator = torch.manual_seed(0)
noise = floats_tensor((batch_size, num_latent_frames, num_channels, latent_height, latent_width))
input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
pipeline_inputs = {
"prompt": "a robot dancing",
"num_frames": num_frames,
"num_inference_steps": 2,
"guidance_scale": 1.0,
"height": 32,
"width": 32,
"frame_rate": 25.0,
"max_sequence_length": sequence_length,
"output_type": "np",
}
if with_generator:
pipeline_inputs.update({"generator": generator})
return noise, input_ids, pipeline_inputs
def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None):
# Override to instantiate LTX2-specific components (connectors, audio_vae, vocoder)
torch.manual_seed(0)
text_encoder = self.text_encoder_cls.from_pretrained(self.text_encoder_id)
tokenizer = self.tokenizer_cls.from_pretrained(self.tokenizer_id)
# Update caption_channels and text_proj_in_factor based on text_encoder config
transformer_kwargs = self.transformer_kwargs.copy()
transformer_kwargs["caption_channels"] = text_encoder.config.text_config.hidden_size
connectors_kwargs = self.connectors_kwargs.copy()
connectors_kwargs["caption_channels"] = text_encoder.config.text_config.hidden_size
connectors_kwargs["text_proj_in_factor"] = text_encoder.config.text_config.num_hidden_layers + 1
torch.manual_seed(0)
transformer = self.transformer_cls(**transformer_kwargs)
torch.manual_seed(0)
vae = self.vae_cls(**self.vae_kwargs)
vae.use_framewise_encoding = False
vae.use_framewise_decoding = False
torch.manual_seed(0)
audio_vae = self.audio_vae_cls(**self.audio_vae_kwargs)
torch.manual_seed(0)
vocoder = self.vocoder_cls(**self.vocoder_kwargs)
torch.manual_seed(0)
connectors = self.connectors_cls(**connectors_kwargs)
if scheduler_cls is None:
scheduler_cls = self.scheduler_cls
scheduler = scheduler_cls(**self.scheduler_kwargs)
rank = 4
lora_alpha = rank if lora_alpha is None else lora_alpha
text_lora_config = LoraConfig(
r=rank,
lora_alpha=lora_alpha,
target_modules=self.text_encoder_target_modules,
init_lora_weights=False,
use_dora=use_dora,
)
denoiser_lora_config = LoraConfig(
r=rank,
lora_alpha=lora_alpha,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=use_dora,
)
pipeline_components = {
"transformer": transformer,
"vae": vae,
"audio_vae": audio_vae,
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"connectors": connectors,
"vocoder": vocoder,
}
return pipeline_components, text_lora_config, denoiser_lora_config
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
def test_simple_inference_with_text_denoiser_lora_unfused(self):
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
@unittest.skip("Not supported in LTX2.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
@unittest.skip("Not supported in LTX2.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@unittest.skip("Not supported in LTX2.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in LTX2.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in LTX2.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in LTX2.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in LTX2.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in LTX2.")
def test_simple_inference_with_text_lora_save_load(self):
pass
@unittest.skip("Text encoder LoRA is not supported in LTX2.")
def test_simple_inference_save_pretrained_with_text_lora(self):
pass