mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix Unfuse Lora (#4833)
* Fix Unfuse Lora * add tests * Fix more * Fix more * Fix all * make style * make style
This commit is contained in:
committed by
GitHub
parent
fbca2e0a7a
commit
9f1936d2fc
@@ -85,12 +85,21 @@ class PatchedLoraProjection(nn.Module):
|
||||
|
||||
self.lora_scale = lora_scale
|
||||
|
||||
# overwrite PyTorch's `state_dict` to be sure that only the 'regular_linear_layer' weights are saved
|
||||
# when saving the whole text encoder model and when LoRA is unloaded or fused
|
||||
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
|
||||
if self.lora_linear_layer is None:
|
||||
return self.regular_linear_layer.state_dict(
|
||||
*args, destination=destination, prefix=prefix, keep_vars=keep_vars
|
||||
)
|
||||
|
||||
return super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars)
|
||||
|
||||
def _fuse_lora(self):
|
||||
if self.lora_linear_layer is None:
|
||||
return
|
||||
|
||||
dtype, device = self.regular_linear_layer.weight.data.dtype, self.regular_linear_layer.weight.data.device
|
||||
logger.info(f"Fusing LoRA weights for {self.__class__}")
|
||||
|
||||
w_orig = self.regular_linear_layer.weight.data.float()
|
||||
w_up = self.lora_linear_layer.up.weight.data.float()
|
||||
@@ -112,14 +121,14 @@ class PatchedLoraProjection(nn.Module):
|
||||
def _unfuse_lora(self):
|
||||
if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
|
||||
return
|
||||
logger.info(f"Unfusing LoRA weights for {self.__class__}")
|
||||
|
||||
fused_weight = self.regular_linear_layer.weight.data
|
||||
dtype, device = fused_weight.dtype, fused_weight.device
|
||||
|
||||
self.w_up = self.w_up.to(device=device, dtype=dtype)
|
||||
self.w_down = self.w_down.to(device, dtype=dtype)
|
||||
unfused_weight = fused_weight - torch.bmm(self.w_up[None, :], self.w_down[None, :])[0]
|
||||
w_up = self.w_up.to(device=device).float()
|
||||
w_down = self.w_down.to(device).float()
|
||||
|
||||
unfused_weight = fused_weight.float() - torch.bmm(w_up[None, :], w_down[None, :])[0]
|
||||
self.regular_linear_layer.weight.data = unfused_weight.to(device=device, dtype=dtype)
|
||||
|
||||
self.w_up = None
|
||||
@@ -1405,15 +1414,15 @@ class LoraLoaderMixin:
|
||||
def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder):
|
||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
||||
attn_module.q_proj = attn_module.q_proj.regular_linear_layer
|
||||
attn_module.k_proj = attn_module.k_proj.regular_linear_layer
|
||||
attn_module.v_proj = attn_module.v_proj.regular_linear_layer
|
||||
attn_module.out_proj = attn_module.out_proj.regular_linear_layer
|
||||
attn_module.q_proj.lora_linear_layer = None
|
||||
attn_module.k_proj.lora_linear_layer = None
|
||||
attn_module.v_proj.lora_linear_layer = None
|
||||
attn_module.out_proj.lora_linear_layer = None
|
||||
|
||||
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
|
||||
if isinstance(mlp_module.fc1, PatchedLoraProjection):
|
||||
mlp_module.fc1 = mlp_module.fc1.regular_linear_layer
|
||||
mlp_module.fc2 = mlp_module.fc2.regular_linear_layer
|
||||
mlp_module.fc1.lora_linear_layer = None
|
||||
mlp_module.fc2.lora_linear_layer = None
|
||||
|
||||
@classmethod
|
||||
def _modify_text_encoder(
|
||||
@@ -1447,23 +1456,43 @@ class LoraLoaderMixin:
|
||||
else:
|
||||
current_rank = rank
|
||||
|
||||
q_linear_layer = (
|
||||
attn_module.q_proj.regular_linear_layer
|
||||
if isinstance(attn_module.q_proj, PatchedLoraProjection)
|
||||
else attn_module.q_proj
|
||||
)
|
||||
attn_module.q_proj = PatchedLoraProjection(
|
||||
attn_module.q_proj, lora_scale, network_alpha=query_alpha, rank=current_rank, dtype=dtype
|
||||
q_linear_layer, lora_scale, network_alpha=query_alpha, rank=current_rank, dtype=dtype
|
||||
)
|
||||
lora_parameters.extend(attn_module.q_proj.lora_linear_layer.parameters())
|
||||
|
||||
k_linear_layer = (
|
||||
attn_module.k_proj.regular_linear_layer
|
||||
if isinstance(attn_module.k_proj, PatchedLoraProjection)
|
||||
else attn_module.k_proj
|
||||
)
|
||||
attn_module.k_proj = PatchedLoraProjection(
|
||||
attn_module.k_proj, lora_scale, network_alpha=key_alpha, rank=current_rank, dtype=dtype
|
||||
k_linear_layer, lora_scale, network_alpha=key_alpha, rank=current_rank, dtype=dtype
|
||||
)
|
||||
lora_parameters.extend(attn_module.k_proj.lora_linear_layer.parameters())
|
||||
|
||||
v_linear_layer = (
|
||||
attn_module.v_proj.regular_linear_layer
|
||||
if isinstance(attn_module.v_proj, PatchedLoraProjection)
|
||||
else attn_module.v_proj
|
||||
)
|
||||
attn_module.v_proj = PatchedLoraProjection(
|
||||
attn_module.v_proj, lora_scale, network_alpha=value_alpha, rank=current_rank, dtype=dtype
|
||||
v_linear_layer, lora_scale, network_alpha=value_alpha, rank=current_rank, dtype=dtype
|
||||
)
|
||||
lora_parameters.extend(attn_module.v_proj.lora_linear_layer.parameters())
|
||||
|
||||
out_linear_layer = (
|
||||
attn_module.out_proj.regular_linear_layer
|
||||
if isinstance(attn_module.out_proj, PatchedLoraProjection)
|
||||
else attn_module.out_proj
|
||||
)
|
||||
attn_module.out_proj = PatchedLoraProjection(
|
||||
attn_module.out_proj, lora_scale, network_alpha=out_alpha, rank=current_rank, dtype=dtype
|
||||
out_linear_layer, lora_scale, network_alpha=out_alpha, rank=current_rank, dtype=dtype
|
||||
)
|
||||
lora_parameters.extend(attn_module.out_proj.lora_linear_layer.parameters())
|
||||
|
||||
@@ -1475,13 +1504,23 @@ class LoraLoaderMixin:
|
||||
current_rank_fc1 = rank.pop(f"{name}.fc1.lora_linear_layer.up.weight")
|
||||
current_rank_fc2 = rank.pop(f"{name}.fc2.lora_linear_layer.up.weight")
|
||||
|
||||
fc1_linear_layer = (
|
||||
mlp_module.fc1.regular_linear_layer
|
||||
if isinstance(mlp_module.fc1, PatchedLoraProjection)
|
||||
else mlp_module.fc1
|
||||
)
|
||||
mlp_module.fc1 = PatchedLoraProjection(
|
||||
mlp_module.fc1, lora_scale, network_alpha=fc1_alpha, rank=current_rank_fc1, dtype=dtype
|
||||
fc1_linear_layer, lora_scale, network_alpha=fc1_alpha, rank=current_rank_fc1, dtype=dtype
|
||||
)
|
||||
lora_parameters.extend(mlp_module.fc1.lora_linear_layer.parameters())
|
||||
|
||||
fc2_linear_layer = (
|
||||
mlp_module.fc2.regular_linear_layer
|
||||
if isinstance(mlp_module.fc2, PatchedLoraProjection)
|
||||
else mlp_module.fc2
|
||||
)
|
||||
mlp_module.fc2 = PatchedLoraProjection(
|
||||
mlp_module.fc2, lora_scale, network_alpha=fc2_alpha, rank=current_rank_fc2, dtype=dtype
|
||||
fc2_linear_layer, lora_scale, network_alpha=fc2_alpha, rank=current_rank_fc2, dtype=dtype
|
||||
)
|
||||
lora_parameters.extend(mlp_module.fc2.lora_linear_layer.parameters())
|
||||
|
||||
|
||||
@@ -168,7 +168,6 @@ class LoRACompatibleLinear(nn.Linear):
|
||||
return
|
||||
|
||||
dtype, device = self.weight.data.dtype, self.weight.data.device
|
||||
logger.info(f"Fusing LoRA weights for {self.__class__}")
|
||||
|
||||
w_orig = self.weight.data.float()
|
||||
w_up = self.lora_layer.up.weight.data.float()
|
||||
@@ -190,14 +189,14 @@ class LoRACompatibleLinear(nn.Linear):
|
||||
def _unfuse_lora(self):
|
||||
if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
|
||||
return
|
||||
logger.info(f"Unfusing LoRA weights for {self.__class__}")
|
||||
|
||||
fused_weight = self.weight.data
|
||||
dtype, device = fused_weight.dtype, fused_weight.device
|
||||
|
||||
self.w_up = self.w_up.to(device=device, dtype=dtype)
|
||||
self.w_down = self.w_down.to(device, dtype=dtype)
|
||||
unfused_weight = fused_weight - torch.bmm(self.w_up[None, :], self.w_down[None, :])[0]
|
||||
w_up = self.w_up.to(device=device).float()
|
||||
w_down = self.w_down.to(device).float()
|
||||
|
||||
unfused_weight = fused_weight.float() - torch.bmm(w_up[None, :], w_down[None, :])[0]
|
||||
self.weight.data = unfused_weight.to(device=device, dtype=dtype)
|
||||
|
||||
self.w_up = None
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
@@ -100,6 +101,18 @@ def set_lora_weights(lora_attn_parameters, randn_weight=False):
|
||||
torch.zero_(parameter)
|
||||
|
||||
|
||||
def state_dicts_almost_equal(sd1, sd2):
|
||||
sd1 = dict(sorted(sd1.items()))
|
||||
sd2 = dict(sorted(sd2.items()))
|
||||
|
||||
models_are_equal = True
|
||||
for ten1, ten2 in zip(sd1.values(), sd2.values()):
|
||||
if (ten1 - ten2).abs().sum() > 1e-3:
|
||||
models_are_equal = False
|
||||
|
||||
return models_are_equal
|
||||
|
||||
|
||||
class LoraLoaderMixinTests(unittest.TestCase):
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
@@ -674,6 +687,45 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase):
|
||||
|
||||
sd_pipe.unload_lora_weights()
|
||||
|
||||
def test_text_encoder_lora_state_dict_unchanged(self):
|
||||
pipeline_components, lora_components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionXLPipeline(**pipeline_components)
|
||||
|
||||
text_encoder_1_sd_keys = sorted(sd_pipe.text_encoder.state_dict().keys())
|
||||
text_encoder_2_sd_keys = sorted(sd_pipe.text_encoder_2.state_dict().keys())
|
||||
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
StableDiffusionXLPipeline.save_lora_weights(
|
||||
save_directory=tmpdirname,
|
||||
unet_lora_layers=lora_components["unet_lora_layers"],
|
||||
text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
|
||||
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
|
||||
safe_serialization=False,
|
||||
)
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
|
||||
sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
|
||||
|
||||
text_encoder_1_sd_keys_2 = sorted(sd_pipe.text_encoder.state_dict().keys())
|
||||
text_encoder_2_sd_keys_2 = sorted(sd_pipe.text_encoder_2.state_dict().keys())
|
||||
|
||||
sd_pipe.unload_lora_weights()
|
||||
|
||||
text_encoder_1_sd_keys_3 = sorted(sd_pipe.text_encoder.state_dict().keys())
|
||||
text_encoder_2_sd_keys_3 = sorted(sd_pipe.text_encoder_2.state_dict().keys())
|
||||
|
||||
# default & unloaded LoRA weights should have identical state_dicts
|
||||
assert text_encoder_1_sd_keys == text_encoder_1_sd_keys_3
|
||||
# default & loaded LoRA weights should NOT have identical state_dicts
|
||||
assert text_encoder_1_sd_keys != text_encoder_1_sd_keys_2 #
|
||||
|
||||
# default & unloaded LoRA weights should have identical state_dicts
|
||||
assert text_encoder_2_sd_keys == text_encoder_2_sd_keys_3
|
||||
# default & loaded LoRA weights should NOT have identical state_dicts
|
||||
assert text_encoder_2_sd_keys != text_encoder_2_sd_keys_2
|
||||
|
||||
def test_load_lora_locally_safetensors(self):
|
||||
pipeline_components, lora_components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionXLPipeline(**pipeline_components)
|
||||
@@ -1187,3 +1239,22 @@ class LoraIntegrationTests(unittest.TestCase):
|
||||
expected = np.array([0.5244, 0.4347, 0.4312, 0.4246, 0.4398, 0.4409, 0.4884, 0.4938, 0.4094])
|
||||
|
||||
self.assertTrue(np.allclose(images, expected, atol=1e-3))
|
||||
|
||||
def test_sdxl_1_0_fuse_unfuse_all(self):
|
||||
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
|
||||
text_encoder_1_sd = copy.deepcopy(pipe.text_encoder.state_dict())
|
||||
text_encoder_2_sd = copy.deepcopy(pipe.text_encoder_2.state_dict())
|
||||
unet_sd = copy.deepcopy(pipe.unet.state_dict())
|
||||
|
||||
pipe.load_lora_weights("davizca87/sun-flower", weight_name="snfw3rXL-000004.safetensors")
|
||||
pipe.fuse_lora()
|
||||
pipe.unload_lora_weights()
|
||||
pipe.unfuse_lora()
|
||||
|
||||
new_text_encoder_1_sd = copy.deepcopy(pipe.text_encoder.state_dict())
|
||||
new_text_encoder_2_sd = copy.deepcopy(pipe.text_encoder_2.state_dict())
|
||||
new_unet_sd = copy.deepcopy(pipe.unet.state_dict())
|
||||
|
||||
assert state_dicts_almost_equal(text_encoder_1_sd, new_text_encoder_1_sd)
|
||||
assert state_dicts_almost_equal(text_encoder_2_sd, new_text_encoder_2_sd)
|
||||
assert state_dicts_almost_equal(unet_sd, new_unet_sd)
|
||||
|
||||
Reference in New Issue
Block a user