1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

fix a bug in StableDiffusionUpscalePipeline when prompt is None (#4278)

* fix batch_size

* add test

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
This commit is contained in:
YiYi Xu
2023-07-27 03:07:50 -10:00
committed by Patrick von Platen
parent 49c95178ad
commit a9829164f4
2 changed files with 67 additions and 2 deletions

View File

@@ -424,10 +424,13 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
# verify batch size of prompt and image are same if image is a list or tensor or numpy array
if isinstance(image, list) or isinstance(image, torch.Tensor) or isinstance(image, np.ndarray):
if isinstance(prompt, str):
if prompt is not None and isinstance(prompt, str):
batch_size = 1
else:
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if isinstance(image, list):
image_batch_size = len(image)
else:

View File

@@ -210,6 +210,68 @@ class StableDiffusionUpscalePipelineFastTests(unittest.TestCase):
image = output.images
assert image.shape[0] == 2
def test_stable_diffusion_upscale_prompt_embeds(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet_upscale
low_res_scheduler = DDPMScheduler()
scheduler = DDIMScheduler(prediction_type="v_prediction")
vae = self.dummy_vae
text_encoder = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
low_res_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionUpscalePipeline(
unet=unet,
low_res_scheduler=low_res_scheduler,
scheduler=scheduler,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
max_noise_level=350,
)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=device).manual_seed(0)
output = sd_pipe(
[prompt],
image=low_res_image,
generator=generator,
guidance_scale=6.0,
noise_level=20,
num_inference_steps=2,
output_type="np",
)
image = output.images
generator = torch.Generator(device=device).manual_seed(0)
prompt_embeds = sd_pipe._encode_prompt(prompt, device, 1, False)
image_from_prompt_embeds = sd_pipe(
prompt_embeds=prompt_embeds,
image=[low_res_image],
generator=generator,
guidance_scale=6.0,
noise_level=20,
num_inference_steps=2,
output_type="np",
return_dict=False,
)[0]
image_slice = image[0, -3:, -3:, -1]
image_from_prompt_embeds_slice = image_from_prompt_embeds[0, -3:, -3:, -1]
expected_height_width = low_res_image.size[0] * 4
assert image.shape == (1, expected_height_width, expected_height_width, 3)
expected_slice = np.array([0.3113, 0.3910, 0.4272, 0.4859, 0.5061, 0.4652, 0.5362, 0.5715, 0.5661])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_prompt_embeds_slice.flatten() - expected_slice).max() < 1e-2
@unittest.skipIf(torch_device != "cuda", "This test requires a GPU")
def test_stable_diffusion_upscale_fp16(self):
"""Test that stable diffusion upscale works with fp16"""