mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Allow Automodel to support custom model code (#12353)
* update * update
This commit is contained in:
@@ -19,6 +19,7 @@ from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..utils import logging
|
||||
from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -114,6 +115,8 @@ class AutoModel(ConfigMixin):
|
||||
disable_mmap ('bool', *optional*, defaults to 'False'):
|
||||
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
|
||||
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
|
||||
trust_remote_cocde (`bool`, *optional*, defaults to `False`):
|
||||
Whether to trust remote code
|
||||
|
||||
<Tip>
|
||||
|
||||
@@ -140,22 +143,22 @@ class AutoModel(ConfigMixin):
|
||||
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
|
||||
```
|
||||
"""
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
token = kwargs.pop("token", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
trust_remote_code = kwargs.pop("trust_remote_code", False)
|
||||
|
||||
load_config_kwargs = {
|
||||
"cache_dir": cache_dir,
|
||||
"force_download": force_download,
|
||||
"proxies": proxies,
|
||||
"token": token,
|
||||
"local_files_only": local_files_only,
|
||||
"revision": revision,
|
||||
}
|
||||
hub_kwargs_names = [
|
||||
"cache_dir",
|
||||
"force_download",
|
||||
"local_files_only",
|
||||
"proxies",
|
||||
"resume_download",
|
||||
"revision",
|
||||
"token",
|
||||
]
|
||||
hub_kwargs = {name: kwargs.pop(name, None) for name in hub_kwargs_names}
|
||||
|
||||
# load_config_kwargs uses the same hub kwargs minus subfolder and resume_download
|
||||
load_config_kwargs = {k: v for k, v in hub_kwargs.items() if k not in ["subfolder", "resume_download"]}
|
||||
|
||||
library = None
|
||||
orig_class_name = None
|
||||
@@ -189,15 +192,35 @@ class AutoModel(ConfigMixin):
|
||||
else:
|
||||
raise ValueError(f"Couldn't find model associated with the config file at {pretrained_model_or_path}.")
|
||||
|
||||
from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates
|
||||
has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
|
||||
trust_remote_code = resolve_trust_remote_code(trust_remote_code, pretrained_model_or_path, has_remote_code)
|
||||
if not (has_remote_code and trust_remote_code):
|
||||
raise ValueError(
|
||||
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
|
||||
)
|
||||
|
||||
model_cls, _ = get_class_obj_and_candidates(
|
||||
library_name=library,
|
||||
class_name=orig_class_name,
|
||||
importable_classes=ALL_IMPORTABLE_CLASSES,
|
||||
pipelines=None,
|
||||
is_pipeline_module=False,
|
||||
)
|
||||
if has_remote_code and trust_remote_code:
|
||||
class_ref = config["auto_map"][cls.__name__]
|
||||
module_file, class_name = class_ref.split(".")
|
||||
module_file = module_file + ".py"
|
||||
model_cls = get_class_from_dynamic_module(
|
||||
pretrained_model_or_path,
|
||||
subfolder=subfolder,
|
||||
module_file=module_file,
|
||||
class_name=class_name,
|
||||
**hub_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates
|
||||
|
||||
model_cls, _ = get_class_obj_and_candidates(
|
||||
library_name=library,
|
||||
class_name=orig_class_name,
|
||||
importable_classes=ALL_IMPORTABLE_CLASSES,
|
||||
pipelines=None,
|
||||
is_pipeline_module=False,
|
||||
)
|
||||
|
||||
if model_cls is None:
|
||||
raise ValueError(f"AutoModel can't find a model linked to {orig_class_name}.")
|
||||
|
||||
@@ -247,6 +247,7 @@ def find_pipeline_class(loaded_module):
|
||||
def get_cached_module_file(
|
||||
pretrained_model_name_or_path: Union[str, os.PathLike],
|
||||
module_file: str,
|
||||
subfolder: Optional[str] = None,
|
||||
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
||||
force_download: bool = False,
|
||||
proxies: Optional[Dict[str, str]] = None,
|
||||
@@ -353,6 +354,7 @@ def get_cached_module_file(
|
||||
resolved_module_file = hf_hub_download(
|
||||
pretrained_model_name_or_path,
|
||||
module_file,
|
||||
subfolder=subfolder,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
@@ -410,6 +412,7 @@ def get_cached_module_file(
|
||||
get_cached_module_file(
|
||||
pretrained_model_name_or_path,
|
||||
f"{module_needed}.py",
|
||||
subfolder=subfolder,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
@@ -424,6 +427,7 @@ def get_cached_module_file(
|
||||
def get_class_from_dynamic_module(
|
||||
pretrained_model_name_or_path: Union[str, os.PathLike],
|
||||
module_file: str,
|
||||
subfolder: Optional[str] = None,
|
||||
class_name: Optional[str] = None,
|
||||
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
||||
force_download: bool = False,
|
||||
@@ -497,6 +501,7 @@ def get_class_from_dynamic_module(
|
||||
final_module = get_cached_module_file(
|
||||
pretrained_model_name_or_path,
|
||||
module_file,
|
||||
subfolder=subfolder,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
|
||||
Reference in New Issue
Block a user