1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Single File] Add GGUF support for LTX (#10298)

* update

* add docs.

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
Dhruv Nair
2024-12-20 11:52:29 +05:30
committed by GitHub
parent 151b74cd77
commit dbc1d505f0
2 changed files with 46 additions and 8 deletions

View File

@@ -61,6 +61,45 @@ pipe = LTXImageToVideoPipeline.from_single_file(
)
```
Loading [LTX GGUF checkpoints](https://huggingface.co/city96/LTX-Video-gguf) are also supported:
```py
import torch
from diffusers.utils import export_to_video
from diffusers import LTXPipeline, LTXVideoTransformer3DModel, GGUFQuantizationConfig
ckpt_path = (
"https://huggingface.co/city96/LTX-Video-gguf/blob/main/ltx-video-2b-v0.9-Q3_K_S.gguf"
)
transformer = LTXVideoTransformer3DModel.from_single_file(
ckpt_path,
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
torch_dtype=torch.bfloat16,
)
pipe = LTXPipeline.from_pretrained(
"Lightricks/LTX-Video",
transformer=transformer,
generator=torch.manual_seed(0),
torch_dtype=torch.bfloat16,
)
pipe.enable_model_cpu_offload()
prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage"
negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
video = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=704,
height=480,
num_frames=161,
num_inference_steps=50,
).frames[0]
export_to_video(video, "output_gguf_ltx.mp4", fps=24)
```
Make sure to read the [documentation on GGUF](../../quantization/gguf) to learn more about our GGUF support.
Refer to [this section](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox#memory-optimization) to learn more about optimizing memory consumption.
## LTXPipeline

View File

@@ -99,10 +99,11 @@ CHECKPOINT_KEY_NAMES = {
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
],
"ltx-video": [
(
"model.diffusion_model.patchify_proj.weight",
"model.diffusion_model.transformer_blocks.27.scale_shift_table",
),
"model.diffusion_model.patchify_proj.weight",
"model.diffusion_model.transformer_blocks.27.scale_shift_table",
"patchify_proj.weight",
"transformer_blocks.27.scale_shift_table",
"vae.per_channel_statistics.mean-of-means",
],
"autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias",
"autoencoder-dc-sana": "encoder.project_in.conv.bias",
@@ -601,7 +602,7 @@ def infer_diffusers_model_type(checkpoint):
else:
model_type = "flux-schnell"
elif any(all(key in checkpoint for key in key_list) for key_list in CHECKPOINT_KEY_NAMES["ltx-video"]):
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["ltx-video"]):
model_type = "ltx-video"
elif CHECKPOINT_KEY_NAMES["autoencoder-dc"] in checkpoint:
@@ -2266,9 +2267,7 @@ def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
def convert_ltx_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {
key: checkpoint.pop(key) for key in list(checkpoint.keys()) if "model.diffusion_model." in key
}
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys()) if "vae" not in key}
TRANSFORMER_KEYS_RENAME_DICT = {
"model.diffusion_model.": "",