1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
DN6
2025-05-19 22:30:24 +05:30
parent 915c537891
commit 9ccb82dc77
2 changed files with 44 additions and 0 deletions

View File

@@ -70,6 +70,44 @@ from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_cre
from ..utils.torch_utils import get_device, is_compiled_module
class DeprecatedPipelineMixin:
"""
A mixin that can be used to mark a pipeline as deprecated.
Pipelines inheriting from this mixin will raise a warning when instantiated, indicating
that they are deprecated and won't receive updates past the specified version.
Tests will be skipped for pipelines that inherit from this mixin.
Example usage:
```python
class MyDeprecatedPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
_last_supported_version = "0.20.0"
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
```
"""
# Override this in the inheriting class to specify the last version that will support this pipeline
_last_supported_version = None
def __init__(self, *args, **kwargs):
# Call the parent class's __init__ method
super().__init__(*args, **kwargs)
# Get the class name for the warning message
class_name = self.__class__.__name__
# Get the last supported version or use the current version if not specified
last_version = getattr(self.__class__, "_last_supported_version", __version__)
# Raise a warning that this pipeline is deprecated
logging.warning(
f"The {class_name} pipeline is deprecated and will not receive updates after version {last_version}. "
f"Please consider switching to a maintained pipeline."
)
if is_torch_npu_available():
import torch_npu # noqa: F401

View File

@@ -1114,6 +1114,12 @@ class PipelineTesterMixin:
torch._dynamo.reset()
gc.collect()
backend_empty_cache(torch_device)
# Skip tests for pipelines that inherit from DeprecatedPipelineMixin
from diffusers.pipelines.pipeline_utils import DeprecatedPipelineMixin
if hasattr(self, "pipeline_class") and issubclass(self.pipeline_class, DeprecatedPipelineMixin):
import pytest
pytest.skip(f"Skipping tests for deprecated pipeline: {self.pipeline_class.__name__}")
def tearDown(self):
# clean up the VRAM after each test in case of CUDA runtime errors