mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Fix using non-square images with UNet2DModel and DDIM/DDPM pipelines (#1289)
* fix non square images with UNet2DModel and DDIM/DDPM pipelines * fix unet_2d `sample_size` docstring * update pipeline tests for unet uncond Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -43,8 +43,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
implements for all the model (such as downloading or saving, etc.)
|
||||
|
||||
Parameters:
|
||||
sample_size (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*):
|
||||
Input sample size.
|
||||
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
||||
Height and width of input/output sample.
|
||||
in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image.
|
||||
out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
|
||||
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
|
||||
@@ -71,7 +71,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
sample_size: Optional[int] = None,
|
||||
sample_size: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
center_input_sample: bool = False,
|
||||
|
||||
@@ -56,7 +56,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
implements for all the models (such as downloading or saving, etc.)
|
||||
|
||||
Parameters:
|
||||
sample_size (`int`, *optional*): The size of the input sample.
|
||||
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
||||
Height and width of input/output sample.
|
||||
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
|
||||
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
|
||||
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
|
||||
|
||||
@@ -89,7 +89,11 @@ class DDIMPipeline(DiffusionPipeline):
|
||||
generator = None
|
||||
|
||||
# Sample gaussian noise to begin loop
|
||||
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
|
||||
if isinstance(self.unet.sample_size, int):
|
||||
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
|
||||
else:
|
||||
image_shape = (batch_size, self.unet.in_channels, *self.unet.sample_size)
|
||||
|
||||
if self.device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
image = torch.randn(image_shape, generator=generator)
|
||||
|
||||
@@ -94,7 +94,11 @@ class DDPMPipeline(DiffusionPipeline):
|
||||
generator = None
|
||||
|
||||
# Sample gaussian noise to begin loop
|
||||
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
|
||||
if isinstance(self.unet.sample_size, int):
|
||||
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
|
||||
else:
|
||||
image_shape = (batch_size, self.unet.in_channels, *self.unet.sample_size)
|
||||
|
||||
if self.device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
image = torch.randn(image_shape, generator=generator)
|
||||
|
||||
@@ -18,6 +18,7 @@ import os
|
||||
import random
|
||||
import tempfile
|
||||
import unittest
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -46,6 +47,7 @@ from diffusers.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
||||
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, slow, torch_device
|
||||
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, require_torch_gpu
|
||||
from parameterized import parameterized
|
||||
from PIL import Image
|
||||
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
@@ -247,7 +249,6 @@ class CustomPipelineTests(unittest.TestCase):
|
||||
|
||||
|
||||
class PipelineFastTests(unittest.TestCase):
|
||||
@property
|
||||
def dummy_image(self):
|
||||
batch_size = 1
|
||||
num_channels = 3
|
||||
@@ -256,13 +257,12 @@ class PipelineFastTests(unittest.TestCase):
|
||||
image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device)
|
||||
return image
|
||||
|
||||
@property
|
||||
def dummy_uncond_unet(self):
|
||||
def dummy_uncond_unet(self, sample_size=32):
|
||||
torch.manual_seed(0)
|
||||
model = UNet2DModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
sample_size=sample_size,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
|
||||
@@ -270,13 +270,12 @@ class PipelineFastTests(unittest.TestCase):
|
||||
)
|
||||
return model
|
||||
|
||||
@property
|
||||
def dummy_cond_unet(self):
|
||||
def dummy_cond_unet(self, sample_size=32):
|
||||
torch.manual_seed(0)
|
||||
model = UNet2DConditionModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
sample_size=sample_size,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
@@ -285,13 +284,12 @@ class PipelineFastTests(unittest.TestCase):
|
||||
)
|
||||
return model
|
||||
|
||||
@property
|
||||
def dummy_cond_unet_inpaint(self):
|
||||
def dummy_cond_unet_inpaint(self, sample_size=32):
|
||||
torch.manual_seed(0)
|
||||
model = UNet2DConditionModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
sample_size=sample_size,
|
||||
in_channels=9,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
@@ -300,7 +298,6 @@ class PipelineFastTests(unittest.TestCase):
|
||||
)
|
||||
return model
|
||||
|
||||
@property
|
||||
def dummy_vq_model(self):
|
||||
torch.manual_seed(0)
|
||||
model = VQModel(
|
||||
@@ -313,7 +310,6 @@ class PipelineFastTests(unittest.TestCase):
|
||||
)
|
||||
return model
|
||||
|
||||
@property
|
||||
def dummy_vae(self):
|
||||
torch.manual_seed(0)
|
||||
model = AutoencoderKL(
|
||||
@@ -326,7 +322,6 @@ class PipelineFastTests(unittest.TestCase):
|
||||
)
|
||||
return model
|
||||
|
||||
@property
|
||||
def dummy_text_encoder(self):
|
||||
torch.manual_seed(0)
|
||||
config = CLIPTextConfig(
|
||||
@@ -342,7 +337,6 @@ class PipelineFastTests(unittest.TestCase):
|
||||
)
|
||||
return CLIPTextModel(config)
|
||||
|
||||
@property
|
||||
def dummy_extractor(self):
|
||||
def extract(*args, **kwargs):
|
||||
class Out:
|
||||
@@ -357,15 +351,43 @@ class PipelineFastTests(unittest.TestCase):
|
||||
|
||||
return extract
|
||||
|
||||
def test_components(self):
|
||||
@parameterized.expand(
|
||||
[
|
||||
[DDIMScheduler, DDIMPipeline, 32],
|
||||
[partial(DDPMScheduler, predict_epsilon=True), DDPMPipeline, 32],
|
||||
[DDIMScheduler, DDIMPipeline, (32, 64)],
|
||||
[partial(DDPMScheduler, predict_epsilon=True), DDPMPipeline, (64, 32)],
|
||||
]
|
||||
)
|
||||
def test_uncond_unet_components(self, scheduler_fn=DDPMScheduler, pipeline_fn=DDPMPipeline, sample_size=32):
|
||||
unet = self.dummy_uncond_unet(sample_size)
|
||||
# DDIM doesn't take `predict_epsilon`, and DDPM requires it -- so using partial in parameterized decorator
|
||||
scheduler = scheduler_fn()
|
||||
pipeline = pipeline_fn(unet, scheduler).to(torch_device)
|
||||
|
||||
# Device type MPS is not supported for torch.Generator() api.
|
||||
if torch_device == "mps":
|
||||
generator = torch.manual_seed(0)
|
||||
else:
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
|
||||
out_image = pipeline(
|
||||
generator=generator,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
).images
|
||||
sample_size = (sample_size, sample_size) if isinstance(sample_size, int) else sample_size
|
||||
assert out_image.shape == (1, *sample_size, 3)
|
||||
|
||||
def test_stable_diffusion_components(self):
|
||||
"""Test that components property works correctly"""
|
||||
unet = self.dummy_cond_unet
|
||||
unet = self.dummy_cond_unet()
|
||||
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
vae = self.dummy_vae()
|
||||
bert = self.dummy_text_encoder()
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
|
||||
image = self.dummy_image().cpu().permute(0, 2, 3, 1)[0]
|
||||
init_image = Image.fromarray(np.uint8(image)).convert("RGB")
|
||||
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))
|
||||
|
||||
@@ -377,7 +399,7 @@ class PipelineFastTests(unittest.TestCase):
|
||||
text_encoder=bert,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=None,
|
||||
feature_extractor=self.dummy_extractor,
|
||||
feature_extractor=self.dummy_extractor(),
|
||||
).to(torch_device)
|
||||
img2img = StableDiffusionImg2ImgPipeline(**inpaint.components).to(torch_device)
|
||||
text2img = StableDiffusionPipeline(**inpaint.components).to(torch_device)
|
||||
|
||||
Reference in New Issue
Block a user