From f55873b783e2739124360d869955f9218817c211 Mon Sep 17 00:00:00 2001 From: Junsong Chen Date: Sun, 3 Mar 2024 13:01:21 +0800 Subject: [PATCH] Fix PixArt 256px inference (#6789) * feat 256px diffusers inference bug * change the max_length of T5 to pipeline config file * fix bug in convert_pixart_alpha_to_diffusers.py * Update scripts/convert_pixart_alpha_to_diffusers.py Co-authored-by: Sayak Paul * remove multi_scale_train parser * Update src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py Co-authored-by: YiYi Xu * styling * change `model_token_max_length` to call argument. * Refactoring * add: max_sequence_length to the docstring. --------- Co-authored-by: Sayak Paul Co-authored-by: YiYi Xu --- scripts/convert_pixart_alpha_to_diffusers.py | 8 +-- .../models/transformers/transformer_2d.py | 6 +- .../pixart_alpha/pipeline_pixart_alpha.py | 56 +++++++++++++++++-- 3 files changed, 59 insertions(+), 11 deletions(-) diff --git a/scripts/convert_pixart_alpha_to_diffusers.py b/scripts/convert_pixart_alpha_to_diffusers.py index fc037c87f5..228b479df0 100644 --- a/scripts/convert_pixart_alpha_to_diffusers.py +++ b/scripts/convert_pixart_alpha_to_diffusers.py @@ -9,11 +9,11 @@ from diffusers import AutoencoderKL, DPMSolverMultistepScheduler, PixArtAlphaPip ckpt_id = "PixArt-alpha/PixArt-alpha" # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/scripts/inference.py#L125 -interpolation_scale = {512: 1, 1024: 2} +interpolation_scale = {256: 0.5, 512: 1, 1024: 2} def main(args): - all_state_dict = torch.load(args.orig_ckpt_path) + all_state_dict = torch.load(args.orig_ckpt_path, map_location="cpu") state_dict = all_state_dict.pop("state_dict") converted_state_dict = {} @@ -22,7 +22,6 @@ def main(args): converted_state_dict["pos_embed.proj.bias"] = state_dict.pop("x_embedder.proj.bias") # Caption projection. - converted_state_dict["caption_projection.y_embedding"] = state_dict.pop("y_embedder.y_embedding") converted_state_dict["caption_projection.linear_1.weight"] = state_dict.pop("y_embedder.y_proj.fc1.weight") converted_state_dict["caption_projection.linear_1.bias"] = state_dict.pop("y_embedder.y_proj.fc1.bias") converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight") @@ -155,6 +154,7 @@ def main(args): assert transformer.pos_embed.pos_embed is not None state_dict.pop("pos_embed") + state_dict.pop("y_embedder.y_embedding") assert len(state_dict) == 0, f"State dict is not empty, {state_dict.keys()}" num_model_params = sum(p.numel() for p in transformer.parameters()) @@ -187,7 +187,7 @@ if __name__ == "__main__": "--image_size", default=1024, type=int, - choices=[512, 1024], + choices=[256, 512, 1024], required=False, help="Image size of pretrained model, either 512 or 1024.", ) diff --git a/src/diffusers/models/transformers/transformer_2d.py b/src/diffusers/models/transformers/transformer_2d.py index aaf5f2a901..bd632660f4 100644 --- a/src/diffusers/models/transformers/transformer_2d.py +++ b/src/diffusers/models/transformers/transformer_2d.py @@ -97,6 +97,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin): norm_eps: float = 1e-5, attention_type: str = "default", caption_channels: int = None, + interpolation_scale: float = None, ): super().__init__() self.use_linear_projection = use_linear_projection @@ -168,8 +169,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin): self.width = sample_size self.patch_size = patch_size - interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1 - interpolation_scale = max(interpolation_scale, 1) + interpolation_scale = ( + interpolation_scale if interpolation_scale is not None else max(self.config.sample_size // 64, 1) + ) self.pos_embed = PatchEmbed( height=sample_size, width=sample_size, diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py index b4453e63d2..c12bca90aa 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py @@ -133,6 +133,42 @@ ASPECT_RATIO_512_BIN = { "4.0": [1024.0, 256.0], } +ASPECT_RATIO_256_BIN = { + "0.25": [128.0, 512.0], + "0.28": [128.0, 464.0], + "0.32": [144.0, 448.0], + "0.33": [144.0, 432.0], + "0.35": [144.0, 416.0], + "0.4": [160.0, 400.0], + "0.42": [160.0, 384.0], + "0.48": [176.0, 368.0], + "0.5": [176.0, 352.0], + "0.52": [176.0, 336.0], + "0.57": [192.0, 336.0], + "0.6": [192.0, 320.0], + "0.68": [208.0, 304.0], + "0.72": [208.0, 288.0], + "0.78": [224.0, 288.0], + "0.82": [224.0, 272.0], + "0.88": [240.0, 272.0], + "0.94": [240.0, 256.0], + "1.0": [256.0, 256.0], + "1.07": [256.0, 240.0], + "1.13": [272.0, 240.0], + "1.21": [272.0, 224.0], + "1.29": [288.0, 224.0], + "1.38": [288.0, 208.0], + "1.46": [304.0, 208.0], + "1.67": [320.0, 192.0], + "1.75": [336.0, 192.0], + "2.0": [352.0, 176.0], + "2.09": [368.0, 176.0], + "2.4": [384.0, 160.0], + "2.5": [400.0, 160.0], + "3.0": [432.0, 144.0], + "4.0": [512.0, 128.0], +} + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( @@ -260,6 +296,7 @@ class PixArtAlphaPipeline(DiffusionPipeline): prompt_attention_mask: Optional[torch.FloatTensor] = None, negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, clean_caption: bool = False, + max_sequence_length: int = 120, **kwargs, ): r""" @@ -284,8 +321,9 @@ class PixArtAlphaPipeline(DiffusionPipeline): negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the "" string. - clean_caption (bool, defaults to `False`): + clean_caption (`bool`, defaults to `False`): If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 120): Maximum sequence length to use for the prompt. """ if "mask_feature" in kwargs: @@ -303,7 +341,7 @@ class PixArtAlphaPipeline(DiffusionPipeline): batch_size = prompt_embeds.shape[0] # See Section 3.1. of the paper. - max_length = 120 + max_length = max_sequence_length if prompt_embeds is None: prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) @@ -688,6 +726,7 @@ class PixArtAlphaPipeline(DiffusionPipeline): callback_steps: int = 1, clean_caption: bool = True, use_resolution_binning: bool = True, + max_sequence_length: int = 120, **kwargs, ) -> Union[ImagePipelineOutput, Tuple]: """ @@ -757,6 +796,7 @@ class PixArtAlphaPipeline(DiffusionPipeline): If set to `True`, the requested height and width are first mapped to the closest resolutions using `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to the requested resolution. Useful for generating non-square images. + max_sequence_length (`int` defaults to 120): Maximum sequence length to use with the `prompt`. Examples: @@ -772,9 +812,14 @@ class PixArtAlphaPipeline(DiffusionPipeline): height = height or self.transformer.config.sample_size * self.vae_scale_factor width = width or self.transformer.config.sample_size * self.vae_scale_factor if use_resolution_binning: - aspect_ratio_bin = ( - ASPECT_RATIO_1024_BIN if self.transformer.config.sample_size == 128 else ASPECT_RATIO_512_BIN - ) + if self.transformer.config.sample_size == 128: + aspect_ratio_bin = ASPECT_RATIO_1024_BIN + elif self.transformer.config.sample_size == 64: + aspect_ratio_bin = ASPECT_RATIO_512_BIN + elif self.transformer.config.sample_size == 32: + aspect_ratio_bin = ASPECT_RATIO_256_BIN + else: + raise ValueError("Invalid sample size") orig_height, orig_width = height, width height, width = self.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) @@ -822,6 +867,7 @@ class PixArtAlphaPipeline(DiffusionPipeline): prompt_attention_mask=prompt_attention_mask, negative_prompt_attention_mask=negative_prompt_attention_mask, clean_caption=clean_caption, + max_sequence_length=max_sequence_length, ) if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)