From 2b24dba599fa2e9f306e6fc77a67f1a4a02a88f7 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Fri, 30 Sep 2022 14:58:08 +0200 Subject: [PATCH] Don't use `load_state_dict` if torch is not installed. --- src/diffusers/modeling_flax_utils.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index 80c3fb68a4..1a3fea1e8a 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -28,7 +28,6 @@ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, R from requests import HTTPError from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax -from .modeling_utils import load_state_dict from .utils import ( CONFIG_NAME, DIFFUSERS_CACHE, @@ -37,6 +36,7 @@ from .utils import ( WEIGHTS_NAME, logging, ) +from . import is_torch_available logger = logging.get_logger(__name__) @@ -391,6 +391,14 @@ class FlaxModelMixin: ) if from_pt: + if is_torch_available(): + from .modeling_utils import load_state_dict + else: + raise EnvironmentError( + f"Can't load the model in PyTorch format because PyTorch is not installed. " + f"Please, install PyTorch or use native Flax weights." + ) + # Step 1: Get the pytorch file pytorch_model_file = load_state_dict(model_file)