diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 026cd7b2dd..c4085f6f20 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -33,7 +33,7 @@ from ..utils import ( ONNX_WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, - _should_remap_transformers_class, + _maybe_remap_transformers_class, deprecate, get_class_from_dynamic_module, is_accelerate_available, @@ -360,7 +360,7 @@ def maybe_raise_or_warn( # Handle deprecated Transformers classes if library_name == "transformers": - class_name = _should_remap_transformers_class(class_name) or class_name + class_name = _maybe_remap_transformers_class(class_name) or class_name class_obj = getattr(library, class_name) class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} @@ -399,7 +399,7 @@ def simple_get_class_obj(library_name, class_name): # Handle deprecated Transformers classes if library_name == "transformers": - class_name = _should_remap_transformers_class(class_name) or class_name + class_name = _maybe_remap_transformers_class(class_name) or class_name class_obj = getattr(library, class_name) @@ -429,7 +429,7 @@ def get_class_obj_and_candidates( # Handle deprecated Transformers classes if library_name == "transformers": - class_name = _should_remap_transformers_class(class_name) or class_name + class_name = _maybe_remap_transformers_class(class_name) or class_name class_obj = getattr(library, class_name) class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index a36eef9589..d8e1a55401 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -38,7 +38,7 @@ from .constants import ( WEIGHTS_INDEX_NAME, WEIGHTS_NAME, ) -from .deprecation_utils import _should_remap_transformers_class, deprecate +from .deprecation_utils import _maybe_remap_transformers_class, deprecate from .doc_utils import replace_example_docstring from .dynamic_modules_utils import get_class_from_dynamic_module from .export_utils import export_to_gif, export_to_obj, export_to_ply, export_to_video diff --git a/src/diffusers/utils/deprecation_utils.py b/src/diffusers/utils/deprecation_utils.py index 5f276382b6..d76623541b 100644 --- a/src/diffusers/utils/deprecation_utils.py +++ b/src/diffusers/utils/deprecation_utils.py @@ -4,6 +4,10 @@ from typing import Any, Dict, Optional, Union from packaging import version +from ..utils import logging + + +logger = logging.get_logger(__name__) # Mapping for deprecated Transformers classes to their replacements # This is used to handle models that reference deprecated class names in their configs @@ -22,7 +26,7 @@ _TRANSFORMERS_CLASS_REMAPPING = { } -def _should_remap_transformers_class(class_name: str) -> Optional[str]: +def _maybe_remap_transformers_class(class_name: str) -> Optional[str]: """ Check if a Transformers class should be remapped to a newer version. @@ -42,6 +46,8 @@ def _should_remap_transformers_class(class_name: str) -> Optional[str]: # Only remap if the transformers version meets the requirement if is_transformers_version(operation, required_version): + new_class = mapping["new_class"] + logger.warning(f"{class_name} appears to have been deprecated in transformers. Using {new_class} instead.") return mapping["new_class"] return None