mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Sana bug] bug fix for 2K model config (#10340)
* fix the Positinoal Embedding bug in 2K model; * Change the default model to the BF16 one for more stable training and output * make style * substract buffer size * add compute_module_persistent_sizes --------- Co-authored-by: yiyixuxu <yixu310@gmail.com>
This commit is contained in:
@@ -22,7 +22,7 @@ The model can be loaded with the following code snippet.
|
||||
```python
|
||||
from diffusers import SanaTransformer2DModel
|
||||
|
||||
transformer = SanaTransformer2DModel.from_pretrained("Efficient-Large-Model/Sana_1600M_1024px_diffusers", subfolder="transformer", torch_dtype=torch.float16)
|
||||
transformer = SanaTransformer2DModel.from_pretrained("Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
## SanaTransformer2DModel
|
||||
|
||||
@@ -32,9 +32,9 @@ Available models:
|
||||
|
||||
| Model | Recommended dtype |
|
||||
|:-----:|:-----------------:|
|
||||
| [`Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers) | `torch.bfloat16` |
|
||||
| [`Efficient-Large-Model/Sana_1600M_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_diffusers) | `torch.float16` |
|
||||
| [`Efficient-Large-Model/Sana_1600M_1024px_MultiLing_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_MultiLing_diffusers) | `torch.float16` |
|
||||
| [`Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers) | `torch.bfloat16` |
|
||||
| [`Efficient-Large-Model/Sana_1600M_512px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_512px_diffusers) | `torch.float16` |
|
||||
| [`Efficient-Large-Model/Sana_1600M_512px_MultiLing_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_512px_MultiLing_diffusers) | `torch.float16` |
|
||||
| [`Efficient-Large-Model/Sana_600M_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_600M_1024px_diffusers) | `torch.float16` |
|
||||
|
||||
@@ -88,13 +88,18 @@ def main(args):
|
||||
# y norm
|
||||
converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight")
|
||||
|
||||
# scheduler
|
||||
flow_shift = 3.0
|
||||
|
||||
# model config
|
||||
if args.model_type == "SanaMS_1600M_P1_D20":
|
||||
layer_num = 20
|
||||
elif args.model_type == "SanaMS_600M_P1_D28":
|
||||
layer_num = 28
|
||||
else:
|
||||
raise ValueError(f"{args.model_type} is not supported.")
|
||||
# Positional embedding interpolation scale.
|
||||
interpolation_scale = {512: None, 1024: None, 2048: 1.0}
|
||||
|
||||
for depth in range(layer_num):
|
||||
# Transformer blocks.
|
||||
@@ -176,6 +181,7 @@ def main(args):
|
||||
patch_size=1,
|
||||
norm_elementwise_affine=False,
|
||||
norm_eps=1e-6,
|
||||
interpolation_scale=interpolation_scale[args.image_size],
|
||||
)
|
||||
|
||||
if is_accelerate_available():
|
||||
|
||||
@@ -242,6 +242,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
patch_size: int = 1,
|
||||
norm_elementwise_affine: bool = False,
|
||||
norm_eps: float = 1e-6,
|
||||
interpolation_scale: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -249,14 +250,14 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
# 1. Patch Embedding
|
||||
interpolation_scale = interpolation_scale if interpolation_scale is not None else max(sample_size // 64, 1)
|
||||
self.patch_embed = PatchEmbed(
|
||||
height=sample_size,
|
||||
width=sample_size,
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
embed_dim=inner_dim,
|
||||
interpolation_scale=None,
|
||||
pos_embed_type=None,
|
||||
interpolation_scale=interpolation_scale,
|
||||
)
|
||||
|
||||
# 2. Additional condition embeddings
|
||||
|
||||
@@ -59,13 +59,13 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> from diffusers import SanaPAGPipeline
|
||||
|
||||
>>> pipe = SanaPAGPipeline.from_pretrained(
|
||||
... "Efficient-Large-Model/Sana_1600M_1024px_diffusers",
|
||||
... "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
|
||||
... pag_applied_layers=["transformer_blocks.8"],
|
||||
... torch_dtype=torch.float32,
|
||||
... )
|
||||
>>> pipe.to("cuda")
|
||||
>>> pipe.text_encoder.to(torch.bfloat16)
|
||||
>>> pipe.transformer = pipe.transformer.to(torch.float16)
|
||||
>>> pipe.transformer = pipe.transformer.to(torch.bfloat16)
|
||||
|
||||
>>> image = pipe(prompt='a cyberpunk cat with a neon sign that says "Sana"')[0]
|
||||
>>> image[0].save("output.png")
|
||||
|
||||
@@ -62,11 +62,11 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> from diffusers import SanaPipeline
|
||||
|
||||
>>> pipe = SanaPipeline.from_pretrained(
|
||||
... "Efficient-Large-Model/Sana_1600M_1024px_diffusers", torch_dtype=torch.float32
|
||||
... "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", torch_dtype=torch.float32
|
||||
... )
|
||||
>>> pipe.to("cuda")
|
||||
>>> pipe.text_encoder.to(torch.bfloat16)
|
||||
>>> pipe.transformer = pipe.transformer.to(torch.float16)
|
||||
>>> pipe.transformer = pipe.transformer.to(torch.bfloat16)
|
||||
|
||||
>>> image = pipe(prompt='a cyberpunk cat with a neon sign that says "Sana"')[0]
|
||||
>>> image[0].save("output.png")
|
||||
|
||||
@@ -22,12 +22,14 @@ import traceback
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
import uuid
|
||||
from typing import Dict, List, Tuple
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import requests_mock
|
||||
import torch
|
||||
from accelerate.utils import compute_module_sizes
|
||||
import torch.nn as nn
|
||||
from accelerate.utils.modeling import _get_proper_dtype, dtype_byte_size
|
||||
from huggingface_hub import ModelCard, delete_repo, snapshot_download
|
||||
from huggingface_hub.utils import is_jinja_available
|
||||
from parameterized import parameterized
|
||||
@@ -113,6 +115,72 @@ def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout):
|
||||
out_queue.join()
|
||||
|
||||
|
||||
def named_persistent_module_tensors(
|
||||
module: nn.Module,
|
||||
recurse: bool = False,
|
||||
):
|
||||
"""
|
||||
A helper function that gathers all the tensors (parameters + persistent buffers) of a given module.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module we want the tensors on.
|
||||
recurse (`bool`, *optional`, defaults to `False`):
|
||||
Whether or not to go look in every submodule or just return the direct parameters and buffers.
|
||||
"""
|
||||
yield from module.named_parameters(recurse=recurse)
|
||||
|
||||
for named_buffer in module.named_buffers(recurse=recurse):
|
||||
name, _ = named_buffer
|
||||
# Get parent by splitting on dots and traversing the model
|
||||
parent = module
|
||||
if "." in name:
|
||||
parent_name = name.rsplit(".", 1)[0]
|
||||
for part in parent_name.split("."):
|
||||
parent = getattr(parent, part)
|
||||
name = name.split(".")[-1]
|
||||
if name not in parent._non_persistent_buffers_set:
|
||||
yield named_buffer
|
||||
|
||||
|
||||
def compute_module_persistent_sizes(
|
||||
model: nn.Module,
|
||||
dtype: Optional[Union[str, torch.device]] = None,
|
||||
special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None,
|
||||
):
|
||||
"""
|
||||
Compute the size of each submodule of a given model (parameters + persistent buffers).
|
||||
"""
|
||||
if dtype is not None:
|
||||
dtype = _get_proper_dtype(dtype)
|
||||
dtype_size = dtype_byte_size(dtype)
|
||||
if special_dtypes is not None:
|
||||
special_dtypes = {key: _get_proper_dtype(dtyp) for key, dtyp in special_dtypes.items()}
|
||||
special_dtypes_size = {key: dtype_byte_size(dtyp) for key, dtyp in special_dtypes.items()}
|
||||
module_sizes = defaultdict(int)
|
||||
|
||||
module_list = []
|
||||
|
||||
module_list = named_persistent_module_tensors(model, recurse=True)
|
||||
|
||||
for name, tensor in module_list:
|
||||
if special_dtypes is not None and name in special_dtypes:
|
||||
size = tensor.numel() * special_dtypes_size[name]
|
||||
elif dtype is None:
|
||||
size = tensor.numel() * dtype_byte_size(tensor.dtype)
|
||||
elif str(tensor.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
|
||||
# According to the code in set_module_tensor_to_device, these types won't be converted
|
||||
# so use their original size here
|
||||
size = tensor.numel() * dtype_byte_size(tensor.dtype)
|
||||
else:
|
||||
size = tensor.numel() * min(dtype_size, dtype_byte_size(tensor.dtype))
|
||||
name_parts = name.split(".")
|
||||
for idx in range(len(name_parts) + 1):
|
||||
module_sizes[".".join(name_parts[:idx])] += size
|
||||
|
||||
return module_sizes
|
||||
|
||||
|
||||
class ModelUtilsTest(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
@@ -1012,7 +1080,7 @@ class ModelTesterMixin:
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
model_size = compute_module_persistent_sizes(model)[""]
|
||||
# We test several splits of sizes to make sure it works.
|
||||
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
@@ -1042,7 +1110,7 @@ class ModelTesterMixin:
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
model_size = compute_module_persistent_sizes(model)[""]
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir, safe_serialization=False)
|
||||
|
||||
@@ -1076,7 +1144,7 @@ class ModelTesterMixin:
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
model_size = compute_module_persistent_sizes(model)[""]
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir)
|
||||
|
||||
@@ -1104,7 +1172,7 @@ class ModelTesterMixin:
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
model_size = compute_module_persistent_sizes(model)[""]
|
||||
# We test several splits of sizes to make sure it works.
|
||||
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
@@ -1132,7 +1200,7 @@ class ModelTesterMixin:
|
||||
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
model_size = compute_module_persistent_sizes(model)[""]
|
||||
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
|
||||
@@ -1164,7 +1232,7 @@ class ModelTesterMixin:
|
||||
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
model_size = compute_module_persistent_sizes(model)[""]
|
||||
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
|
||||
variant = "fp16"
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
@@ -1204,7 +1272,7 @@ class ModelTesterMixin:
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
model_size = compute_module_persistent_sizes(model)[""]
|
||||
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
|
||||
@@ -1233,7 +1301,7 @@ class ModelTesterMixin:
|
||||
config, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**config).eval()
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
model_size = compute_module_persistent_sizes(model)[""]
|
||||
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
|
||||
variant = "fp16"
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
|
||||
Reference in New Issue
Block a user