1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
DN6
2025-10-23 08:18:57 +05:30
parent 6098e45b36
commit a4415e2bd6
3 changed files with 12 additions and 6 deletions

View File

@@ -33,7 +33,7 @@ from ..utils import (
ONNX_WEIGHTS_NAME,
SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_NAME,
_should_remap_transformers_class,
_maybe_remap_transformers_class,
deprecate,
get_class_from_dynamic_module,
is_accelerate_available,
@@ -360,7 +360,7 @@ def maybe_raise_or_warn(
# Handle deprecated Transformers classes
if library_name == "transformers":
class_name = _should_remap_transformers_class(class_name) or class_name
class_name = _maybe_remap_transformers_class(class_name) or class_name
class_obj = getattr(library, class_name)
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
@@ -399,7 +399,7 @@ def simple_get_class_obj(library_name, class_name):
# Handle deprecated Transformers classes
if library_name == "transformers":
class_name = _should_remap_transformers_class(class_name) or class_name
class_name = _maybe_remap_transformers_class(class_name) or class_name
class_obj = getattr(library, class_name)
@@ -429,7 +429,7 @@ def get_class_obj_and_candidates(
# Handle deprecated Transformers classes
if library_name == "transformers":
class_name = _should_remap_transformers_class(class_name) or class_name
class_name = _maybe_remap_transformers_class(class_name) or class_name
class_obj = getattr(library, class_name)
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}

View File

@@ -38,7 +38,7 @@ from .constants import (
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
)
from .deprecation_utils import _should_remap_transformers_class, deprecate
from .deprecation_utils import _maybe_remap_transformers_class, deprecate
from .doc_utils import replace_example_docstring
from .dynamic_modules_utils import get_class_from_dynamic_module
from .export_utils import export_to_gif, export_to_obj, export_to_ply, export_to_video

View File

@@ -4,6 +4,10 @@ from typing import Any, Dict, Optional, Union
from packaging import version
from ..utils import logging
logger = logging.get_logger(__name__)
# Mapping for deprecated Transformers classes to their replacements
# This is used to handle models that reference deprecated class names in their configs
@@ -22,7 +26,7 @@ _TRANSFORMERS_CLASS_REMAPPING = {
}
def _should_remap_transformers_class(class_name: str) -> Optional[str]:
def _maybe_remap_transformers_class(class_name: str) -> Optional[str]:
"""
Check if a Transformers class should be remapped to a newer version.
@@ -42,6 +46,8 @@ def _should_remap_transformers_class(class_name: str) -> Optional[str]:
# Only remap if the transformers version meets the requirement
if is_transformers_version(operation, required_version):
new_class = mapping["new_class"]
logger.warning(f"{class_name} appears to have been deprecated in transformers. Using {new_class} instead.")
return mapping["new_class"]
return None