1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

improve code quality

This commit is contained in:
ishan-modi
2025-03-12 12:20:42 +05:30
parent 321193beb4
commit 1955579ab7
7 changed files with 24 additions and 21 deletions

View File

@@ -9,12 +9,8 @@ import torch
from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download, snapshot_download
from termcolor import colored
from transformers import AutoModelForCausalLM, AutoTokenizer
from diffusers import (
AutoencoderDC,
DPMSolverMultistepScheduler,
FlowMatchEulerDiscreteScheduler,
SanaControlNetModel,
)
from diffusers.models.modeling_utils import load_model_dict_into_meta
@@ -53,7 +49,7 @@ def main(args):
state_dict = all_state_dict.pop("state_dict")
converted_state_dict = {}
# Patch embeddings.
# Patch embeddings.
converted_state_dict["patch_embed.proj.weight"] = state_dict.pop("x_embedder.proj.weight")
converted_state_dict["patch_embed.proj.bias"] = state_dict.pop("x_embedder.proj.bias")
@@ -79,7 +75,7 @@ def main(args):
# y norm
converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight")
# Positional embedding interpolation scale.
interpolation_scale = {512: None, 1024: None, 2048: 1.0, 4096: 2.0}
@@ -128,7 +124,9 @@ def main(args):
q = state_dict.pop(f"controlnet.{depth}.copied_block.cross_attn.q_linear.weight")
q_bias = state_dict.pop(f"controlnet.{depth}.copied_block.cross_attn.q_linear.bias")
k, v = torch.chunk(state_dict.pop(f"controlnet.{depth}.copied_block.cross_attn.kv_linear.weight"), 2, dim=0)
k_bias, v_bias = torch.chunk(state_dict.pop(f"controlnet.{depth}.copied_block.cross_attn.kv_linear.bias"), 2, dim=0)
k_bias, v_bias = torch.chunk(
state_dict.pop(f"controlnet.{depth}.copied_block.cross_attn.kv_linear.bias"), 2, dim=0
)
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.weight"] = q
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.bias"] = q_bias
@@ -145,7 +143,9 @@ def main(args):
)
# ControlNet After Projection
converted_state_dict[f"controlnet_blocks.{depth}.weight"] = state_dict.pop(f"controlnet.{depth}.after_proj.weight")
converted_state_dict[f"controlnet_blocks.{depth}.weight"] = state_dict.pop(
f"controlnet.{depth}.after_proj.weight"
)
converted_state_dict[f"controlnet_blocks.{depth}.bias"] = state_dict.pop(f"controlnet.{depth}.after_proj.bias")
# ControlNet
@@ -205,7 +205,10 @@ if __name__ == "__main__":
help="Image size of pretrained model, 512, 1024, 2048 or 4096.",
)
parser.add_argument(
"--model_type", default="SanaMS_1600M_P1_ControlNet_D7", type=str, choices=["SanaMS_1600M_P1_D20", "SanaMS_600M_P1_D28"]
"--model_type",
default="SanaMS_1600M_P1_ControlNet_D7",
type=str,
choices=["SanaMS_1600M_P1_D20", "SanaMS_600M_P1_D28"],
)
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
parser.add_argument("--dtype", default="fp16", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.")

View File

@@ -132,8 +132,8 @@ else:
"OmniGenTransformer2DModel",
"PixArtTransformer2DModel",
"PriorTransformer",
"SanaTransformer2DModel",
"SanaControlNetModel",
"SanaTransformer2DModel",
"SD3ControlNetModel",
"SD3MultiControlNetModel",
"SD3Transformer2DModel",
@@ -365,9 +365,9 @@ else:
"PixArtSigmaPAGPipeline",
"PixArtSigmaPipeline",
"ReduxImageEncoder",
"SanaControlNetPipeline",
"SanaPAGPipeline",
"SanaPipeline",
"SanaControlNetPipeline",
"SemanticStableDiffusionPipeline",
"ShapEImg2ImgPipeline",
"ShapEPipeline",
@@ -666,8 +666,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
OmniGenTransformer2DModel,
PixArtTransformer2DModel,
PriorTransformer,
SanaTransformer2DModel,
SanaControlNetModel,
SanaTransformer2DModel,
SD3ControlNetModel,
SD3MultiControlNetModel,
SD3Transformer2DModel,
@@ -878,9 +878,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
PixArtSigmaPAGPipeline,
PixArtSigmaPipeline,
ReduxImageEncoder,
SanaControlNetPipeline,
SanaPAGPipeline,
SanaPipeline,
SanaControlNetPipeline,
SemanticStableDiffusionPipeline,
ShapEImg2ImgPipeline,
ShapEPipeline,

View File

@@ -48,13 +48,13 @@ if is_torch_available():
"HunyuanDiT2DControlNetModel",
"HunyuanDiT2DMultiControlNetModel",
]
_import_structure["controlnets.controlnet_sana"] = ["SanaControlNetModel"]
_import_structure["controlnets.controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
_import_structure["controlnets.controlnet_sparsectrl"] = ["SparseControlNetModel"]
_import_structure["controlnets.controlnet_union"] = ["ControlNetUnionModel"]
_import_structure["controlnets.controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
_import_structure["controlnets.multicontrolnet"] = ["MultiControlNetModel"]
_import_structure["controlnets.multicontrolnet_union"] = ["MultiControlNetUnionModel"]
_import_structure["controlnets.controlnet_sana"] = ["SanaControlNetModel"]
_import_structure["embeddings"] = ["ImageProjection"]
_import_structure["modeling_utils"] = ["ModelMixin"]
_import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"]
@@ -132,11 +132,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
HunyuanDiT2DMultiControlNetModel,
MultiControlNetModel,
MultiControlNetUnionModel,
SanaControlNetModel,
SD3ControlNetModel,
SD3MultiControlNetModel,
SparseControlNetModel,
UNetControlNetXSModel,
SanaControlNetModel,
)
from .embeddings import ImageProjection
from .modeling_utils import ModelMixin

View File

@@ -267,7 +267,7 @@ class SanaControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
post_patch_width,
)
block_res_samples = block_res_samples + (hidden_states,)
# 3. ControlNet blocks
controlnet_block_res_samples = ()
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):

View File

@@ -651,7 +651,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .paint_by_example import PaintByExamplePipeline
from .pia import PIAPipeline
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
from .sana import SanaPipeline, SanaControlNetPipeline
from .sana import SanaControlNetPipeline, SanaPipeline
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
from .stable_audio import StableAudioPipeline, StableAudioProjectionModel

View File

@@ -927,7 +927,7 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
# learned sigma
if self.transformer.config.out_channels // 2 == latent_channels:
noise_pred = noise_pred.chunk(2, dim=1)[0]
noise_pred = noise_pred.chunk(2, dim=1)[0]
# compute previous image: x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

View File

@@ -25,7 +25,7 @@ from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFa
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput, PixArtImageProcessor
from ...loaders import SanaLoraLoaderMixin
from ...models import AutoencoderDC, SanaTransformer2DModel, SanaControlNetModel
from ...models import AutoencoderDC, SanaControlNetModel, SanaTransformer2DModel
from ...schedulers import DPMSolverMultistepScheduler
from ...utils import (
BACKENDS_MAPPING,
@@ -210,12 +210,12 @@ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
super().__init__()
self.register_modules(
tokenizer=tokenizer,
tokenizer=tokenizer,
text_encoder=text_encoder,
vae=vae,
transformer=transformer,
controlnet=controlnet,
scheduler=scheduler
scheduler=scheduler,
)
self.vae_scale_factor = (