mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Convert alphas for embedders for sd-scripts to ai toolkit conversion (#12332)
* Convert alphas for embedders for sd-scripts to ai toolkit conversion * Add kohya embedders conversion test * Apply style fixes --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
@@ -558,70 +558,62 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
||||
ait_sd[target_key] = value
|
||||
|
||||
if any("guidance_in" in k for k in sds_sd):
|
||||
assign_remaining_weights(
|
||||
[
|
||||
(
|
||||
"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight",
|
||||
"lora_unet_guidance_in_in_layer.{orig_lora_key}.weight",
|
||||
None,
|
||||
),
|
||||
(
|
||||
"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight",
|
||||
"lora_unet_guidance_in_out_layer.{orig_lora_key}.weight",
|
||||
None,
|
||||
),
|
||||
],
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
"lora_unet_guidance_in_in_layer",
|
||||
"time_text_embed.guidance_embedder.linear_1",
|
||||
)
|
||||
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
"lora_unet_guidance_in_out_layer",
|
||||
"time_text_embed.guidance_embedder.linear_2",
|
||||
)
|
||||
|
||||
if any("img_in" in k for k in sds_sd):
|
||||
assign_remaining_weights(
|
||||
[
|
||||
("x_embedder.{lora_key}.weight", "lora_unet_img_in.{orig_lora_key}.weight", None),
|
||||
],
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
"lora_unet_img_in",
|
||||
"x_embedder",
|
||||
)
|
||||
|
||||
if any("txt_in" in k for k in sds_sd):
|
||||
assign_remaining_weights(
|
||||
[
|
||||
("context_embedder.{lora_key}.weight", "lora_unet_txt_in.{orig_lora_key}.weight", None),
|
||||
],
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
"lora_unet_txt_in",
|
||||
"context_embedder",
|
||||
)
|
||||
|
||||
if any("time_in" in k for k in sds_sd):
|
||||
assign_remaining_weights(
|
||||
[
|
||||
(
|
||||
"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight",
|
||||
"lora_unet_time_in_in_layer.{orig_lora_key}.weight",
|
||||
None,
|
||||
),
|
||||
(
|
||||
"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight",
|
||||
"lora_unet_time_in_out_layer.{orig_lora_key}.weight",
|
||||
None,
|
||||
),
|
||||
],
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
"lora_unet_time_in_in_layer",
|
||||
"time_text_embed.timestep_embedder.linear_1",
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
"lora_unet_time_in_out_layer",
|
||||
"time_text_embed.timestep_embedder.linear_2",
|
||||
)
|
||||
|
||||
if any("vector_in" in k for k in sds_sd):
|
||||
assign_remaining_weights(
|
||||
[
|
||||
(
|
||||
"time_text_embed.text_embedder.linear_1.{lora_key}.weight",
|
||||
"lora_unet_vector_in_in_layer.{orig_lora_key}.weight",
|
||||
None,
|
||||
),
|
||||
(
|
||||
"time_text_embed.text_embedder.linear_2.{lora_key}.weight",
|
||||
"lora_unet_vector_in_out_layer.{orig_lora_key}.weight",
|
||||
None,
|
||||
),
|
||||
],
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
"lora_unet_vector_in_in_layer",
|
||||
"time_text_embed.text_embedder.linear_1",
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
"lora_unet_vector_in_out_layer",
|
||||
"time_text_embed.text_embedder.linear_2",
|
||||
)
|
||||
|
||||
if any("final_layer" in k for k in sds_sd):
|
||||
|
||||
@@ -907,6 +907,13 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
|
||||
|
||||
assert max_diff < 1e-3
|
||||
|
||||
def test_flux_kohya_embedders_conversion(self):
|
||||
"""Test that embedders load without throwing errors"""
|
||||
self.pipeline.load_lora_weights("rockerBOO/flux-bpo-po-lora")
|
||||
self.pipeline.unload_lora_weights()
|
||||
|
||||
assert True
|
||||
|
||||
def test_flux_xlabs(self):
|
||||
self.pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors")
|
||||
self.pipeline.fuse_lora()
|
||||
|
||||
Reference in New Issue
Block a user