mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Merge branch 'main' of https://github.com/huggingface/diffusers into main
This commit is contained in:
@@ -536,7 +536,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
mask, masked_image_latents = self.prepare_mask_latents(
|
||||
mask,
|
||||
masked_image,
|
||||
batch_size,
|
||||
batch_size * num_images_per_prompt,
|
||||
height,
|
||||
width,
|
||||
text_embeddings.dtype,
|
||||
|
||||
@@ -215,6 +215,47 @@ class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_inpaint_with_num_images_per_prompt(self):
|
||||
device = "cpu"
|
||||
unet = self.dummy_cond_unet_inpaint
|
||||
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
|
||||
init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((128, 128))
|
||||
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionInpaintPipeline(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=bert,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
images = sd_pipe(
|
||||
[prompt],
|
||||
generator=generator,
|
||||
guidance_scale=6.0,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
image=init_image,
|
||||
mask_image=mask_image,
|
||||
num_images_per_prompt=2,
|
||||
).images
|
||||
|
||||
# check if the output is a list of 2 images
|
||||
assert len(images) == 2
|
||||
|
||||
@unittest.skipIf(torch_device != "cuda", "This test requires a GPU")
|
||||
def test_stable_diffusion_inpaint_fp16(self):
|
||||
"""Test that stable diffusion inpaint_legacy works with fp16"""
|
||||
|
||||
Reference in New Issue
Block a user