mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Don't use load_state_dict if torch is not installed.
This commit is contained in:
@@ -28,7 +28,6 @@ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, R
|
||||
from requests import HTTPError
|
||||
|
||||
from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
|
||||
from .modeling_utils import load_state_dict
|
||||
from .utils import (
|
||||
CONFIG_NAME,
|
||||
DIFFUSERS_CACHE,
|
||||
@@ -37,6 +36,7 @@ from .utils import (
|
||||
WEIGHTS_NAME,
|
||||
logging,
|
||||
)
|
||||
from . import is_torch_available
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -391,6 +391,14 @@ class FlaxModelMixin:
|
||||
)
|
||||
|
||||
if from_pt:
|
||||
if is_torch_available():
|
||||
from .modeling_utils import load_state_dict
|
||||
else:
|
||||
raise EnvironmentError(
|
||||
f"Can't load the model in PyTorch format because PyTorch is not installed. "
|
||||
f"Please, install PyTorch or use native Flax weights."
|
||||
)
|
||||
|
||||
# Step 1: Get the pytorch file
|
||||
pytorch_model_file = load_state_dict(model_file)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user