From a216b0bb7fbf713e348edb030e865c2703965bd2 Mon Sep 17 00:00:00 2001 From: Luo Chaofan <79003314+fkcptlst@users.noreply.github.com> Date: Fri, 28 Jun 2024 19:45:46 +0800 Subject: [PATCH] fix: ValueError when using FromOriginalModelMixin in subclasses #8440 (#8454) * fix: ValueError when using FromOriginalModelMixin in subclasses #8440 (cherry picked from commit 92859978436acf844760fc0e992165b489d0180a) * Update src/diffusers/loaders/single_file_model.py Co-authored-by: Dhruv Nair * Update single_file_model.py * Update single_file_model.py --------- Co-authored-by: Dhruv Nair Co-authored-by: Sayak Paul --- src/diffusers/loaders/single_file_model.py | 23 +++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index f537a3f449..dbcf081b1f 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import importlib import inspect import re from contextlib import nullcontext @@ -72,6 +73,17 @@ SINGLE_FILE_LOADABLE_CLASSES = { } +def _get_single_file_loadable_mapping_class(cls): + diffusers_module = importlib.import_module(__name__.split(".")[0]) + for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES: + loadable_class = getattr(diffusers_module, loadable_class_str) + + if issubclass(cls, loadable_class): + return loadable_class_str + + return None + + def _get_mapping_function_kwargs(mapping_fn, **kwargs): parameters = inspect.signature(mapping_fn).parameters @@ -149,8 +161,9 @@ class FromOriginalModelMixin: ``` """ - class_name = cls.__name__ - if class_name not in SINGLE_FILE_LOADABLE_CLASSES: + mapping_class_name = _get_single_file_loadable_mapping_class(cls) + # if class_name not in SINGLE_FILE_LOADABLE_CLASSES: + if mapping_class_name is None: raise ValueError( f"FromOriginalModelMixin is currently only compatible with {', '.join(SINGLE_FILE_LOADABLE_CLASSES.keys())}" ) @@ -195,7 +208,7 @@ class FromOriginalModelMixin: revision=revision, ) - mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[class_name] + mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name] checkpoint_mapping_fn = mapping_functions["checkpoint_mapping_fn"] if original_config: @@ -207,7 +220,7 @@ class FromOriginalModelMixin: if config_mapping_fn is None: raise ValueError( ( - f"`original_config` has been provided for {class_name} but no mapping function" + f"`original_config` has been provided for {mapping_class_name} but no mapping function" "was found to convert the original config to a Diffusers config in" "`diffusers.loaders.single_file_utils`" ) @@ -267,7 +280,7 @@ class FromOriginalModelMixin: ) if not diffusers_format_checkpoint: raise SingleFileComponentError( - f"Failed to load {class_name}. Weights for this component appear to be missing in the checkpoint." + f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint." ) ctx = init_empty_weights if is_accelerate_available() else nullcontext