1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Fix using non-square images with UNet2DModel and DDIM/DDPM pipelines (#1289)

* fix non square images with UNet2DModel and DDIM/DDPM pipelines

* fix unet_2d `sample_size` docstring

* update pipeline tests for unet uncond

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Penn
2022-11-23 05:11:39 -05:00
committed by GitHub
parent 44e56de9aa
commit 8fd3a74322
5 changed files with 57 additions and 26 deletions

View File

@@ -43,8 +43,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
implements for all the model (such as downloading or saving, etc.)
Parameters:
sample_size (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*):
Input sample size.
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
Height and width of input/output sample.
in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image.
out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
@@ -71,7 +71,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
sample_size: Optional[int] = None,
sample_size: Optional[Union[int, Tuple[int, int]]] = None,
in_channels: int = 3,
out_channels: int = 3,
center_input_sample: bool = False,

View File

@@ -56,7 +56,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
implements for all the models (such as downloading or saving, etc.)
Parameters:
sample_size (`int`, *optional*): The size of the input sample.
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
Height and width of input/output sample.
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.

View File

@@ -89,7 +89,11 @@ class DDIMPipeline(DiffusionPipeline):
generator = None
# Sample gaussian noise to begin loop
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
if isinstance(self.unet.sample_size, int):
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
else:
image_shape = (batch_size, self.unet.in_channels, *self.unet.sample_size)
if self.device.type == "mps":
# randn does not work reproducibly on mps
image = torch.randn(image_shape, generator=generator)

View File

@@ -94,7 +94,11 @@ class DDPMPipeline(DiffusionPipeline):
generator = None
# Sample gaussian noise to begin loop
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
if isinstance(self.unet.sample_size, int):
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
else:
image_shape = (batch_size, self.unet.in_channels, *self.unet.sample_size)
if self.device.type == "mps":
# randn does not work reproducibly on mps
image = torch.randn(image_shape, generator=generator)

View File

@@ -18,6 +18,7 @@ import os
import random
import tempfile
import unittest
from functools import partial
import numpy as np
import torch
@@ -46,6 +47,7 @@ from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, slow, torch_device
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, require_torch_gpu
from parameterized import parameterized
from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
@@ -247,7 +249,6 @@ class CustomPipelineTests(unittest.TestCase):
class PipelineFastTests(unittest.TestCase):
@property
def dummy_image(self):
batch_size = 1
num_channels = 3
@@ -256,13 +257,12 @@ class PipelineFastTests(unittest.TestCase):
image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device)
return image
@property
def dummy_uncond_unet(self):
def dummy_uncond_unet(self, sample_size=32):
torch.manual_seed(0)
model = UNet2DModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
sample_size=sample_size,
in_channels=3,
out_channels=3,
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
@@ -270,13 +270,12 @@ class PipelineFastTests(unittest.TestCase):
)
return model
@property
def dummy_cond_unet(self):
def dummy_cond_unet(self, sample_size=32):
torch.manual_seed(0)
model = UNet2DConditionModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
sample_size=sample_size,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
@@ -285,13 +284,12 @@ class PipelineFastTests(unittest.TestCase):
)
return model
@property
def dummy_cond_unet_inpaint(self):
def dummy_cond_unet_inpaint(self, sample_size=32):
torch.manual_seed(0)
model = UNet2DConditionModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
sample_size=sample_size,
in_channels=9,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
@@ -300,7 +298,6 @@ class PipelineFastTests(unittest.TestCase):
)
return model
@property
def dummy_vq_model(self):
torch.manual_seed(0)
model = VQModel(
@@ -313,7 +310,6 @@ class PipelineFastTests(unittest.TestCase):
)
return model
@property
def dummy_vae(self):
torch.manual_seed(0)
model = AutoencoderKL(
@@ -326,7 +322,6 @@ class PipelineFastTests(unittest.TestCase):
)
return model
@property
def dummy_text_encoder(self):
torch.manual_seed(0)
config = CLIPTextConfig(
@@ -342,7 +337,6 @@ class PipelineFastTests(unittest.TestCase):
)
return CLIPTextModel(config)
@property
def dummy_extractor(self):
def extract(*args, **kwargs):
class Out:
@@ -357,15 +351,43 @@ class PipelineFastTests(unittest.TestCase):
return extract
def test_components(self):
@parameterized.expand(
[
[DDIMScheduler, DDIMPipeline, 32],
[partial(DDPMScheduler, predict_epsilon=True), DDPMPipeline, 32],
[DDIMScheduler, DDIMPipeline, (32, 64)],
[partial(DDPMScheduler, predict_epsilon=True), DDPMPipeline, (64, 32)],
]
)
def test_uncond_unet_components(self, scheduler_fn=DDPMScheduler, pipeline_fn=DDPMPipeline, sample_size=32):
unet = self.dummy_uncond_unet(sample_size)
# DDIM doesn't take `predict_epsilon`, and DDPM requires it -- so using partial in parameterized decorator
scheduler = scheduler_fn()
pipeline = pipeline_fn(unet, scheduler).to(torch_device)
# Device type MPS is not supported for torch.Generator() api.
if torch_device == "mps":
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)
out_image = pipeline(
generator=generator,
num_inference_steps=2,
output_type="np",
).images
sample_size = (sample_size, sample_size) if isinstance(sample_size, int) else sample_size
assert out_image.shape == (1, *sample_size, 3)
def test_stable_diffusion_components(self):
"""Test that components property works correctly"""
unet = self.dummy_cond_unet
unet = self.dummy_cond_unet()
scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
vae = self.dummy_vae()
bert = self.dummy_text_encoder()
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
image = self.dummy_image().cpu().permute(0, 2, 3, 1)[0]
init_image = Image.fromarray(np.uint8(image)).convert("RGB")
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))
@@ -377,7 +399,7 @@ class PipelineFastTests(unittest.TestCase):
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=None,
feature_extractor=self.dummy_extractor,
feature_extractor=self.dummy_extractor(),
).to(torch_device)
img2img = StableDiffusionImg2ImgPipeline(**inpaint.components).to(torch_device)
text2img = StableDiffusionPipeline(**inpaint.components).to(torch_device)