mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Single File] Add GGUF support for LTX (#10298)
* update * add docs. --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -61,6 +61,45 @@ pipe = LTXImageToVideoPipeline.from_single_file(
|
||||
)
|
||||
```
|
||||
|
||||
Loading [LTX GGUF checkpoints](https://huggingface.co/city96/LTX-Video-gguf) are also supported:
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers.utils import export_to_video
|
||||
from diffusers import LTXPipeline, LTXVideoTransformer3DModel, GGUFQuantizationConfig
|
||||
|
||||
ckpt_path = (
|
||||
"https://huggingface.co/city96/LTX-Video-gguf/blob/main/ltx-video-2b-v0.9-Q3_K_S.gguf"
|
||||
)
|
||||
transformer = LTXVideoTransformer3DModel.from_single_file(
|
||||
ckpt_path,
|
||||
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
pipe = LTXPipeline.from_pretrained(
|
||||
"Lightricks/LTX-Video",
|
||||
transformer=transformer,
|
||||
generator=torch.manual_seed(0),
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage"
|
||||
negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
|
||||
|
||||
video = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
width=704,
|
||||
height=480,
|
||||
num_frames=161,
|
||||
num_inference_steps=50,
|
||||
).frames[0]
|
||||
export_to_video(video, "output_gguf_ltx.mp4", fps=24)
|
||||
```
|
||||
|
||||
Make sure to read the [documentation on GGUF](../../quantization/gguf) to learn more about our GGUF support.
|
||||
|
||||
Refer to [this section](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox#memory-optimization) to learn more about optimizing memory consumption.
|
||||
|
||||
## LTXPipeline
|
||||
|
||||
@@ -99,10 +99,11 @@ CHECKPOINT_KEY_NAMES = {
|
||||
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
|
||||
],
|
||||
"ltx-video": [
|
||||
(
|
||||
"model.diffusion_model.patchify_proj.weight",
|
||||
"model.diffusion_model.transformer_blocks.27.scale_shift_table",
|
||||
),
|
||||
"model.diffusion_model.patchify_proj.weight",
|
||||
"model.diffusion_model.transformer_blocks.27.scale_shift_table",
|
||||
"patchify_proj.weight",
|
||||
"transformer_blocks.27.scale_shift_table",
|
||||
"vae.per_channel_statistics.mean-of-means",
|
||||
],
|
||||
"autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias",
|
||||
"autoencoder-dc-sana": "encoder.project_in.conv.bias",
|
||||
@@ -601,7 +602,7 @@ def infer_diffusers_model_type(checkpoint):
|
||||
else:
|
||||
model_type = "flux-schnell"
|
||||
|
||||
elif any(all(key in checkpoint for key in key_list) for key_list in CHECKPOINT_KEY_NAMES["ltx-video"]):
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["ltx-video"]):
|
||||
model_type = "ltx-video"
|
||||
|
||||
elif CHECKPOINT_KEY_NAMES["autoencoder-dc"] in checkpoint:
|
||||
@@ -2266,9 +2267,7 @@ def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
|
||||
|
||||
def convert_ltx_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
converted_state_dict = {
|
||||
key: checkpoint.pop(key) for key in list(checkpoint.keys()) if "model.diffusion_model." in key
|
||||
}
|
||||
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys()) if "vae" not in key}
|
||||
|
||||
TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
"model.diffusion_model.": "",
|
||||
|
||||
Reference in New Issue
Block a user