1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[docs] diffusers gguf checkpoints (#12092)

* feat: support loading diffusers format gguf checkpoints.

* update

* update

* qwen

* up

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* up

---------

Co-authored-by: DN6 <dhruv.nair@gmail.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
This commit is contained in:
Sayak Paul
2025-08-09 08:49:49 +05:30
committed by GitHub
parent f20aba3e87
commit 03c3f69aa5

View File

@@ -77,3 +77,44 @@ Once installed, set `DIFFUSERS_GGUF_CUDA_KERNELS=true` to use optimized kernels
- Q5_K
- Q6_K
## Convert to GGUF
Use the Space below to convert a Diffusers checkpoint into the GGUF format for inference.
run conversion:
<iframe
src="https://diffusers-internal-dev-diffusers-to-gguf.hf.space"
frameborder="0"
width="850"
height="450"
></iframe>
```py
import torch
from diffusers import FluxPipeline, FluxTransformer2DModel, GGUFQuantizationConfig
ckpt_path = (
"https://huggingface.co/sayakpaul/different-lora-from-civitai/blob/main/flux_dev_diffusers-q4_0.gguf"
)
transformer = FluxTransformer2DModel.from_single_file(
ckpt_path,
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
config="black-forest-labs/FLUX.1-dev",
subfolder="transformer",
torch_dtype=torch.bfloat16,
)
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
transformer=transformer,
torch_dtype=torch.bfloat16,
)
pipe.enable_model_cpu_offload()
prompt = "A cat holding a sign that says hello world"
image = pipe(prompt, generator=torch.manual_seed(0)).images[0]
image.save("flux-gguf.png")
```
When using Diffusers format GGUF checkpoints, it's a must to provide the model `config` path. If the
model config resides in a `subfolder`, that needs to be specified, too.