mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Single File] Add single file loading for SANA Transformer (#10947)
* added support for from_single_file * added diffusers mapping script * added testcase * bug fix * updated tests * corrected code quality * corrected code quality --------- Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
@@ -37,6 +37,7 @@ from .single_file_utils import (
|
||||
convert_ltx_vae_checkpoint_to_diffusers,
|
||||
convert_lumina2_to_diffusers,
|
||||
convert_mochi_transformer_checkpoint_to_diffusers,
|
||||
convert_sana_transformer_to_diffusers,
|
||||
convert_sd3_transformer_checkpoint_to_diffusers,
|
||||
convert_stable_cascade_unet_single_file_to_diffusers,
|
||||
convert_wan_transformer_to_diffusers,
|
||||
@@ -119,6 +120,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
||||
"checkpoint_mapping_fn": convert_lumina2_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"SanaTransformer2DModel": {
|
||||
"checkpoint_mapping_fn": convert_sana_transformer_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"WanTransformer3DModel": {
|
||||
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
|
||||
@@ -117,6 +117,12 @@ CHECKPOINT_KEY_NAMES = {
|
||||
"hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias",
|
||||
"instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight",
|
||||
"lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"],
|
||||
"sana": [
|
||||
"blocks.0.cross_attn.q_linear.weight",
|
||||
"blocks.0.cross_attn.q_linear.bias",
|
||||
"blocks.0.cross_attn.kv_linear.weight",
|
||||
"blocks.0.cross_attn.kv_linear.bias",
|
||||
],
|
||||
"wan": ["model.diffusion_model.head.modulation", "head.modulation"],
|
||||
"wan_vae": "decoder.middle.0.residual.0.gamma",
|
||||
}
|
||||
@@ -178,6 +184,7 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
"hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"},
|
||||
"instruct-pix2pix": {"pretrained_model_name_or_path": "timbrooks/instruct-pix2pix"},
|
||||
"lumina2": {"pretrained_model_name_or_path": "Alpha-VLLM/Lumina-Image-2.0"},
|
||||
"sana": {"pretrained_model_name_or_path": "Efficient-Large-Model/Sana_1600M_1024px_diffusers"},
|
||||
"wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"},
|
||||
"wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"},
|
||||
"wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"},
|
||||
@@ -669,6 +676,9 @@ def infer_diffusers_model_type(checkpoint):
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["lumina2"]):
|
||||
model_type = "lumina2"
|
||||
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sana"]):
|
||||
model_type = "sana"
|
||||
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["wan"]):
|
||||
if "model.diffusion_model.patch_embedding.weight" in checkpoint:
|
||||
target_key = "model.diffusion_model.patch_embedding.weight"
|
||||
@@ -2897,6 +2907,111 @@ def convert_lumina2_to_diffusers(checkpoint, **kwargs):
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def convert_sana_transformer_to_diffusers(checkpoint, **kwargs):
|
||||
converted_state_dict = {}
|
||||
keys = list(checkpoint.keys())
|
||||
for k in keys:
|
||||
if "model.diffusion_model." in k:
|
||||
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
|
||||
|
||||
num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "blocks" in k))[-1] + 1 # noqa: C401
|
||||
|
||||
# Positional and patch embeddings.
|
||||
checkpoint.pop("pos_embed")
|
||||
converted_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
|
||||
converted_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
|
||||
|
||||
# Timestep embeddings.
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = checkpoint.pop(
|
||||
"t_embedder.mlp.0.weight"
|
||||
)
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = checkpoint.pop(
|
||||
"t_embedder.mlp.2.weight"
|
||||
)
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
|
||||
converted_state_dict["time_embed.linear.weight"] = checkpoint.pop("t_block.1.weight")
|
||||
converted_state_dict["time_embed.linear.bias"] = checkpoint.pop("t_block.1.bias")
|
||||
|
||||
# Caption Projection.
|
||||
checkpoint.pop("y_embedder.y_embedding")
|
||||
converted_state_dict["caption_projection.linear_1.weight"] = checkpoint.pop("y_embedder.y_proj.fc1.weight")
|
||||
converted_state_dict["caption_projection.linear_1.bias"] = checkpoint.pop("y_embedder.y_proj.fc1.bias")
|
||||
converted_state_dict["caption_projection.linear_2.weight"] = checkpoint.pop("y_embedder.y_proj.fc2.weight")
|
||||
converted_state_dict["caption_projection.linear_2.bias"] = checkpoint.pop("y_embedder.y_proj.fc2.bias")
|
||||
converted_state_dict["caption_norm.weight"] = checkpoint.pop("attention_y_norm.weight")
|
||||
|
||||
for i in range(num_layers):
|
||||
converted_state_dict[f"transformer_blocks.{i}.scale_shift_table"] = checkpoint.pop(
|
||||
f"blocks.{i}.scale_shift_table"
|
||||
)
|
||||
|
||||
# Self-Attention
|
||||
sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"blocks.{i}.attn.qkv.weight"), 3, dim=0)
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn1.to_q.weight"] = torch.cat([sample_q])
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn1.to_k.weight"] = torch.cat([sample_k])
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn1.to_v.weight"] = torch.cat([sample_v])
|
||||
|
||||
# Output Projections
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.weight"] = checkpoint.pop(
|
||||
f"blocks.{i}.attn.proj.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.bias"] = checkpoint.pop(
|
||||
f"blocks.{i}.attn.proj.bias"
|
||||
)
|
||||
|
||||
# Cross-Attention
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = checkpoint.pop(
|
||||
f"blocks.{i}.cross_attn.q_linear.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = checkpoint.pop(
|
||||
f"blocks.{i}.cross_attn.q_linear.bias"
|
||||
)
|
||||
|
||||
linear_sample_k, linear_sample_v = torch.chunk(
|
||||
checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.weight"), 2, dim=0
|
||||
)
|
||||
linear_sample_k_bias, linear_sample_v_bias = torch.chunk(
|
||||
checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.bias"), 2, dim=0
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.weight"] = linear_sample_k
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.weight"] = linear_sample_v
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.bias"] = linear_sample_k_bias
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.bias"] = linear_sample_v_bias
|
||||
|
||||
# Output Projections
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = checkpoint.pop(
|
||||
f"blocks.{i}.cross_attn.proj.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = checkpoint.pop(
|
||||
f"blocks.{i}.cross_attn.proj.bias"
|
||||
)
|
||||
|
||||
# MLP
|
||||
converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.weight"] = checkpoint.pop(
|
||||
f"blocks.{i}.mlp.inverted_conv.conv.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.bias"] = checkpoint.pop(
|
||||
f"blocks.{i}.mlp.inverted_conv.conv.bias"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.weight"] = checkpoint.pop(
|
||||
f"blocks.{i}.mlp.depth_conv.conv.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.bias"] = checkpoint.pop(
|
||||
f"blocks.{i}.mlp.depth_conv.conv.bias"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.ff.conv_point.weight"] = checkpoint.pop(
|
||||
f"blocks.{i}.mlp.point_conv.conv.weight"
|
||||
)
|
||||
|
||||
# Final layer
|
||||
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
|
||||
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
|
||||
converted_state_dict["scale_shift_table"] = checkpoint.pop("final_layer.scale_shift_table")
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def convert_wan_transformer_to_diffusers(checkpoint, **kwargs):
|
||||
converted_state_dict = {}
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention_processor import (
|
||||
Attention,
|
||||
@@ -195,7 +195,7 @@ class SanaTransformerBlock(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
r"""
|
||||
A 2D Transformer model introduced in [Sana](https://huggingface.co/papers/2410.10629) family of models.
|
||||
|
||||
|
||||
61
tests/single_file/test_sana_transformer.py
Normal file
61
tests/single_file/test_sana_transformer.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import (
|
||||
SanaTransformer2DModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
require_torch_accelerator,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@require_torch_accelerator
|
||||
class SanaTransformer2DModelSingleFileTests(unittest.TestCase):
|
||||
model_class = SanaTransformer2DModel
|
||||
ckpt_path = (
|
||||
"https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px/blob/main/checkpoints/Sana_1600M_1024px.pth"
|
||||
)
|
||||
alternate_keys_ckpt_paths = [
|
||||
"https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px/blob/main/checkpoints/Sana_1600M_1024px.pth"
|
||||
]
|
||||
|
||||
repo_id = "Efficient-Large-Model/Sana_1600M_1024px_diffusers"
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_single_file_components(self):
|
||||
model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer")
|
||||
model_single_file = self.model_class.from_single_file(self.ckpt_path)
|
||||
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
|
||||
for param_name, param_value in model_single_file.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert (
|
||||
model.config[param_name] == param_value
|
||||
), f"{param_name} differs between single file loading and pretrained loading"
|
||||
|
||||
def test_checkpoint_loading(self):
|
||||
for ckpt_path in self.alternate_keys_ckpt_paths:
|
||||
torch.cuda.empty_cache()
|
||||
model = self.model_class.from_single_file(ckpt_path)
|
||||
|
||||
del model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
Reference in New Issue
Block a user