mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Remote code] Add functionality to run remote models, schedulers, pipelines (#5472)
* upload custom remote poc * up * make style * finish * better name * Apply suggestions from code review * Update tests/pipelines/test_pipelines.py * more fixes * remove ipdb * more fixes * fix more * finish tests --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
committed by
GitHub
parent
5b448a5e5d
commit
cee1cd6e9c
@@ -485,10 +485,18 @@ class ConfigMixin:
|
||||
|
||||
# remove attributes from orig class that cannot be expected
|
||||
orig_cls_name = config_dict.pop("_class_name", cls.__name__)
|
||||
if orig_cls_name != cls.__name__ and hasattr(diffusers_library, orig_cls_name):
|
||||
if (
|
||||
isinstance(orig_cls_name, str)
|
||||
and orig_cls_name != cls.__name__
|
||||
and hasattr(diffusers_library, orig_cls_name)
|
||||
):
|
||||
orig_cls = getattr(diffusers_library, orig_cls_name)
|
||||
unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys
|
||||
config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig}
|
||||
elif not isinstance(orig_cls_name, str) and not isinstance(orig_cls_name, (list, tuple)):
|
||||
raise ValueError(
|
||||
"Make sure that the `_class_name` is of type string or list of string (for custom pipelines)."
|
||||
)
|
||||
|
||||
# remove private attributes
|
||||
config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
|
||||
|
||||
@@ -33,8 +33,6 @@ from packaging import version
|
||||
from requests.exceptions import HTTPError
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import diffusers
|
||||
|
||||
from .. import __version__
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
|
||||
@@ -305,13 +303,23 @@ def maybe_raise_or_warn(
|
||||
)
|
||||
|
||||
|
||||
def get_class_obj_and_candidates(library_name, class_name, importable_classes, pipelines, is_pipeline_module):
|
||||
def get_class_obj_and_candidates(
|
||||
library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None
|
||||
):
|
||||
"""Simple helper method to retrieve class object of module as well as potential parent class objects"""
|
||||
component_folder = os.path.join(cache_dir, component_name)
|
||||
|
||||
if is_pipeline_module:
|
||||
pipeline_module = getattr(pipelines, library_name)
|
||||
|
||||
class_obj = getattr(pipeline_module, class_name)
|
||||
class_candidates = {c: class_obj for c in importable_classes.keys()}
|
||||
elif os.path.isfile(os.path.join(component_folder, library_name + ".py")):
|
||||
# load custom component
|
||||
class_obj = get_class_from_dynamic_module(
|
||||
component_folder, module_file=library_name + ".py", class_name=class_name
|
||||
)
|
||||
class_candidates = {c: class_obj for c in importable_classes.keys()}
|
||||
else:
|
||||
# else we just import it from the library.
|
||||
library = importlib.import_module(library_name)
|
||||
@@ -323,7 +331,15 @@ def get_class_obj_and_candidates(library_name, class_name, importable_classes, p
|
||||
|
||||
|
||||
def _get_pipeline_class(
|
||||
class_obj, config, load_connected_pipeline=False, custom_pipeline=None, cache_dir=None, revision=None
|
||||
class_obj,
|
||||
config,
|
||||
load_connected_pipeline=False,
|
||||
custom_pipeline=None,
|
||||
repo_id=None,
|
||||
hub_revision=None,
|
||||
class_name=None,
|
||||
cache_dir=None,
|
||||
revision=None,
|
||||
):
|
||||
if custom_pipeline is not None:
|
||||
if custom_pipeline.endswith(".py"):
|
||||
@@ -331,11 +347,19 @@ def _get_pipeline_class(
|
||||
# decompose into folder & file
|
||||
file_name = path.name
|
||||
custom_pipeline = path.parent.absolute()
|
||||
elif repo_id is not None:
|
||||
file_name = f"{custom_pipeline}.py"
|
||||
custom_pipeline = repo_id
|
||||
else:
|
||||
file_name = CUSTOM_PIPELINE_FILE_NAME
|
||||
|
||||
return get_class_from_dynamic_module(
|
||||
custom_pipeline, module_file=file_name, cache_dir=cache_dir, revision=revision
|
||||
custom_pipeline,
|
||||
module_file=file_name,
|
||||
class_name=class_name,
|
||||
repo_id=repo_id,
|
||||
cache_dir=cache_dir,
|
||||
revision=revision if hub_revision is None else hub_revision,
|
||||
)
|
||||
|
||||
if class_obj != DiffusionPipeline:
|
||||
@@ -383,11 +407,18 @@ def load_sub_model(
|
||||
variant: str,
|
||||
low_cpu_mem_usage: bool,
|
||||
cached_folder: Union[str, os.PathLike],
|
||||
revision: str = None,
|
||||
):
|
||||
"""Helper method to load the module `name` from `library_name` and `class_name`"""
|
||||
# retrieve class candidates
|
||||
class_obj, class_candidates = get_class_obj_and_candidates(
|
||||
library_name, class_name, importable_classes, pipelines, is_pipeline_module
|
||||
library_name,
|
||||
class_name,
|
||||
importable_classes,
|
||||
pipelines,
|
||||
is_pipeline_module,
|
||||
component_name=name,
|
||||
cache_dir=cached_folder,
|
||||
)
|
||||
|
||||
load_method_name = None
|
||||
@@ -414,14 +445,15 @@ def load_sub_model(
|
||||
load_method = getattr(class_obj, load_method_name)
|
||||
|
||||
# add kwargs to loading method
|
||||
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
||||
loading_kwargs = {}
|
||||
if issubclass(class_obj, torch.nn.Module):
|
||||
loading_kwargs["torch_dtype"] = torch_dtype
|
||||
if issubclass(class_obj, diffusers.OnnxRuntimeModel):
|
||||
if issubclass(class_obj, diffusers_module.OnnxRuntimeModel):
|
||||
loading_kwargs["provider"] = provider
|
||||
loading_kwargs["sess_options"] = sess_options
|
||||
|
||||
is_diffusers_model = issubclass(class_obj, diffusers.ModelMixin)
|
||||
is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin)
|
||||
|
||||
if is_transformers_available():
|
||||
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
|
||||
@@ -501,7 +533,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
def register_modules(self, **kwargs):
|
||||
# import it here to avoid circular import
|
||||
from diffusers import pipelines
|
||||
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
||||
pipelines = getattr(diffusers_module, "pipelines")
|
||||
|
||||
for name, module in kwargs.items():
|
||||
# retrieve library
|
||||
@@ -1080,11 +1113,21 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
# 3. Load the pipeline class, if using custom module then load it from the hub
|
||||
# if we load from explicit class, let's use it
|
||||
custom_class_name = None
|
||||
if os.path.isfile(os.path.join(cached_folder, f"{custom_pipeline}.py")):
|
||||
custom_pipeline = os.path.join(cached_folder, f"{custom_pipeline}.py")
|
||||
elif isinstance(config_dict["_class_name"], (list, tuple)) and os.path.isfile(
|
||||
os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py")
|
||||
):
|
||||
custom_pipeline = os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py")
|
||||
custom_class_name = config_dict["_class_name"][1]
|
||||
|
||||
pipeline_class = _get_pipeline_class(
|
||||
cls,
|
||||
config_dict,
|
||||
load_connected_pipeline=load_connected_pipeline,
|
||||
custom_pipeline=custom_pipeline,
|
||||
class_name=custom_class_name,
|
||||
cache_dir=cache_dir,
|
||||
revision=custom_revision,
|
||||
)
|
||||
@@ -1223,6 +1266,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
variant=variant,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
cached_folder=cached_folder,
|
||||
revision=revision,
|
||||
)
|
||||
logger.info(
|
||||
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
|
||||
@@ -1542,6 +1586,10 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is
|
||||
`False` for non-ONNX pipelines and `True` for ONNX pipelines. ONNX weights include both files ending
|
||||
with `.onnx` and `.pb`.
|
||||
trust_remote_code (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to allow for custom pipelines and components defined on the Hub in their own files. This
|
||||
option should only be set to `True` for repositories you trust and in which you have read the code, as
|
||||
it will execute code present on the Hub on your local machine.
|
||||
|
||||
Returns:
|
||||
`os.PathLike`:
|
||||
@@ -1569,6 +1617,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
use_onnx = kwargs.pop("use_onnx", None)
|
||||
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
|
||||
trust_remote_code = kwargs.pop("trust_remote_code", False)
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
@@ -1604,15 +1653,34 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
)
|
||||
|
||||
config_dict = cls._dict_from_json_file(config_file)
|
||||
|
||||
ignore_filenames = config_dict.pop("_ignore_files", [])
|
||||
|
||||
# retrieve all folder_names that contain relevant files
|
||||
folder_names = [k for k, v in config_dict.items() if isinstance(v, list)]
|
||||
folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"]
|
||||
|
||||
filenames = {sibling.rfilename for sibling in info.siblings}
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
||||
|
||||
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
||||
pipelines = getattr(diffusers_module, "pipelines")
|
||||
|
||||
# optionally create a custom component <> custom file mapping
|
||||
custom_components = {}
|
||||
for component in folder_names:
|
||||
module_candidate = config_dict[component][0]
|
||||
|
||||
if module_candidate is None:
|
||||
continue
|
||||
|
||||
candidate_file = os.path.join(component, module_candidate + ".py")
|
||||
|
||||
if candidate_file in filenames:
|
||||
custom_components[component] = module_candidate
|
||||
elif module_candidate not in LOADABLE_CLASSES and not hasattr(pipelines, module_candidate):
|
||||
raise ValueError(
|
||||
f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'."
|
||||
)
|
||||
|
||||
if len(variant_filenames) == 0 and variant is not None:
|
||||
deprecation_message = (
|
||||
f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
|
||||
@@ -1636,12 +1704,21 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names}
|
||||
|
||||
custom_class_name = None
|
||||
if custom_pipeline is None and isinstance(config_dict["_class_name"], (list, tuple)):
|
||||
custom_pipeline = config_dict["_class_name"][0]
|
||||
custom_class_name = config_dict["_class_name"][1]
|
||||
|
||||
# all filenames compatible with variant will be added
|
||||
allow_patterns = list(model_filenames)
|
||||
|
||||
# allow all patterns from non-model folders
|
||||
# this enables downloading schedulers, tokenizers, ...
|
||||
allow_patterns += [f"{k}/*" for k in folder_names if k not in model_folder_names]
|
||||
# add custom component files
|
||||
allow_patterns += [f"{k}/{f}.py" for k, f in custom_components.items()]
|
||||
# add custom pipeline file
|
||||
allow_patterns += [f"{custom_pipeline}.py"] if f"{custom_pipeline}.py" in filenames else []
|
||||
# also allow downloading config.json files with the model
|
||||
allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names]
|
||||
|
||||
@@ -1652,12 +1729,32 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
CUSTOM_PIPELINE_FILE_NAME,
|
||||
]
|
||||
|
||||
load_pipe_from_hub = custom_pipeline is not None and f"{custom_pipeline}.py" in filenames
|
||||
load_components_from_hub = len(custom_components) > 0
|
||||
|
||||
if load_pipe_from_hub and not trust_remote_code:
|
||||
raise ValueError(
|
||||
f"The repository for {pretrained_model_name} contains custom code in {custom_pipeline}.py which must be executed to correctly "
|
||||
f"load the model. You can inspect the repository content at https://hf.co/{pretrained_model_name}/blob/main/{custom_pipeline}.py.\n"
|
||||
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
|
||||
)
|
||||
|
||||
if load_components_from_hub and not trust_remote_code:
|
||||
raise ValueError(
|
||||
f"The repository for {pretrained_model_name} contains custom code in {'.py, '.join([os.path.join(k, v) for k,v in custom_components.items()])} which must be executed to correctly "
|
||||
f"load the model. You can inspect the repository content at {', '.join([f'https://hf.co/{pretrained_model_name}/{k}/{v}.py' for k,v in custom_components.items()])}.\n"
|
||||
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
|
||||
)
|
||||
|
||||
# retrieve passed components that should not be downloaded
|
||||
pipeline_class = _get_pipeline_class(
|
||||
cls,
|
||||
config_dict,
|
||||
load_connected_pipeline=load_connected_pipeline,
|
||||
custom_pipeline=custom_pipeline,
|
||||
repo_id=pretrained_model_name if load_pipe_from_hub else None,
|
||||
hub_revision=revision,
|
||||
class_name=custom_class_name,
|
||||
cache_dir=cache_dir,
|
||||
revision=custom_revision,
|
||||
)
|
||||
@@ -1754,9 +1851,10 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
# retrieve pipeline class from local file
|
||||
cls_name = cls.load_config(os.path.join(cached_folder, "model_index.json")).get("_class_name", None)
|
||||
cls_name = cls_name[4:] if cls_name.startswith("Flax") else cls_name
|
||||
cls_name = cls_name[4:] if isinstance(cls_name, str) and cls_name.startswith("Flax") else cls_name
|
||||
|
||||
pipeline_class = getattr(diffusers, cls_name, None)
|
||||
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
||||
pipeline_class = getattr(diffusers_module, cls_name, None) if isinstance(cls_name, str) else None
|
||||
|
||||
if pipeline_class is not None and pipeline_class._load_connected_pipes:
|
||||
modelcard = ModelCard.load(os.path.join(cached_folder, "README.md"))
|
||||
|
||||
@@ -862,6 +862,58 @@ class CustomPipelineTests(unittest.TestCase):
|
||||
# compare output to https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline/blob/main/pipeline.py#L102
|
||||
assert output_str == "This is a test"
|
||||
|
||||
def test_remote_components(self):
|
||||
# make sure that trust remote code has to be passed
|
||||
with self.assertRaises(ValueError):
|
||||
pipeline = DiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sdxl-custom-components")
|
||||
|
||||
# Check that only loading custom componets "my_unet", "my_scheduler" works
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"hf-internal-testing/tiny-sdxl-custom-components", trust_remote_code=True
|
||||
)
|
||||
|
||||
assert pipeline.config.unet == ("diffusers_modules.local.my_unet_model", "MyUNetModel")
|
||||
assert pipeline.config.scheduler == ("diffusers_modules.local.my_scheduler", "MyScheduler")
|
||||
assert pipeline.__class__.__name__ == "StableDiffusionXLPipeline"
|
||||
|
||||
pipeline = pipeline.to(torch_device)
|
||||
images = pipeline("test", num_inference_steps=2, output_type="np")[0]
|
||||
|
||||
assert images.shape == (1, 64, 64, 3)
|
||||
|
||||
# Check that only loading custom componets "my_unet", "my_scheduler" and explicit custom pipeline works
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"hf-internal-testing/tiny-sdxl-custom-components", custom_pipeline="my_pipeline", trust_remote_code=True
|
||||
)
|
||||
|
||||
assert pipeline.config.unet == ("diffusers_modules.local.my_unet_model", "MyUNetModel")
|
||||
assert pipeline.config.scheduler == ("diffusers_modules.local.my_scheduler", "MyScheduler")
|
||||
assert pipeline.__class__.__name__ == "MyPipeline"
|
||||
|
||||
pipeline = pipeline.to(torch_device)
|
||||
images = pipeline("test", num_inference_steps=2, output_type="np")[0]
|
||||
|
||||
assert images.shape == (1, 64, 64, 3)
|
||||
|
||||
def test_remote_auto_custom_pipe(self):
|
||||
# make sure that trust remote code has to be passed
|
||||
with self.assertRaises(ValueError):
|
||||
pipeline = DiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sdxl-custom-all")
|
||||
|
||||
# Check that only loading custom componets "my_unet", "my_scheduler" and auto custom pipeline works
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"hf-internal-testing/tiny-sdxl-custom-all", trust_remote_code=True
|
||||
)
|
||||
|
||||
assert pipeline.config.unet == ("diffusers_modules.local.my_unet_model", "MyUNetModel")
|
||||
assert pipeline.config.scheduler == ("diffusers_modules.local.my_scheduler", "MyScheduler")
|
||||
assert pipeline.__class__.__name__ == "MyPipeline"
|
||||
|
||||
pipeline = pipeline.to(torch_device)
|
||||
images = pipeline("test", num_inference_steps=2, output_type="np")[0]
|
||||
|
||||
assert images.shape == (1, 64, 64, 3)
|
||||
|
||||
def test_local_custom_pipeline_repo(self):
|
||||
local_custom_pipeline_path = get_tests_dir("fixtures/custom_pipeline")
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
|
||||
Reference in New Issue
Block a user