mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add new LTX 2.0 spatial latent upsampler logic
This commit is contained in:
@@ -11,9 +11,8 @@ from huggingface_hub import hf_hub_download
|
||||
from transformers import AutoModel, AutoTokenizer, Gemma3ForConditionalGeneration
|
||||
|
||||
from diffusers import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, FlowMatchEulerDiscreteScheduler, LTX2Pipeline, LTX2VideoTransformer3DModel
|
||||
from diffusers.pipelines.ltx2 import LTX2TextConnectors, LTX2Vocoder
|
||||
from diffusers.pipelines.ltx2 import LTX2LatentUpsamplerModel, LTX2TextConnectors, LTX2Vocoder
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
from diffusers.pipelines.ltx.modeling_latent_upsampler import LTXLatentUpsamplerModel
|
||||
|
||||
|
||||
CTX = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
@@ -580,11 +579,12 @@ def get_ltx2_spatial_latent_upsampler_config(version: str):
|
||||
if version == "2.0":
|
||||
config = {
|
||||
"in_channels": 128,
|
||||
"mid_channels": 512,
|
||||
"mid_channels": 1024,
|
||||
"num_blocks_per_stage": 4,
|
||||
"dims": 3,
|
||||
"spatial_upsample": True,
|
||||
"temporal_upsample": False,
|
||||
"rational_spatial_scale": 2.0,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unsupported version: {version}")
|
||||
@@ -595,7 +595,7 @@ def convert_ltx2_spatial_latent_upsampler(
|
||||
original_state_dict: Dict[str, Any], config: Dict[str, Any], dtype: torch.dtype
|
||||
):
|
||||
with init_empty_weights():
|
||||
latent_upsampler = LTXLatentUpsamplerModel(**config)
|
||||
latent_upsampler = LTX2LatentUpsamplerModel(**config)
|
||||
|
||||
latent_upsampler.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
latent_upsampler.to(dtype)
|
||||
@@ -708,7 +708,10 @@ def get_args():
|
||||
help="HF Hub id for the LTX 2.0 text tokenizer",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--latent_upsampler_filename", default=None, type=str, help="Latent upsampler filename"
|
||||
"--latent_upsampler_filename",
|
||||
default="rc1/ltx-2-spatial-upscaler-x2-1.0-rc1.safetensors",
|
||||
type=str,
|
||||
help="Latent upsampler filename",
|
||||
)
|
||||
|
||||
parser.add_argument("--vae", action="store_true", help="Whether to convert the video VAE model")
|
||||
@@ -818,7 +821,9 @@ def main(args):
|
||||
tokenizer.save_pretrained(os.path.join(args.output_path, "tokenizer"))
|
||||
|
||||
if args.latent_upsampler:
|
||||
original_latent_upsampler_ckpt = load_hub_or_local_checkpoint(filename=args.latent_upsampler_filename)
|
||||
original_latent_upsampler_ckpt = load_hub_or_local_checkpoint(
|
||||
repo_id=args.original_state_dict_repo_id, filename=args.latent_upsampler_filename
|
||||
)
|
||||
latent_upsampler_config = get_ltx2_spatial_latent_upsampler_config(args.version)
|
||||
latent_upsampler = convert_ltx2_spatial_latent_upsampler(
|
||||
original_latent_upsampler_ckpt, latent_upsampler_config, dtype=vae_dtype,
|
||||
|
||||
Reference in New Issue
Block a user