1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Add AuraFlow GGUF support (#10463)

* Add support for loading AuraFlow models from GGUF

https://huggingface.co/city96/AuraFlow-v0.3-gguf

* Update AuraFlow documentation for GGUF, add GGUF tests and model detection.

* Address code review comments.

* Remove unused config.

---------

Co-authored-by: hlky <hlky@hlky.ac>
This commit is contained in:
AstraliteHeart
2025-01-07 23:53:12 -08:00
committed by GitHub
parent 80fd9260bb
commit cb342b745a
6 changed files with 218 additions and 3 deletions

View File

@@ -62,6 +62,33 @@ image = pipeline(prompt).images[0]
image.save("auraflow.png")
```
Loading [GGUF checkpoints](https://huggingface.co/docs/diffusers/quantization/gguf) are also supported:
```py
import torch
from diffusers import (
AuraFlowPipeline,
GGUFQuantizationConfig,
AuraFlowTransformer2DModel,
)
transformer = AuraFlowTransformer2DModel.from_single_file(
"https://huggingface.co/city96/AuraFlow-v0.3-gguf/blob/main/aura_flow_0.3-Q2_K.gguf",
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
torch_dtype=torch.bfloat16,
)
pipeline = AuraFlowPipeline.from_pretrained(
"fal/AuraFlow-v0.3",
transformer=transformer,
torch_dtype=torch.bfloat16,
)
prompt = "a cute pony in a field of flowers"
image = pipeline(prompt).images[0]
image.save("auraflow.png")
```
## AuraFlowPipeline
[[autodoc]] AuraFlowPipeline

View File

@@ -25,6 +25,7 @@ from ..utils import deprecate, is_accelerate_available, logging
from .single_file_utils import (
SingleFileComponentError,
convert_animatediff_checkpoint_to_diffusers,
convert_auraflow_transformer_checkpoint_to_diffusers,
convert_autoencoder_dc_checkpoint_to_diffusers,
convert_controlnet_checkpoint,
convert_flux_transformer_checkpoint_to_diffusers,
@@ -106,6 +107,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
"checkpoint_mapping_fn": convert_hunyuan_video_transformer_to_diffusers,
"default_subfolder": "transformer",
},
"AuraFlowTransformer2DModel": {
"checkpoint_mapping_fn": convert_auraflow_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
}

View File

@@ -94,6 +94,12 @@ CHECKPOINT_KEY_NAMES = {
"animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
"animatediff_scribble": "controlnet_cond_embedding.conv_in.weight",
"animatediff_rgb": "controlnet_cond_embedding.weight",
"auraflow": [
"double_layers.0.attn.w2q.weight",
"double_layers.0.attn.w1q.weight",
"cond_seq_linear.weight",
"t_embedder.mlp.0.weight",
],
"flux": [
"double_blocks.0.img_attn.norm.key_norm.scale",
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
@@ -154,6 +160,7 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
"animatediff_sdxl_beta": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-sdxl-beta"},
"animatediff_scribble": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-scribble"},
"animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"},
"auraflow": {"pretrained_model_name_or_path": "fal/AuraFlow-v0.3"},
"flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"},
"flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"},
"flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"},
@@ -635,6 +642,9 @@ def infer_diffusers_model_type(checkpoint):
elif CHECKPOINT_KEY_NAMES["hunyuan-video"] in checkpoint:
model_type = "hunyuan-video"
elif all(key in checkpoint for key in CHECKPOINT_KEY_NAMES["auraflow"]):
model_type = "auraflow"
elif (
CHECKPOINT_KEY_NAMES["instruct-pix2pix"] in checkpoint
and checkpoint[CHECKPOINT_KEY_NAMES["instruct-pix2pix"]].shape[1] == 8
@@ -2090,6 +2100,7 @@ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs):
def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {}
keys = list(checkpoint.keys())
for k in keys:
if "model.diffusion_model." in k:
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
@@ -2689,3 +2700,95 @@ def convert_hunyuan_video_transformer_to_diffusers(checkpoint, **kwargs):
handler_fn_inplace(key, checkpoint)
return checkpoint
def convert_auraflow_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {}
state_dict_keys = list(checkpoint.keys())
# Handle register tokens and positional embeddings
converted_state_dict["register_tokens"] = checkpoint.pop("register_tokens", None)
# Handle time step projection
converted_state_dict["time_step_proj.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight", None)
converted_state_dict["time_step_proj.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias", None)
converted_state_dict["time_step_proj.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight", None)
converted_state_dict["time_step_proj.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias", None)
# Handle context embedder
converted_state_dict["context_embedder.weight"] = checkpoint.pop("cond_seq_linear.weight", None)
# Calculate the number of layers
def calculate_layers(keys, key_prefix):
layers = set()
for k in keys:
if key_prefix in k:
layer_num = int(k.split(".")[1]) # get the layer number
layers.add(layer_num)
return len(layers)
mmdit_layers = calculate_layers(state_dict_keys, key_prefix="double_layers")
single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers")
# MMDiT blocks
for i in range(mmdit_layers):
# Feed-forward
path_mapping = {"mlpX": "ff", "mlpC": "ff_context"}
weight_mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"}
for orig_k, diffuser_k in path_mapping.items():
for k, v in weight_mapping.items():
converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.{v}.weight"] = checkpoint.pop(
f"double_layers.{i}.{orig_k}.{k}.weight", None
)
# Norms
path_mapping = {"modX": "norm1", "modC": "norm1_context"}
for orig_k, diffuser_k in path_mapping.items():
converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.linear.weight"] = checkpoint.pop(
f"double_layers.{i}.{orig_k}.1.weight", None
)
# Attentions
x_attn_mapping = {"w2q": "to_q", "w2k": "to_k", "w2v": "to_v", "w2o": "to_out.0"}
context_attn_mapping = {"w1q": "add_q_proj", "w1k": "add_k_proj", "w1v": "add_v_proj", "w1o": "to_add_out"}
for attn_mapping in [x_attn_mapping, context_attn_mapping]:
for k, v in attn_mapping.items():
converted_state_dict[f"joint_transformer_blocks.{i}.attn.{v}.weight"] = checkpoint.pop(
f"double_layers.{i}.attn.{k}.weight", None
)
# Single-DiT blocks
for i in range(single_dit_layers):
# Feed-forward
mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"}
for k, v in mapping.items():
converted_state_dict[f"single_transformer_blocks.{i}.ff.{v}.weight"] = checkpoint.pop(
f"single_layers.{i}.mlp.{k}.weight", None
)
# Norms
converted_state_dict[f"single_transformer_blocks.{i}.norm1.linear.weight"] = checkpoint.pop(
f"single_layers.{i}.modCX.1.weight", None
)
# Attentions
x_attn_mapping = {"w1q": "to_q", "w1k": "to_k", "w1v": "to_v", "w1o": "to_out.0"}
for k, v in x_attn_mapping.items():
converted_state_dict[f"single_transformer_blocks.{i}.attn.{v}.weight"] = checkpoint.pop(
f"single_layers.{i}.attn.{k}.weight", None
)
# Final blocks
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_linear.weight", None)
# Handle the final norm layer
norm_weight = checkpoint.pop("modF.1.weight", None)
if norm_weight is not None:
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(norm_weight, dim=None)
else:
converted_state_dict["norm_out.linear.weight"] = None
converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("positional_encoding")
converted_state_dict["pos_embed.proj.weight"] = checkpoint.pop("init_x_linear.weight")
converted_state_dict["pos_embed.proj.bias"] = checkpoint.pop("init_x_linear.bias")
return converted_state_dict

View File

@@ -20,6 +20,7 @@ import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin
from ...utils import is_torch_version, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention_processor import (
@@ -253,7 +254,7 @@ class AuraFlowJointTransformerBlock(nn.Module):
return encoder_hidden_states, hidden_states
class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A 2D Transformer model as introduced in AuraFlow (https://blog.fal.ai/auraflow/).

View File

@@ -450,7 +450,7 @@ class GGUFLinear(nn.Linear):
def forward(self, inputs):
weight = dequantize_gguf_tensor(self.weight)
weight = weight.to(self.compute_dtype)
bias = self.bias.to(self.compute_dtype)
bias = self.bias.to(self.compute_dtype) if self.bias is not None else None
output = torch.nn.functional.linear(inputs, weight, bias)
return output

View File

@@ -6,6 +6,8 @@ import torch
import torch.nn as nn
from diffusers import (
AuraFlowPipeline,
AuraFlowTransformer2DModel,
FluxPipeline,
FluxTransformer2DModel,
GGUFQuantizationConfig,
@@ -54,7 +56,8 @@ class GGUFSingleFileTesterMixin:
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear) and hasattr(module.weight, "quant_type"):
assert module.weight.dtype == torch.uint8
assert module.bias.dtype == torch.float32
if module.bias is not None:
assert module.bias.dtype == torch.float32
def test_gguf_memory_usage(self):
quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
@@ -377,3 +380,79 @@ class SD35MediumGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase
)
max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice)
assert max_diff < 1e-4
class AuraFlowGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
ckpt_path = "https://huggingface.co/city96/AuraFlow-v0.3-gguf/blob/main/aura_flow_0.3-Q2_K.gguf"
torch_dtype = torch.bfloat16
model_cls = AuraFlowTransformer2DModel
expected_memory_use_in_gb = 4
def setUp(self):
gc.collect()
torch.cuda.empty_cache()
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
def get_dummy_inputs(self):
return {
"hidden_states": torch.randn((1, 4, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
torch_device, self.torch_dtype
),
"encoder_hidden_states": torch.randn(
(1, 512, 2048),
generator=torch.Generator("cpu").manual_seed(0),
).to(torch_device, self.torch_dtype),
"timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
}
def test_pipeline_inference(self):
quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
transformer = self.model_cls.from_single_file(
self.ckpt_path, quantization_config=quantization_config, torch_dtype=self.torch_dtype
)
pipe = AuraFlowPipeline.from_pretrained(
"fal/AuraFlow-v0.3", transformer=transformer, torch_dtype=self.torch_dtype
)
pipe.enable_model_cpu_offload()
prompt = "a pony holding a sign that says hello"
output = pipe(
prompt=prompt, num_inference_steps=2, generator=torch.Generator("cpu").manual_seed(0), output_type="np"
).images[0]
output_slice = output[:3, :3, :].flatten()
expected_slice = np.array(
[
0.46484375,
0.546875,
0.64453125,
0.48242188,
0.53515625,
0.59765625,
0.47070312,
0.5078125,
0.5703125,
0.42773438,
0.50390625,
0.5703125,
0.47070312,
0.515625,
0.57421875,
0.45898438,
0.48632812,
0.53515625,
0.4453125,
0.5078125,
0.56640625,
0.47851562,
0.5234375,
0.57421875,
0.48632812,
0.5234375,
0.56640625,
]
)
max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice)
assert max_diff < 1e-4