From 639f6455b4e1aca0d2bdc858c359dff0499d43bf Mon Sep 17 00:00:00 2001 From: Will Berman Date: Wed, 12 Apr 2023 04:11:09 -0700 Subject: [PATCH] fix pipeline __setattr__ value == None (#3063) * fix pipeline __setattr__ * add test --------- Co-authored-by: Patrick von Platen --- src/diffusers/pipelines/pipeline_utils.py | 2 +- tests/test_pipelines.py | 40 +++++++++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 06912a1464..2e20c21aaf 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -510,7 +510,7 @@ class DiffusionPipeline(ConfigMixin): if hasattr(self, name) and hasattr(self.config, name): # We need to overwrite the config if name exists in config if isinstance(getattr(self.config, name), (tuple, list)): - if self.config[name][0] is not None: + if value is not None and self.config[name][0] is not None: class_library_tuple = (value.__module__.split(".")[0], value.__class__.__name__) else: class_library_tuple = (None, None) diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 048030d983..a5d70b01d4 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -929,6 +929,46 @@ class PipelineFastTests(unittest.TestCase): sd.scheduler = DPMSolverMultistepScheduler.from_config(sd.scheduler.config) assert isinstance(sd.scheduler, DPMSolverMultistepScheduler) + def test_set_component_to_none(self): + unet = self.dummy_cond_unet() + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + pipeline = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + + generator = torch.Generator(device="cpu").manual_seed(0) + + prompt = "This is a flower" + + out_image = pipeline( + prompt=prompt, + generator=generator, + num_inference_steps=1, + output_type="np", + ).images + + pipeline.feature_extractor = None + generator = torch.Generator(device="cpu").manual_seed(0) + out_image_2 = pipeline( + prompt=prompt, + generator=generator, + num_inference_steps=1, + output_type="np", + ).images + + assert out_image.shape == (1, 64, 64, 3) + assert np.abs(out_image - out_image_2).max() < 1e-3 + def test_set_scheduler_consistency(self): unet = self.dummy_cond_unet() pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")