1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Fix Unfuse Lora (#4833)

* Fix Unfuse Lora

* add tests

* Fix more

* Fix more

* Fix all

* make style

* make style
This commit is contained in:
Patrick von Platen
2023-08-30 06:02:25 +02:00
committed by GitHub
parent fbca2e0a7a
commit 9f1936d2fc
3 changed files with 131 additions and 22 deletions

View File

@@ -85,12 +85,21 @@ class PatchedLoraProjection(nn.Module):
self.lora_scale = lora_scale
# overwrite PyTorch's `state_dict` to be sure that only the 'regular_linear_layer' weights are saved
# when saving the whole text encoder model and when LoRA is unloaded or fused
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
if self.lora_linear_layer is None:
return self.regular_linear_layer.state_dict(
*args, destination=destination, prefix=prefix, keep_vars=keep_vars
)
return super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars)
def _fuse_lora(self):
if self.lora_linear_layer is None:
return
dtype, device = self.regular_linear_layer.weight.data.dtype, self.regular_linear_layer.weight.data.device
logger.info(f"Fusing LoRA weights for {self.__class__}")
w_orig = self.regular_linear_layer.weight.data.float()
w_up = self.lora_linear_layer.up.weight.data.float()
@@ -112,14 +121,14 @@ class PatchedLoraProjection(nn.Module):
def _unfuse_lora(self):
if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
return
logger.info(f"Unfusing LoRA weights for {self.__class__}")
fused_weight = self.regular_linear_layer.weight.data
dtype, device = fused_weight.dtype, fused_weight.device
self.w_up = self.w_up.to(device=device, dtype=dtype)
self.w_down = self.w_down.to(device, dtype=dtype)
unfused_weight = fused_weight - torch.bmm(self.w_up[None, :], self.w_down[None, :])[0]
w_up = self.w_up.to(device=device).float()
w_down = self.w_down.to(device).float()
unfused_weight = fused_weight.float() - torch.bmm(w_up[None, :], w_down[None, :])[0]
self.regular_linear_layer.weight.data = unfused_weight.to(device=device, dtype=dtype)
self.w_up = None
@@ -1405,15 +1414,15 @@ class LoraLoaderMixin:
def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder):
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj = attn_module.q_proj.regular_linear_layer
attn_module.k_proj = attn_module.k_proj.regular_linear_layer
attn_module.v_proj = attn_module.v_proj.regular_linear_layer
attn_module.out_proj = attn_module.out_proj.regular_linear_layer
attn_module.q_proj.lora_linear_layer = None
attn_module.k_proj.lora_linear_layer = None
attn_module.v_proj.lora_linear_layer = None
attn_module.out_proj.lora_linear_layer = None
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1 = mlp_module.fc1.regular_linear_layer
mlp_module.fc2 = mlp_module.fc2.regular_linear_layer
mlp_module.fc1.lora_linear_layer = None
mlp_module.fc2.lora_linear_layer = None
@classmethod
def _modify_text_encoder(
@@ -1447,23 +1456,43 @@ class LoraLoaderMixin:
else:
current_rank = rank
q_linear_layer = (
attn_module.q_proj.regular_linear_layer
if isinstance(attn_module.q_proj, PatchedLoraProjection)
else attn_module.q_proj
)
attn_module.q_proj = PatchedLoraProjection(
attn_module.q_proj, lora_scale, network_alpha=query_alpha, rank=current_rank, dtype=dtype
q_linear_layer, lora_scale, network_alpha=query_alpha, rank=current_rank, dtype=dtype
)
lora_parameters.extend(attn_module.q_proj.lora_linear_layer.parameters())
k_linear_layer = (
attn_module.k_proj.regular_linear_layer
if isinstance(attn_module.k_proj, PatchedLoraProjection)
else attn_module.k_proj
)
attn_module.k_proj = PatchedLoraProjection(
attn_module.k_proj, lora_scale, network_alpha=key_alpha, rank=current_rank, dtype=dtype
k_linear_layer, lora_scale, network_alpha=key_alpha, rank=current_rank, dtype=dtype
)
lora_parameters.extend(attn_module.k_proj.lora_linear_layer.parameters())
v_linear_layer = (
attn_module.v_proj.regular_linear_layer
if isinstance(attn_module.v_proj, PatchedLoraProjection)
else attn_module.v_proj
)
attn_module.v_proj = PatchedLoraProjection(
attn_module.v_proj, lora_scale, network_alpha=value_alpha, rank=current_rank, dtype=dtype
v_linear_layer, lora_scale, network_alpha=value_alpha, rank=current_rank, dtype=dtype
)
lora_parameters.extend(attn_module.v_proj.lora_linear_layer.parameters())
out_linear_layer = (
attn_module.out_proj.regular_linear_layer
if isinstance(attn_module.out_proj, PatchedLoraProjection)
else attn_module.out_proj
)
attn_module.out_proj = PatchedLoraProjection(
attn_module.out_proj, lora_scale, network_alpha=out_alpha, rank=current_rank, dtype=dtype
out_linear_layer, lora_scale, network_alpha=out_alpha, rank=current_rank, dtype=dtype
)
lora_parameters.extend(attn_module.out_proj.lora_linear_layer.parameters())
@@ -1475,13 +1504,23 @@ class LoraLoaderMixin:
current_rank_fc1 = rank.pop(f"{name}.fc1.lora_linear_layer.up.weight")
current_rank_fc2 = rank.pop(f"{name}.fc2.lora_linear_layer.up.weight")
fc1_linear_layer = (
mlp_module.fc1.regular_linear_layer
if isinstance(mlp_module.fc1, PatchedLoraProjection)
else mlp_module.fc1
)
mlp_module.fc1 = PatchedLoraProjection(
mlp_module.fc1, lora_scale, network_alpha=fc1_alpha, rank=current_rank_fc1, dtype=dtype
fc1_linear_layer, lora_scale, network_alpha=fc1_alpha, rank=current_rank_fc1, dtype=dtype
)
lora_parameters.extend(mlp_module.fc1.lora_linear_layer.parameters())
fc2_linear_layer = (
mlp_module.fc2.regular_linear_layer
if isinstance(mlp_module.fc2, PatchedLoraProjection)
else mlp_module.fc2
)
mlp_module.fc2 = PatchedLoraProjection(
mlp_module.fc2, lora_scale, network_alpha=fc2_alpha, rank=current_rank_fc2, dtype=dtype
fc2_linear_layer, lora_scale, network_alpha=fc2_alpha, rank=current_rank_fc2, dtype=dtype
)
lora_parameters.extend(mlp_module.fc2.lora_linear_layer.parameters())

View File

@@ -168,7 +168,6 @@ class LoRACompatibleLinear(nn.Linear):
return
dtype, device = self.weight.data.dtype, self.weight.data.device
logger.info(f"Fusing LoRA weights for {self.__class__}")
w_orig = self.weight.data.float()
w_up = self.lora_layer.up.weight.data.float()
@@ -190,14 +189,14 @@ class LoRACompatibleLinear(nn.Linear):
def _unfuse_lora(self):
if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
return
logger.info(f"Unfusing LoRA weights for {self.__class__}")
fused_weight = self.weight.data
dtype, device = fused_weight.dtype, fused_weight.device
self.w_up = self.w_up.to(device=device, dtype=dtype)
self.w_down = self.w_down.to(device, dtype=dtype)
unfused_weight = fused_weight - torch.bmm(self.w_up[None, :], self.w_down[None, :])[0]
w_up = self.w_up.to(device=device).float()
w_down = self.w_down.to(device).float()
unfused_weight = fused_weight.float() - torch.bmm(w_up[None, :], w_down[None, :])[0]
self.weight.data = unfused_weight.to(device=device, dtype=dtype)
self.w_up = None

View File

@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import os
import tempfile
import time
@@ -100,6 +101,18 @@ def set_lora_weights(lora_attn_parameters, randn_weight=False):
torch.zero_(parameter)
def state_dicts_almost_equal(sd1, sd2):
sd1 = dict(sorted(sd1.items()))
sd2 = dict(sorted(sd2.items()))
models_are_equal = True
for ten1, ten2 in zip(sd1.values(), sd2.values()):
if (ten1 - ten2).abs().sum() > 1e-3:
models_are_equal = False
return models_are_equal
class LoraLoaderMixinTests(unittest.TestCase):
def get_dummy_components(self):
torch.manual_seed(0)
@@ -674,6 +687,45 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase):
sd_pipe.unload_lora_weights()
def test_text_encoder_lora_state_dict_unchanged(self):
pipeline_components, lora_components = self.get_dummy_components()
sd_pipe = StableDiffusionXLPipeline(**pipeline_components)
text_encoder_1_sd_keys = sorted(sd_pipe.text_encoder.state_dict().keys())
text_encoder_2_sd_keys = sorted(sd_pipe.text_encoder_2.state_dict().keys())
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionXLPipeline.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
safe_serialization=False,
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
text_encoder_1_sd_keys_2 = sorted(sd_pipe.text_encoder.state_dict().keys())
text_encoder_2_sd_keys_2 = sorted(sd_pipe.text_encoder_2.state_dict().keys())
sd_pipe.unload_lora_weights()
text_encoder_1_sd_keys_3 = sorted(sd_pipe.text_encoder.state_dict().keys())
text_encoder_2_sd_keys_3 = sorted(sd_pipe.text_encoder_2.state_dict().keys())
# default & unloaded LoRA weights should have identical state_dicts
assert text_encoder_1_sd_keys == text_encoder_1_sd_keys_3
# default & loaded LoRA weights should NOT have identical state_dicts
assert text_encoder_1_sd_keys != text_encoder_1_sd_keys_2 #
# default & unloaded LoRA weights should have identical state_dicts
assert text_encoder_2_sd_keys == text_encoder_2_sd_keys_3
# default & loaded LoRA weights should NOT have identical state_dicts
assert text_encoder_2_sd_keys != text_encoder_2_sd_keys_2
def test_load_lora_locally_safetensors(self):
pipeline_components, lora_components = self.get_dummy_components()
sd_pipe = StableDiffusionXLPipeline(**pipeline_components)
@@ -1187,3 +1239,22 @@ class LoraIntegrationTests(unittest.TestCase):
expected = np.array([0.5244, 0.4347, 0.4312, 0.4246, 0.4398, 0.4409, 0.4884, 0.4938, 0.4094])
self.assertTrue(np.allclose(images, expected, atol=1e-3))
def test_sdxl_1_0_fuse_unfuse_all(self):
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
text_encoder_1_sd = copy.deepcopy(pipe.text_encoder.state_dict())
text_encoder_2_sd = copy.deepcopy(pipe.text_encoder_2.state_dict())
unet_sd = copy.deepcopy(pipe.unet.state_dict())
pipe.load_lora_weights("davizca87/sun-flower", weight_name="snfw3rXL-000004.safetensors")
pipe.fuse_lora()
pipe.unload_lora_weights()
pipe.unfuse_lora()
new_text_encoder_1_sd = copy.deepcopy(pipe.text_encoder.state_dict())
new_text_encoder_2_sd = copy.deepcopy(pipe.text_encoder_2.state_dict())
new_unet_sd = copy.deepcopy(pipe.unet.state_dict())
assert state_dicts_almost_equal(text_encoder_1_sd, new_text_encoder_1_sd)
assert state_dicts_almost_equal(text_encoder_2_sd, new_text_encoder_2_sd)
assert state_dicts_almost_equal(unet_sd, new_unet_sd)