diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index c89204dff7..dc98af07b7 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -731,23 +731,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): " the logger on the traceback to understand the reason why the quantized model is not serializable." ) - weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME - weights_name = _add_variant(weights_name, variant) - weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace( - ".safetensors", "{suffix}.safetensors" - ) - os.makedirs(save_directory, exist_ok=True) - if use_flashpack: - if not is_main_process: - return - from ..utils.flashpack_utils import save_flashpack - save_flashpack( - self, - save_directory, - variant=variant, - ) if push_to_hub: commit_message = kwargs.pop("commit_message", None) private = kwargs.pop("private", None) @@ -759,67 +744,80 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): # Only save the model itself if we are using distributed training model_to_save = self - # Attach architecture to the config # Save the config if is_main_process: model_to_save.save_config(save_directory) - # Save the model - state_dict = model_to_save.state_dict() + if use_flashpack: + if not is_main_process: + return - # Save the model - state_dict_split = split_torch_state_dict_into_shards( - state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern - ) + from ..utils.flashpack_utils import save_flashpack - # Clean the folder from a previous save - if is_main_process: - for filename in os.listdir(save_directory): - if filename in state_dict_split.filename_to_tensors.keys(): - continue - full_filename = os.path.join(save_directory, filename) - if not os.path.isfile(full_filename): - continue - weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "") - weights_without_ext = weights_without_ext.replace("{suffix}", "") - filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "") - # make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005 - if ( - filename.startswith(weights_without_ext) - and _REGEX_SHARD.fullmatch(filename_without_ext) is not None - ): - os.remove(full_filename) - - for filename, tensors in state_dict_split.filename_to_tensors.items(): - shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors} - filepath = os.path.join(save_directory, filename) - if safe_serialization: - # At some point we will need to deal better with save_function (used for TPU and other distributed - # joyfulness), but for now this enough. - safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"}) - else: - torch.save(shard, filepath) - - if state_dict_split.is_sharded: - index = { - "metadata": state_dict_split.metadata, - "weight_map": state_dict_split.tensor_to_filename, - } - save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME - save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant)) - # Save the index as well - with open(save_index_file, "w", encoding="utf-8") as f: - content = json.dumps(index, indent=2, sort_keys=True) + "\n" - f.write(content) - logger.info( - f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be " - f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the " - f"index located at {save_index_file}." - ) + save_flashpack(model_to_save, save_directory, variant=variant) else: - path_to_weights = os.path.join(save_directory, weights_name) - logger.info(f"Model weights saved in {path_to_weights}") + weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME + weights_name = _add_variant(weights_name, variant) + weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace( + ".safetensors", "{suffix}.safetensors" + ) + state_dict = model_to_save.state_dict() + state_dict_split = split_torch_state_dict_into_shards( + state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern + ) + + # Clean the folder from a previous save + if is_main_process: + for filename in os.listdir(save_directory): + if filename in state_dict_split.filename_to_tensors.keys(): + continue + full_filename = os.path.join(save_directory, filename) + if not os.path.isfile(full_filename): + continue + weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "") + weights_without_ext = weights_without_ext.replace("{suffix}", "") + filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "") + # make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005 + if ( + filename.startswith(weights_without_ext) + and _REGEX_SHARD.fullmatch(filename_without_ext) is not None + ): + os.remove(full_filename) + + # Save each shard + for filename, tensors in state_dict_split.filename_to_tensors.items(): + shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors} + filepath = os.path.join(save_directory, filename) + if safe_serialization: + # At some point we will need to deal better with save_function (used for TPU and other distributed + # joyfulness), but for now this enough. + safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"}) + else: + torch.save(shard, filepath) + + # Save index file if sharded + if state_dict_split.is_sharded: + index = { + "metadata": state_dict_split.metadata, + "weight_map": state_dict_split.tensor_to_filename, + } + save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME + save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant)) + # Save the index as well + with open(save_index_file, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + logger.info( + f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be " + f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + else: + path_to_weights = os.path.join(save_directory, weights_name) + logger.info(f"Model weights saved in {path_to_weights}") + + # Push to hub if requested (common to both paths) if push_to_hub: # Create a new empty model card and eventually tag it model_card = load_or_create_model_card(repo_id, token=token)