diff --git a/src/diffusers/pipeline_flax_utils.py b/src/diffusers/pipeline_flax_utils.py index c281c772db..e63009b49c 100644 --- a/src/diffusers/pipeline_flax_utils.py +++ b/src/diffusers/pipeline_flax_utils.py @@ -161,6 +161,10 @@ class FlaxDiffusionPipeline(ConfigMixin): for pipeline_component_name in model_index_dict.keys(): sub_model = getattr(self, pipeline_component_name) + if sub_model is None: + # edge case for saving a pipeline with safety_checker=None + continue + model_cls = sub_model.__class__ save_method_name = None @@ -367,6 +371,11 @@ class FlaxDiffusionPipeline(ConfigMixin): # 3. Load each module in the pipeline for name, (library_name, class_name) in init_dict.items(): + if class_name is None: + # edge case for when the pipeline was saved with safety_checker=None + init_kwargs[name] = None + continue + is_pipeline_module = hasattr(pipelines, library_name) loaded_sub_model = None sub_model_should_be_defined = True diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 94c1e135ab..4ba8d2d930 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -176,6 +176,10 @@ class DiffusionPipeline(ConfigMixin): for pipeline_component_name in model_index_dict.keys(): sub_model = getattr(self, pipeline_component_name) + if sub_model is None: + # edge case for saving a pipeline with safety_checker=None + continue + model_cls = sub_model.__class__ save_method_name = None @@ -477,6 +481,11 @@ class DiffusionPipeline(ConfigMixin): # 3. Load each module in the pipeline for name, (library_name, class_name) in init_dict.items(): + if class_name is None: + # edge case for when the pipeline was saved with safety_checker=None + init_kwargs[name] = None + continue + # 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names if class_name.startswith("Flax"): class_name = class_name[4:] diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index 260d58e94b..ded2470cc2 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -15,6 +15,7 @@ import gc import random +import tempfile import time import unittest @@ -318,6 +319,16 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): image = pipe("example prompt", num_inference_steps=2).images[0] assert image is not None + # check that there's no error when saving a pipeline with one of the models being None + with tempfile.TemporaryDirectory() as tmpdirname: + pipe.save_pretrained(tmpdirname) + pipe = StableDiffusionPipeline.from_pretrained(tmpdirname) + + # sanity check that the pipeline still works + assert pipe.safety_checker is None + image = pipe("example prompt", num_inference_steps=2).images[0] + assert image is not None + def test_stable_diffusion_k_lms(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator unet = self.dummy_cond_unet