diff --git a/src/diffusers/pipelines/stable_diffusion_safe/__init__.py b/src/diffusers/pipelines/stable_diffusion_safe/__init__.py index 67c6ab1f66..2bab91c552 100644 --- a/src/diffusers/pipelines/stable_diffusion_safe/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion_safe/__init__.py @@ -50,13 +50,26 @@ class SafetyConfig(object): _dummy_objects = {} _additional_imports = {} -_import_structure = { - "pipeline_output": ["StableDiffusionSafePipelineOutput"], - "pipeline_stable_diffusion_safe": ["StableDiffusionPipelineSafe"], - "safety_checker": ["StableDiffusionSafetyChecker"], -} +_import_structure = {} + _additional_imports.update({"SafetyConfig": SafetyConfig}) +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure.update( + { + "pipeline_output": ["StableDiffusionSafePipelineOutput"], + "pipeline_stable_diffusion_safe": ["StableDiffusionPipelineSafe"], + "safety_checker": ["StableDiffusionSafetyChecker"], + } + ) + if TYPE_CHECKING: try: @@ -70,25 +83,16 @@ if TYPE_CHECKING: from .safety_checker import SafeStableDiffusionSafetyChecker else: - try: - if not (is_transformers_available() and is_torch_available()): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from ...utils import dummy_torch_and_transformers_objects + import sys - _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) - else: - import sys - - sys.modules[__name__] = _LazyModule( - __name__, - globals()["__file__"], - _import_structure, - module_spec=__spec__, - ) - - for name, value in _dummy_objects.items(): - setattr(sys.modules[__name__], name, value) - for name, value in _additional_imports.items(): - setattr(sys.modules[__name__], name, value) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) + for name, value in _additional_imports.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/text_to_video_synthesis/__init__.py b/src/diffusers/pipelines/text_to_video_synthesis/__init__.py index a09a63476b..8bc8e407d4 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/__init__.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/__init__.py @@ -47,3 +47,5 @@ else: _import_structure, module_spec=__spec__, ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/vq_diffusion/__init__.py b/src/diffusers/pipelines/vq_diffusion/__init__.py index b8fb7f55e8..dac43806a5 100644 --- a/src/diffusers/pipelines/vq_diffusion/__init__.py +++ b/src/diffusers/pipelines/vq_diffusion/__init__.py @@ -51,3 +51,6 @@ else: _import_structure, module_spec=__spec__, ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/wuerstchen/__init__.py b/src/diffusers/pipelines/wuerstchen/__init__.py index 13407f2cd1..3a6a464aef 100644 --- a/src/diffusers/pipelines/wuerstchen/__init__.py +++ b/src/diffusers/pipelines/wuerstchen/__init__.py @@ -41,7 +41,6 @@ if TYPE_CHECKING: from .pipeline_wuerstchen import WuerstchenDecoderPipeline from .pipeline_wuerstchen_combined import WuerstchenCombinedPipeline from .pipeline_wuerstchen_prior import WuerstchenPriorPipeline - else: import sys @@ -51,3 +50,6 @@ else: _import_structure, module_spec=__spec__, ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value)