1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Allow resolutions that are not multiples of 64 (#505)

* Allow resolutions that are not multiples of 64

* ran black

* fix bug

* add test

* more explanation

* more comments

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Josh Achiam
2022-09-30 00:54:40 -07:00
committed by GitHub
parent 9ebaea545f
commit a784be2ebe
5 changed files with 106 additions and 12 deletions

View File

@@ -193,7 +193,9 @@ if os.name == "nt": # windows
else:
extras["flax"] = deps_list("jax", "jaxlib", "flax")
extras["dev"] = extras["quality"] + extras["test"] + extras["training"] + extras["docs"] + extras["torch"] + extras["flax"]
extras["dev"] = (
extras["quality"] + extras["test"] + extras["training"] + extras["docs"] + extras["torch"] + extras["flax"]
)
install_requires = [
deps["importlib_metadata"],

View File

@@ -34,12 +34,18 @@ class Upsample2D(nn.Module):
else:
self.Conv2d_0 = conv
def forward(self, hidden_states):
def forward(self, hidden_states, output_size=None):
assert hidden_states.shape[1] == self.channels
if self.use_conv_transpose:
return self.conv(hidden_states)
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
# if `output_size` is passed we force the interpolation output
# size and do not make use of `scale_factor=2`
if output_size is None:
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
else:
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if self.use_conv:

View File

@@ -7,7 +7,7 @@ import torch.utils.checkpoint
from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin
from ..utils import BaseOutput
from ..utils import BaseOutput, logging
from .embeddings import TimestepEmbedding, Timesteps
from .unet_blocks import (
CrossAttnDownBlock2D,
@@ -20,6 +20,9 @@ from .unet_blocks import (
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class UNet2DConditionOutput(BaseOutput):
"""
@@ -145,15 +148,25 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
resnet_groups=norm_num_groups,
)
# count how many layers upsample the images
self.num_upsamplers = 0
# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
is_final_block = i == len(block_out_channels) - 1
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
is_final_block = i == len(block_out_channels) - 1
# add upsample block for all BUT final layer
if not is_final_block:
add_upsample = True
self.num_upsamplers += 1
else:
add_upsample = False
up_block = get_up_block(
up_block_type,
@@ -162,7 +175,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
add_upsample=not is_final_block,
add_upsample=add_upsample,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
@@ -223,6 +236,20 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
# By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
# However, the upsampling interpolation output size can be forced to fit any upsampling size
# on the fly if necessary.
default_overall_up_factor = 2**self.num_upsamplers
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
forward_upsample_size = False
upsample_size = None
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
logger.info("Forward upsample size to force interpolation output size.")
forward_upsample_size = True
# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
@@ -262,20 +289,29 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
# 5. up
for upsample_block in self.up_blocks:
for i, upsample_block in enumerate(self.up_blocks):
is_final_block = i == len(self.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
# if we have not reached the final block and need to forward the
# upsample size, we do it here
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
upsample_size=upsample_size,
)
else:
sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples)
sample = upsample_block(
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
)
# 6. post-process
# make sure hidden states is in float32
# when running in half-precision

View File

@@ -1126,6 +1126,7 @@ class CrossAttnUpBlock2D(nn.Module):
res_hidden_states_tuple,
temb=None,
encoder_hidden_states=None,
upsample_size=None,
):
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
@@ -1151,7 +1152,7 @@ class CrossAttnUpBlock2D(nn.Module):
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
@@ -1204,7 +1205,7 @@ class UpBlock2D(nn.Module):
self.gradient_checkpointing = False
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
for resnet in self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
@@ -1225,7 +1226,7 @@ class UpBlock2D(nn.Module):
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states

View File

@@ -336,6 +336,55 @@ class PipelineFastTests(unittest.TestCase):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_ddim_factor_8(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=self.dummy_safety_checker,
feature_extractor=self.dummy_extractor,
)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=device).manual_seed(0)
output = sd_pipe(
[prompt],
generator=generator,
guidance_scale=6.0,
height=536,
width=536,
num_inference_steps=2,
output_type="np",
)
image = output.images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 134, 134, 3)
expected_slice = np.array([0.7834, 0.5488, 0.5781, 0.46, 0.3609, 0.5369, 0.542, 0.4855, 0.5557])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_pndm(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet