1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Enhacne] Support maybe_raise_or_warn for peft (#5653)

* Support maybe_raise_or_warn for peft

* fix by comment

* unwrap function
This commit is contained in:
takuoko
2023-11-14 19:40:35 +09:00
committed by GitHub
parent c9c5436c94
commit bfe94a3993

View File

@@ -49,6 +49,7 @@ from ..utils import (
get_class_from_dynamic_module,
is_accelerate_available,
is_accelerate_version,
is_peft_available,
is_torch_version,
is_transformers_available,
logging,
@@ -270,6 +271,20 @@ def warn_deprecated_model_variant(pretrained_model_name_or_path, use_auth_token,
)
def _unwrap_model(model):
"""Unwraps a model."""
if is_compiled_module(model):
model = model._orig_mod
if is_peft_available():
from peft import PeftModel
if isinstance(model, PeftModel):
model = model.base_model.model
return model
def maybe_raise_or_warn(
library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module
):
@@ -287,9 +302,8 @@ def maybe_raise_or_warn(
# Dynamo wraps the original model in a private class.
# I didn't find a public API to get the original class.
sub_model = passed_class_obj[name]
model_cls = sub_model.__class__
if is_compiled_module(sub_model):
model_cls = sub_model._orig_mod.__class__
unwrapped_sub_model = _unwrap_model(sub_model)
model_cls = unwrapped_sub_model.__class__
if not issubclass(model_cls, expected_class_obj):
raise ValueError(