From ae4112d2bbfa363f2f3049daad54f0fb89c34499 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 6 Dec 2022 11:18:53 +0100 Subject: [PATCH] Mega community pipeline (#1561) * Mega community pipeline * fix --- examples/community/stable_diffusion_mega.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/community/stable_diffusion_mega.py b/examples/community/stable_diffusion_mega.py index 30699b6a1b..be114ca9b1 100644 --- a/examples/community/stable_diffusion_mega.py +++ b/examples/community/stable_diffusion_mega.py @@ -50,6 +50,7 @@ class StableDiffusionMegaPipeline(DiffusionPipeline): feature_extractor ([`CLIPFeatureExtractor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ + _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, @@ -60,6 +61,7 @@ class StableDiffusionMegaPipeline(DiffusionPipeline): scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, ): super().__init__() if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: @@ -85,6 +87,7 @@ class StableDiffusionMegaPipeline(DiffusionPipeline): safety_checker=safety_checker, feature_extractor=feature_extractor, ) + self.register_to_config(requires_safety_checker=requires_safety_checker) @property def components(self) -> Dict[str, Any]: