1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

add sana-sprint (#11074)

* add sana-sprint




---------

Co-authored-by: Junsong Chen <cjs1020440147@icloud.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Aryan <aryan@huggingface.co>
This commit is contained in:
YiYi Xu
2025-03-21 06:21:18 -10:00
committed by GitHub
parent 844221ae4e
commit 8a63aa5e4f
15 changed files with 1996 additions and 124 deletions

View File

@@ -16,7 +16,9 @@ from diffusers import (
DPMSolverMultistepScheduler,
FlowMatchEulerDiscreteScheduler,
SanaPipeline,
SanaSprintPipeline,
SanaTransformer2DModel,
SCMScheduler,
)
from diffusers.models.modeling_utils import load_model_dict_into_meta
from diffusers.utils.import_utils import is_accelerate_available
@@ -25,6 +27,7 @@ from diffusers.utils.import_utils import is_accelerate_available
CTX = init_empty_weights if is_accelerate_available else nullcontext
ckpt_ids = [
"Efficient-Large-Model/SANA1.5_4.8B_1024px/checkpoints/SANA1.5_4.8B_1024px.pth",
"Efficient-Large-Model/Sana_1600M_4Kpx_BF16/checkpoints/Sana_1600M_4Kpx_BF16.pth",
"Efficient-Large-Model/Sana_1600M_2Kpx_BF16/checkpoints/Sana_1600M_2Kpx_BF16.pth",
"Efficient-Large-Model/Sana_1600M_1024px_MultiLing/checkpoints/Sana_1600M_1024px_MultiLing.pth",
@@ -72,15 +75,42 @@ def main(args):
converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight")
converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias")
# AdaLN-single LN
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = state_dict.pop(
"t_embedder.mlp.0.weight"
)
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias")
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = state_dict.pop(
"t_embedder.mlp.2.weight"
)
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias")
# Handle different time embedding structure based on model type
if args.model_type in ["SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"]:
# For Sana Sprint, the time embedding structure is different
converted_state_dict["time_embed.timestep_embedder.linear_1.weight"] = state_dict.pop(
"t_embedder.mlp.0.weight"
)
converted_state_dict["time_embed.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias")
converted_state_dict["time_embed.timestep_embedder.linear_2.weight"] = state_dict.pop(
"t_embedder.mlp.2.weight"
)
converted_state_dict["time_embed.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias")
# Guidance embedder for Sana Sprint
converted_state_dict["time_embed.guidance_embedder.linear_1.weight"] = state_dict.pop(
"cfg_embedder.mlp.0.weight"
)
converted_state_dict["time_embed.guidance_embedder.linear_1.bias"] = state_dict.pop("cfg_embedder.mlp.0.bias")
converted_state_dict["time_embed.guidance_embedder.linear_2.weight"] = state_dict.pop(
"cfg_embedder.mlp.2.weight"
)
converted_state_dict["time_embed.guidance_embedder.linear_2.bias"] = state_dict.pop("cfg_embedder.mlp.2.bias")
else:
# Original Sana time embedding structure
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = state_dict.pop(
"t_embedder.mlp.0.weight"
)
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = state_dict.pop(
"t_embedder.mlp.0.bias"
)
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = state_dict.pop(
"t_embedder.mlp.2.weight"
)
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = state_dict.pop(
"t_embedder.mlp.2.bias"
)
# Shared norm.
converted_state_dict["time_embed.linear.weight"] = state_dict.pop("t_block.1.weight")
@@ -96,14 +126,22 @@ def main(args):
flow_shift = 3.0
# model config
if args.model_type == "SanaMS_1600M_P1_D20":
if args.model_type in ["SanaMS_1600M_P1_D20", "SanaSprint_1600M_P1_D20", "SanaMS1.5_1600M_P1_D20"]:
layer_num = 20
elif args.model_type == "SanaMS_600M_P1_D28":
elif args.model_type in ["SanaMS_600M_P1_D28", "SanaSprint_600M_P1_D28"]:
layer_num = 28
elif args.model_type == "SanaMS_4800M_P1_D60":
layer_num = 60
else:
raise ValueError(f"{args.model_type} is not supported.")
# Positional embedding interpolation scale.
interpolation_scale = {512: None, 1024: None, 2048: 1.0, 4096: 2.0}
qk_norm = (
"rms_norm_across_heads"
if args.model_type
in ["SanaMS1.5_1600M_P1_D20", "SanaMS1.5_4800M_P1_D60", "SanaSprint_600M_P1_D28", "SanaSprint_1600M_P1_D20"]
else None
)
for depth in range(layer_num):
# Transformer blocks.
@@ -117,6 +155,14 @@ def main(args):
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
if qk_norm is not None:
# Add Q/K normalization for self-attention (attn1) - needed for Sana-Sprint and Sana-1.5
converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_q.weight"] = state_dict.pop(
f"blocks.{depth}.attn.q_norm.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_k.weight"] = state_dict.pop(
f"blocks.{depth}.attn.k_norm.weight"
)
# Projection.
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict.pop(
f"blocks.{depth}.attn.proj.weight"
@@ -154,6 +200,14 @@ def main(args):
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias
if qk_norm is not None:
# Add Q/K normalization for cross-attention (attn2) - needed for Sana-Sprint and Sana-1.5
converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_q.weight"] = state_dict.pop(
f"blocks.{depth}.cross_attn.q_norm.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_k.weight"] = state_dict.pop(
f"blocks.{depth}.cross_attn.k_norm.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = state_dict.pop(
f"blocks.{depth}.cross_attn.proj.weight"
@@ -169,24 +223,37 @@ def main(args):
# Transformer
with CTX():
transformer = SanaTransformer2DModel(
in_channels=32,
out_channels=32,
num_attention_heads=model_kwargs[args.model_type]["num_attention_heads"],
attention_head_dim=model_kwargs[args.model_type]["attention_head_dim"],
num_layers=model_kwargs[args.model_type]["num_layers"],
num_cross_attention_heads=model_kwargs[args.model_type]["num_cross_attention_heads"],
cross_attention_head_dim=model_kwargs[args.model_type]["cross_attention_head_dim"],
cross_attention_dim=model_kwargs[args.model_type]["cross_attention_dim"],
caption_channels=2304,
mlp_ratio=2.5,
attention_bias=False,
sample_size=args.image_size // 32,
patch_size=1,
norm_elementwise_affine=False,
norm_eps=1e-6,
interpolation_scale=interpolation_scale[args.image_size],
)
transformer_kwargs = {
"in_channels": 32,
"out_channels": 32,
"num_attention_heads": model_kwargs[args.model_type]["num_attention_heads"],
"attention_head_dim": model_kwargs[args.model_type]["attention_head_dim"],
"num_layers": model_kwargs[args.model_type]["num_layers"],
"num_cross_attention_heads": model_kwargs[args.model_type]["num_cross_attention_heads"],
"cross_attention_head_dim": model_kwargs[args.model_type]["cross_attention_head_dim"],
"cross_attention_dim": model_kwargs[args.model_type]["cross_attention_dim"],
"caption_channels": 2304,
"mlp_ratio": 2.5,
"attention_bias": False,
"sample_size": args.image_size // 32,
"patch_size": 1,
"norm_elementwise_affine": False,
"norm_eps": 1e-6,
"interpolation_scale": interpolation_scale[args.image_size],
}
# Add qk_norm parameter for Sana Sprint
if args.model_type in [
"SanaMS1.5_1600M_P1_D20",
"SanaMS1.5_4800M_P1_D60",
"SanaSprint_600M_P1_D28",
"SanaSprint_1600M_P1_D20",
]:
transformer_kwargs["qk_norm"] = "rms_norm_across_heads"
if args.model_type in ["SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"]:
transformer_kwargs["guidance_embeds"] = True
transformer = SanaTransformer2DModel(**transformer_kwargs)
if is_accelerate_available():
load_model_dict_into_meta(transformer, converted_state_dict)
@@ -196,6 +263,8 @@ def main(args):
try:
state_dict.pop("y_embedder.y_embedding")
state_dict.pop("pos_embed")
state_dict.pop("logvar_linear.weight")
state_dict.pop("logvar_linear.bias")
except KeyError:
print("y_embedder.y_embedding or pos_embed not found in the state_dict")
@@ -210,47 +279,75 @@ def main(args):
print(
colored(
f"Only saving transformer model of {args.model_type}. "
f"Set --save_full_pipeline to save the whole SanaPipeline",
f"Set --save_full_pipeline to save the whole Pipeline",
"green",
attrs=["bold"],
)
)
transformer.save_pretrained(
os.path.join(args.dump_path, "transformer"), safe_serialization=True, max_shard_size="5GB", variant=variant
os.path.join(args.dump_path, "transformer"), safe_serialization=True, max_shard_size="5GB"
)
else:
print(colored(f"Saving the whole SanaPipeline containing {args.model_type}", "green", attrs=["bold"]))
print(colored(f"Saving the whole Pipeline containing {args.model_type}", "green", attrs=["bold"]))
# VAE
ae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers", torch_dtype=torch.float32)
ae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers", torch_dtype=torch.float32)
# Text Encoder
text_encoder_model_path = "google/gemma-2-2b-it"
text_encoder_model_path = "Efficient-Large-Model/gemma-2-2b-it"
tokenizer = AutoTokenizer.from_pretrained(text_encoder_model_path)
tokenizer.padding_side = "right"
text_encoder = AutoModelForCausalLM.from_pretrained(
text_encoder_model_path, torch_dtype=torch.bfloat16
).get_decoder()
# Scheduler
if args.scheduler_type == "flow-dpm_solver":
scheduler = DPMSolverMultistepScheduler(
flow_shift=flow_shift,
use_flow_sigmas=True,
prediction_type="flow_prediction",
)
elif args.scheduler_type == "flow-euler":
scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift)
else:
raise ValueError(f"Scheduler type {args.scheduler_type} is not supported")
# Choose the appropriate pipeline and scheduler based on model type
if args.model_type in ["SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"]:
# Force SCM Scheduler for Sana Sprint regardless of scheduler_type
if args.scheduler_type != "scm":
print(
colored(
f"Warning: Overriding scheduler_type '{args.scheduler_type}' to 'scm' for SanaSprint model",
"yellow",
attrs=["bold"],
)
)
pipe = SanaPipeline(
tokenizer=tokenizer,
text_encoder=text_encoder,
transformer=transformer,
vae=ae,
scheduler=scheduler,
)
pipe.save_pretrained(args.dump_path, safe_serialization=True, max_shard_size="5GB", variant=variant)
# SCM Scheduler for Sana Sprint
scheduler_config = {
"num_train_timesteps": 1000,
"prediction_type": "trigflow",
"sigma_data": 0.5,
}
scheduler = SCMScheduler(**scheduler_config)
pipe = SanaSprintPipeline(
tokenizer=tokenizer,
text_encoder=text_encoder,
transformer=transformer,
vae=ae,
scheduler=scheduler,
)
else:
# Original Sana scheduler
if args.scheduler_type == "flow-dpm_solver":
scheduler = DPMSolverMultistepScheduler(
flow_shift=flow_shift,
use_flow_sigmas=True,
prediction_type="flow_prediction",
)
elif args.scheduler_type == "flow-euler":
scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift)
else:
raise ValueError(f"Scheduler type {args.scheduler_type} is not supported")
pipe = SanaPipeline(
tokenizer=tokenizer,
text_encoder=text_encoder,
transformer=transformer,
vae=ae,
scheduler=scheduler,
)
pipe.save_pretrained(args.dump_path, safe_serialization=True, max_shard_size="5GB")
DTYPE_MAPPING = {
@@ -259,12 +356,6 @@ DTYPE_MAPPING = {
"bf16": torch.bfloat16,
}
VARIANT_MAPPING = {
"fp32": None,
"fp16": "fp16",
"bf16": "bf16",
}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
@@ -281,10 +372,23 @@ if __name__ == "__main__":
help="Image size of pretrained model, 512, 1024, 2048 or 4096.",
)
parser.add_argument(
"--model_type", default="SanaMS_1600M_P1_D20", type=str, choices=["SanaMS_1600M_P1_D20", "SanaMS_600M_P1_D28"]
"--model_type",
default="SanaMS_1600M_P1_D20",
type=str,
choices=[
"SanaMS_1600M_P1_D20",
"SanaMS_600M_P1_D28",
"SanaMS_4800M_P1_D60",
"SanaSprint_1600M_P1_D20",
"SanaSprint_600M_P1_D28",
],
)
parser.add_argument(
"--scheduler_type", default="flow-dpm_solver", type=str, choices=["flow-dpm_solver", "flow-euler"]
"--scheduler_type",
default="flow-dpm_solver",
type=str,
choices=["flow-dpm_solver", "flow-euler", "scm"],
help="Scheduler type to use. Use 'scm' for Sana Sprint models.",
)
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipelien elemets in one.")
@@ -309,10 +413,41 @@ if __name__ == "__main__":
"cross_attention_dim": 1152,
"num_layers": 28,
},
"SanaMS1.5_1600M_P1_D20": {
"num_attention_heads": 70,
"attention_head_dim": 32,
"num_cross_attention_heads": 20,
"cross_attention_head_dim": 112,
"cross_attention_dim": 2240,
"num_layers": 20,
},
"SanaMS1.5__4800M_P1_D60": {
"num_attention_heads": 70,
"attention_head_dim": 32,
"num_cross_attention_heads": 20,
"cross_attention_head_dim": 112,
"cross_attention_dim": 2240,
"num_layers": 60,
},
"SanaSprint_600M_P1_D28": {
"num_attention_heads": 36,
"attention_head_dim": 32,
"num_cross_attention_heads": 16,
"cross_attention_head_dim": 72,
"cross_attention_dim": 1152,
"num_layers": 28,
},
"SanaSprint_1600M_P1_D20": {
"num_attention_heads": 70,
"attention_head_dim": 32,
"num_cross_attention_heads": 20,
"cross_attention_head_dim": 112,
"cross_attention_dim": 2240,
"num_layers": 20,
},
}
device = "cuda" if torch.cuda.is_available() else "cpu"
weight_dtype = DTYPE_MAPPING[args.dtype]
variant = VARIANT_MAPPING[args.dtype]
main(args)