From 9ccb82dc772fa58b194db996f22467fe7670e885 Mon Sep 17 00:00:00 2001 From: DN6 Date: Mon, 19 May 2025 22:30:24 +0530 Subject: [PATCH] update --- src/diffusers/pipelines/pipeline_utils.py | 38 +++++++++++++++++++++++ tests/pipelines/test_pipelines_common.py | 6 ++++ 2 files changed, 44 insertions(+) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 4348143b8e..f51d5edf73 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -70,6 +70,44 @@ from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_cre from ..utils.torch_utils import get_device, is_compiled_module +class DeprecatedPipelineMixin: + """ + A mixin that can be used to mark a pipeline as deprecated. + + Pipelines inheriting from this mixin will raise a warning when instantiated, indicating + that they are deprecated and won't receive updates past the specified version. + Tests will be skipped for pipelines that inherit from this mixin. + + Example usage: + ```python + class MyDeprecatedPipeline(DeprecatedPipelineMixin, DiffusionPipeline): + _last_supported_version = "0.20.0" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + ``` + """ + + # Override this in the inheriting class to specify the last version that will support this pipeline + _last_supported_version = None + + def __init__(self, *args, **kwargs): + # Call the parent class's __init__ method + super().__init__(*args, **kwargs) + + # Get the class name for the warning message + class_name = self.__class__.__name__ + + # Get the last supported version or use the current version if not specified + last_version = getattr(self.__class__, "_last_supported_version", __version__) + + # Raise a warning that this pipeline is deprecated + logging.warning( + f"The {class_name} pipeline is deprecated and will not receive updates after version {last_version}. " + f"Please consider switching to a maintained pipeline." + ) + + if is_torch_npu_available(): import torch_npu # noqa: F401 diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index af3a832d31..e24c632b06 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -1114,6 +1114,12 @@ class PipelineTesterMixin: torch._dynamo.reset() gc.collect() backend_empty_cache(torch_device) + + # Skip tests for pipelines that inherit from DeprecatedPipelineMixin + from diffusers.pipelines.pipeline_utils import DeprecatedPipelineMixin + if hasattr(self, "pipeline_class") and issubclass(self.pipeline_class, DeprecatedPipelineMixin): + import pytest + pytest.skip(f"Skipping tests for deprecated pipeline: {self.pipeline_class.__name__}") def tearDown(self): # clean up the VRAM after each test in case of CUDA runtime errors