mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
LTX 2.0 scheduler and full pipeline conversion
This commit is contained in:
@@ -7,9 +7,15 @@ import safetensors.torch
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import AutoModel, AutoProcessor
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
from diffusers import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, LTX2VideoTransformer3DModel
|
||||
from diffusers import (
|
||||
AutoencoderKLLTX2Audio,
|
||||
AutoencoderKLLTX2Video,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
LTX2Pipeline,
|
||||
LTX2VideoTransformer3DModel,
|
||||
)
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
from diffusers.pipelines.ltx2.text_encoder import LTX2AudioVisualTextEncoder
|
||||
from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
|
||||
@@ -721,12 +727,31 @@ def main(args):
|
||||
if not args.full_pipeline:
|
||||
text_encoder.to(text_encoder_dtype).save_pretrained(os.path.join(args.output_path, "text_encoder"))
|
||||
|
||||
tokenizer = AutoProcessor.from_pretrained(args.tokenizer_id)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_id)
|
||||
if not args.full_pipeline:
|
||||
tokenizer.save_pretrained(os.path.join(args.output_path, "tokenizer"))
|
||||
|
||||
if args.full_pipeline:
|
||||
pass
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(
|
||||
use_dynamic_shifting=True,
|
||||
base_shift=0.95,
|
||||
max_shift=2.05,
|
||||
base_image_seq_len=1024,
|
||||
max_image_seq_len=4096,
|
||||
shift_terminal=0.1,
|
||||
)
|
||||
|
||||
pipe = LTX2Pipeline(
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
audio_vae=audio_vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
vocoder=vocoder,
|
||||
)
|
||||
|
||||
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -883,10 +883,10 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
||||
mu = calculate_shift(
|
||||
video_sequence_length,
|
||||
self.scheduler.config.get("base_image_seq_len", 256),
|
||||
self.scheduler.config.get("base_image_seq_len", 1024),
|
||||
self.scheduler.config.get("max_image_seq_len", 4096),
|
||||
self.scheduler.config.get("base_shift", 0.5),
|
||||
self.scheduler.config.get("max_shift", 1.15),
|
||||
self.scheduler.config.get("base_shift", 0.95),
|
||||
self.scheduler.config.get("max_shift", 2.05),
|
||||
)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
|
||||
Reference in New Issue
Block a user