mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
RePaint fast tests and API conforming (#1701)
* add fast tests * better tests and fp16 * batch fixes * Reuse preprocessing * quickfix
This commit is contained in:
@@ -13,33 +13,61 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import PIL
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from ...models import UNet2DModel
|
||||
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from ...schedulers import RePaintScheduler
|
||||
from ...utils import PIL_INTERPOLATION, deprecate, logging
|
||||
|
||||
|
||||
def _preprocess_image(image: PIL.Image.Image):
|
||||
image = np.array(image.convert("RGB"))
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
|
||||
def _preprocess_image(image: Union[List, PIL.Image.Image, torch.Tensor]):
|
||||
if isinstance(image, torch.Tensor):
|
||||
return image
|
||||
elif isinstance(image, PIL.Image.Image):
|
||||
image = [image]
|
||||
|
||||
if isinstance(image[0], PIL.Image.Image):
|
||||
w, h = image[0].size
|
||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
|
||||
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
|
||||
image = np.concatenate(image, axis=0)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = 2.0 * image - 1.0
|
||||
image = torch.from_numpy(image)
|
||||
elif isinstance(image[0], torch.Tensor):
|
||||
image = torch.cat(image, dim=0)
|
||||
return image
|
||||
|
||||
|
||||
def _preprocess_mask(mask: PIL.Image.Image):
|
||||
mask = np.array(mask.convert("L"))
|
||||
mask = mask.astype(np.float32) / 255.0
|
||||
mask = mask[None, None]
|
||||
mask[mask < 0.5] = 0
|
||||
mask[mask >= 0.5] = 1
|
||||
mask = torch.from_numpy(mask)
|
||||
def _preprocess_mask(mask: Union[List, PIL.Image.Image, torch.Tensor]):
|
||||
if isinstance(mask, torch.Tensor):
|
||||
return mask
|
||||
elif isinstance(mask, PIL.Image.Image):
|
||||
mask = [mask]
|
||||
|
||||
if isinstance(mask[0], PIL.Image.Image):
|
||||
w, h = mask[0].size
|
||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
mask = [np.array(m.convert("L").resize((w, h), resample=PIL_INTERPOLATION["nearest"]))[None, :] for m in mask]
|
||||
mask = np.concatenate(mask, axis=0)
|
||||
mask = mask.astype(np.float32) / 255.0
|
||||
mask[mask < 0.5] = 0
|
||||
mask[mask >= 0.5] = 1
|
||||
mask = torch.from_numpy(mask)
|
||||
elif isinstance(mask[0], torch.Tensor):
|
||||
mask = torch.cat(mask, dim=0)
|
||||
return mask
|
||||
|
||||
|
||||
@@ -54,8 +82,8 @@ class RePaintPipeline(DiffusionPipeline):
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
original_image: Union[torch.FloatTensor, PIL.Image.Image],
|
||||
mask_image: Union[torch.FloatTensor, PIL.Image.Image],
|
||||
image: Union[torch.Tensor, PIL.Image.Image],
|
||||
mask_image: Union[torch.Tensor, PIL.Image.Image],
|
||||
num_inference_steps: int = 250,
|
||||
eta: float = 0.0,
|
||||
jump_length: int = 10,
|
||||
@@ -63,10 +91,11 @@ class RePaintPipeline(DiffusionPipeline):
|
||||
generator: Optional[torch.Generator] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
**kwargs,
|
||||
) -> Union[ImagePipelineOutput, Tuple]:
|
||||
r"""
|
||||
Args:
|
||||
original_image (`torch.FloatTensor` or `PIL.Image.Image`):
|
||||
image (`torch.FloatTensor` or `PIL.Image.Image`):
|
||||
The original image to inpaint on.
|
||||
mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
|
||||
The mask_image where 0.0 values define which part of the original image to inpaint (change).
|
||||
@@ -97,12 +126,14 @@ class RePaintPipeline(DiffusionPipeline):
|
||||
generated images.
|
||||
"""
|
||||
|
||||
if not isinstance(original_image, torch.FloatTensor):
|
||||
original_image = _preprocess_image(original_image)
|
||||
original_image = original_image.to(self.device)
|
||||
if not isinstance(mask_image, torch.FloatTensor):
|
||||
mask_image = _preprocess_mask(mask_image)
|
||||
mask_image = mask_image.to(self.device)
|
||||
message = "Please use `image` instead of `original_image`."
|
||||
original_image = deprecate("original_image", "0.15.0", message, take_from=kwargs)
|
||||
original_image = original_image or image
|
||||
|
||||
original_image = _preprocess_image(original_image)
|
||||
original_image = original_image.to(device=self.device, dtype=self.unet.dtype)
|
||||
mask_image = _preprocess_mask(mask_image)
|
||||
mask_image = mask_image.to(device=self.device, dtype=self.unet.dtype)
|
||||
|
||||
# sample gaussian noise to begin the loop
|
||||
image = torch.randn(
|
||||
@@ -110,14 +141,14 @@ class RePaintPipeline(DiffusionPipeline):
|
||||
generator=generator,
|
||||
device=self.device,
|
||||
)
|
||||
image = image.to(self.device)
|
||||
image = image.to(device=self.device, dtype=self.unet.dtype)
|
||||
|
||||
# set step values
|
||||
self.scheduler.set_timesteps(num_inference_steps, jump_length, jump_n_sample, self.device)
|
||||
self.scheduler.eta = eta
|
||||
|
||||
t_last = self.scheduler.timesteps[0] + 1
|
||||
for i, t in enumerate(tqdm(self.scheduler.timesteps)):
|
||||
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
|
||||
if t < t_last:
|
||||
# predict the noise residual
|
||||
model_output = self.unet(image, t).sample
|
||||
|
||||
@@ -270,9 +270,13 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
|
||||
# been observed.
|
||||
|
||||
# 5. Add noise
|
||||
noise = torch.randn(
|
||||
model_output.shape, dtype=model_output.dtype, generator=generator, device=model_output.device
|
||||
)
|
||||
device = model_output.device
|
||||
if device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator)
|
||||
noise = noise.to(device)
|
||||
else:
|
||||
noise = torch.randn(model_output.shape, generator=generator, device=device, dtype=model_output.dtype)
|
||||
std_dev_t = self.eta * self._get_variance(timestep) ** 0.5
|
||||
|
||||
variance = 0
|
||||
@@ -305,7 +309,12 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
for i in range(n):
|
||||
beta = self.betas[timestep + i]
|
||||
noise = torch.randn(sample.shape, generator=generator, device=sample.device)
|
||||
if sample.device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
noise = torch.randn(sample.shape, dtype=sample.dtype, generator=generator)
|
||||
noise = noise.to(sample.device)
|
||||
else:
|
||||
noise = torch.randn(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype)
|
||||
|
||||
# 10. Algorithm 1 Line 10 https://arxiv.org/pdf/2201.09865.pdf
|
||||
sample = (1 - beta) ** 0.5 * sample + beta**0.5 * noise
|
||||
|
||||
@@ -21,10 +21,68 @@ import torch
|
||||
from diffusers import RePaintPipeline, RePaintScheduler, UNet2DModel
|
||||
from diffusers.utils.testing_utils import load_image, load_numpy, require_torch_gpu, slow, torch_device
|
||||
|
||||
from ...test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
class RepaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = RePaintPipeline
|
||||
test_cpu_offload = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
|
||||
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
|
||||
)
|
||||
scheduler = RePaintScheduler()
|
||||
components = {"unet": unet, "scheduler": scheduler}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
image = np.random.RandomState(seed).standard_normal((1, 3, 32, 32))
|
||||
image = torch.from_numpy(image).to(device=device, dtype=torch.float32)
|
||||
mask = (image > 0).to(device=device, dtype=torch.float32)
|
||||
inputs = {
|
||||
"image": image,
|
||||
"mask_image": mask,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 5,
|
||||
"eta": 0.0,
|
||||
"jump_length": 2,
|
||||
"jump_n_sample": 2,
|
||||
"output_type": "numpy",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_repaint(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = RePaintPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = sd_pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
expected_slice = np.array([1.0000, 0.5426, 0.5497, 0.2200, 1.0000, 1.0000, 0.5623, 1.0000, 0.6274])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class RepaintPipelineIntegrationTests(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user