mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
* fix: ValueError when using FromOriginalModelMixin in subclasses #8440
(cherry picked from commit 9285997843)
* Update src/diffusers/loaders/single_file_model.py
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
* Update single_file_model.py
* Update single_file_model.py
---------
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -11,6 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import inspect
|
||||
import re
|
||||
from contextlib import nullcontext
|
||||
@@ -72,6 +73,17 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
||||
}
|
||||
|
||||
|
||||
def _get_single_file_loadable_mapping_class(cls):
|
||||
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
||||
for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES:
|
||||
loadable_class = getattr(diffusers_module, loadable_class_str)
|
||||
|
||||
if issubclass(cls, loadable_class):
|
||||
return loadable_class_str
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _get_mapping_function_kwargs(mapping_fn, **kwargs):
|
||||
parameters = inspect.signature(mapping_fn).parameters
|
||||
|
||||
@@ -149,8 +161,9 @@ class FromOriginalModelMixin:
|
||||
```
|
||||
"""
|
||||
|
||||
class_name = cls.__name__
|
||||
if class_name not in SINGLE_FILE_LOADABLE_CLASSES:
|
||||
mapping_class_name = _get_single_file_loadable_mapping_class(cls)
|
||||
# if class_name not in SINGLE_FILE_LOADABLE_CLASSES:
|
||||
if mapping_class_name is None:
|
||||
raise ValueError(
|
||||
f"FromOriginalModelMixin is currently only compatible with {', '.join(SINGLE_FILE_LOADABLE_CLASSES.keys())}"
|
||||
)
|
||||
@@ -195,7 +208,7 @@ class FromOriginalModelMixin:
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[class_name]
|
||||
mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name]
|
||||
|
||||
checkpoint_mapping_fn = mapping_functions["checkpoint_mapping_fn"]
|
||||
if original_config:
|
||||
@@ -207,7 +220,7 @@ class FromOriginalModelMixin:
|
||||
if config_mapping_fn is None:
|
||||
raise ValueError(
|
||||
(
|
||||
f"`original_config` has been provided for {class_name} but no mapping function"
|
||||
f"`original_config` has been provided for {mapping_class_name} but no mapping function"
|
||||
"was found to convert the original config to a Diffusers config in"
|
||||
"`diffusers.loaders.single_file_utils`"
|
||||
)
|
||||
@@ -267,7 +280,7 @@ class FromOriginalModelMixin:
|
||||
)
|
||||
if not diffusers_format_checkpoint:
|
||||
raise SingleFileComponentError(
|
||||
f"Failed to load {class_name}. Weights for this component appear to be missing in the checkpoint."
|
||||
f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
|
||||
)
|
||||
|
||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
|
||||
Reference in New Issue
Block a user