From ff26d9ffd5bf21688361d00296dff13a0e4734aa Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 22 Jan 2026 17:12:43 +0530 Subject: [PATCH] up --- src/diffusers/models/modeling_utils.py | 132 +++++++++++++------------ 1 file changed, 67 insertions(+), 65 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index dc98af07b7..b29f160657 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1216,72 +1216,50 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): model = load_flax_checkpoint_in_pytorch_model(model, resolved_model_file) - else: - flashpack_file = None - if use_flashpack: + flashpack_file = None + if use_flashpack: + try: + flashpack_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant("model.flashpack", variant), + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + dduf_entries=dduf_entries, + ) + except EnvironmentError: + flashpack_file = None + logger.warning( + "`use_flashpack` was specified to be True but not flashpack file was found. Resorting to non-flashpack alternatives." + ) + + if flashpack_file is None: + # in the case it is sharded, we have already the index + if is_sharded: + resolved_model_file, sharded_metadata = _get_checkpoint_shard_files( + pretrained_model_name_or_path, + index_file, + cache_dir=cache_dir, + proxies=proxies, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder or "", + dduf_entries=dduf_entries, + ) + elif use_safetensors: + logger.warning("Trying to load model weights with safetensors format.") try: - flashpack_file = _get_model_file( - pretrained_model_name_or_path, - weights_name=_add_variant("model.flashpack", variant), - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - commit_hash=commit_hash, - dduf_entries=dduf_entries, - ) - except EnvironmentError: - flashpack_file = None - - if flashpack_file is None: - # in the case it is sharded, we have already the index - if is_sharded: - resolved_model_file, sharded_metadata = _get_checkpoint_shard_files( - pretrained_model_name_or_path, - index_file, - cache_dir=cache_dir, - proxies=proxies, - local_files_only=local_files_only, - token=token, - user_agent=user_agent, - revision=revision, - subfolder=subfolder or "", - dduf_entries=dduf_entries, - ) - elif use_safetensors: - logger.warning("Trying to load model weights with safetensors format.") - try: - resolved_model_file = _get_model_file( - pretrained_model_name_or_path, - weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - commit_hash=commit_hash, - dduf_entries=dduf_entries, - ) - - except IOError as e: - logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}") - if not allow_pickle: - raise - logger.warning( - "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead." - ) - - if resolved_model_file is None and not is_sharded: resolved_model_file = _get_model_file( pretrained_model_name_or_path, - weights_name=_add_variant(WEIGHTS_NAME, variant), + weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), cache_dir=cache_dir, force_download=force_download, proxies=proxies, @@ -1294,8 +1272,32 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): dduf_entries=dduf_entries, ) - if not isinstance(resolved_model_file, list): - resolved_model_file = [resolved_model_file] + except IOError as e: + logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}") + if not allow_pickle: + raise + logger.warning( + "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead." + ) + + if resolved_model_file is None and not is_sharded: + resolved_model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + dduf_entries=dduf_entries, + ) + + if not isinstance(resolved_model_file, list): + resolved_model_file = [resolved_model_file] # set dtype to instantiate the model under: # 1. If torch_dtype is not None, we use that dtype