1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Feature] AutoModel can load components using model_index.json (#11401)

* update

* update

* update

* update

* addressed PR comments

* update

* addressed PR comments

* added tests

* addressed PR comments

* updates

* update

* addressed PR comments

* update

* fix style

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
Ishan Modi
2025-05-26 14:06:36 +05:30
committed by GitHub
parent 049082e013
commit f64fa9492d
3 changed files with 78 additions and 8 deletions

View File

@@ -12,13 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import os
from typing import Optional, Union
from huggingface_hub.utils import validate_hf_hub_args
from ..configuration_utils import ConfigMixin
from ..utils import logging
logger = logging.get_logger(__name__)
class AutoModel(ConfigMixin):
@@ -152,15 +155,50 @@ class AutoModel(ConfigMixin):
"token": token,
"local_files_only": local_files_only,
"revision": revision,
"subfolder": subfolder,
}
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
orig_class_name = config["_class_name"]
library = None
orig_class_name = None
library = importlib.import_module("diffusers")
# Always attempt to fetch model_index.json first
try:
cls.config_name = "model_index.json"
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
if subfolder is not None and subfolder in config:
library, orig_class_name = config[subfolder]
load_config_kwargs.update({"subfolder": subfolder})
except EnvironmentError as e:
logger.debug(e)
# Unable to load from model_index.json so fallback to loading from config
if library is None and orig_class_name is None:
cls.config_name = "config.json"
config = cls.load_config(pretrained_model_or_path, subfolder=subfolder, **load_config_kwargs)
if "_class_name" in config:
# If we find a class name in the config, we can try to load the model as a diffusers model
orig_class_name = config["_class_name"]
library = "diffusers"
load_config_kwargs.update({"subfolder": subfolder})
elif "model_type" in config:
orig_class_name = "AutoModel"
library = "transformers"
load_config_kwargs.update({"subfolder": "" if subfolder is None else subfolder})
else:
raise ValueError(f"Couldn't find model associated with the config file at {pretrained_model_or_path}.")
from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates
model_cls, _ = get_class_obj_and_candidates(
library_name=library,
class_name=orig_class_name,
importable_classes=ALL_IMPORTABLE_CLASSES,
pipelines=None,
is_pipeline_module=False,
)
model_cls = getattr(library, orig_class_name, None)
if model_cls is None:
raise ValueError(f"AutoModel can't find a model linked to {orig_class_name}.")

View File

@@ -335,14 +335,14 @@ def get_class_obj_and_candidates(
library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None
):
"""Simple helper method to retrieve class object of module as well as potential parent class objects"""
component_folder = os.path.join(cache_dir, component_name)
component_folder = os.path.join(cache_dir, component_name) if component_name and cache_dir else None
if is_pipeline_module:
pipeline_module = getattr(pipelines, library_name)
class_obj = getattr(pipeline_module, class_name)
class_candidates = dict.fromkeys(importable_classes.keys(), class_obj)
elif os.path.isfile(os.path.join(component_folder, library_name + ".py")):
elif component_folder and os.path.isfile(os.path.join(component_folder, library_name + ".py")):
# load custom component
class_obj = get_class_from_dynamic_module(
component_folder, module_file=library_name + ".py", class_name=class_name

View File

@@ -0,0 +1,32 @@
import unittest
from unittest.mock import patch
from transformers import CLIPTextModel, LongformerModel
from diffusers.models import AutoModel, UNet2DConditionModel
class TestAutoModel(unittest.TestCase):
@patch(
"diffusers.models.AutoModel.load_config",
side_effect=[EnvironmentError("File not found"), {"_class_name": "UNet2DConditionModel"}],
)
def test_load_from_config_diffusers_with_subfolder(self, mock_load_config):
model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet")
assert isinstance(model, UNet2DConditionModel)
@patch(
"diffusers.models.AutoModel.load_config",
side_effect=[EnvironmentError("File not found"), {"model_type": "clip_text_model"}],
)
def test_load_from_config_transformers_with_subfolder(self, mock_load_config):
model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="text_encoder")
assert isinstance(model, CLIPTextModel)
def test_load_from_config_without_subfolder(self):
model = AutoModel.from_pretrained("hf-internal-testing/tiny-random-longformer")
assert isinstance(model, LongformerModel)
def test_load_from_model_index(self):
model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="text_encoder")
assert isinstance(model, CLIPTextModel)