1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Update pipeline_stable_diffusion_inpaint_legacy.py (#2903)

* Update pipeline_stable_diffusion_inpaint_legacy.py

* fix preprocessing of Pil images with adequate batch size

* revert map

* add tests

* reformat

* Update test_stable_diffusion_inpaint_legacy.py

* Update test_stable_diffusion_inpaint_legacy.py

* Update test_stable_diffusion_inpaint_legacy.py

* Update test_stable_diffusion_inpaint_legacy.py

* next try to fix the style

* wth is this

* Update testing_utils.py

* Update testing_utils.py

* Update test_stable_diffusion_inpaint_legacy.py

* Update test_stable_diffusion_inpaint_legacy.py

* Update test_stable_diffusion_inpaint_legacy.py

* Update test_stable_diffusion_inpaint_legacy.py

* Update test_stable_diffusion_inpaint_legacy.py

* Update test_stable_diffusion_inpaint_legacy.py

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
hwuebben
2023-04-19 18:58:13 +02:00
committed by GitHub
parent c8fdfe4572
commit 3becd368b1
3 changed files with 108 additions and 15 deletions

View File

@@ -41,17 +41,17 @@ from .safety_checker import StableDiffusionSafetyChecker
logger = logging.get_logger(__name__)
def preprocess_image(image):
def preprocess_image(image, batch_size):
w, h = image.size
w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = np.vstack([image[None].transpose(0, 3, 1, 2)] * batch_size)
image = torch.from_numpy(image)
return 2.0 * image - 1.0
def preprocess_mask(mask, scale_factor=8):
def preprocess_mask(mask, batch_size, scale_factor=8):
if not isinstance(mask, torch.FloatTensor):
mask = mask.convert("L")
w, h = mask.size
@@ -59,7 +59,7 @@ def preprocess_mask(mask, scale_factor=8):
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
mask = np.vstack([mask[None]] * batch_size)
mask = 1 - mask # repaint white, keep black
mask = torch.from_numpy(mask)
return mask
@@ -521,14 +521,14 @@ class StableDiffusionInpaintPipelineLegacy(
return timesteps, num_inference_steps - t_start
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator):
def prepare_latents(self, image, timestep, num_images_per_prompt, dtype, device, generator):
image = image.to(device=self.device, dtype=dtype)
init_latent_dist = self.vae.encode(image).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
init_latents = self.vae.config.scaling_factor * init_latents
# Expand init_latents for batch_size and num_images_per_prompt
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
init_latents_orig = init_latents
# add noise to latents using the timesteps
@@ -659,9 +659,9 @@ class StableDiffusionInpaintPipelineLegacy(
# 4. Preprocess image and mask
if not isinstance(image, torch.FloatTensor):
image = preprocess_image(image)
image = preprocess_image(image, batch_size)
mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
mask_image = preprocess_mask(mask_image, batch_size, self.vae_scale_factor)
# 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
@@ -671,12 +671,12 @@ class StableDiffusionInpaintPipelineLegacy(
# 6. Prepare latent variables
# encode the init image into latents and scale the latents
latents, init_latents_orig, noise = self.prepare_latents(
image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator
image, latent_timestep, num_images_per_prompt, prompt_embeds.dtype, device, generator
)
# 7. Prepare mask latent
mask = mask_image.to(device=self.device, dtype=latents.dtype)
mask = torch.cat([mask] * batch_size * num_images_per_prompt)
mask = torch.cat([mask] * num_images_per_prompt)
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

View File

@@ -279,6 +279,16 @@ def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
return image
def preprocess_image(image: PIL.Image, batch_size: int):
w, h = image.size
w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = np.vstack([image[None].transpose(0, 3, 1, 2)] * batch_size)
image = torch.from_numpy(image)
return 2.0 * image - 1.0
def export_to_video(video_frames: List[np.ndarray], output_video_path: str = None) -> str:
if is_opencv_available():
import cv2

View File

@@ -34,7 +34,7 @@ from diffusers import (
VQModel,
)
from diffusers.utils import floats_tensor, load_image, nightly, slow, torch_device
from diffusers.utils.testing_utils import load_numpy, require_torch_gpu
from diffusers.utils.testing_utils import load_numpy, preprocess_image, require_torch_gpu
torch.backends.cuda.matmul.allow_tf32 = False
@@ -217,6 +217,55 @@ class StableDiffusionInpaintLegacyPipelineFastTests(unittest.TestCase):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_inpaint_legacy_batched(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
init_image = Image.fromarray(np.uint8(image)).convert("RGB")
init_images_tens = preprocess_image(init_image, batch_size=2)
init_masks_tens = init_images_tens + 4
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionInpaintPipelineLegacy(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=None,
feature_extractor=self.dummy_extractor,
)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=device).manual_seed(0)
images = sd_pipe(
[prompt] * 2,
generator=generator,
guidance_scale=6.0,
num_inference_steps=2,
output_type="np",
image=init_images_tens,
mask_image=init_masks_tens,
).images
assert images.shape == (2, 32, 32, 3)
image_slice_0 = images[0, -3:, -3:, -1].flatten()
image_slice_1 = images[1, -3:, -3:, -1].flatten()
expected_slice_0 = np.array([0.4697, 0.3770, 0.4096, 0.4653, 0.4497, 0.4183, 0.3950, 0.4668, 0.4672])
expected_slice_1 = np.array([0.4105, 0.4987, 0.5771, 0.4921, 0.4237, 0.5684, 0.5496, 0.4645, 0.5272])
assert np.abs(expected_slice_0 - image_slice_0).max() < 1e-2
assert np.abs(expected_slice_1 - image_slice_1).max() < 1e-2
def test_stable_diffusion_inpaint_legacy_negative_prompt(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
@@ -349,7 +398,7 @@ class StableDiffusionInpaintLegacyPipelineSlowTests(unittest.TestCase):
gc.collect()
torch.cuda.empty_cache()
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
def get_inputs(self, generator_device="cpu", seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
init_image = load_image(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
@@ -379,7 +428,7 @@ class StableDiffusionInpaintLegacyPipelineSlowTests(unittest.TestCase):
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
inputs = self.get_inputs(torch_device)
inputs = self.get_inputs()
image = pipe(**inputs).images
image_slice = image[0, 253:256, 253:256, -1].flatten()
@@ -388,6 +437,40 @@ class StableDiffusionInpaintLegacyPipelineSlowTests(unittest.TestCase):
assert np.abs(expected_slice - image_slice).max() < 1e-4
def test_stable_diffusion_inpaint_legacy_batched(self):
pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained(
"CompVis/stable-diffusion-v1-4", safety_checker=None
)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
inputs = self.get_inputs()
inputs["prompt"] = [inputs["prompt"]] * 2
inputs["image"] = preprocess_image(inputs["image"], batch_size=2)
mask = inputs["mask_image"].convert("L")
mask = np.array(mask).astype(np.float32) / 255.0
mask = torch.from_numpy(1 - mask)
masks = torch.vstack([mask[None][None]] * 2)
inputs["mask_image"] = masks
image = pipe(**inputs).images
assert image.shape == (2, 512, 512, 3)
image_slice_0 = image[0, 253:256, 253:256, -1].flatten()
image_slice_1 = image[1, 253:256, 253:256, -1].flatten()
expected_slice_0 = np.array(
[0.52093095, 0.4176447, 0.32752383, 0.6175223, 0.50563973, 0.36470804, 0.65460044, 0.5775188, 0.44332123]
)
expected_slice_1 = np.array(
[0.3592432, 0.4233033, 0.3914635, 0.31014425, 0.3702293, 0.39412856, 0.17526966, 0.2642669, 0.37480092]
)
assert np.abs(expected_slice_0 - image_slice_0).max() < 1e-4
assert np.abs(expected_slice_1 - image_slice_1).max() < 1e-4
def test_stable_diffusion_inpaint_legacy_k_lms(self):
pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained(
"CompVis/stable-diffusion-v1-4", safety_checker=None
@@ -397,7 +480,7 @@ class StableDiffusionInpaintLegacyPipelineSlowTests(unittest.TestCase):
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
inputs = self.get_inputs(torch_device)
inputs = self.get_inputs()
image = pipe(**inputs).images
image_slice = image[0, 253:256, 253:256, -1].flatten()
@@ -437,7 +520,7 @@ class StableDiffusionInpaintLegacyPipelineSlowTests(unittest.TestCase):
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
inputs = self.get_inputs(torch_device, dtype=torch.float16)
inputs = self.get_inputs()
pipe(**inputs, callback=callback_fn, callback_steps=1)
assert callback_fn.has_been_called
assert number_of_steps == 2