mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Feature] AutoModel can load components using model_index.json (#11401)
* update * update * update * update * addressed PR comments * update * addressed PR comments * added tests * addressed PR comments * updates * update * addressed PR comments * update * fix style --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
@@ -12,13 +12,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class AutoModel(ConfigMixin):
|
||||
@@ -152,15 +155,50 @@ class AutoModel(ConfigMixin):
|
||||
"token": token,
|
||||
"local_files_only": local_files_only,
|
||||
"revision": revision,
|
||||
"subfolder": subfolder,
|
||||
}
|
||||
|
||||
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
|
||||
orig_class_name = config["_class_name"]
|
||||
library = None
|
||||
orig_class_name = None
|
||||
|
||||
library = importlib.import_module("diffusers")
|
||||
# Always attempt to fetch model_index.json first
|
||||
try:
|
||||
cls.config_name = "model_index.json"
|
||||
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
|
||||
|
||||
if subfolder is not None and subfolder in config:
|
||||
library, orig_class_name = config[subfolder]
|
||||
load_config_kwargs.update({"subfolder": subfolder})
|
||||
|
||||
except EnvironmentError as e:
|
||||
logger.debug(e)
|
||||
|
||||
# Unable to load from model_index.json so fallback to loading from config
|
||||
if library is None and orig_class_name is None:
|
||||
cls.config_name = "config.json"
|
||||
config = cls.load_config(pretrained_model_or_path, subfolder=subfolder, **load_config_kwargs)
|
||||
|
||||
if "_class_name" in config:
|
||||
# If we find a class name in the config, we can try to load the model as a diffusers model
|
||||
orig_class_name = config["_class_name"]
|
||||
library = "diffusers"
|
||||
load_config_kwargs.update({"subfolder": subfolder})
|
||||
elif "model_type" in config:
|
||||
orig_class_name = "AutoModel"
|
||||
library = "transformers"
|
||||
load_config_kwargs.update({"subfolder": "" if subfolder is None else subfolder})
|
||||
else:
|
||||
raise ValueError(f"Couldn't find model associated with the config file at {pretrained_model_or_path}.")
|
||||
|
||||
from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates
|
||||
|
||||
model_cls, _ = get_class_obj_and_candidates(
|
||||
library_name=library,
|
||||
class_name=orig_class_name,
|
||||
importable_classes=ALL_IMPORTABLE_CLASSES,
|
||||
pipelines=None,
|
||||
is_pipeline_module=False,
|
||||
)
|
||||
|
||||
model_cls = getattr(library, orig_class_name, None)
|
||||
if model_cls is None:
|
||||
raise ValueError(f"AutoModel can't find a model linked to {orig_class_name}.")
|
||||
|
||||
|
||||
@@ -335,14 +335,14 @@ def get_class_obj_and_candidates(
|
||||
library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None
|
||||
):
|
||||
"""Simple helper method to retrieve class object of module as well as potential parent class objects"""
|
||||
component_folder = os.path.join(cache_dir, component_name)
|
||||
component_folder = os.path.join(cache_dir, component_name) if component_name and cache_dir else None
|
||||
|
||||
if is_pipeline_module:
|
||||
pipeline_module = getattr(pipelines, library_name)
|
||||
|
||||
class_obj = getattr(pipeline_module, class_name)
|
||||
class_candidates = dict.fromkeys(importable_classes.keys(), class_obj)
|
||||
elif os.path.isfile(os.path.join(component_folder, library_name + ".py")):
|
||||
elif component_folder and os.path.isfile(os.path.join(component_folder, library_name + ".py")):
|
||||
# load custom component
|
||||
class_obj = get_class_from_dynamic_module(
|
||||
component_folder, module_file=library_name + ".py", class_name=class_name
|
||||
|
||||
32
tests/models/test_models_auto.py
Normal file
32
tests/models/test_models_auto.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from transformers import CLIPTextModel, LongformerModel
|
||||
|
||||
from diffusers.models import AutoModel, UNet2DConditionModel
|
||||
|
||||
|
||||
class TestAutoModel(unittest.TestCase):
|
||||
@patch(
|
||||
"diffusers.models.AutoModel.load_config",
|
||||
side_effect=[EnvironmentError("File not found"), {"_class_name": "UNet2DConditionModel"}],
|
||||
)
|
||||
def test_load_from_config_diffusers_with_subfolder(self, mock_load_config):
|
||||
model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet")
|
||||
assert isinstance(model, UNet2DConditionModel)
|
||||
|
||||
@patch(
|
||||
"diffusers.models.AutoModel.load_config",
|
||||
side_effect=[EnvironmentError("File not found"), {"model_type": "clip_text_model"}],
|
||||
)
|
||||
def test_load_from_config_transformers_with_subfolder(self, mock_load_config):
|
||||
model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="text_encoder")
|
||||
assert isinstance(model, CLIPTextModel)
|
||||
|
||||
def test_load_from_config_without_subfolder(self):
|
||||
model = AutoModel.from_pretrained("hf-internal-testing/tiny-random-longformer")
|
||||
assert isinstance(model, LongformerModel)
|
||||
|
||||
def test_load_from_model_index(self):
|
||||
model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="text_encoder")
|
||||
assert isinstance(model, CLIPTextModel)
|
||||
Reference in New Issue
Block a user