mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix typos (#9739)
* update * update * update * update * update * update
This commit is contained in:
@@ -313,6 +313,26 @@ image = pipe("a picture of a cat holding a sign that says hello world").images[0
|
||||
image.save('sd3-single-file-t5-fp8.png')
|
||||
```
|
||||
|
||||
### Loading the single file checkpoint for the Stable Diffusion 3.5 Transformer Model
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import SD3Transformer2DModel, StableDiffusion3Pipeline
|
||||
|
||||
transformer = SD3Transformer2DModel.from_single_file(
|
||||
"https://huggingface.co/stabilityai/stable-diffusion-3.5-large-turbo/blob/main/sd3.5_large.safetensors",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
pipe = StableDiffusion3Pipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-3.5-large",
|
||||
transformer=transformer,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
image = pipe("a cat holding a sign that says hello world").images[0]
|
||||
image.save("sd35.png")
|
||||
```
|
||||
|
||||
## StableDiffusion3Pipeline
|
||||
|
||||
[[autodoc]] StableDiffusion3Pipeline
|
||||
|
||||
@@ -75,6 +75,7 @@ CHECKPOINT_KEY_NAMES = {
|
||||
"stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight",
|
||||
"stable_cascade_stage_c": "clip_txt_mapper.weight",
|
||||
"sd3": "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
|
||||
"sd35_large": "model.diffusion_model.joint_blocks.37.x_block.mlp.fc1.weight",
|
||||
"animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe",
|
||||
"animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias",
|
||||
"animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
|
||||
@@ -113,6 +114,9 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
"sd3": {
|
||||
"pretrained_model_name_or_path": "stabilityai/stable-diffusion-3-medium-diffusers",
|
||||
},
|
||||
"sd35_large": {
|
||||
"pretrained_model_name_or_path": "stabilityai/stable-diffusion-3.5-large",
|
||||
},
|
||||
"animatediff_v1": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5"},
|
||||
"animatediff_v2": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-2"},
|
||||
"animatediff_v3": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-3"},
|
||||
@@ -504,9 +508,12 @@ def infer_diffusers_model_type(checkpoint):
|
||||
):
|
||||
model_type = "stable_cascade_stage_b"
|
||||
|
||||
elif CHECKPOINT_KEY_NAMES["sd3"] in checkpoint:
|
||||
elif CHECKPOINT_KEY_NAMES["sd3"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["sd3"]].shape[-1] == 9216:
|
||||
model_type = "sd3"
|
||||
|
||||
elif CHECKPOINT_KEY_NAMES["sd35_large"] in checkpoint:
|
||||
model_type = "sd35_large"
|
||||
|
||||
elif CHECKPOINT_KEY_NAMES["animatediff"] in checkpoint:
|
||||
if CHECKPOINT_KEY_NAMES["animatediff_scribble"] in checkpoint:
|
||||
model_type = "animatediff_scribble"
|
||||
@@ -1670,6 +1677,22 @@ def swap_scale_shift(weight, dim):
|
||||
return new_weight
|
||||
|
||||
|
||||
def get_attn2_layers(state_dict):
|
||||
attn2_layers = []
|
||||
for key in state_dict.keys():
|
||||
if "attn2." in key:
|
||||
# Extract the layer number from the key
|
||||
layer_num = int(key.split(".")[1])
|
||||
attn2_layers.append(layer_num)
|
||||
|
||||
return tuple(sorted(set(attn2_layers)))
|
||||
|
||||
|
||||
def get_caption_projection_dim(state_dict):
|
||||
caption_projection_dim = state_dict["context_embedder.weight"].shape[0]
|
||||
return caption_projection_dim
|
||||
|
||||
|
||||
def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
converted_state_dict = {}
|
||||
keys = list(checkpoint.keys())
|
||||
@@ -1678,7 +1701,10 @@ def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
|
||||
|
||||
num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "joint_blocks" in k))[-1] + 1 # noqa: C401
|
||||
caption_projection_dim = 1536
|
||||
dual_attention_layers = get_attn2_layers(checkpoint)
|
||||
|
||||
caption_projection_dim = get_caption_projection_dim(checkpoint)
|
||||
has_qk_norm = any("ln_q" in key for key in checkpoint.keys())
|
||||
|
||||
# Positional and patch embeddings.
|
||||
converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("pos_embed")
|
||||
@@ -1735,6 +1761,21 @@ def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.weight"] = torch.cat([context_v])
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.bias"] = torch.cat([context_v_bias])
|
||||
|
||||
# qk norm
|
||||
if has_qk_norm:
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn.norm_q.weight"] = checkpoint.pop(
|
||||
f"joint_blocks.{i}.x_block.attn.ln_q.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn.norm_k.weight"] = checkpoint.pop(
|
||||
f"joint_blocks.{i}.x_block.attn.ln_k.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_q.weight"] = checkpoint.pop(
|
||||
f"joint_blocks.{i}.context_block.attn.ln_q.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_k.weight"] = checkpoint.pop(
|
||||
f"joint_blocks.{i}.context_block.attn.ln_k.weight"
|
||||
)
|
||||
|
||||
# output projections.
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.weight"] = checkpoint.pop(
|
||||
f"joint_blocks.{i}.x_block.attn.proj.weight"
|
||||
@@ -1750,6 +1791,38 @@ def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
f"joint_blocks.{i}.context_block.attn.proj.bias"
|
||||
)
|
||||
|
||||
if i in dual_attention_layers:
|
||||
# Q, K, V
|
||||
sample_q2, sample_k2, sample_v2 = torch.chunk(
|
||||
checkpoint.pop(f"joint_blocks.{i}.x_block.attn2.qkv.weight"), 3, dim=0
|
||||
)
|
||||
sample_q2_bias, sample_k2_bias, sample_v2_bias = torch.chunk(
|
||||
checkpoint.pop(f"joint_blocks.{i}.x_block.attn2.qkv.bias"), 3, dim=0
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = torch.cat([sample_q2])
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = torch.cat([sample_q2_bias])
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.weight"] = torch.cat([sample_k2])
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.bias"] = torch.cat([sample_k2_bias])
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.weight"] = torch.cat([sample_v2])
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.bias"] = torch.cat([sample_v2_bias])
|
||||
|
||||
# qk norm
|
||||
if has_qk_norm:
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.norm_q.weight"] = checkpoint.pop(
|
||||
f"joint_blocks.{i}.x_block.attn2.ln_q.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.norm_k.weight"] = checkpoint.pop(
|
||||
f"joint_blocks.{i}.x_block.attn2.ln_k.weight"
|
||||
)
|
||||
|
||||
# output projections.
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = checkpoint.pop(
|
||||
f"joint_blocks.{i}.x_block.attn2.proj.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = checkpoint.pop(
|
||||
f"joint_blocks.{i}.x_block.attn2.proj.bias"
|
||||
)
|
||||
|
||||
# norms.
|
||||
converted_state_dict[f"transformer_blocks.{i}.norm1.linear.weight"] = checkpoint.pop(
|
||||
f"joint_blocks.{i}.x_block.adaLN_modulation.1.weight"
|
||||
|
||||
Reference in New Issue
Block a user