1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

fix: ValueError when using FromOriginalModelMixin in subclasses #8440 (#8454)

* fix: ValueError when using FromOriginalModelMixin in subclasses #8440

(cherry picked from commit 9285997843)

* Update src/diffusers/loaders/single_file_model.py

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* Update single_file_model.py

* Update single_file_model.py

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
Luo Chaofan
2024-06-28 19:45:46 +08:00
committed by GitHub
parent 150142c537
commit a216b0bb7f

View File

@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import inspect
import re
from contextlib import nullcontext
@@ -72,6 +73,17 @@ SINGLE_FILE_LOADABLE_CLASSES = {
}
def _get_single_file_loadable_mapping_class(cls):
diffusers_module = importlib.import_module(__name__.split(".")[0])
for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES:
loadable_class = getattr(diffusers_module, loadable_class_str)
if issubclass(cls, loadable_class):
return loadable_class_str
return None
def _get_mapping_function_kwargs(mapping_fn, **kwargs):
parameters = inspect.signature(mapping_fn).parameters
@@ -149,8 +161,9 @@ class FromOriginalModelMixin:
```
"""
class_name = cls.__name__
if class_name not in SINGLE_FILE_LOADABLE_CLASSES:
mapping_class_name = _get_single_file_loadable_mapping_class(cls)
# if class_name not in SINGLE_FILE_LOADABLE_CLASSES:
if mapping_class_name is None:
raise ValueError(
f"FromOriginalModelMixin is currently only compatible with {', '.join(SINGLE_FILE_LOADABLE_CLASSES.keys())}"
)
@@ -195,7 +208,7 @@ class FromOriginalModelMixin:
revision=revision,
)
mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[class_name]
mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name]
checkpoint_mapping_fn = mapping_functions["checkpoint_mapping_fn"]
if original_config:
@@ -207,7 +220,7 @@ class FromOriginalModelMixin:
if config_mapping_fn is None:
raise ValueError(
(
f"`original_config` has been provided for {class_name} but no mapping function"
f"`original_config` has been provided for {mapping_class_name} but no mapping function"
"was found to convert the original config to a Diffusers config in"
"`diffusers.loaders.single_file_utils`"
)
@@ -267,7 +280,7 @@ class FromOriginalModelMixin:
)
if not diffusers_format_checkpoint:
raise SingleFileComponentError(
f"Failed to load {class_name}. Weights for this component appear to be missing in the checkpoint."
f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
)
ctx = init_empty_weights if is_accelerate_available() else nullcontext