From fe6c903373550ae928a71fcb438c1edde1ce0e30 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 24 May 2023 17:25:57 +0530 Subject: [PATCH] removed print statements. --- examples/dreambooth/train_dreambooth_lora.py | 23 +------------------- src/diffusers/loaders.py | 4 ---- 2 files changed, 1 insertion(+), 26 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index ac053a2f58..e268fe3ea2 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -1264,30 +1264,9 @@ def main(args): # Save the lora layers accelerator.wait_for_everyone() if accelerator.is_main_process: - # unet = unet.to(torch.float32) + unet = unet.to(torch.float32) unet_lora_layers = accelerator.unwrap_model(unet_lora_layers) - # print("*****************Initial state dict*****************") - # for k in sorted(initial_state_dict.keys()): - # if isinstance(initial_state_dict[k], torch.Tensor): - # print( - # f"{k} {list(initial_state_dict[k].shape)} mean={torch.mean(initial_state_dict[k]):.3g} std={torch.std(initial_state_dict[k]):.3g}" - # ) - - # trained_state_dict = unet_lora_layers.state_dict() - # print("*****************Trained state dict*****************") - # for k in sorted(trained_state_dict.keys()): - # if isinstance(trained_state_dict[k], torch.Tensor): - # print( - # f"{k} {list(trained_state_dict[k].shape)} mean={torch.mean(trained_state_dict[k]):.3g} std={torch.std(trained_state_dict[k]):.3g}" - # ) - - # unet_attn_proc_state_dict = AttnProcsLayers(unet.attn_processors).state_dict() - # for k in unet_attn_proc_state_dict: - # from_unet = unet_attn_proc_state_dict[k] - # orig = trained_state_dict[k] - # print(f"Assertion: {torch.allclose(from_unet, orig)}") - if text_encoder is not None: text_encoder = text_encoder.to(torch.float32) text_encoder_lora_layers = accelerator.unwrap_model(text_encoder_lora_layers) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 8d55c84334..776cdb36e5 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -281,11 +281,9 @@ class UNet2DConditionLoadersMixin: cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1] attn_processor_class = LoRAAttnProcessor - # print(f"attn_processor_class: {attn_processor_class}") attn_processors[key] = attn_processor_class( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank ) - # print(f"{attn_processors[key]} is being loaded with: {value_dict.keys()}") attn_processors[key].load_state_dict(value_dict) elif is_custom_diffusion: custom_diffusion_grouped_dict = defaultdict(dict) @@ -897,7 +895,6 @@ class LoraLoaderMixin: # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as # their prefixes. keys = list(state_dict.keys()) - # print("Inside the lora loader.") if all(key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in keys): # Load the layers corresponding to UNet. unet_keys = [k for k in keys if k.startswith(self.unet_name)] @@ -906,7 +903,6 @@ class LoraLoaderMixin: k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys } self.unet.load_attn_procs(unet_lora_state_dict) - # print("UNet lora loaded.") # Load the layers corresponding to text encoder and make necessary adjustments. text_encoder_keys = [k for k in keys if k.startswith(self.text_encoder_name)]