mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[img2img, inpainting] fix fp16 inference (#769)
* handle dtype in vae and image2image pipeline * fix inpaint in fp16 * dtype should be handled in add_noise * style * address review comments * add simple fast tests to check fp16 * fix test name * put mask in fp16
This commit is contained in:
@@ -337,12 +337,16 @@ class DiagonalGaussianDistribution(object):
|
||||
self.std = torch.exp(0.5 * self.logvar)
|
||||
self.var = torch.exp(self.logvar)
|
||||
if self.deterministic:
|
||||
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
|
||||
self.var = self.std = torch.zeros_like(
|
||||
self.mean, device=self.parameters.device, dtype=self.parameters.dtype
|
||||
)
|
||||
|
||||
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
|
||||
device = self.parameters.device
|
||||
sample_device = "cpu" if device.type == "mps" else device
|
||||
sample = torch.randn(self.mean.shape, generator=generator, device=sample_device).to(device)
|
||||
sample = torch.randn(self.mean.shape, generator=generator, device=sample_device)
|
||||
# make sure sample is on the same device as the parameters and has same dtype
|
||||
sample = sample.to(device=device, dtype=self.parameters.dtype)
|
||||
x = self.mean + self.std * sample
|
||||
return x
|
||||
|
||||
|
||||
@@ -217,26 +217,6 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
if isinstance(init_image, PIL.Image.Image):
|
||||
init_image = preprocess(init_image)
|
||||
|
||||
# encode the init image into latents and scale the latents
|
||||
init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist
|
||||
init_latents = init_latent_dist.sample(generator=generator)
|
||||
init_latents = 0.18215 * init_latents
|
||||
|
||||
# expand init_latents for batch_size
|
||||
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
|
||||
|
||||
# get the original timestep using init_timestep
|
||||
offset = self.scheduler.config.get("steps_offset", 0)
|
||||
init_timestep = int(num_inference_steps * strength) + offset
|
||||
init_timestep = min(init_timestep, num_inference_steps)
|
||||
|
||||
timesteps = self.scheduler.timesteps[-init_timestep]
|
||||
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
|
||||
|
||||
# add noise to latents using the timesteps
|
||||
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
|
||||
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
|
||||
|
||||
# get prompt text embeddings
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
@@ -297,6 +277,28 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
# to avoid doing two forward passes
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
# encode the init image into latents and scale the latents
|
||||
latents_dtype = text_embeddings.dtype
|
||||
init_image = init_image.to(device=self.device, dtype=latents_dtype)
|
||||
init_latent_dist = self.vae.encode(init_image).latent_dist
|
||||
init_latents = init_latent_dist.sample(generator=generator)
|
||||
init_latents = 0.18215 * init_latents
|
||||
|
||||
# expand init_latents for batch_size
|
||||
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
|
||||
|
||||
# get the original timestep using init_timestep
|
||||
offset = self.scheduler.config.get("steps_offset", 0)
|
||||
init_timestep = int(num_inference_steps * strength) + offset
|
||||
init_timestep = min(init_timestep, num_inference_steps)
|
||||
|
||||
timesteps = self.scheduler.timesteps[-init_timestep]
|
||||
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
|
||||
|
||||
# add noise to latents using the timesteps
|
||||
noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype)
|
||||
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
|
||||
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
@@ -341,7 +343,9 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
|
||||
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
|
||||
)
|
||||
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
@@ -234,43 +234,6 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
# set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
# preprocess image
|
||||
if not isinstance(init_image, torch.FloatTensor):
|
||||
init_image = preprocess_image(init_image)
|
||||
init_image = init_image.to(self.device)
|
||||
|
||||
# encode the init image into latents and scale the latents
|
||||
init_latent_dist = self.vae.encode(init_image).latent_dist
|
||||
init_latents = init_latent_dist.sample(generator=generator)
|
||||
|
||||
init_latents = 0.18215 * init_latents
|
||||
|
||||
# Expand init_latents for batch_size and num_images_per_prompt
|
||||
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
|
||||
init_latents_orig = init_latents
|
||||
|
||||
# preprocess mask
|
||||
if not isinstance(mask_image, torch.FloatTensor):
|
||||
mask_image = preprocess_mask(mask_image)
|
||||
mask_image = mask_image.to(self.device)
|
||||
mask = torch.cat([mask_image] * batch_size * num_images_per_prompt)
|
||||
|
||||
# check sizes
|
||||
if not mask.shape == init_latents.shape:
|
||||
raise ValueError("The mask and init_image should be the same size!")
|
||||
|
||||
# get the original timestep using init_timestep
|
||||
offset = self.scheduler.config.get("steps_offset", 0)
|
||||
init_timestep = int(num_inference_steps * strength) + offset
|
||||
init_timestep = min(init_timestep, num_inference_steps)
|
||||
|
||||
timesteps = self.scheduler.timesteps[-init_timestep]
|
||||
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
|
||||
|
||||
# add noise to latents using the timesteps
|
||||
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
|
||||
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
|
||||
|
||||
# get prompt text embeddings
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
@@ -335,6 +298,43 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
# to avoid doing two forward passes
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
# preprocess image
|
||||
if not isinstance(init_image, torch.FloatTensor):
|
||||
init_image = preprocess_image(init_image)
|
||||
|
||||
# encode the init image into latents and scale the latents
|
||||
latents_dtype = text_embeddings.dtype
|
||||
init_image = init_image.to(device=self.device, dtype=latents_dtype)
|
||||
init_latent_dist = self.vae.encode(init_image).latent_dist
|
||||
init_latents = init_latent_dist.sample(generator=generator)
|
||||
init_latents = 0.18215 * init_latents
|
||||
|
||||
# Expand init_latents for batch_size and num_images_per_prompt
|
||||
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
|
||||
init_latents_orig = init_latents
|
||||
|
||||
# preprocess mask
|
||||
if not isinstance(mask_image, torch.FloatTensor):
|
||||
mask_image = preprocess_mask(mask_image)
|
||||
mask_image = mask_image.to(device=self.device, dtype=latents_dtype)
|
||||
mask = torch.cat([mask_image] * batch_size * num_images_per_prompt)
|
||||
|
||||
# check sizes
|
||||
if not mask.shape == init_latents.shape:
|
||||
raise ValueError("The mask and init_image should be the same size!")
|
||||
|
||||
# get the original timestep using init_timestep
|
||||
offset = self.scheduler.config.get("steps_offset", 0)
|
||||
init_timestep = int(num_inference_steps * strength) + offset
|
||||
init_timestep = min(init_timestep, num_inference_steps)
|
||||
|
||||
timesteps = self.scheduler.timesteps[-init_timestep]
|
||||
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
|
||||
|
||||
# add noise to latents using the timesteps
|
||||
noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype)
|
||||
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
|
||||
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
|
||||
@@ -1005,6 +1005,124 @@ class PipelineFastTests(unittest.TestCase):
|
||||
|
||||
assert images.shape == (batch_size * num_images_per_prompt, 32, 32, 3)
|
||||
|
||||
@unittest.skipIf(torch_device == "cpu", "This test requires a GPU")
|
||||
def test_stable_diffusion_fp16(self):
|
||||
"""Test that stable diffusion works with fp16"""
|
||||
unet = self.dummy_cond_unet
|
||||
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
# put models in fp16
|
||||
unet = unet.half()
|
||||
vae = vae.half()
|
||||
bert = bert.half()
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionPipeline(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=bert,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=self.dummy_safety_checker,
|
||||
feature_extractor=self.dummy_extractor,
|
||||
)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
image = sd_pipe([prompt], generator=generator, num_inference_steps=2, output_type="np").images
|
||||
|
||||
assert image.shape == (1, 128, 128, 3)
|
||||
|
||||
@unittest.skipIf(torch_device == "cpu", "This test requires a GPU")
|
||||
def test_stable_diffusion_img2img_fp16(self):
|
||||
"""Test that stable diffusion img2img works with fp16"""
|
||||
unet = self.dummy_cond_unet
|
||||
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
init_image = self.dummy_image.to(torch_device)
|
||||
|
||||
# put models in fp16
|
||||
unet = unet.half()
|
||||
vae = vae.half()
|
||||
bert = bert.half()
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionImg2ImgPipeline(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=bert,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=self.dummy_safety_checker,
|
||||
feature_extractor=self.dummy_extractor,
|
||||
)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
image = sd_pipe(
|
||||
[prompt],
|
||||
generator=generator,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
init_image=init_image,
|
||||
).images
|
||||
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
|
||||
@unittest.skipIf(torch_device == "cpu", "This test requires a GPU")
|
||||
def test_stable_diffusion_inpaint_fp16(self):
|
||||
"""Test that stable diffusion inpaint works with fp16"""
|
||||
unet = self.dummy_cond_unet
|
||||
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
|
||||
init_image = Image.fromarray(np.uint8(image)).convert("RGB")
|
||||
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))
|
||||
|
||||
# put models in fp16
|
||||
unet = unet.half()
|
||||
vae = vae.half()
|
||||
bert = bert.half()
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionInpaintPipeline(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=bert,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=self.dummy_safety_checker,
|
||||
feature_extractor=self.dummy_extractor,
|
||||
)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
image = sd_pipe(
|
||||
[prompt],
|
||||
generator=generator,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
init_image=init_image,
|
||||
mask_image=mask_image,
|
||||
).images
|
||||
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
|
||||
|
||||
class PipelineTesterMixin(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
|
||||
Reference in New Issue
Block a user