mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
fix a bug in StableDiffusionUpscalePipeline when prompt is None (#4278)
* fix batch_size * add test --------- Co-authored-by: yiyixuxu <yixu310@gmail,com>
This commit is contained in:
committed by
Patrick von Platen
parent
49c95178ad
commit
a9829164f4
@@ -424,10 +424,13 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
|
||||
|
||||
# verify batch size of prompt and image are same if image is a list or tensor or numpy array
|
||||
if isinstance(image, list) or isinstance(image, torch.Tensor) or isinstance(image, np.ndarray):
|
||||
if isinstance(prompt, str):
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
else:
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if isinstance(image, list):
|
||||
image_batch_size = len(image)
|
||||
else:
|
||||
|
||||
@@ -210,6 +210,68 @@ class StableDiffusionUpscalePipelineFastTests(unittest.TestCase):
|
||||
image = output.images
|
||||
assert image.shape[0] == 2
|
||||
|
||||
def test_stable_diffusion_upscale_prompt_embeds(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
unet = self.dummy_cond_unet_upscale
|
||||
low_res_scheduler = DDPMScheduler()
|
||||
scheduler = DDIMScheduler(prediction_type="v_prediction")
|
||||
vae = self.dummy_vae
|
||||
text_encoder = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
|
||||
low_res_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionUpscalePipeline(
|
||||
unet=unet,
|
||||
low_res_scheduler=low_res_scheduler,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
max_noise_level=350,
|
||||
)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
output = sd_pipe(
|
||||
[prompt],
|
||||
image=low_res_image,
|
||||
generator=generator,
|
||||
guidance_scale=6.0,
|
||||
noise_level=20,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
)
|
||||
|
||||
image = output.images
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
prompt_embeds = sd_pipe._encode_prompt(prompt, device, 1, False)
|
||||
image_from_prompt_embeds = sd_pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
image=[low_res_image],
|
||||
generator=generator,
|
||||
guidance_scale=6.0,
|
||||
noise_level=20,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_prompt_embeds_slice = image_from_prompt_embeds[0, -3:, -3:, -1]
|
||||
|
||||
expected_height_width = low_res_image.size[0] * 4
|
||||
assert image.shape == (1, expected_height_width, expected_height_width, 3)
|
||||
expected_slice = np.array([0.3113, 0.3910, 0.4272, 0.4859, 0.5061, 0.4652, 0.5362, 0.5715, 0.5661])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_prompt_embeds_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@unittest.skipIf(torch_device != "cuda", "This test requires a GPU")
|
||||
def test_stable_diffusion_upscale_fp16(self):
|
||||
"""Test that stable diffusion upscale works with fp16"""
|
||||
|
||||
Reference in New Issue
Block a user