1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

removed print statements.

This commit is contained in:
Sayak Paul
2023-05-24 17:25:57 +05:30
parent 7ba7c65700
commit fe6c903373
2 changed files with 1 additions and 26 deletions

View File

@@ -1264,30 +1264,9 @@ def main(args):
# Save the lora layers
accelerator.wait_for_everyone()
if accelerator.is_main_process:
# unet = unet.to(torch.float32)
unet = unet.to(torch.float32)
unet_lora_layers = accelerator.unwrap_model(unet_lora_layers)
# print("*****************Initial state dict*****************")
# for k in sorted(initial_state_dict.keys()):
# if isinstance(initial_state_dict[k], torch.Tensor):
# print(
# f"{k} {list(initial_state_dict[k].shape)} mean={torch.mean(initial_state_dict[k]):.3g} std={torch.std(initial_state_dict[k]):.3g}"
# )
# trained_state_dict = unet_lora_layers.state_dict()
# print("*****************Trained state dict*****************")
# for k in sorted(trained_state_dict.keys()):
# if isinstance(trained_state_dict[k], torch.Tensor):
# print(
# f"{k} {list(trained_state_dict[k].shape)} mean={torch.mean(trained_state_dict[k]):.3g} std={torch.std(trained_state_dict[k]):.3g}"
# )
# unet_attn_proc_state_dict = AttnProcsLayers(unet.attn_processors).state_dict()
# for k in unet_attn_proc_state_dict:
# from_unet = unet_attn_proc_state_dict[k]
# orig = trained_state_dict[k]
# print(f"Assertion: {torch.allclose(from_unet, orig)}")
if text_encoder is not None:
text_encoder = text_encoder.to(torch.float32)
text_encoder_lora_layers = accelerator.unwrap_model(text_encoder_lora_layers)

View File

@@ -281,11 +281,9 @@ class UNet2DConditionLoadersMixin:
cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
attn_processor_class = LoRAAttnProcessor
# print(f"attn_processor_class: {attn_processor_class}")
attn_processors[key] = attn_processor_class(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank
)
# print(f"{attn_processors[key]} is being loaded with: {value_dict.keys()}")
attn_processors[key].load_state_dict(value_dict)
elif is_custom_diffusion:
custom_diffusion_grouped_dict = defaultdict(dict)
@@ -897,7 +895,6 @@ class LoraLoaderMixin:
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
# their prefixes.
keys = list(state_dict.keys())
# print("Inside the lora loader.")
if all(key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in keys):
# Load the layers corresponding to UNet.
unet_keys = [k for k in keys if k.startswith(self.unet_name)]
@@ -906,7 +903,6 @@ class LoraLoaderMixin:
k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys
}
self.unet.load_attn_procs(unet_lora_state_dict)
# print("UNet lora loaded.")
# Load the layers corresponding to text encoder and make necessary adjustments.
text_encoder_keys = [k for k in keys if k.startswith(self.text_encoder_name)]