mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add AuraFlow GGUF support (#10463)
* Add support for loading AuraFlow models from GGUF https://huggingface.co/city96/AuraFlow-v0.3-gguf * Update AuraFlow documentation for GGUF, add GGUF tests and model detection. * Address code review comments. * Remove unused config. --------- Co-authored-by: hlky <hlky@hlky.ac>
This commit is contained in:
@@ -62,6 +62,33 @@ image = pipeline(prompt).images[0]
|
||||
image.save("auraflow.png")
|
||||
```
|
||||
|
||||
Loading [GGUF checkpoints](https://huggingface.co/docs/diffusers/quantization/gguf) are also supported:
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import (
|
||||
AuraFlowPipeline,
|
||||
GGUFQuantizationConfig,
|
||||
AuraFlowTransformer2DModel,
|
||||
)
|
||||
|
||||
transformer = AuraFlowTransformer2DModel.from_single_file(
|
||||
"https://huggingface.co/city96/AuraFlow-v0.3-gguf/blob/main/aura_flow_0.3-Q2_K.gguf",
|
||||
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
pipeline = AuraFlowPipeline.from_pretrained(
|
||||
"fal/AuraFlow-v0.3",
|
||||
transformer=transformer,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
prompt = "a cute pony in a field of flowers"
|
||||
image = pipeline(prompt).images[0]
|
||||
image.save("auraflow.png")
|
||||
```
|
||||
|
||||
## AuraFlowPipeline
|
||||
|
||||
[[autodoc]] AuraFlowPipeline
|
||||
|
||||
@@ -25,6 +25,7 @@ from ..utils import deprecate, is_accelerate_available, logging
|
||||
from .single_file_utils import (
|
||||
SingleFileComponentError,
|
||||
convert_animatediff_checkpoint_to_diffusers,
|
||||
convert_auraflow_transformer_checkpoint_to_diffusers,
|
||||
convert_autoencoder_dc_checkpoint_to_diffusers,
|
||||
convert_controlnet_checkpoint,
|
||||
convert_flux_transformer_checkpoint_to_diffusers,
|
||||
@@ -106,6 +107,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
||||
"checkpoint_mapping_fn": convert_hunyuan_video_transformer_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"AuraFlowTransformer2DModel": {
|
||||
"checkpoint_mapping_fn": convert_auraflow_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -94,6 +94,12 @@ CHECKPOINT_KEY_NAMES = {
|
||||
"animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
|
||||
"animatediff_scribble": "controlnet_cond_embedding.conv_in.weight",
|
||||
"animatediff_rgb": "controlnet_cond_embedding.weight",
|
||||
"auraflow": [
|
||||
"double_layers.0.attn.w2q.weight",
|
||||
"double_layers.0.attn.w1q.weight",
|
||||
"cond_seq_linear.weight",
|
||||
"t_embedder.mlp.0.weight",
|
||||
],
|
||||
"flux": [
|
||||
"double_blocks.0.img_attn.norm.key_norm.scale",
|
||||
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
|
||||
@@ -154,6 +160,7 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
"animatediff_sdxl_beta": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-sdxl-beta"},
|
||||
"animatediff_scribble": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-scribble"},
|
||||
"animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"},
|
||||
"auraflow": {"pretrained_model_name_or_path": "fal/AuraFlow-v0.3"},
|
||||
"flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"},
|
||||
"flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"},
|
||||
"flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"},
|
||||
@@ -635,6 +642,9 @@ def infer_diffusers_model_type(checkpoint):
|
||||
elif CHECKPOINT_KEY_NAMES["hunyuan-video"] in checkpoint:
|
||||
model_type = "hunyuan-video"
|
||||
|
||||
elif all(key in checkpoint for key in CHECKPOINT_KEY_NAMES["auraflow"]):
|
||||
model_type = "auraflow"
|
||||
|
||||
elif (
|
||||
CHECKPOINT_KEY_NAMES["instruct-pix2pix"] in checkpoint
|
||||
and checkpoint[CHECKPOINT_KEY_NAMES["instruct-pix2pix"]].shape[1] == 8
|
||||
@@ -2090,6 +2100,7 @@ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
converted_state_dict = {}
|
||||
keys = list(checkpoint.keys())
|
||||
|
||||
for k in keys:
|
||||
if "model.diffusion_model." in k:
|
||||
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
|
||||
@@ -2689,3 +2700,95 @@ def convert_hunyuan_video_transformer_to_diffusers(checkpoint, **kwargs):
|
||||
handler_fn_inplace(key, checkpoint)
|
||||
|
||||
return checkpoint
|
||||
|
||||
|
||||
def convert_auraflow_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
converted_state_dict = {}
|
||||
state_dict_keys = list(checkpoint.keys())
|
||||
|
||||
# Handle register tokens and positional embeddings
|
||||
converted_state_dict["register_tokens"] = checkpoint.pop("register_tokens", None)
|
||||
|
||||
# Handle time step projection
|
||||
converted_state_dict["time_step_proj.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight", None)
|
||||
converted_state_dict["time_step_proj.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias", None)
|
||||
converted_state_dict["time_step_proj.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight", None)
|
||||
converted_state_dict["time_step_proj.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias", None)
|
||||
|
||||
# Handle context embedder
|
||||
converted_state_dict["context_embedder.weight"] = checkpoint.pop("cond_seq_linear.weight", None)
|
||||
|
||||
# Calculate the number of layers
|
||||
def calculate_layers(keys, key_prefix):
|
||||
layers = set()
|
||||
for k in keys:
|
||||
if key_prefix in k:
|
||||
layer_num = int(k.split(".")[1]) # get the layer number
|
||||
layers.add(layer_num)
|
||||
return len(layers)
|
||||
|
||||
mmdit_layers = calculate_layers(state_dict_keys, key_prefix="double_layers")
|
||||
single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers")
|
||||
|
||||
# MMDiT blocks
|
||||
for i in range(mmdit_layers):
|
||||
# Feed-forward
|
||||
path_mapping = {"mlpX": "ff", "mlpC": "ff_context"}
|
||||
weight_mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"}
|
||||
for orig_k, diffuser_k in path_mapping.items():
|
||||
for k, v in weight_mapping.items():
|
||||
converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.{v}.weight"] = checkpoint.pop(
|
||||
f"double_layers.{i}.{orig_k}.{k}.weight", None
|
||||
)
|
||||
|
||||
# Norms
|
||||
path_mapping = {"modX": "norm1", "modC": "norm1_context"}
|
||||
for orig_k, diffuser_k in path_mapping.items():
|
||||
converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.linear.weight"] = checkpoint.pop(
|
||||
f"double_layers.{i}.{orig_k}.1.weight", None
|
||||
)
|
||||
|
||||
# Attentions
|
||||
x_attn_mapping = {"w2q": "to_q", "w2k": "to_k", "w2v": "to_v", "w2o": "to_out.0"}
|
||||
context_attn_mapping = {"w1q": "add_q_proj", "w1k": "add_k_proj", "w1v": "add_v_proj", "w1o": "to_add_out"}
|
||||
for attn_mapping in [x_attn_mapping, context_attn_mapping]:
|
||||
for k, v in attn_mapping.items():
|
||||
converted_state_dict[f"joint_transformer_blocks.{i}.attn.{v}.weight"] = checkpoint.pop(
|
||||
f"double_layers.{i}.attn.{k}.weight", None
|
||||
)
|
||||
|
||||
# Single-DiT blocks
|
||||
for i in range(single_dit_layers):
|
||||
# Feed-forward
|
||||
mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"}
|
||||
for k, v in mapping.items():
|
||||
converted_state_dict[f"single_transformer_blocks.{i}.ff.{v}.weight"] = checkpoint.pop(
|
||||
f"single_layers.{i}.mlp.{k}.weight", None
|
||||
)
|
||||
|
||||
# Norms
|
||||
converted_state_dict[f"single_transformer_blocks.{i}.norm1.linear.weight"] = checkpoint.pop(
|
||||
f"single_layers.{i}.modCX.1.weight", None
|
||||
)
|
||||
|
||||
# Attentions
|
||||
x_attn_mapping = {"w1q": "to_q", "w1k": "to_k", "w1v": "to_v", "w1o": "to_out.0"}
|
||||
for k, v in x_attn_mapping.items():
|
||||
converted_state_dict[f"single_transformer_blocks.{i}.attn.{v}.weight"] = checkpoint.pop(
|
||||
f"single_layers.{i}.attn.{k}.weight", None
|
||||
)
|
||||
# Final blocks
|
||||
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_linear.weight", None)
|
||||
|
||||
# Handle the final norm layer
|
||||
norm_weight = checkpoint.pop("modF.1.weight", None)
|
||||
if norm_weight is not None:
|
||||
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(norm_weight, dim=None)
|
||||
else:
|
||||
converted_state_dict["norm_out.linear.weight"] = None
|
||||
|
||||
converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("positional_encoding")
|
||||
converted_state_dict["pos_embed.proj.weight"] = checkpoint.pop("init_x_linear.weight")
|
||||
converted_state_dict["pos_embed.proj.bias"] = checkpoint.pop("init_x_linear.bias")
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
@@ -20,6 +20,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin
|
||||
from ...utils import is_torch_version, logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention_processor import (
|
||||
@@ -253,7 +254,7 @@ class AuraFlowJointTransformerBlock(nn.Module):
|
||||
return encoder_hidden_states, hidden_states
|
||||
|
||||
|
||||
class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
r"""
|
||||
A 2D Transformer model as introduced in AuraFlow (https://blog.fal.ai/auraflow/).
|
||||
|
||||
|
||||
@@ -450,7 +450,7 @@ class GGUFLinear(nn.Linear):
|
||||
def forward(self, inputs):
|
||||
weight = dequantize_gguf_tensor(self.weight)
|
||||
weight = weight.to(self.compute_dtype)
|
||||
bias = self.bias.to(self.compute_dtype)
|
||||
bias = self.bias.to(self.compute_dtype) if self.bias is not None else None
|
||||
|
||||
output = torch.nn.functional.linear(inputs, weight, bias)
|
||||
return output
|
||||
|
||||
@@ -6,6 +6,8 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from diffusers import (
|
||||
AuraFlowPipeline,
|
||||
AuraFlowTransformer2DModel,
|
||||
FluxPipeline,
|
||||
FluxTransformer2DModel,
|
||||
GGUFQuantizationConfig,
|
||||
@@ -54,7 +56,8 @@ class GGUFSingleFileTesterMixin:
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, torch.nn.Linear) and hasattr(module.weight, "quant_type"):
|
||||
assert module.weight.dtype == torch.uint8
|
||||
assert module.bias.dtype == torch.float32
|
||||
if module.bias is not None:
|
||||
assert module.bias.dtype == torch.float32
|
||||
|
||||
def test_gguf_memory_usage(self):
|
||||
quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
|
||||
@@ -377,3 +380,79 @@ class SD35MediumGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase
|
||||
)
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice)
|
||||
assert max_diff < 1e-4
|
||||
|
||||
|
||||
class AuraFlowGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
|
||||
ckpt_path = "https://huggingface.co/city96/AuraFlow-v0.3-gguf/blob/main/aura_flow_0.3-Q2_K.gguf"
|
||||
torch_dtype = torch.bfloat16
|
||||
model_cls = AuraFlowTransformer2DModel
|
||||
expected_memory_use_in_gb = 4
|
||||
|
||||
def setUp(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
return {
|
||||
"hidden_states": torch.randn((1, 4, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
|
||||
torch_device, self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states": torch.randn(
|
||||
(1, 512, 2048),
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
).to(torch_device, self.torch_dtype),
|
||||
"timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
|
||||
}
|
||||
|
||||
def test_pipeline_inference(self):
|
||||
quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
|
||||
transformer = self.model_cls.from_single_file(
|
||||
self.ckpt_path, quantization_config=quantization_config, torch_dtype=self.torch_dtype
|
||||
)
|
||||
pipe = AuraFlowPipeline.from_pretrained(
|
||||
"fal/AuraFlow-v0.3", transformer=transformer, torch_dtype=self.torch_dtype
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
prompt = "a pony holding a sign that says hello"
|
||||
output = pipe(
|
||||
prompt=prompt, num_inference_steps=2, generator=torch.Generator("cpu").manual_seed(0), output_type="np"
|
||||
).images[0]
|
||||
output_slice = output[:3, :3, :].flatten()
|
||||
expected_slice = np.array(
|
||||
[
|
||||
0.46484375,
|
||||
0.546875,
|
||||
0.64453125,
|
||||
0.48242188,
|
||||
0.53515625,
|
||||
0.59765625,
|
||||
0.47070312,
|
||||
0.5078125,
|
||||
0.5703125,
|
||||
0.42773438,
|
||||
0.50390625,
|
||||
0.5703125,
|
||||
0.47070312,
|
||||
0.515625,
|
||||
0.57421875,
|
||||
0.45898438,
|
||||
0.48632812,
|
||||
0.53515625,
|
||||
0.4453125,
|
||||
0.5078125,
|
||||
0.56640625,
|
||||
0.47851562,
|
||||
0.5234375,
|
||||
0.57421875,
|
||||
0.48632812,
|
||||
0.5234375,
|
||||
0.56640625,
|
||||
]
|
||||
)
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice)
|
||||
assert max_diff < 1e-4
|
||||
|
||||
Reference in New Issue
Block a user