mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Enhacne] Support maybe_raise_or_warn for peft (#5653)
* Support maybe_raise_or_warn for peft * fix by comment * unwrap function
This commit is contained in:
@@ -49,6 +49,7 @@ from ..utils import (
|
||||
get_class_from_dynamic_module,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_peft_available,
|
||||
is_torch_version,
|
||||
is_transformers_available,
|
||||
logging,
|
||||
@@ -270,6 +271,20 @@ def warn_deprecated_model_variant(pretrained_model_name_or_path, use_auth_token,
|
||||
)
|
||||
|
||||
|
||||
def _unwrap_model(model):
|
||||
"""Unwraps a model."""
|
||||
if is_compiled_module(model):
|
||||
model = model._orig_mod
|
||||
|
||||
if is_peft_available():
|
||||
from peft import PeftModel
|
||||
|
||||
if isinstance(model, PeftModel):
|
||||
model = model.base_model.model
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def maybe_raise_or_warn(
|
||||
library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module
|
||||
):
|
||||
@@ -287,9 +302,8 @@ def maybe_raise_or_warn(
|
||||
# Dynamo wraps the original model in a private class.
|
||||
# I didn't find a public API to get the original class.
|
||||
sub_model = passed_class_obj[name]
|
||||
model_cls = sub_model.__class__
|
||||
if is_compiled_module(sub_model):
|
||||
model_cls = sub_model._orig_mod.__class__
|
||||
unwrapped_sub_model = _unwrap_model(sub_model)
|
||||
model_cls = unwrapped_sub_model.__class__
|
||||
|
||||
if not issubclass(model_cls, expected_class_obj):
|
||||
raise ValueError(
|
||||
|
||||
Reference in New Issue
Block a user