1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

LTX 2.0 scheduler and full pipeline conversion

This commit is contained in:
Daniel Gu
2025-12-23 07:41:28 +01:00
parent cbb10b8dca
commit 595f485ad8
2 changed files with 32 additions and 7 deletions

View File

@@ -7,9 +7,15 @@ import safetensors.torch
import torch
from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download
from transformers import AutoModel, AutoProcessor
from transformers import AutoModel, AutoTokenizer
from diffusers import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, LTX2VideoTransformer3DModel
from diffusers import (
AutoencoderKLLTX2Audio,
AutoencoderKLLTX2Video,
FlowMatchEulerDiscreteScheduler,
LTX2Pipeline,
LTX2VideoTransformer3DModel,
)
from diffusers.utils.import_utils import is_accelerate_available
from diffusers.pipelines.ltx2.text_encoder import LTX2AudioVisualTextEncoder
from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
@@ -721,12 +727,31 @@ def main(args):
if not args.full_pipeline:
text_encoder.to(text_encoder_dtype).save_pretrained(os.path.join(args.output_path, "text_encoder"))
tokenizer = AutoProcessor.from_pretrained(args.tokenizer_id)
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_id)
if not args.full_pipeline:
tokenizer.save_pretrained(os.path.join(args.output_path, "tokenizer"))
if args.full_pipeline:
pass
scheduler = FlowMatchEulerDiscreteScheduler(
use_dynamic_shifting=True,
base_shift=0.95,
max_shift=2.05,
base_image_seq_len=1024,
max_image_seq_len=4096,
shift_terminal=0.1,
)
pipe = LTX2Pipeline(
scheduler=scheduler,
vae=vae,
audio_vae=audio_vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
vocoder=vocoder,
)
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
if __name__ == '__main__':

View File

@@ -883,10 +883,10 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
mu = calculate_shift(
video_sequence_length,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("base_image_seq_len", 1024),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.15),
self.scheduler.config.get("base_shift", 0.95),
self.scheduler.config.get("max_shift", 2.05),
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,