diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index e7d34b6237..e760355ff7 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -181,7 +181,6 @@ else: "CLIPImageProjection", "StableDiffusionAttendAndExcitePipeline", "StableDiffusionDepth2ImgPipeline", - "StableDiffusionDiffEditPipeline", "StableDiffusionGLIGENPipeline", "StableDiffusionGLIGENPipeline", "StableDiffusionGLIGENTextImagePipeline", @@ -209,6 +208,7 @@ else: "StableDiffusionXLPipeline", ] ) + _import_structure["stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"] _import_structure["t2i_adapter"] = [ "StableDiffusionAdapterPipeline", "StableDiffusionXLAdapterPipeline", @@ -422,7 +422,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: CLIPImageProjection, StableDiffusionAttendAndExcitePipeline, StableDiffusionDepth2ImgPipeline, - StableDiffusionDiffEditPipeline, StableDiffusionGLIGENPipeline, StableDiffusionGLIGENTextImagePipeline, StableDiffusionImageVariationPipeline, @@ -438,6 +437,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: StableUnCLIPImg2ImgPipeline, StableUnCLIPPipeline, ) + from .stable_diffusion_diffedit import StableDiffusionDiffEditPipeline from .stable_diffusion_safe import StableDiffusionPipelineSafe from .stable_diffusion_xl import ( StableDiffusionXLImg2ImgPipeline, diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index dbd79ec1f3..085b46beff 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -67,20 +67,17 @@ try: except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import ( StableDiffusionDepth2ImgPipeline, - StableDiffusionDiffEditPipeline, StableDiffusionPix2PixZeroPipeline, ) _dummy_objects.update( { "StableDiffusionDepth2ImgPipeline": StableDiffusionDepth2ImgPipeline, - "StableDiffusionDiffEditPipeline": StableDiffusionDiffEditPipeline, "StableDiffusionPix2PixZeroPipeline": StableDiffusionPix2PixZeroPipeline, } ) else: _import_structure["pipeline_stable_diffusion_depth2img"] = ["StableDiffusionDepth2ImgPipeline"] - _import_structure["pipeline_stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"] _import_structure["pipeline_stable_diffusion_pix2pix_zero"] = ["StableDiffusionPix2PixZeroPipeline"] try: if not ( @@ -181,14 +178,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import ( StableDiffusionDepth2ImgPipeline, - StableDiffusionDiffEditPipeline, StableDiffusionPix2PixZeroPipeline, ) else: from .pipeline_stable_diffusion_depth2img import ( StableDiffusionDepth2ImgPipeline, ) - from .pipeline_stable_diffusion_diffedit import StableDiffusionDiffEditPipeline try: if not ( diff --git a/src/diffusers/pipelines/stable_diffusion_diffedit/__init__.py b/src/diffusers/pipelines/stable_diffusion_diffedit/__init__.py new file mode 100644 index 0000000000..e2145edb96 --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion_diffedit/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_stable_diffusion_diffedit import StableDiffusionDiffEditPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py similarity index 99% rename from src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py rename to src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py index 81d936be62..d0d132555e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +++ b/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py @@ -40,8 +40,8 @@ from ...utils import ( ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline -from . import StableDiffusionPipelineOutput -from .safety_checker import StableDiffusionSafetyChecker +from ..stable_diffusion import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name