1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[PixArt-Alpha] Introduce resolution binning (#5739)

* feat: add resolution binning

Co-authored-by: lawrence-cj <jschen@mail.dlut.edu.cn>

* rename

* debug

* add :test

* remove unused variable

* set resolution_binning to False.

---------

Co-authored-by: lawrence-cj <jschen@mail.dlut.edu.cn>
This commit is contained in:
Sayak Paul
2023-11-14 08:34:59 +05:30
committed by GitHub
parent 5b231aa38b
commit ed759f0aee
2 changed files with 87 additions and 3 deletions

View File

@@ -19,6 +19,7 @@ import urllib.parse as ul
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from transformers import T5EncoderModel, T5Tokenizer
from ...image_processor import VaeImageProcessor
@@ -43,7 +44,6 @@ if is_bs4_available():
if is_ftfy_available():
import ftfy
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -60,6 +60,42 @@ EXAMPLE_DOC_STRING = """
```
"""
ASPECT_RATIO_1024_BIN = {
"0.25": [512.0, 2048.0],
"0.28": [512.0, 1856.0],
"0.32": [576.0, 1792.0],
"0.33": [576.0, 1728.0],
"0.35": [576.0, 1664.0],
"0.4": [640.0, 1600.0],
"0.42": [640.0, 1536.0],
"0.48": [704.0, 1472.0],
"0.5": [704.0, 1408.0],
"0.52": [704.0, 1344.0],
"0.57": [768.0, 1344.0],
"0.6": [768.0, 1280.0],
"0.68": [832.0, 1216.0],
"0.72": [832.0, 1152.0],
"0.78": [896.0, 1152.0],
"0.82": [896.0, 1088.0],
"0.88": [960.0, 1088.0],
"0.94": [960.0, 1024.0],
"1.0": [1024.0, 1024.0],
"1.07": [1024.0, 960.0],
"1.13": [1088.0, 960.0],
"1.21": [1088.0, 896.0],
"1.29": [1152.0, 896.0],
"1.38": [1152.0, 832.0],
"1.46": [1216.0, 832.0],
"1.67": [1280.0, 768.0],
"1.75": [1344.0, 768.0],
"2.0": [1408.0, 704.0],
"2.09": [1472.0, 704.0],
"2.4": [1536.0, 640.0],
"2.5": [1600.0, 640.0],
"3.0": [1728.0, 576.0],
"4.0": [2048.0, 512.0],
}
class PixArtAlphaPipeline(DiffusionPipeline):
r"""
@@ -495,6 +531,38 @@ class PixArtAlphaPipeline(DiffusionPipeline):
latents = latents * self.scheduler.init_noise_sigma
return latents
@staticmethod
def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]:
"""Returns binned height and width."""
ar = float(height / width)
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
default_hw = ratios[closest_ratio]
return int(default_hw[0]), int(default_hw[1])
@staticmethod
def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: int) -> torch.Tensor:
orig_height, orig_width = samples.shape[2], samples.shape[3]
# Check if resizing is needed
if orig_height != new_height or orig_width != new_width:
ratio = max(new_height / orig_height, new_width / orig_width)
resized_width = int(orig_width * ratio)
resized_height = int(orig_height * ratio)
# Resize
samples = F.interpolate(
samples, size=(resized_height, resized_width), mode="bilinear", align_corners=False
)
# Center Crop
start_x = (resized_width - new_width) // 2
end_x = start_x + new_width
start_y = (resized_height - new_height) // 2
end_y = start_y + new_height
samples = samples[:, :, start_y:end_y, start_x:end_x]
return samples
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -518,6 +586,7 @@ class PixArtAlphaPipeline(DiffusionPipeline):
callback_steps: int = 1,
clean_caption: bool = True,
mask_feature: bool = True,
use_resolution_binning: bool = True,
) -> Union[ImagePipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
@@ -580,6 +649,10 @@ class PixArtAlphaPipeline(DiffusionPipeline):
be installed. If the dependencies are not installed, the embeddings will be created from the raw
prompt.
mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked.
use_resolution_binning:
(`bool` defaults to `True`): If set to `True`, the requested height and width are first mapped to the
closest resolutions using `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images,
they are resized back to the requested resolution. Useful for generating non-square images.
Examples:
@@ -591,6 +664,10 @@ class PixArtAlphaPipeline(DiffusionPipeline):
# 1. Check inputs. Raise error if not correct
height = height or self.transformer.config.sample_size * self.vae_scale_factor
width = width or self.transformer.config.sample_size * self.vae_scale_factor
if use_resolution_binning:
orig_height, orig_width = height, width
height, width = self.classify_height_width_bin(height, width, ratios=ASPECT_RATIO_1024_BIN)
self.check_inputs(
prompt, height, width, negative_prompt, callback_steps, prompt_embeds, negative_prompt_embeds
)
@@ -709,6 +786,8 @@ class PixArtAlphaPipeline(DiffusionPipeline):
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
if use_resolution_binning:
image = self.resize_and_crop_tensor(image, orig_width, orig_height)
else:
image = latents

View File

@@ -89,7 +89,8 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 5.0,
"output_type": "numpy",
"use_resolution_binning": False,
"output_type": "np",
}
return inputs
@@ -120,6 +121,7 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"generator": generator,
"num_inference_steps": num_inference_steps,
"output_type": output_type,
"use_resolution_binning": False,
}
# set all optional components to None
@@ -154,6 +156,7 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"generator": generator,
"num_inference_steps": num_inference_steps,
"output_type": output_type,
"use_resolution_binning": False,
}
output_loaded = pipe_loaded(**inputs)[0]
@@ -189,8 +192,8 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs, height=32, width=48).images
image_slice = image[0, -3:, -3:, -1]
self.assertEqual(image.shape, (1, 32, 48, 3))
expected_slice = np.array([0.3859, 0.2987, 0.2333, 0.5243, 0.6721, 0.4436, 0.5292, 0.5373, 0.4416])
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3)
@@ -219,6 +222,7 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"num_inference_steps": num_inference_steps,
"output_type": output_type,
"num_images_per_prompt": 2,
"use_resolution_binning": False,
}
# set all optional components to None
@@ -254,6 +258,7 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"num_inference_steps": num_inference_steps,
"output_type": output_type,
"num_images_per_prompt": 2,
"use_resolution_binning": False,
}
output_loaded = pipe_loaded(**inputs)[0]