mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Diffusers -> Original SD conversion] fix things (#6933)
* fix: bias loading bug * fixes for SDXL * apply changes to the conversion script to match single_file_utils.py * do transpose to match the single file loading logic.
This commit is contained in:
@@ -167,7 +167,10 @@ vae_conversion_map_attn = [
|
||||
|
||||
def reshape_weight_for_sd(w):
|
||||
# convert HF linear weights to SD conv2d weights
|
||||
return w.reshape(*w.shape, 1, 1)
|
||||
if not w.ndim == 1:
|
||||
return w.reshape(*w.shape, 1, 1)
|
||||
else:
|
||||
return w
|
||||
|
||||
|
||||
def convert_vae_state_dict(vae_state_dict):
|
||||
@@ -321,11 +324,18 @@ if __name__ == "__main__":
|
||||
vae_state_dict = convert_vae_state_dict(vae_state_dict)
|
||||
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
|
||||
|
||||
# Convert text encoder 1
|
||||
text_enc_dict = convert_openai_text_enc_state_dict(text_enc_dict)
|
||||
text_enc_dict = {"conditioner.embedders.0.transformer." + k: v for k, v in text_enc_dict.items()}
|
||||
|
||||
# Convert text encoder 2
|
||||
text_enc_2_dict = convert_openclip_text_enc_state_dict(text_enc_2_dict)
|
||||
text_enc_2_dict = {"conditioner.embedders.1.model." + k: v for k, v in text_enc_2_dict.items()}
|
||||
# We call the `.T.contiguous()` to match what's done in
|
||||
# https://github.com/huggingface/diffusers/blob/84905ca7287876b925b6bf8e9bb92fec21c78764/src/diffusers/loaders/single_file_utils.py#L1085
|
||||
text_enc_2_dict["conditioner.embedders.1.model.text_projection"] = text_enc_2_dict.pop(
|
||||
"conditioner.embedders.1.model.text_projection.weight"
|
||||
).T.contiguous()
|
||||
|
||||
# Put together new checkpoint
|
||||
state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict, **text_enc_2_dict}
|
||||
|
||||
@@ -170,7 +170,10 @@ vae_extra_conversion_map = [
|
||||
|
||||
def reshape_weight_for_sd(w):
|
||||
# convert HF linear weights to SD conv2d weights
|
||||
return w.reshape(*w.shape, 1, 1)
|
||||
if not w.ndim == 1:
|
||||
return w.reshape(*w.shape, 1, 1)
|
||||
else:
|
||||
return w
|
||||
|
||||
|
||||
def convert_vae_state_dict(vae_state_dict):
|
||||
|
||||
Reference in New Issue
Block a user