From 7c2a58fd4d4ce766d68f46dfdb3e7d9dff0e6a37 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 4 Nov 2022 14:58:52 +0100 Subject: [PATCH] Move accelerate to a soft-dependency (#1134) * finish * finish * Update src/diffusers/modeling_utils.py * Update src/diffusers/pipeline_utils.py Co-authored-by: Anton Lozhkov * more fixes * fix Co-authored-by: Anton Lozhkov --- src/diffusers/__init__.py | 8 - src/diffusers/modeling_utils.py | 34 +- src/diffusers/pipeline_utils.py | 12 +- src/diffusers/utils/__init__.py | 1 + src/diffusers/utils/dummy_pt_objects.py | 15 - .../dummy_torch_and_accelerate_objects.py | 452 ------------------ src/diffusers/utils/import_utils.py | 38 ++ tests/repo_utils/test_check_dummies.py | 4 +- 8 files changed, 82 insertions(+), 482 deletions(-) delete mode 100644 src/diffusers/utils/dummy_torch_and_accelerate_objects.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 22b6589973..e4a69641d5 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -1,5 +1,4 @@ from .utils import ( - is_accelerate_available, is_flax_available, is_inflect_available, is_onnx_available, @@ -17,13 +16,6 @@ from .onnx_utils import OnnxRuntimeModel from .utils import logging -# This will create an extra dummy file "dummy_torch_and_accelerate_objects.py" -# TODO: (patil-suraj, anton-l) maybe import everything under is_torch_and_accelerate_available -if is_torch_available() and not is_accelerate_available(): - error_msg = "Please install the `accelerate` library to use Diffusers with PyTorch. You can do so by running `pip install diffusers[torch]`. Or if torch is already installed, you can run `pip install accelerate`." # noqa: E501 - raise ImportError(error_msg) - - if is_torch_available(): from .modeling_utils import ModelMixin from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py index 9e05672bf1..1e91ccd56a 100644 --- a/src/diffusers/modeling_utils.py +++ b/src/diffusers/modeling_utils.py @@ -21,15 +21,20 @@ from typing import Callable, List, Optional, Tuple, Union import torch from torch import Tensor, device -import accelerate -from accelerate.utils import set_module_tensor_to_device -from accelerate.utils.versions import is_torch_version from huggingface_hub import hf_hub_download from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError from requests import HTTPError from . import __version__ -from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, WEIGHTS_NAME, logging +from .utils import ( + CONFIG_NAME, + DIFFUSERS_CACHE, + HUGGINGFACE_CO_RESOLVE_ENDPOINT, + WEIGHTS_NAME, + is_accelerate_available, + is_torch_version, + logging, +) logger = logging.get_logger(__name__) @@ -41,6 +46,12 @@ else: _LOW_CPU_MEM_USAGE_DEFAULT = False +if is_accelerate_available(): + import accelerate + from accelerate.utils import set_module_tensor_to_device + from accelerate.utils.versions import is_torch_version + + def get_parameter_device(parameter: torch.nn.Module): try: return next(parameter.parameters()).device @@ -319,6 +330,21 @@ class ModelMixin(torch.nn.Module): device_map = kwargs.pop("device_map", None) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) + if low_cpu_mem_usage and not is_accelerate_available(): + low_cpu_mem_usage = False + logger.warn( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" + " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" + " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" + " install accelerate\n```\n." + ) + + if device_map is not None and not is_accelerate_available(): + raise NotImplementedError( + "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set" + " `device_map=None`. You can install accelerate with `pip install accelerate`." + ) + # Check if we can handle device_map and dispatching the weights if device_map is not None and not is_torch_version(">=", "1.9.0"): raise NotImplementedError( diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 36c2d5b888..97e196e723 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -25,7 +25,6 @@ import torch import diffusers import PIL -from accelerate.utils.versions import is_torch_version from huggingface_hub import snapshot_download from packaging import version from PIL import Image @@ -43,6 +42,8 @@ from .utils import ( WEIGHTS_NAME, BaseOutput, deprecate, + is_accelerate_available, + is_torch_version, is_transformers_available, logging, ) @@ -397,6 +398,15 @@ class DiffusionPipeline(ConfigMixin): device_map = kwargs.pop("device_map", None) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) + if low_cpu_mem_usage and not is_accelerate_available(): + low_cpu_mem_usage = False + logger.warn( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" + " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" + " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" + " install accelerate\n```\n." + ) + if device_map is not None and not is_torch_version(">=", "1.9.0"): raise NotImplementedError( "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 7395f4edfa..3fa477e7dc 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -31,6 +31,7 @@ from .import_utils import ( is_scipy_available, is_tf_available, is_torch_available, + is_torch_version, is_transformers_available, is_unidecode_available, requires_backends, diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 833f2b6c50..25aa82d6c5 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -272,21 +272,6 @@ class ScoreSdeVePipeline(metaclass=DummyObject): requires_backends(cls, ["torch"]) -class VQDiffusionPipeline(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - class DDIMScheduler(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_accelerate_objects.py b/src/diffusers/utils/dummy_torch_and_accelerate_objects.py deleted file mode 100644 index 335e3ca24d..0000000000 --- a/src/diffusers/utils/dummy_torch_and_accelerate_objects.py +++ /dev/null @@ -1,452 +0,0 @@ -# This file is autogenerated by the command `make fix-copies`, do not edit. -# flake8: noqa - -from ..utils import DummyObject, requires_backends - - -class ModelMixin(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class AutoencoderKL(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class Transformer2DModel(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class UNet1DModel(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class UNet2DConditionModel(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class UNet2DModel(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class VQModel(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -def get_constant_schedule(*args, **kwargs): - requires_backends(get_constant_schedule, ["torch", "accelerate"]) - - -def get_constant_schedule_with_warmup(*args, **kwargs): - requires_backends(get_constant_schedule_with_warmup, ["torch", "accelerate"]) - - -def get_cosine_schedule_with_warmup(*args, **kwargs): - requires_backends(get_cosine_schedule_with_warmup, ["torch", "accelerate"]) - - -def get_cosine_with_hard_restarts_schedule_with_warmup(*args, **kwargs): - requires_backends(get_cosine_with_hard_restarts_schedule_with_warmup, ["torch", "accelerate"]) - - -def get_linear_schedule_with_warmup(*args, **kwargs): - requires_backends(get_linear_schedule_with_warmup, ["torch", "accelerate"]) - - -def get_polynomial_decay_schedule_with_warmup(*args, **kwargs): - requires_backends(get_polynomial_decay_schedule_with_warmup, ["torch", "accelerate"]) - - -def get_scheduler(*args, **kwargs): - requires_backends(get_scheduler, ["torch", "accelerate"]) - - -class DiffusionPipeline(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class DanceDiffusionPipeline(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class DDIMPipeline(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class DDPMPipeline(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class KarrasVePipeline(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class LDMPipeline(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class PNDMPipeline(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class RePaintPipeline(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class ScoreSdeVePipeline(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class DDIMScheduler(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class DDPMScheduler(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class EulerAncestralDiscreteScheduler(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class EulerDiscreteScheduler(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class IPNDMScheduler(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class KarrasVeScheduler(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class PNDMScheduler(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class RePaintScheduler(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class SchedulerMixin(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class ScoreSdeVeScheduler(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class VQDiffusionScheduler(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class EMAModel(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 4ea02dcc94..005cbb6170 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -15,11 +15,14 @@ Import utilities: Utilities related to imports and our lazy inits. """ import importlib.util +import operator as op import os import sys from collections import OrderedDict +from typing import Union from packaging import version +from packaging.version import Version, parse from . import logging @@ -40,6 +43,8 @@ USE_TF = os.environ.get("USE_TF", "AUTO").upper() USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper() +STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt} + _torch_version = "N/A" if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: _torch_available = importlib.util.find_spec("torch") is not None @@ -309,3 +314,36 @@ class DummyObject(type): if key.startswith("_"): return super().__getattr__(cls, key) requires_backends(cls, cls._backends) + + +# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319 +def compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str): + """ + Args: + Compares a library version to some requirement using a given operation. + library_or_version (`str` or `packaging.version.Version`): + A library name or a version to check. + operation (`str`): + A string representation of an operator, such as `">"` or `"<="`. + requirement_version (`str`): + The version to compare the library version against + """ + if operation not in STR_OPERATION_TO_FUNC.keys(): + raise ValueError(f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}, received {operation}") + operation = STR_OPERATION_TO_FUNC[operation] + if isinstance(library_or_version, str): + library_or_version = parse(importlib_metadata.version(library_or_version)) + return operation(library_or_version, parse(requirement_version)) + + +# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L338 +def is_torch_version(operation: str, version: str): + """ + Args: + Compares the current PyTorch version to a given reference with an operation. + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A string version of PyTorch + """ + return compare_versions(parse(_torch_version), operation, version) diff --git a/tests/repo_utils/test_check_dummies.py b/tests/repo_utils/test_check_dummies.py index 0331b5e8c2..d8fa9ce105 100644 --- a/tests/repo_utils/test_check_dummies.py +++ b/tests/repo_utils/test_check_dummies.py @@ -52,13 +52,13 @@ class CheckDummiesTester(unittest.TestCase): def test_read_init(self): objects = read_init() # We don't assert on the exact list of keys to allow for smooth grow of backend-specific objects - self.assertIn("torch_and_accelerate", objects) + self.assertIn("torch", objects) self.assertIn("torch_and_transformers", objects) self.assertIn("flax_and_transformers", objects) self.assertIn("torch_and_transformers_and_onnx", objects) # Likewise, we can't assert on the exact content of a key - self.assertIn("UNet2DModel", objects["torch_and_accelerate"]) + self.assertIn("UNet2DModel", objects["torch"]) self.assertIn("FlaxUNet2DConditionModel", objects["flax"]) self.assertIn("StableDiffusionPipeline", objects["torch_and_transformers"]) self.assertIn("FlaxStableDiffusionPipeline", objects["flax_and_transformers"])