mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[LoRA] Support original format loras for HunyuanVideo (#10376)
* update * fix make copies * update * add relevant markers to the integration test suite. * add copied. * fox-copies * temporarily add print. * directly place on CUDA as CPU isn't that big on the CIO. * fixes to fuse_lora, aryan was right. * fixes --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -973,3 +973,178 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
|
||||
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def _convert_hunyuan_video_lora_to_diffusers(original_state_dict):
|
||||
converted_state_dict = {k: original_state_dict.pop(k) for k in list(original_state_dict.keys())}
|
||||
|
||||
def remap_norm_scale_shift_(key, state_dict):
|
||||
weight = state_dict.pop(key)
|
||||
shift, scale = weight.chunk(2, dim=0)
|
||||
new_weight = torch.cat([scale, shift], dim=0)
|
||||
state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight
|
||||
|
||||
def remap_txt_in_(key, state_dict):
|
||||
def rename_key(key):
|
||||
new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks")
|
||||
new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear")
|
||||
new_key = new_key.replace("txt_in", "context_embedder")
|
||||
new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1")
|
||||
new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2")
|
||||
new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder")
|
||||
new_key = new_key.replace("mlp", "ff")
|
||||
return new_key
|
||||
|
||||
if "self_attn_qkv" in key:
|
||||
weight = state_dict.pop(key)
|
||||
to_q, to_k, to_v = weight.chunk(3, dim=0)
|
||||
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q
|
||||
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k
|
||||
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v
|
||||
else:
|
||||
state_dict[rename_key(key)] = state_dict.pop(key)
|
||||
|
||||
def remap_img_attn_qkv_(key, state_dict):
|
||||
weight = state_dict.pop(key)
|
||||
if "lora_A" in key:
|
||||
state_dict[key.replace("img_attn_qkv", "attn.to_q")] = weight
|
||||
state_dict[key.replace("img_attn_qkv", "attn.to_k")] = weight
|
||||
state_dict[key.replace("img_attn_qkv", "attn.to_v")] = weight
|
||||
else:
|
||||
to_q, to_k, to_v = weight.chunk(3, dim=0)
|
||||
state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q
|
||||
state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k
|
||||
state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v
|
||||
|
||||
def remap_txt_attn_qkv_(key, state_dict):
|
||||
weight = state_dict.pop(key)
|
||||
if "lora_A" in key:
|
||||
state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = weight
|
||||
state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = weight
|
||||
state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = weight
|
||||
else:
|
||||
to_q, to_k, to_v = weight.chunk(3, dim=0)
|
||||
state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q
|
||||
state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k
|
||||
state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v
|
||||
|
||||
def remap_single_transformer_blocks_(key, state_dict):
|
||||
hidden_size = 3072
|
||||
|
||||
if "linear1.lora_A.weight" in key or "linear1.lora_B.weight" in key:
|
||||
linear1_weight = state_dict.pop(key)
|
||||
if "lora_A" in key:
|
||||
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
|
||||
".linear1.lora_A.weight"
|
||||
)
|
||||
state_dict[f"{new_key}.attn.to_q.lora_A.weight"] = linear1_weight
|
||||
state_dict[f"{new_key}.attn.to_k.lora_A.weight"] = linear1_weight
|
||||
state_dict[f"{new_key}.attn.to_v.lora_A.weight"] = linear1_weight
|
||||
state_dict[f"{new_key}.proj_mlp.lora_A.weight"] = linear1_weight
|
||||
else:
|
||||
split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size)
|
||||
q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0)
|
||||
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
|
||||
".linear1.lora_B.weight"
|
||||
)
|
||||
state_dict[f"{new_key}.attn.to_q.lora_B.weight"] = q
|
||||
state_dict[f"{new_key}.attn.to_k.lora_B.weight"] = k
|
||||
state_dict[f"{new_key}.attn.to_v.lora_B.weight"] = v
|
||||
state_dict[f"{new_key}.proj_mlp.lora_B.weight"] = mlp
|
||||
|
||||
elif "linear1.lora_A.bias" in key or "linear1.lora_B.bias" in key:
|
||||
linear1_bias = state_dict.pop(key)
|
||||
if "lora_A" in key:
|
||||
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
|
||||
".linear1.lora_A.bias"
|
||||
)
|
||||
state_dict[f"{new_key}.attn.to_q.lora_A.bias"] = linear1_bias
|
||||
state_dict[f"{new_key}.attn.to_k.lora_A.bias"] = linear1_bias
|
||||
state_dict[f"{new_key}.attn.to_v.lora_A.bias"] = linear1_bias
|
||||
state_dict[f"{new_key}.proj_mlp.lora_A.bias"] = linear1_bias
|
||||
else:
|
||||
split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size)
|
||||
q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0)
|
||||
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
|
||||
".linear1.lora_B.bias"
|
||||
)
|
||||
state_dict[f"{new_key}.attn.to_q.lora_B.bias"] = q_bias
|
||||
state_dict[f"{new_key}.attn.to_k.lora_B.bias"] = k_bias
|
||||
state_dict[f"{new_key}.attn.to_v.lora_B.bias"] = v_bias
|
||||
state_dict[f"{new_key}.proj_mlp.lora_B.bias"] = mlp_bias
|
||||
|
||||
else:
|
||||
new_key = key.replace("single_blocks", "single_transformer_blocks")
|
||||
new_key = new_key.replace("linear2", "proj_out")
|
||||
new_key = new_key.replace("q_norm", "attn.norm_q")
|
||||
new_key = new_key.replace("k_norm", "attn.norm_k")
|
||||
state_dict[new_key] = state_dict.pop(key)
|
||||
|
||||
TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
"img_in": "x_embedder",
|
||||
"time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1",
|
||||
"time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2",
|
||||
"guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1",
|
||||
"guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2",
|
||||
"vector_in.in_layer": "time_text_embed.text_embedder.linear_1",
|
||||
"vector_in.out_layer": "time_text_embed.text_embedder.linear_2",
|
||||
"double_blocks": "transformer_blocks",
|
||||
"img_attn_q_norm": "attn.norm_q",
|
||||
"img_attn_k_norm": "attn.norm_k",
|
||||
"img_attn_proj": "attn.to_out.0",
|
||||
"txt_attn_q_norm": "attn.norm_added_q",
|
||||
"txt_attn_k_norm": "attn.norm_added_k",
|
||||
"txt_attn_proj": "attn.to_add_out",
|
||||
"img_mod.linear": "norm1.linear",
|
||||
"img_norm1": "norm1.norm",
|
||||
"img_norm2": "norm2",
|
||||
"img_mlp": "ff",
|
||||
"txt_mod.linear": "norm1_context.linear",
|
||||
"txt_norm1": "norm1.norm",
|
||||
"txt_norm2": "norm2_context",
|
||||
"txt_mlp": "ff_context",
|
||||
"self_attn_proj": "attn.to_out.0",
|
||||
"modulation.linear": "norm.linear",
|
||||
"pre_norm": "norm.norm",
|
||||
"final_layer.norm_final": "norm_out.norm",
|
||||
"final_layer.linear": "proj_out",
|
||||
"fc1": "net.0.proj",
|
||||
"fc2": "net.2",
|
||||
"input_embedder": "proj_in",
|
||||
}
|
||||
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
||||
"txt_in": remap_txt_in_,
|
||||
"img_attn_qkv": remap_img_attn_qkv_,
|
||||
"txt_attn_qkv": remap_txt_attn_qkv_,
|
||||
"single_blocks": remap_single_transformer_blocks_,
|
||||
"final_layer.adaLN_modulation.1": remap_norm_scale_shift_,
|
||||
}
|
||||
|
||||
# Some folks attempt to make their state dict compatible with diffusers by adding "transformer." prefix to all keys
|
||||
# and use their custom code. To make sure both "original" and "attempted diffusers" loras work as expected, we make
|
||||
# sure that both follow the same initial format by stripping off the "transformer." prefix.
|
||||
for key in list(converted_state_dict.keys()):
|
||||
if key.startswith("transformer."):
|
||||
converted_state_dict[key[len("transformer.") :]] = converted_state_dict.pop(key)
|
||||
if key.startswith("diffusion_model."):
|
||||
converted_state_dict[key[len("diffusion_model.") :]] = converted_state_dict.pop(key)
|
||||
|
||||
# Rename and remap the state dict keys
|
||||
for key in list(converted_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
converted_state_dict[new_key] = converted_state_dict.pop(key)
|
||||
|
||||
for key in list(converted_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, converted_state_dict)
|
||||
|
||||
# Add back the "transformer." prefix
|
||||
for key in list(converted_state_dict.keys()):
|
||||
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
@@ -36,6 +36,7 @@ from ..utils import (
|
||||
from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, LoraBaseMixin, _fetch_state_dict # noqa
|
||||
from .lora_conversion_utils import (
|
||||
_convert_bfl_flux_control_lora_to_diffusers,
|
||||
_convert_hunyuan_video_lora_to_diffusers,
|
||||
_convert_kohya_flux_lora_to_diffusers,
|
||||
_convert_non_diffusers_lora_to_diffusers,
|
||||
_convert_xlabs_flux_lora_to_diffusers,
|
||||
@@ -4007,7 +4008,6 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
@@ -4018,7 +4018,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
|
||||
We support loading original format HunyuanVideo LoRA checkpoints.
|
||||
|
||||
This function is experimental and might change in the future.
|
||||
|
||||
@@ -4101,6 +4101,10 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
logger.warning(warn_msg)
|
||||
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
||||
|
||||
is_original_hunyuan_video = any("img_attn_qkv" in k for k in state_dict)
|
||||
if is_original_hunyuan_video:
|
||||
state_dict = _convert_hunyuan_video_lora_to_diffusers(state_dict)
|
||||
|
||||
return state_dict
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
||||
@@ -4239,10 +4243,9 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer", "text_encoder"],
|
||||
components: List[str] = ["transformer"],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
@@ -4283,8 +4286,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
|
||||
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
|
||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||
r"""
|
||||
Reverses the effect of
|
||||
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
||||
|
||||
@@ -12,9 +12,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast
|
||||
|
||||
@@ -26,7 +29,11 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
floats_tensor,
|
||||
nightly,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_big_gpu_with_torch_cuda,
|
||||
require_peft_backend,
|
||||
require_torch_gpu,
|
||||
skip_mps,
|
||||
)
|
||||
|
||||
@@ -182,3 +189,69 @@ class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_gpu
|
||||
@require_peft_backend
|
||||
@require_big_gpu_with_torch_cuda
|
||||
@pytest.mark.big_gpu_with_torch_cuda
|
||||
class HunyuanVideoLoRAIntegrationTests(unittest.TestCase):
|
||||
"""internal note: The integration slices were obtained on DGX.
|
||||
|
||||
torch: 2.5.1+cu124 with CUDA 12.5. Need the same setup for the
|
||||
assertions to pass.
|
||||
"""
|
||||
|
||||
num_inference_steps = 10
|
||||
seed = 0
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
model_id = "hunyuanvideo-community/HunyuanVideo"
|
||||
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
|
||||
model_id, subfolder="transformer", torch_dtype=torch.bfloat16
|
||||
)
|
||||
self.pipeline = HunyuanVideoPipeline.from_pretrained(
|
||||
model_id, transformer=transformer, torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_original_format_cseti(self):
|
||||
self.pipeline.load_lora_weights(
|
||||
"Cseti/HunyuanVideo-LoRA-Arcane_Jinx-v1", weight_name="csetiarcane-nfjinx-v1-6000.safetensors"
|
||||
)
|
||||
self.pipeline.fuse_lora()
|
||||
self.pipeline.unload_lora_weights()
|
||||
self.pipeline.vae.enable_tiling()
|
||||
|
||||
prompt = "CSETIARCANE. A cat walks on the grass, realistic"
|
||||
|
||||
out = self.pipeline(
|
||||
prompt=prompt,
|
||||
height=320,
|
||||
width=512,
|
||||
num_frames=9,
|
||||
num_inference_steps=self.num_inference_steps,
|
||||
output_type="np",
|
||||
generator=torch.manual_seed(self.seed),
|
||||
).frames[0]
|
||||
out = out.flatten()
|
||||
out_slice = np.concatenate((out[:8], out[-8:]))
|
||||
|
||||
# fmt: off
|
||||
expected_slice = np.array([0.1013, 0.1924, 0.0078, 0.1021, 0.1929, 0.0078, 0.1023, 0.1919, 0.7402, 0.104, 0.4482, 0.7354, 0.0925, 0.4382, 0.7275, 0.0815])
|
||||
# fmt: on
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
|
||||
|
||||
assert max_diff < 1e-3
|
||||
|
||||
Reference in New Issue
Block a user