mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
fix pipeline __setattr__ value == None (#3063)
* fix pipeline __setattr__ * add test --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -510,7 +510,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
if hasattr(self, name) and hasattr(self.config, name):
|
||||
# We need to overwrite the config if name exists in config
|
||||
if isinstance(getattr(self.config, name), (tuple, list)):
|
||||
if self.config[name][0] is not None:
|
||||
if value is not None and self.config[name][0] is not None:
|
||||
class_library_tuple = (value.__module__.split(".")[0], value.__class__.__name__)
|
||||
else:
|
||||
class_library_tuple = (None, None)
|
||||
|
||||
@@ -929,6 +929,46 @@ class PipelineFastTests(unittest.TestCase):
|
||||
sd.scheduler = DPMSolverMultistepScheduler.from_config(sd.scheduler.config)
|
||||
assert isinstance(sd.scheduler, DPMSolverMultistepScheduler)
|
||||
|
||||
def test_set_component_to_none(self):
|
||||
unet = self.dummy_cond_unet()
|
||||
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
pipeline = StableDiffusionPipeline(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=bert,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=None,
|
||||
feature_extractor=self.dummy_extractor,
|
||||
)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
|
||||
prompt = "This is a flower"
|
||||
|
||||
out_image = pipeline(
|
||||
prompt=prompt,
|
||||
generator=generator,
|
||||
num_inference_steps=1,
|
||||
output_type="np",
|
||||
).images
|
||||
|
||||
pipeline.feature_extractor = None
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
out_image_2 = pipeline(
|
||||
prompt=prompt,
|
||||
generator=generator,
|
||||
num_inference_steps=1,
|
||||
output_type="np",
|
||||
).images
|
||||
|
||||
assert out_image.shape == (1, 64, 64, 3)
|
||||
assert np.abs(out_image - out_image_2).max() < 1e-3
|
||||
|
||||
def test_set_scheduler_consistency(self):
|
||||
unet = self.dummy_cond_unet()
|
||||
pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")
|
||||
|
||||
Reference in New Issue
Block a user