diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index 80c3fb68a4..1a3fea1e8a 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -28,7 +28,6 @@ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, R from requests import HTTPError from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax -from .modeling_utils import load_state_dict from .utils import ( CONFIG_NAME, DIFFUSERS_CACHE, @@ -37,6 +36,7 @@ from .utils import ( WEIGHTS_NAME, logging, ) +from . import is_torch_available logger = logging.get_logger(__name__) @@ -391,6 +391,14 @@ class FlaxModelMixin: ) if from_pt: + if is_torch_available(): + from .modeling_utils import load_state_dict + else: + raise EnvironmentError( + f"Can't load the model in PyTorch format because PyTorch is not installed. " + f"Please, install PyTorch or use native Flax weights." + ) + # Step 1: Get the pytorch file pytorch_model_file = load_state_dict(model_file)