mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
@@ -70,6 +70,44 @@ from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_cre
|
||||
from ..utils.torch_utils import get_device, is_compiled_module
|
||||
|
||||
|
||||
class DeprecatedPipelineMixin:
|
||||
"""
|
||||
A mixin that can be used to mark a pipeline as deprecated.
|
||||
|
||||
Pipelines inheriting from this mixin will raise a warning when instantiated, indicating
|
||||
that they are deprecated and won't receive updates past the specified version.
|
||||
Tests will be skipped for pipelines that inherit from this mixin.
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
class MyDeprecatedPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
|
||||
_last_supported_version = "0.20.0"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
```
|
||||
"""
|
||||
|
||||
# Override this in the inheriting class to specify the last version that will support this pipeline
|
||||
_last_supported_version = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Call the parent class's __init__ method
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# Get the class name for the warning message
|
||||
class_name = self.__class__.__name__
|
||||
|
||||
# Get the last supported version or use the current version if not specified
|
||||
last_version = getattr(self.__class__, "_last_supported_version", __version__)
|
||||
|
||||
# Raise a warning that this pipeline is deprecated
|
||||
logging.warning(
|
||||
f"The {class_name} pipeline is deprecated and will not receive updates after version {last_version}. "
|
||||
f"Please consider switching to a maintained pipeline."
|
||||
)
|
||||
|
||||
|
||||
if is_torch_npu_available():
|
||||
import torch_npu # noqa: F401
|
||||
|
||||
|
||||
@@ -1114,6 +1114,12 @@ class PipelineTesterMixin:
|
||||
torch._dynamo.reset()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
# Skip tests for pipelines that inherit from DeprecatedPipelineMixin
|
||||
from diffusers.pipelines.pipeline_utils import DeprecatedPipelineMixin
|
||||
if hasattr(self, "pipeline_class") and issubclass(self.pipeline_class, DeprecatedPipelineMixin):
|
||||
import pytest
|
||||
pytest.skip(f"Skipping tests for deprecated pipeline: {self.pipeline_class.__name__}")
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test in case of CUDA runtime errors
|
||||
|
||||
Reference in New Issue
Block a user