1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Don't use load_state_dict if torch is not installed.

This commit is contained in:
Pedro Cuenca
2022-09-30 14:58:08 +02:00
parent 877bec8a91
commit 2b24dba599

View File

@@ -28,7 +28,6 @@ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, R
from requests import HTTPError
from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
from .modeling_utils import load_state_dict
from .utils import (
CONFIG_NAME,
DIFFUSERS_CACHE,
@@ -37,6 +36,7 @@ from .utils import (
WEIGHTS_NAME,
logging,
)
from . import is_torch_available
logger = logging.get_logger(__name__)
@@ -391,6 +391,14 @@ class FlaxModelMixin:
)
if from_pt:
if is_torch_available():
from .modeling_utils import load_state_dict
else:
raise EnvironmentError(
f"Can't load the model in PyTorch format because PyTorch is not installed. "
f"Please, install PyTorch or use native Flax weights."
)
# Step 1: Get the pytorch file
pytorch_model_file = load_state_dict(model_file)