From 1955579ab720ba5e2b36b6e2ee8144a59b35bf62 Mon Sep 17 00:00:00 2001 From: ishan-modi Date: Wed, 12 Mar 2025 12:20:42 +0530 Subject: [PATCH] improve code quality --- .../convert_sana_controlnet_to_diffusers.py | 21 +++++++++++-------- src/diffusers/__init__.py | 8 +++---- src/diffusers/models/__init__.py | 4 ++-- .../models/controlnets/controlnet_sana.py | 2 +- src/diffusers/pipelines/__init__.py | 2 +- src/diffusers/pipelines/sana/pipeline_sana.py | 2 +- .../sana/pipeline_sana_controlnet.py | 6 +++--- 7 files changed, 24 insertions(+), 21 deletions(-) diff --git a/scripts/convert_sana_controlnet_to_diffusers.py b/scripts/convert_sana_controlnet_to_diffusers.py index 7fa2c4bf1f..cd5fdfd731 100644 --- a/scripts/convert_sana_controlnet_to_diffusers.py +++ b/scripts/convert_sana_controlnet_to_diffusers.py @@ -9,12 +9,8 @@ import torch from accelerate import init_empty_weights from huggingface_hub import hf_hub_download, snapshot_download from termcolor import colored -from transformers import AutoModelForCausalLM, AutoTokenizer from diffusers import ( - AutoencoderDC, - DPMSolverMultistepScheduler, - FlowMatchEulerDiscreteScheduler, SanaControlNetModel, ) from diffusers.models.modeling_utils import load_model_dict_into_meta @@ -53,7 +49,7 @@ def main(args): state_dict = all_state_dict.pop("state_dict") converted_state_dict = {} - # Patch embeddings. + # Patch embeddings. converted_state_dict["patch_embed.proj.weight"] = state_dict.pop("x_embedder.proj.weight") converted_state_dict["patch_embed.proj.bias"] = state_dict.pop("x_embedder.proj.bias") @@ -79,7 +75,7 @@ def main(args): # y norm converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight") - + # Positional embedding interpolation scale. interpolation_scale = {512: None, 1024: None, 2048: 1.0, 4096: 2.0} @@ -128,7 +124,9 @@ def main(args): q = state_dict.pop(f"controlnet.{depth}.copied_block.cross_attn.q_linear.weight") q_bias = state_dict.pop(f"controlnet.{depth}.copied_block.cross_attn.q_linear.bias") k, v = torch.chunk(state_dict.pop(f"controlnet.{depth}.copied_block.cross_attn.kv_linear.weight"), 2, dim=0) - k_bias, v_bias = torch.chunk(state_dict.pop(f"controlnet.{depth}.copied_block.cross_attn.kv_linear.bias"), 2, dim=0) + k_bias, v_bias = torch.chunk( + state_dict.pop(f"controlnet.{depth}.copied_block.cross_attn.kv_linear.bias"), 2, dim=0 + ) converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.weight"] = q converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.bias"] = q_bias @@ -145,7 +143,9 @@ def main(args): ) # ControlNet After Projection - converted_state_dict[f"controlnet_blocks.{depth}.weight"] = state_dict.pop(f"controlnet.{depth}.after_proj.weight") + converted_state_dict[f"controlnet_blocks.{depth}.weight"] = state_dict.pop( + f"controlnet.{depth}.after_proj.weight" + ) converted_state_dict[f"controlnet_blocks.{depth}.bias"] = state_dict.pop(f"controlnet.{depth}.after_proj.bias") # ControlNet @@ -205,7 +205,10 @@ if __name__ == "__main__": help="Image size of pretrained model, 512, 1024, 2048 or 4096.", ) parser.add_argument( - "--model_type", default="SanaMS_1600M_P1_ControlNet_D7", type=str, choices=["SanaMS_1600M_P1_D20", "SanaMS_600M_P1_D28"] + "--model_type", + default="SanaMS_1600M_P1_ControlNet_D7", + type=str, + choices=["SanaMS_1600M_P1_D20", "SanaMS_600M_P1_D28"], ) parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.") parser.add_argument("--dtype", default="fp16", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.") diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 227fac3030..1d35588a9b 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -132,8 +132,8 @@ else: "OmniGenTransformer2DModel", "PixArtTransformer2DModel", "PriorTransformer", - "SanaTransformer2DModel", "SanaControlNetModel", + "SanaTransformer2DModel", "SD3ControlNetModel", "SD3MultiControlNetModel", "SD3Transformer2DModel", @@ -365,9 +365,9 @@ else: "PixArtSigmaPAGPipeline", "PixArtSigmaPipeline", "ReduxImageEncoder", + "SanaControlNetPipeline", "SanaPAGPipeline", "SanaPipeline", - "SanaControlNetPipeline", "SemanticStableDiffusionPipeline", "ShapEImg2ImgPipeline", "ShapEPipeline", @@ -666,8 +666,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: OmniGenTransformer2DModel, PixArtTransformer2DModel, PriorTransformer, - SanaTransformer2DModel, SanaControlNetModel, + SanaTransformer2DModel, SD3ControlNetModel, SD3MultiControlNetModel, SD3Transformer2DModel, @@ -878,9 +878,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: PixArtSigmaPAGPipeline, PixArtSigmaPipeline, ReduxImageEncoder, + SanaControlNetPipeline, SanaPAGPipeline, SanaPipeline, - SanaControlNetPipeline, SemanticStableDiffusionPipeline, ShapEImg2ImgPipeline, ShapEPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 15e6d26213..2758884e25 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -48,13 +48,13 @@ if is_torch_available(): "HunyuanDiT2DControlNetModel", "HunyuanDiT2DMultiControlNetModel", ] + _import_structure["controlnets.controlnet_sana"] = ["SanaControlNetModel"] _import_structure["controlnets.controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"] _import_structure["controlnets.controlnet_sparsectrl"] = ["SparseControlNetModel"] _import_structure["controlnets.controlnet_union"] = ["ControlNetUnionModel"] _import_structure["controlnets.controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"] _import_structure["controlnets.multicontrolnet"] = ["MultiControlNetModel"] _import_structure["controlnets.multicontrolnet_union"] = ["MultiControlNetUnionModel"] - _import_structure["controlnets.controlnet_sana"] = ["SanaControlNetModel"] _import_structure["embeddings"] = ["ImageProjection"] _import_structure["modeling_utils"] = ["ModelMixin"] _import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"] @@ -132,11 +132,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: HunyuanDiT2DMultiControlNetModel, MultiControlNetModel, MultiControlNetUnionModel, + SanaControlNetModel, SD3ControlNetModel, SD3MultiControlNetModel, SparseControlNetModel, UNetControlNetXSModel, - SanaControlNetModel, ) from .embeddings import ImageProjection from .modeling_utils import ModelMixin diff --git a/src/diffusers/models/controlnets/controlnet_sana.py b/src/diffusers/models/controlnets/controlnet_sana.py index a39bb44574..f3055ad7e5 100644 --- a/src/diffusers/models/controlnets/controlnet_sana.py +++ b/src/diffusers/models/controlnets/controlnet_sana.py @@ -267,7 +267,7 @@ class SanaControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin): post_patch_width, ) block_res_samples = block_res_samples + (hidden_states,) - + # 3. ControlNet blocks controlnet_block_res_samples = () for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks): diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 2d427657d4..baf27a68f0 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -651,7 +651,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: from .paint_by_example import PaintByExamplePipeline from .pia import PIAPipeline from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline - from .sana import SanaPipeline, SanaControlNetPipeline + from .sana import SanaControlNetPipeline, SanaPipeline from .semantic_stable_diffusion import SemanticStableDiffusionPipeline from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline from .stable_audio import StableAudioPipeline, StableAudioProjectionModel diff --git a/src/diffusers/pipelines/sana/pipeline_sana.py b/src/diffusers/pipelines/sana/pipeline_sana.py index 8f80a9ddba..7324949c1b 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana.py +++ b/src/diffusers/pipelines/sana/pipeline_sana.py @@ -927,7 +927,7 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin): # learned sigma if self.transformer.config.out_channels // 2 == latent_channels: - noise_pred = noise_pred.chunk(2, dim=1)[0] + noise_pred = noise_pred.chunk(2, dim=1)[0] # compute previous image: x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] diff --git a/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py b/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py index 4326ced98a..efd4700afd 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py +++ b/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py @@ -25,7 +25,7 @@ from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFa from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, PixArtImageProcessor from ...loaders import SanaLoraLoaderMixin -from ...models import AutoencoderDC, SanaTransformer2DModel, SanaControlNetModel +from ...models import AutoencoderDC, SanaControlNetModel, SanaTransformer2DModel from ...schedulers import DPMSolverMultistepScheduler from ...utils import ( BACKENDS_MAPPING, @@ -210,12 +210,12 @@ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin): super().__init__() self.register_modules( - tokenizer=tokenizer, + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, controlnet=controlnet, - scheduler=scheduler + scheduler=scheduler, ) self.vae_scale_factor = (