1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

fix pipeline __setattr__ value == None (#3063)

* fix pipeline __setattr__

* add test

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Will Berman
2023-04-12 04:11:09 -07:00
committed by GitHub
parent 9d7c08f95e
commit 639f6455b4
2 changed files with 41 additions and 1 deletions

View File

@@ -510,7 +510,7 @@ class DiffusionPipeline(ConfigMixin):
if hasattr(self, name) and hasattr(self.config, name):
# We need to overwrite the config if name exists in config
if isinstance(getattr(self.config, name), (tuple, list)):
if self.config[name][0] is not None:
if value is not None and self.config[name][0] is not None:
class_library_tuple = (value.__module__.split(".")[0], value.__class__.__name__)
else:
class_library_tuple = (None, None)

View File

@@ -929,6 +929,46 @@ class PipelineFastTests(unittest.TestCase):
sd.scheduler = DPMSolverMultistepScheduler.from_config(sd.scheduler.config)
assert isinstance(sd.scheduler, DPMSolverMultistepScheduler)
def test_set_component_to_none(self):
unet = self.dummy_cond_unet()
scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
pipeline = StableDiffusionPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=None,
feature_extractor=self.dummy_extractor,
)
generator = torch.Generator(device="cpu").manual_seed(0)
prompt = "This is a flower"
out_image = pipeline(
prompt=prompt,
generator=generator,
num_inference_steps=1,
output_type="np",
).images
pipeline.feature_extractor = None
generator = torch.Generator(device="cpu").manual_seed(0)
out_image_2 = pipeline(
prompt=prompt,
generator=generator,
num_inference_steps=1,
output_type="np",
).images
assert out_image.shape == (1, 64, 64, 3)
assert np.abs(out_image - out_image_2).max() < 1e-3
def test_set_scheduler_consistency(self):
unet = self.dummy_cond_unet()
pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")