mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
up
This commit is contained in:
@@ -1216,72 +1216,50 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
|
||||
model = load_flax_checkpoint_in_pytorch_model(model, resolved_model_file)
|
||||
|
||||
else:
|
||||
flashpack_file = None
|
||||
if use_flashpack:
|
||||
flashpack_file = None
|
||||
if use_flashpack:
|
||||
try:
|
||||
flashpack_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=_add_variant("model.flashpack", variant),
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
commit_hash=commit_hash,
|
||||
dduf_entries=dduf_entries,
|
||||
)
|
||||
except EnvironmentError:
|
||||
flashpack_file = None
|
||||
logger.warning(
|
||||
"`use_flashpack` was specified to be True but not flashpack file was found. Resorting to non-flashpack alternatives."
|
||||
)
|
||||
|
||||
if flashpack_file is None:
|
||||
# in the case it is sharded, we have already the index
|
||||
if is_sharded:
|
||||
resolved_model_file, sharded_metadata = _get_checkpoint_shard_files(
|
||||
pretrained_model_name_or_path,
|
||||
index_file,
|
||||
cache_dir=cache_dir,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
user_agent=user_agent,
|
||||
revision=revision,
|
||||
subfolder=subfolder or "",
|
||||
dduf_entries=dduf_entries,
|
||||
)
|
||||
elif use_safetensors:
|
||||
logger.warning("Trying to load model weights with safetensors format.")
|
||||
try:
|
||||
flashpack_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=_add_variant("model.flashpack", variant),
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
commit_hash=commit_hash,
|
||||
dduf_entries=dduf_entries,
|
||||
)
|
||||
except EnvironmentError:
|
||||
flashpack_file = None
|
||||
|
||||
if flashpack_file is None:
|
||||
# in the case it is sharded, we have already the index
|
||||
if is_sharded:
|
||||
resolved_model_file, sharded_metadata = _get_checkpoint_shard_files(
|
||||
pretrained_model_name_or_path,
|
||||
index_file,
|
||||
cache_dir=cache_dir,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
user_agent=user_agent,
|
||||
revision=revision,
|
||||
subfolder=subfolder or "",
|
||||
dduf_entries=dduf_entries,
|
||||
)
|
||||
elif use_safetensors:
|
||||
logger.warning("Trying to load model weights with safetensors format.")
|
||||
try:
|
||||
resolved_model_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
commit_hash=commit_hash,
|
||||
dduf_entries=dduf_entries,
|
||||
)
|
||||
|
||||
except IOError as e:
|
||||
logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
|
||||
if not allow_pickle:
|
||||
raise
|
||||
logger.warning(
|
||||
"Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
|
||||
)
|
||||
|
||||
if resolved_model_file is None and not is_sharded:
|
||||
resolved_model_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=_add_variant(WEIGHTS_NAME, variant),
|
||||
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
@@ -1294,8 +1272,32 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
dduf_entries=dduf_entries,
|
||||
)
|
||||
|
||||
if not isinstance(resolved_model_file, list):
|
||||
resolved_model_file = [resolved_model_file]
|
||||
except IOError as e:
|
||||
logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
|
||||
if not allow_pickle:
|
||||
raise
|
||||
logger.warning(
|
||||
"Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
|
||||
)
|
||||
|
||||
if resolved_model_file is None and not is_sharded:
|
||||
resolved_model_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=_add_variant(WEIGHTS_NAME, variant),
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
commit_hash=commit_hash,
|
||||
dduf_entries=dduf_entries,
|
||||
)
|
||||
|
||||
if not isinstance(resolved_model_file, list):
|
||||
resolved_model_file = [resolved_model_file]
|
||||
|
||||
# set dtype to instantiate the model under:
|
||||
# 1. If torch_dtype is not None, we use that dtype
|
||||
|
||||
Reference in New Issue
Block a user