diff --git a/src/diffusers/pipelines/photon/__init__.py b/src/diffusers/pipelines/photon/__init__.py index 6f376e440f..38e85c7285 100644 --- a/src/diffusers/pipelines/photon/__init__.py +++ b/src/diffusers/pipelines/photon/__init__.py @@ -1,16 +1,65 @@ from typing import TYPE_CHECKING -from .pipeline_output import PhotonPipelineOutput -from .pipeline_photon import PhotonPipeline +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) -__all__ = ["PhotonPipeline", "PhotonPipelineOutput"] +_dummy_objects = {} +_additional_imports = {} +_import_structure = {"pipeline_output": ["PhotonPipelineOutput"]} -# Make T5GemmaEncoder importable from this module for pipeline loading -if TYPE_CHECKING: - from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: - try: + _import_structure["pipeline_photon"] = ["PhotonPipeline"] + +# Import T5GemmaEncoder for pipeline loading compatibility +try: + if is_transformers_available(): from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder + + _additional_imports["T5GemmaEncoder"] = T5GemmaEncoder +except ImportError: + pass + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_output import PhotonPipelineOutput + from .pipeline_photon import PhotonPipeline + + try: + if is_transformers_available(): + from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder except ImportError: pass + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) + for name, value in _additional_imports.items(): + setattr(sys.modules[__name__], name, value)