mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
removed print statements.
This commit is contained in:
@@ -1264,30 +1264,9 @@ def main(args):
|
||||
# Save the lora layers
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
# unet = unet.to(torch.float32)
|
||||
unet = unet.to(torch.float32)
|
||||
unet_lora_layers = accelerator.unwrap_model(unet_lora_layers)
|
||||
|
||||
# print("*****************Initial state dict*****************")
|
||||
# for k in sorted(initial_state_dict.keys()):
|
||||
# if isinstance(initial_state_dict[k], torch.Tensor):
|
||||
# print(
|
||||
# f"{k} {list(initial_state_dict[k].shape)} mean={torch.mean(initial_state_dict[k]):.3g} std={torch.std(initial_state_dict[k]):.3g}"
|
||||
# )
|
||||
|
||||
# trained_state_dict = unet_lora_layers.state_dict()
|
||||
# print("*****************Trained state dict*****************")
|
||||
# for k in sorted(trained_state_dict.keys()):
|
||||
# if isinstance(trained_state_dict[k], torch.Tensor):
|
||||
# print(
|
||||
# f"{k} {list(trained_state_dict[k].shape)} mean={torch.mean(trained_state_dict[k]):.3g} std={torch.std(trained_state_dict[k]):.3g}"
|
||||
# )
|
||||
|
||||
# unet_attn_proc_state_dict = AttnProcsLayers(unet.attn_processors).state_dict()
|
||||
# for k in unet_attn_proc_state_dict:
|
||||
# from_unet = unet_attn_proc_state_dict[k]
|
||||
# orig = trained_state_dict[k]
|
||||
# print(f"Assertion: {torch.allclose(from_unet, orig)}")
|
||||
|
||||
if text_encoder is not None:
|
||||
text_encoder = text_encoder.to(torch.float32)
|
||||
text_encoder_lora_layers = accelerator.unwrap_model(text_encoder_lora_layers)
|
||||
|
||||
@@ -281,11 +281,9 @@ class UNet2DConditionLoadersMixin:
|
||||
cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
|
||||
attn_processor_class = LoRAAttnProcessor
|
||||
|
||||
# print(f"attn_processor_class: {attn_processor_class}")
|
||||
attn_processors[key] = attn_processor_class(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank
|
||||
)
|
||||
# print(f"{attn_processors[key]} is being loaded with: {value_dict.keys()}")
|
||||
attn_processors[key].load_state_dict(value_dict)
|
||||
elif is_custom_diffusion:
|
||||
custom_diffusion_grouped_dict = defaultdict(dict)
|
||||
@@ -897,7 +895,6 @@ class LoraLoaderMixin:
|
||||
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
|
||||
# their prefixes.
|
||||
keys = list(state_dict.keys())
|
||||
# print("Inside the lora loader.")
|
||||
if all(key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in keys):
|
||||
# Load the layers corresponding to UNet.
|
||||
unet_keys = [k for k in keys if k.startswith(self.unet_name)]
|
||||
@@ -906,7 +903,6 @@ class LoraLoaderMixin:
|
||||
k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys
|
||||
}
|
||||
self.unet.load_attn_procs(unet_lora_state_dict)
|
||||
# print("UNet lora loaded.")
|
||||
|
||||
# Load the layers corresponding to text encoder and make necessary adjustments.
|
||||
text_encoder_keys = [k for k in keys if k.startswith(self.text_encoder_name)]
|
||||
|
||||
Reference in New Issue
Block a user