mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[PixArt-Alpha] Introduce resolution binning (#5739)
* feat: add resolution binning Co-authored-by: lawrence-cj <jschen@mail.dlut.edu.cn> * rename * debug * add :test * remove unused variable * set resolution_binning to False. --------- Co-authored-by: lawrence-cj <jschen@mail.dlut.edu.cn>
This commit is contained in:
@@ -19,6 +19,7 @@ import urllib.parse as ul
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import T5EncoderModel, T5Tokenizer
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
@@ -43,7 +44,6 @@ if is_bs4_available():
|
||||
if is_ftfy_available():
|
||||
import ftfy
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
@@ -60,6 +60,42 @@ EXAMPLE_DOC_STRING = """
|
||||
```
|
||||
"""
|
||||
|
||||
ASPECT_RATIO_1024_BIN = {
|
||||
"0.25": [512.0, 2048.0],
|
||||
"0.28": [512.0, 1856.0],
|
||||
"0.32": [576.0, 1792.0],
|
||||
"0.33": [576.0, 1728.0],
|
||||
"0.35": [576.0, 1664.0],
|
||||
"0.4": [640.0, 1600.0],
|
||||
"0.42": [640.0, 1536.0],
|
||||
"0.48": [704.0, 1472.0],
|
||||
"0.5": [704.0, 1408.0],
|
||||
"0.52": [704.0, 1344.0],
|
||||
"0.57": [768.0, 1344.0],
|
||||
"0.6": [768.0, 1280.0],
|
||||
"0.68": [832.0, 1216.0],
|
||||
"0.72": [832.0, 1152.0],
|
||||
"0.78": [896.0, 1152.0],
|
||||
"0.82": [896.0, 1088.0],
|
||||
"0.88": [960.0, 1088.0],
|
||||
"0.94": [960.0, 1024.0],
|
||||
"1.0": [1024.0, 1024.0],
|
||||
"1.07": [1024.0, 960.0],
|
||||
"1.13": [1088.0, 960.0],
|
||||
"1.21": [1088.0, 896.0],
|
||||
"1.29": [1152.0, 896.0],
|
||||
"1.38": [1152.0, 832.0],
|
||||
"1.46": [1216.0, 832.0],
|
||||
"1.67": [1280.0, 768.0],
|
||||
"1.75": [1344.0, 768.0],
|
||||
"2.0": [1408.0, 704.0],
|
||||
"2.09": [1472.0, 704.0],
|
||||
"2.4": [1536.0, 640.0],
|
||||
"2.5": [1600.0, 640.0],
|
||||
"3.0": [1728.0, 576.0],
|
||||
"4.0": [2048.0, 512.0],
|
||||
}
|
||||
|
||||
|
||||
class PixArtAlphaPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
@@ -495,6 +531,38 @@ class PixArtAlphaPipeline(DiffusionPipeline):
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]:
|
||||
"""Returns binned height and width."""
|
||||
ar = float(height / width)
|
||||
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
|
||||
default_hw = ratios[closest_ratio]
|
||||
return int(default_hw[0]), int(default_hw[1])
|
||||
|
||||
@staticmethod
|
||||
def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: int) -> torch.Tensor:
|
||||
orig_height, orig_width = samples.shape[2], samples.shape[3]
|
||||
|
||||
# Check if resizing is needed
|
||||
if orig_height != new_height or orig_width != new_width:
|
||||
ratio = max(new_height / orig_height, new_width / orig_width)
|
||||
resized_width = int(orig_width * ratio)
|
||||
resized_height = int(orig_height * ratio)
|
||||
|
||||
# Resize
|
||||
samples = F.interpolate(
|
||||
samples, size=(resized_height, resized_width), mode="bilinear", align_corners=False
|
||||
)
|
||||
|
||||
# Center Crop
|
||||
start_x = (resized_width - new_width) // 2
|
||||
end_x = start_x + new_width
|
||||
start_y = (resized_height - new_height) // 2
|
||||
end_y = start_y + new_height
|
||||
samples = samples[:, :, start_y:end_y, start_x:end_x]
|
||||
|
||||
return samples
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
@@ -518,6 +586,7 @@ class PixArtAlphaPipeline(DiffusionPipeline):
|
||||
callback_steps: int = 1,
|
||||
clean_caption: bool = True,
|
||||
mask_feature: bool = True,
|
||||
use_resolution_binning: bool = True,
|
||||
) -> Union[ImagePipelineOutput, Tuple]:
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@@ -580,6 +649,10 @@ class PixArtAlphaPipeline(DiffusionPipeline):
|
||||
be installed. If the dependencies are not installed, the embeddings will be created from the raw
|
||||
prompt.
|
||||
mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked.
|
||||
use_resolution_binning:
|
||||
(`bool` defaults to `True`): If set to `True`, the requested height and width are first mapped to the
|
||||
closest resolutions using `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images,
|
||||
they are resized back to the requested resolution. Useful for generating non-square images.
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -591,6 +664,10 @@ class PixArtAlphaPipeline(DiffusionPipeline):
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
height = height or self.transformer.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.transformer.config.sample_size * self.vae_scale_factor
|
||||
if use_resolution_binning:
|
||||
orig_height, orig_width = height, width
|
||||
height, width = self.classify_height_width_bin(height, width, ratios=ASPECT_RATIO_1024_BIN)
|
||||
|
||||
self.check_inputs(
|
||||
prompt, height, width, negative_prompt, callback_steps, prompt_embeds, negative_prompt_embeds
|
||||
)
|
||||
@@ -709,6 +786,8 @@ class PixArtAlphaPipeline(DiffusionPipeline):
|
||||
|
||||
if not output_type == "latent":
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
if use_resolution_binning:
|
||||
image = self.resize_and_crop_tensor(image, orig_width, orig_height)
|
||||
else:
|
||||
image = latents
|
||||
|
||||
|
||||
@@ -89,7 +89,8 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 5.0,
|
||||
"output_type": "numpy",
|
||||
"use_resolution_binning": False,
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
@@ -120,6 +121,7 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
"generator": generator,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
"output_type": output_type,
|
||||
"use_resolution_binning": False,
|
||||
}
|
||||
|
||||
# set all optional components to None
|
||||
@@ -154,6 +156,7 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
"generator": generator,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
"output_type": output_type,
|
||||
"use_resolution_binning": False,
|
||||
}
|
||||
|
||||
output_loaded = pipe_loaded(**inputs)[0]
|
||||
@@ -189,8 +192,8 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs, height=32, width=48).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
self.assertEqual(image.shape, (1, 32, 48, 3))
|
||||
|
||||
expected_slice = np.array([0.3859, 0.2987, 0.2333, 0.5243, 0.6721, 0.4436, 0.5292, 0.5373, 0.4416])
|
||||
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
|
||||
self.assertLessEqual(max_diff, 1e-3)
|
||||
@@ -219,6 +222,7 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
"num_inference_steps": num_inference_steps,
|
||||
"output_type": output_type,
|
||||
"num_images_per_prompt": 2,
|
||||
"use_resolution_binning": False,
|
||||
}
|
||||
|
||||
# set all optional components to None
|
||||
@@ -254,6 +258,7 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
"num_inference_steps": num_inference_steps,
|
||||
"output_type": output_type,
|
||||
"num_images_per_prompt": 2,
|
||||
"use_resolution_binning": False,
|
||||
}
|
||||
|
||||
output_loaded = pipe_loaded(**inputs)[0]
|
||||
|
||||
Reference in New Issue
Block a user