diff --git a/examples/community/stable_diffusion_mega.py b/examples/community/stable_diffusion_mega.py index 30699b6a1b..be114ca9b1 100644 --- a/examples/community/stable_diffusion_mega.py +++ b/examples/community/stable_diffusion_mega.py @@ -50,6 +50,7 @@ class StableDiffusionMegaPipeline(DiffusionPipeline): feature_extractor ([`CLIPFeatureExtractor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ + _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, @@ -60,6 +61,7 @@ class StableDiffusionMegaPipeline(DiffusionPipeline): scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, ): super().__init__() if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: @@ -85,6 +87,7 @@ class StableDiffusionMegaPipeline(DiffusionPipeline): safety_checker=safety_checker, feature_extractor=feature_extractor, ) + self.register_to_config(requires_safety_checker=requires_safety_checker) @property def components(self) -> Dict[str, Any]: