1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
Files
diffusers/tests/lora/test_lora_layers_ltx2.py
dg845 f1a93c765f Add Flag to PeftLoraLoaderMixinTests to Enable/Disable Text Encoder LoRA Tests (#12962)
* Improve incorrect LoRA format error message

* Add flag in PeftLoraLoaderMixinTests to disable text encoder LoRA tests

* Apply changes to LTX2LoraTests

* Further improve incorrect LoRA format error msg following review

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-01-12 16:01:58 -08:00

272 lines
9.0 KiB
Python

# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import torch
from transformers import AutoTokenizer, Gemma3ForConditionalGeneration
from diffusers import (
AutoencoderKLLTX2Audio,
AutoencoderKLLTX2Video,
FlowMatchEulerDiscreteScheduler,
LTX2Pipeline,
LTX2VideoTransformer3DModel,
)
from diffusers.pipelines.ltx2 import LTX2TextConnectors
from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
from diffusers.utils.import_utils import is_peft_available
from ..testing_utils import floats_tensor, require_peft_backend
if is_peft_available():
from peft import LoraConfig
sys.path.append(".")
from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
class LTX2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = LTX2Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {}
transformer_kwargs = {
"in_channels": 4,
"out_channels": 4,
"patch_size": 1,
"patch_size_t": 1,
"num_attention_heads": 2,
"attention_head_dim": 8,
"cross_attention_dim": 16,
"audio_in_channels": 4,
"audio_out_channels": 4,
"audio_num_attention_heads": 2,
"audio_attention_head_dim": 4,
"audio_cross_attention_dim": 8,
"num_layers": 1,
"qk_norm": "rms_norm_across_heads",
"caption_channels": 32,
"rope_double_precision": False,
"rope_type": "split",
}
transformer_cls = LTX2VideoTransformer3DModel
vae_kwargs = {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 4,
"block_out_channels": (8,),
"decoder_block_out_channels": (8,),
"layers_per_block": (1,),
"decoder_layers_per_block": (1, 1),
"spatio_temporal_scaling": (True,),
"decoder_spatio_temporal_scaling": (True,),
"decoder_inject_noise": (False, False),
"downsample_type": ("spatial",),
"upsample_residual": (False,),
"upsample_factor": (1,),
"timestep_conditioning": False,
"patch_size": 1,
"patch_size_t": 1,
"encoder_causal": True,
"decoder_causal": False,
}
vae_cls = AutoencoderKLLTX2Video
audio_vae_kwargs = {
"base_channels": 4,
"output_channels": 2,
"ch_mult": (1,),
"num_res_blocks": 1,
"attn_resolutions": None,
"in_channels": 2,
"resolution": 32,
"latent_channels": 2,
"norm_type": "pixel",
"causality_axis": "height",
"dropout": 0.0,
"mid_block_add_attention": False,
"sample_rate": 16000,
"mel_hop_length": 160,
"is_causal": True,
"mel_bins": 8,
}
audio_vae_cls = AutoencoderKLLTX2Audio
vocoder_kwargs = {
"in_channels": 16, # output_channels * mel_bins = 2 * 8
"hidden_channels": 32,
"out_channels": 2,
"upsample_kernel_sizes": [4, 4],
"upsample_factors": [2, 2],
"resnet_kernel_sizes": [3],
"resnet_dilations": [[1, 3, 5]],
"leaky_relu_negative_slope": 0.1,
"output_sampling_rate": 16000,
}
vocoder_cls = LTX2Vocoder
connectors_kwargs = {
"caption_channels": 32, # Will be set dynamically from text_encoder
"text_proj_in_factor": 2, # Will be set dynamically from text_encoder
"video_connector_num_attention_heads": 4,
"video_connector_attention_head_dim": 8,
"video_connector_num_layers": 1,
"video_connector_num_learnable_registers": None,
"audio_connector_num_attention_heads": 4,
"audio_connector_attention_head_dim": 8,
"audio_connector_num_layers": 1,
"audio_connector_num_learnable_registers": None,
"connector_rope_base_seq_len": 32,
"rope_theta": 10000.0,
"rope_double_precision": False,
"causal_temporal_positioning": False,
"rope_type": "split",
}
connectors_cls = LTX2TextConnectors
tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-gemma3"
text_encoder_cls, text_encoder_id = (
Gemma3ForConditionalGeneration,
"hf-internal-testing/tiny-gemma3",
)
denoiser_target_modules = ["to_q", "to_k", "to_out.0"]
supports_text_encoder_loras = False
@property
def output_shape(self):
return (1, 5, 32, 32, 3)
def get_dummy_inputs(self, with_generator=True):
batch_size = 1
sequence_length = 16
num_channels = 4
num_frames = 5
num_latent_frames = 2
latent_height = 8
latent_width = 8
generator = torch.manual_seed(0)
noise = floats_tensor((batch_size, num_latent_frames, num_channels, latent_height, latent_width))
input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
pipeline_inputs = {
"prompt": "a robot dancing",
"num_frames": num_frames,
"num_inference_steps": 2,
"guidance_scale": 1.0,
"height": 32,
"width": 32,
"frame_rate": 25.0,
"max_sequence_length": sequence_length,
"output_type": "np",
}
if with_generator:
pipeline_inputs.update({"generator": generator})
return noise, input_ids, pipeline_inputs
def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None):
# Override to instantiate LTX2-specific components (connectors, audio_vae, vocoder)
torch.manual_seed(0)
text_encoder = self.text_encoder_cls.from_pretrained(self.text_encoder_id)
tokenizer = self.tokenizer_cls.from_pretrained(self.tokenizer_id)
# Update caption_channels and text_proj_in_factor based on text_encoder config
transformer_kwargs = self.transformer_kwargs.copy()
transformer_kwargs["caption_channels"] = text_encoder.config.text_config.hidden_size
connectors_kwargs = self.connectors_kwargs.copy()
connectors_kwargs["caption_channels"] = text_encoder.config.text_config.hidden_size
connectors_kwargs["text_proj_in_factor"] = text_encoder.config.text_config.num_hidden_layers + 1
torch.manual_seed(0)
transformer = self.transformer_cls(**transformer_kwargs)
torch.manual_seed(0)
vae = self.vae_cls(**self.vae_kwargs)
vae.use_framewise_encoding = False
vae.use_framewise_decoding = False
torch.manual_seed(0)
audio_vae = self.audio_vae_cls(**self.audio_vae_kwargs)
torch.manual_seed(0)
vocoder = self.vocoder_cls(**self.vocoder_kwargs)
torch.manual_seed(0)
connectors = self.connectors_cls(**connectors_kwargs)
if scheduler_cls is None:
scheduler_cls = self.scheduler_cls
scheduler = scheduler_cls(**self.scheduler_kwargs)
rank = 4
lora_alpha = rank if lora_alpha is None else lora_alpha
text_lora_config = LoraConfig(
r=rank,
lora_alpha=lora_alpha,
target_modules=self.text_encoder_target_modules,
init_lora_weights=False,
use_dora=use_dora,
)
denoiser_lora_config = LoraConfig(
r=rank,
lora_alpha=lora_alpha,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=use_dora,
)
pipeline_components = {
"transformer": transformer,
"vae": vae,
"audio_vae": audio_vae,
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"connectors": connectors,
"vocoder": vocoder,
}
return pipeline_components, text_lora_config, denoiser_lora_config
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
def test_simple_inference_with_text_denoiser_lora_unfused(self):
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
@unittest.skip("Not supported in LTX2.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
@unittest.skip("Not supported in LTX2.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@unittest.skip("Not supported in LTX2.")
def test_modify_padding_mode(self):
pass