1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
sayakpaul
2026-01-22 17:02:52 +05:30
parent 55eaa6efb2
commit 668f265054

View File

@@ -731,23 +731,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
" the logger on the traceback to understand the reason why the quantized model is not serializable."
)
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
weights_name = _add_variant(weights_name, variant)
weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
".safetensors", "{suffix}.safetensors"
)
os.makedirs(save_directory, exist_ok=True)
if use_flashpack:
if not is_main_process:
return
from ..utils.flashpack_utils import save_flashpack
save_flashpack(
self,
save_directory,
variant=variant,
)
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
private = kwargs.pop("private", None)
@@ -759,67 +744,80 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
# Only save the model itself if we are using distributed training
model_to_save = self
# Attach architecture to the config
# Save the config
if is_main_process:
model_to_save.save_config(save_directory)
# Save the model
state_dict = model_to_save.state_dict()
if use_flashpack:
if not is_main_process:
return
# Save the model
state_dict_split = split_torch_state_dict_into_shards(
state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
)
from ..utils.flashpack_utils import save_flashpack
# Clean the folder from a previous save
if is_main_process:
for filename in os.listdir(save_directory):
if filename in state_dict_split.filename_to_tensors.keys():
continue
full_filename = os.path.join(save_directory, filename)
if not os.path.isfile(full_filename):
continue
weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "")
weights_without_ext = weights_without_ext.replace("{suffix}", "")
filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "")
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
if (
filename.startswith(weights_without_ext)
and _REGEX_SHARD.fullmatch(filename_without_ext) is not None
):
os.remove(full_filename)
for filename, tensors in state_dict_split.filename_to_tensors.items():
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
filepath = os.path.join(save_directory, filename)
if safe_serialization:
# At some point we will need to deal better with save_function (used for TPU and other distributed
# joyfulness), but for now this enough.
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
else:
torch.save(shard, filepath)
if state_dict_split.is_sharded:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
# Save the index as well
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
logger.info(
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
save_flashpack(model_to_save, save_directory, variant=variant)
else:
path_to_weights = os.path.join(save_directory, weights_name)
logger.info(f"Model weights saved in {path_to_weights}")
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
weights_name = _add_variant(weights_name, variant)
weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
".safetensors", "{suffix}.safetensors"
)
state_dict = model_to_save.state_dict()
state_dict_split = split_torch_state_dict_into_shards(
state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
)
# Clean the folder from a previous save
if is_main_process:
for filename in os.listdir(save_directory):
if filename in state_dict_split.filename_to_tensors.keys():
continue
full_filename = os.path.join(save_directory, filename)
if not os.path.isfile(full_filename):
continue
weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "")
weights_without_ext = weights_without_ext.replace("{suffix}", "")
filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "")
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
if (
filename.startswith(weights_without_ext)
and _REGEX_SHARD.fullmatch(filename_without_ext) is not None
):
os.remove(full_filename)
# Save each shard
for filename, tensors in state_dict_split.filename_to_tensors.items():
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
filepath = os.path.join(save_directory, filename)
if safe_serialization:
# At some point we will need to deal better with save_function (used for TPU and other distributed
# joyfulness), but for now this enough.
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
else:
torch.save(shard, filepath)
# Save index file if sharded
if state_dict_split.is_sharded:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
# Save the index as well
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
logger.info(
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
else:
path_to_weights = os.path.join(save_directory, weights_name)
logger.info(f"Model weights saved in {path_to_weights}")
# Push to hub if requested (common to both paths)
if push_to_hub:
# Create a new empty model card and eventually tag it
model_card = load_or_create_model_card(repo_id, token=token)