From 491bc9f5a640bbf46a97a8e52d6eff7e70eb8e4b Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 6 Jul 2023 09:36:25 +0530 Subject: [PATCH] fix setup for watermarking. --- setup.py | 2 ++ .../stable_diffusion_xl/pipeline_stable_diffusion_xl.py | 8 +++++++- tests/others/test_dependencies.py | 2 ++ 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 6e6873447a..513becaefb 100644 --- a/setup.py +++ b/setup.py @@ -117,6 +117,7 @@ _deps = [ "torchvision", "transformers>=4.25.1", "urllib3<=2.0.0", + "invisible-watermark" ] # this is a lookup table with items like: @@ -207,6 +208,7 @@ extras["test"] = deps_list( "scipy", "torchvision", "transformers", + "invisible-watermark", ) extras["torch"] = deps_list("torch", "accelerate") diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index c50381c2eb..fbc1db2b32 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -31,13 +31,19 @@ from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( is_accelerate_available, is_accelerate_version, + is_invisible_watermark_available, logging, randn_tensor, replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionXLPipelineOutput -from .watermark import StableDiffusionXLWatermarker + + +if is_invisible_watermark_available(): + from .watermark import StableDiffusionXLWatermarker +else: + raise ImportError("StableDiffusionXLWatermarker requires the `invisible-watermark` library installed.") logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/tests/others/test_dependencies.py b/tests/others/test_dependencies.py index 9bee7a0db3..3436cf92d8 100644 --- a/tests/others/test_dependencies.py +++ b/tests/others/test_dependencies.py @@ -34,4 +34,6 @@ class DependencyTester(unittest.TestCase): for backend in cls_module._backends: if backend == "k_diffusion": backend = "k-diffusion" + elif backend == "invisible_watermark": + backend = "invisible-watermark" assert backend in deps, f"{backend} is not in the deps table!"