mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
load pipeline from source
This commit is contained in:
@@ -34,6 +34,7 @@ logger = logging.get_logger(__name__)
|
||||
LOADABLE_CLASSES = {
|
||||
"diffusers": {
|
||||
"ModelMixin": ["save_pretrained", "from_pretrained"],
|
||||
"DiffusionPipeline": ["save_pretrained", "from_pretrained"],
|
||||
"GaussianDDPMScheduler": ["save_config", "from_config"],
|
||||
"ClassifierFreeGuidanceScheduler": ["save_config", "from_config"],
|
||||
"GlideDDIMScheduler": ["save_config", "from_config"],
|
||||
@@ -74,7 +75,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
# set models
|
||||
setattr(self, name, module)
|
||||
|
||||
register_dict = {"_module": self.__module__.split(".")[-1] + ".py"}
|
||||
register_dict = {"_module": self.__module__.split(".")[-1]}
|
||||
self.register(**register_dict)
|
||||
|
||||
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
|
||||
@@ -133,19 +134,21 @@ class DiffusionPipeline(ConfigMixin):
|
||||
config_dict = cls.get_config_dict(cached_folder)
|
||||
|
||||
# 2. Get class name and module candidates to load custom models
|
||||
class_name_ = config_dict["_class_name"]
|
||||
module_candidate = config_dict["_module"]
|
||||
module_candidate_name = module_candidate.replace(".py", "")
|
||||
module_candidate_name = config_dict["_module"]
|
||||
module_candidate = module_candidate_name + ".py"
|
||||
|
||||
# 3. Load the pipeline class, if using custom module then load it from the hub
|
||||
# if we load from explicit class, let's use it
|
||||
if cls != DiffusionPipeline:
|
||||
pipeline_class = cls
|
||||
else:
|
||||
diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
|
||||
pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
|
||||
|
||||
# (TODO - we should allow to load custom pipelines
|
||||
# else we need to load the correct module from the Hub
|
||||
class_name_ = config_dict["_class_name"]
|
||||
module = module_candidate
|
||||
pipeline_class = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder)
|
||||
# module = module_candidate
|
||||
# pipeline_class = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder)
|
||||
|
||||
init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user