From edc154da0906ddc7ade2dfea739917266908451a Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 9 Apr 2025 13:21:34 +0200 Subject: [PATCH] Update Ruff to latest Version (#10919) * update * update * update * update --- .../train_dreambooth_lora_flux_advanced.py | 8 +- .../train_dreambooth_lora_sd15_advanced.py | 17 +- .../train_dreambooth_lora_sdxl_advanced.py | 14 +- examples/amused/train_amused.py | 2 +- .../train_cogvideox_image_to_video_lora.py | 2 +- examples/cogvideo/train_cogvideox_lora.py | 2 +- .../community/adaptive_mask_inpainting.py | 2 +- examples/community/hd_painter.py | 2 +- examples/community/img2img_inpainting.py | 2 +- examples/community/llm_grounded_diffusion.py | 4 +- .../community/mod_controlnet_tile_sr_sdxl.py | 2 +- .../pipeline_flux_differential_img2img.py | 4 +- examples/community/pipeline_prompt2prompt.py | 12 +- .../community/pipeline_sdxl_style_aligned.py | 2 +- ...pipeline_stable_diffusion_upscale_ldm3d.py | 2 +- ...ne_stable_diffusion_xl_attentive_eraser.py | 6 +- ...diffusion_xl_controlnet_adapter_inpaint.py | 2 +- examples/community/scheduling_ufogen.py | 3 +- .../train_lcm_distill_lora_sd_wds.py | 2 +- .../train_lcm_distill_lora_sdxl.py | 2 +- .../train_lcm_distill_lora_sdxl_wds.py | 2 +- examples/custom_diffusion/retrieve.py | 8 +- .../train_custom_diffusion.py | 24 +- examples/dreambooth/train_dreambooth.py | 2 +- examples/dreambooth/train_dreambooth_lora.py | 2 +- .../dreambooth/train_dreambooth_lora_flux.py | 2 +- .../train_dreambooth_lora_lumina2.py | 2 +- .../dreambooth/train_dreambooth_lora_sana.py | 2 +- .../dreambooth/train_dreambooth_lora_sd3.py | 2 +- .../dreambooth/train_dreambooth_lora_sdxl.py | 4 +- .../flux-control/train_control_lora_flux.py | 8 +- examples/model_search/pipeline_easy.py | 6 +- examples/research_projects/anytext/anytext.py | 6 +- .../anytext/ocr_recog/RecSVTR.py | 6 +- .../colossalai/train_dreambooth_colossalai.py | 2 +- .../controlnet/train_controlnet_webdataset.py | 7 +- .../diffusion_dpo/train_diffusion_dpo.py | 2 +- .../diffusion_dpo/train_diffusion_dpo_sdxl.py | 2 +- .../train_diffusion_orpo_sdxl_lora.py | 4 +- .../train_diffusion_orpo_sdxl_lora_wds.py | 4 +- .../train_dreambooth_lora_flux_miniature.py | 2 +- .../geodiff_molecule_conformation.ipynb | 7393 +++++++++-------- examples/research_projects/gligen/demo.ipynb | 41 +- .../train_instruct_pix2pix_lora.py | 4 +- .../train_multi_subject_dreambooth.py | 12 +- .../textual_inversion.py | 6 +- .../textual_inversion/textual_inversion.py | 6 +- .../pipeline_prompt_diffusion.py | 3 +- .../text_to_image/train_text_to_image_xla.py | 4 +- .../dreambooth/train_dreambooth.py | 2 +- .../dreambooth/train_dreambooth_lora.py | 2 +- .../dreambooth/train_dreambooth_lora_sdxl.py | 4 +- .../train_text_to_image_lora_sdxl.py | 2 +- .../train_dreambooth_lora_sd3_miniature.py | 2 +- .../train_text_to_image_lora_sdxl.py | 2 +- .../textual_inversion/textual_inversion.py | 6 +- .../textual_inversion_sdxl.py | 12 +- examples/vqgan/test_vqgan.py | 6 +- examples/vqgan/train_vqgan.py | 12 +- pyproject.toml | 2 +- scripts/convert_amused.py | 2 +- scripts/convert_consistency_to_diffusers.py | 4 +- .../convert_dance_diffusion_to_diffusers.py | 12 +- scripts/convert_diffusers_to_original_sdxl.py | 18 +- ..._diffusers_to_original_stable_diffusion.py | 20 +- ...vert_hunyuandit_controlnet_to_diffusers.py | 6 +- scripts/convert_hunyuandit_to_diffusers.py | 9 +- scripts/convert_k_upscaler_to_diffusers.py | 10 +- scripts/convert_mochi_to_diffusers.py | 118 +- ...convert_original_audioldm2_to_diffusers.py | 2 +- .../convert_original_audioldm_to_diffusers.py | 2 +- .../convert_original_musicldm_to_diffusers.py | 2 +- scripts/convert_stable_audio.py | 18 +- scripts/convert_svd_to_diffusers.py | 12 +- scripts/convert_vq_diffusion_to_diffusers.py | 24 +- setup.py | 2 +- src/diffusers/dependency_versions_table.py | 2 +- src/diffusers/loaders/ip_adapter.py | 3 +- .../loaders/lora_conversion_utils.py | 66 +- src/diffusers/models/__init__.py | 2 +- src/diffusers/models/model_loading_utils.py | 2 +- .../models/transformers/transformer_2d.py | 6 +- .../pipelines/audioldm2/pipeline_audioldm2.py | 2 +- .../controlnet/pipeline_controlnet_inpaint.py | 4 +- .../pipeline_easyanimate_inpaint.py | 2 +- .../pipeline_flux_controlnet_inpainting.py | 4 +- .../pipelines/flux/pipeline_flux_inpaint.py | 4 +- src/diffusers/pipelines/free_noise_utils.py | 6 +- .../kandinsky/pipeline_kandinsky_combined.py | 2 +- .../kandinsky/pipeline_kandinsky_inpaint.py | 2 +- .../pipelines/omnigen/processor_omnigen.py | 12 +- .../pag/pipeline_pag_controlnet_sd_inpaint.py | 6 +- .../pipelines/pag/pipeline_pag_sd_inpaint.py | 6 +- .../pag/pipeline_pag_sd_xl_inpaint.py | 6 +- .../pipeline_paint_by_example.py | 2 +- .../pipelines/pipeline_loading_utils.py | 4 +- src/diffusers/pipelines/shap_e/renderer.py | 12 +- .../stable_audio/pipeline_stable_audio.py | 2 +- .../pipeline_flax_stable_diffusion_inpaint.py | 2 +- .../pipeline_onnx_stable_diffusion_inpaint.py | 2 +- .../pipeline_stable_diffusion_inpaint.py | 6 +- ...eline_stable_diffusion_instruct_pix2pix.py | 2 +- ...ipeline_stable_diffusion_latent_upscale.py | 2 +- .../pipeline_stable_diffusion_upscale.py | 2 +- .../pipeline_stable_diffusion_3_inpaint.py | 2 +- .../pipeline_stable_diffusion_xl_inpaint.py | 6 +- .../pipelines/wan/pipeline_wan_i2v.py | 2 +- src/diffusers/quantizers/base.py | 12 +- .../scheduling_consistency_models.py | 3 +- src/diffusers/schedulers/scheduling_ddpm.py | 3 +- .../schedulers/scheduling_ddpm_parallel.py | 3 +- src/diffusers/schedulers/scheduling_lcm.py | 3 +- src/diffusers/schedulers/scheduling_tcd.py | 3 +- src/diffusers/training_utils.py | 4 +- src/diffusers/utils/deprecation_utils.py | 2 +- src/diffusers/utils/logging.py | 3 +- src/diffusers/utils/state_dict_utils.py | 2 +- src/diffusers/utils/testing_utils.py | 4 +- tests/hooks/test_hooks.py | 11 +- tests/models/test_modeling_common.py | 12 +- .../test_models_transformer_sd3.py | 12 +- .../unets/test_models_unet_2d_condition.py | 36 +- tests/others/test_image_processor.py | 30 +- tests/pipelines/amused/test_amused.py | 3 +- tests/pipelines/amused/test_amused_img2img.py | 3 +- tests/pipelines/amused/test_amused_inpaint.py | 3 +- .../aura_flow/test_pipeline_aura_flow.py | 24 +- .../blipdiffusion/test_blipdiffusion.py | 6 +- tests/pipelines/cogvideo/test_cogvideox.py | 24 +- .../cogvideo/test_cogvideox_fun_control.py | 24 +- .../cogvideo/test_cogvideox_image2video.py | 24 +- .../cogvideo/test_cogvideox_video2video.py | 24 +- .../test_controlnet_blip_diffusion.py | 6 +- .../controlnet_flux/test_controlnet_flux.py | 6 +- .../test_controlnet_flux_img2img.py | 24 +- .../test_controlnet_hunyuandit.py | 6 +- .../test_controlnet_inpaint_sd3.py | 6 +- .../controlnet_sd3/test_controlnet_sd3.py | 6 +- tests/pipelines/dit/test_dit.py | 3 +- tests/pipelines/flux/test_pipeline_flux.py | 24 +- .../flux/test_pipeline_flux_control.py | 24 +- .../test_pipeline_flux_control_inpaint.py | 24 +- .../pipelines/hunyuandit/test_hunyuan_dit.py | 31 +- tests/pipelines/kandinsky/test_kandinsky.py | 12 +- .../kandinsky/test_kandinsky_combined.py | 36 +- .../kandinsky/test_kandinsky_img2img.py | 16 +- .../kandinsky/test_kandinsky_inpaint.py | 14 +- .../pipelines/kandinsky2_2/test_kandinsky.py | 12 +- .../kandinsky2_2/test_kandinsky_combined.py | 36 +- .../kandinsky2_2/test_kandinsky_controlnet.py | 12 +- .../test_kandinsky_controlnet_img2img.py | 14 +- .../kandinsky2_2/test_kandinsky_img2img.py | 14 +- .../kandinsky2_2/test_kandinsky_inpaint.py | 14 +- tests/pipelines/kandinsky3/test_kandinsky3.py | 6 +- .../kandinsky3/test_kandinsky3_img2img.py | 6 +- tests/pipelines/pag/test_pag_animatediff.py | 6 +- tests/pipelines/pag/test_pag_controlnet_sd.py | 6 +- .../pag/test_pag_controlnet_sd_inpaint.py | 6 +- .../pipelines/pag/test_pag_controlnet_sdxl.py | 6 +- .../pag/test_pag_controlnet_sdxl_img2img.py | 6 +- tests/pipelines/pag/test_pag_hunyuan_dit.py | 24 +- tests/pipelines/pag/test_pag_kolors.py | 6 +- tests/pipelines/pag/test_pag_pixart_sigma.py | 6 +- tests/pipelines/pag/test_pag_sana.py | 6 +- tests/pipelines/pag/test_pag_sd.py | 18 +- tests/pipelines/pag/test_pag_sd3.py | 30 +- tests/pipelines/pag/test_pag_sd3_img2img.py | 18 +- tests/pipelines/pag/test_pag_sd_img2img.py | 18 +- tests/pipelines/pag/test_pag_sd_inpaint.py | 12 +- tests/pipelines/pag/test_pag_sdxl.py | 18 +- tests/pipelines/pag/test_pag_sdxl_img2img.py | 18 +- tests/pipelines/pag/test_pag_sdxl_inpaint.py | 18 +- tests/pipelines/pixart_sigma/test_pixart.py | 24 +- tests/pipelines/shap_e/test_shap_e_img2img.py | 2 +- .../test_stable_cascade_combined.py | 12 +- .../stable_diffusion/test_stable_diffusion.py | 48 +- .../test_pipeline_stable_diffusion_3.py | 24 +- .../test_stable_diffusion_xl.py | 30 +- .../test_stable_diffusion_xl_inpaint.py | 12 +- tests/pipelines/test_pipelines.py | 24 +- tests/pipelines/test_pipelines_common.py | 72 +- .../wuerstchen/test_wuerstchen_combined.py | 12 +- tests/schedulers/test_scheduler_dpm_multi.py | 6 +- tests/schedulers/test_scheduler_dpm_single.py | 6 +- .../test_scheduler_edm_dpmsolver_multistep.py | 6 +- tests/schedulers/test_scheduler_euler.py | 12 +- tests/schedulers/test_scheduler_heun.py | 6 +- .../single_file/single_file_testing_utils.py | 24 +- tests/single_file/test_lumina2_transformer.py | 6 +- .../test_model_autoencoder_dc_single_file.py | 18 +- .../test_model_controlnet_single_file.py | 6 +- ...test_model_flux_transformer_single_file.py | 6 +- .../test_model_motion_adapter_single_file.py | 24 +- .../test_model_sd_cascade_unet_single_file.py | 24 +- .../single_file/test_model_vae_single_file.py | 6 +- .../test_model_wan_autoencoder_single_file.py | 6 +- ...est_model_wan_transformer3d_single_file.py | 12 +- tests/single_file/test_sana_transformer.py | 6 +- utils/log_reports.py | 2 +- utils/update_metadata.py | 3 +- 200 files changed, 4753 insertions(+), 4694 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index f45e0a51d2..dc774d145c 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -839,9 +839,9 @@ class TokenEmbeddingsHandler: idx = 0 for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders): assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings." - assert all( - isinstance(tok, str) for tok in inserting_toks - ), "All elements in inserting_toks should be strings." + assert all(isinstance(tok, str) for tok in inserting_toks), ( + "All elements in inserting_toks should be strings." + ) self.inserting_toks = inserting_toks special_tokens_dict = {"additional_special_tokens": self.inserting_toks} @@ -1605,7 +1605,7 @@ def main(args): lora_state_dict = FluxPipeline.lora_state_dict(input_dir) transformer_state_dict = { - f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") + f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.") } transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py index 8cd1d777c0..95ba53391c 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py @@ -200,7 +200,8 @@ Special VAE used for training: {vae_path}. "diffusers", "diffusers-training", lora, - "template:sd-lora" "stable-diffusion", + "template:sd-lora", + "stable-diffusion", "stable-diffusion-diffusers", ] model_card = populate_model_card(model_card, tags=tags) @@ -724,9 +725,9 @@ class TokenEmbeddingsHandler: idx = 0 for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders): assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings." - assert all( - isinstance(tok, str) for tok in inserting_toks - ), "All elements in inserting_toks should be strings." + assert all(isinstance(tok, str) for tok in inserting_toks), ( + "All elements in inserting_toks should be strings." + ) self.inserting_toks = inserting_toks special_tokens_dict = {"additional_special_tokens": self.inserting_toks} @@ -746,9 +747,9 @@ class TokenEmbeddingsHandler: .to(dtype=self.dtype) * std_token_embedding ) - self.embeddings_settings[ - f"original_embeddings_{idx}" - ] = text_encoder.text_model.embeddings.token_embedding.weight.data.clone() + self.embeddings_settings[f"original_embeddings_{idx}"] = ( + text_encoder.text_model.embeddings.token_embedding.weight.data.clone() + ) self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding inu = torch.ones((len(tokenizer),), dtype=torch.bool) @@ -1322,7 +1323,7 @@ def main(args): lora_state_dict, network_alphas = StableDiffusionPipeline.lora_state_dict(input_dir) - unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} + unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")} unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") if incompatible_keys is not None: diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index f8253715e6..236dc20d62 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -890,9 +890,9 @@ class TokenEmbeddingsHandler: idx = 0 for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders): assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings." - assert all( - isinstance(tok, str) for tok in inserting_toks - ), "All elements in inserting_toks should be strings." + assert all(isinstance(tok, str) for tok in inserting_toks), ( + "All elements in inserting_toks should be strings." + ) self.inserting_toks = inserting_toks special_tokens_dict = {"additional_special_tokens": self.inserting_toks} @@ -912,9 +912,9 @@ class TokenEmbeddingsHandler: .to(dtype=self.dtype) * std_token_embedding ) - self.embeddings_settings[ - f"original_embeddings_{idx}" - ] = text_encoder.text_model.embeddings.token_embedding.weight.data.clone() + self.embeddings_settings[f"original_embeddings_{idx}"] = ( + text_encoder.text_model.embeddings.token_embedding.weight.data.clone() + ) self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding inu = torch.ones((len(tokenizer),), dtype=torch.bool) @@ -1647,7 +1647,7 @@ def main(args): lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir) - unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} + unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")} unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") if incompatible_keys is not None: diff --git a/examples/amused/train_amused.py b/examples/amused/train_amused.py index df44a0a63a..d71d9ccbb8 100644 --- a/examples/amused/train_amused.py +++ b/examples/amused/train_amused.py @@ -720,7 +720,7 @@ def main(args): # Train! logger.info("***** Running training *****") logger.info(f" Num training steps = {args.max_train_steps}") - logger.info(f" Instantaneous batch size per device = { args.train_batch_size}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") diff --git a/examples/cogvideo/train_cogvideox_image_to_video_lora.py b/examples/cogvideo/train_cogvideox_image_to_video_lora.py index eed8305f4f..35d4d15622 100644 --- a/examples/cogvideo/train_cogvideox_image_to_video_lora.py +++ b/examples/cogvideo/train_cogvideox_image_to_video_lora.py @@ -1138,7 +1138,7 @@ def main(args): lora_state_dict = CogVideoXImageToVideoPipeline.lora_state_dict(input_dir) transformer_state_dict = { - f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") + f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.") } transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") diff --git a/examples/cogvideo/train_cogvideox_lora.py b/examples/cogvideo/train_cogvideox_lora.py index 74ea98cbac..bf09ff02ae 100644 --- a/examples/cogvideo/train_cogvideox_lora.py +++ b/examples/cogvideo/train_cogvideox_lora.py @@ -1159,7 +1159,7 @@ def main(args): lora_state_dict = CogVideoXPipeline.lora_state_dict(input_dir) transformer_state_dict = { - f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") + f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.") } transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") diff --git a/examples/community/adaptive_mask_inpainting.py b/examples/community/adaptive_mask_inpainting.py index df73695648..81f9527b47 100644 --- a/examples/community/adaptive_mask_inpainting.py +++ b/examples/community/adaptive_mask_inpainting.py @@ -1103,7 +1103,7 @@ class AdaptiveMaskInpaintPipeline( f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of" " `pipeline.unet` or your `default_mask_image` or `image` input." ) elif num_channels_unet != 4: diff --git a/examples/community/hd_painter.py b/examples/community/hd_painter.py index 91ebe07610..9d7b95b62c 100644 --- a/examples/community/hd_painter.py +++ b/examples/community/hd_painter.py @@ -686,7 +686,7 @@ class StableDiffusionHDPainterPipeline(StableDiffusionInpaintPipeline): f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of" " `pipeline.unet` or your `mask_image` or `image` input." ) elif num_channels_unet != 4: diff --git a/examples/community/img2img_inpainting.py b/examples/community/img2img_inpainting.py index 292c9aa2bc..001e4cc5b2 100644 --- a/examples/community/img2img_inpainting.py +++ b/examples/community/img2img_inpainting.py @@ -362,7 +362,7 @@ class ImageToImageInpaintingPipeline(DiffusionPipeline): f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of" " `pipeline.unet` or your `mask_image` or `image` input." ) diff --git a/examples/community/llm_grounded_diffusion.py b/examples/community/llm_grounded_diffusion.py index 129793dae6..814694f1e3 100644 --- a/examples/community/llm_grounded_diffusion.py +++ b/examples/community/llm_grounded_diffusion.py @@ -1120,7 +1120,7 @@ class LLMGroundedDiffusionPipeline( if verbose: logger.info( - f"time index {index}, loss: {loss.item()/loss_scale:.3f} (de-scaled with scale {loss_scale:.1f}), loss threshold: {loss_threshold:.3f}" + f"time index {index}, loss: {loss.item() / loss_scale:.3f} (de-scaled with scale {loss_scale:.1f}), loss threshold: {loss_threshold:.3f}" ) try: @@ -1184,7 +1184,7 @@ class LLMGroundedDiffusionPipeline( if verbose: logger.info( - f"time index {index}, loss: {loss.item()/loss_scale:.3f}, loss threshold: {loss_threshold:.3f}, iteration: {iteration}" + f"time index {index}, loss: {loss.item() / loss_scale:.3f}, loss threshold: {loss_threshold:.3f}, iteration: {iteration}" ) finally: diff --git a/examples/community/mod_controlnet_tile_sr_sdxl.py b/examples/community/mod_controlnet_tile_sr_sdxl.py index 80bed2365d..3db2645a78 100644 --- a/examples/community/mod_controlnet_tile_sr_sdxl.py +++ b/examples/community/mod_controlnet_tile_sr_sdxl.py @@ -701,7 +701,7 @@ class StableDiffusionXLControlNetTileSRPipeline( raise ValueError("`max_tile_size` cannot be None.") elif not isinstance(max_tile_size, int) or max_tile_size not in (1024, 1280): raise ValueError( - f"`max_tile_size` has to be in 1024 or 1280 but is {max_tile_size} of type" f" {type(max_tile_size)}." + f"`max_tile_size` has to be in 1024 or 1280 but is {max_tile_size} of type {type(max_tile_size)}." ) if tile_gaussian_sigma is None: raise ValueError("`tile_gaussian_sigma` cannot be None.") diff --git a/examples/community/pipeline_flux_differential_img2img.py b/examples/community/pipeline_flux_differential_img2img.py index 9d6be763a0..5dc321ea98 100644 --- a/examples/community/pipeline_flux_differential_img2img.py +++ b/examples/community/pipeline_flux_differential_img2img.py @@ -488,7 +488,7 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin): if padding_mask_crop is not None: if not isinstance(image, PIL.Image.Image): raise ValueError( - f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." + f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}." ) if not isinstance(mask_image, PIL.Image.Image): raise ValueError( @@ -496,7 +496,7 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin): f" {type(mask_image)}." ) if output_type != "pil": - raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") + raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.") if max_sequence_length is not None and max_sequence_length > 512: raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") diff --git a/examples/community/pipeline_prompt2prompt.py b/examples/community/pipeline_prompt2prompt.py index 736f00799e..b9985542cc 100644 --- a/examples/community/pipeline_prompt2prompt.py +++ b/examples/community/pipeline_prompt2prompt.py @@ -907,12 +907,12 @@ def create_controller( # reweight if edit_type == "reweight": - assert ( - equalizer_words is not None and equalizer_strengths is not None - ), "To use reweight edit, please specify equalizer_words and equalizer_strengths." - assert len(equalizer_words) == len( - equalizer_strengths - ), "equalizer_words and equalizer_strengths must be of same length." + assert equalizer_words is not None and equalizer_strengths is not None, ( + "To use reweight edit, please specify equalizer_words and equalizer_strengths." + ) + assert len(equalizer_words) == len(equalizer_strengths), ( + "equalizer_words and equalizer_strengths must be of same length." + ) equalizer = get_equalizer(prompts[1], equalizer_words, equalizer_strengths, tokenizer=tokenizer) return AttentionReweight( prompts, diff --git a/examples/community/pipeline_sdxl_style_aligned.py b/examples/community/pipeline_sdxl_style_aligned.py index 9377caf7ba..6aebb6c18d 100644 --- a/examples/community/pipeline_sdxl_style_aligned.py +++ b/examples/community/pipeline_sdxl_style_aligned.py @@ -1738,7 +1738,7 @@ class StyleAlignedSDXLPipeline( f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of" " `pipeline.unet` or your `mask_image` or `image` input." ) elif num_channels_unet != 4: diff --git a/examples/community/pipeline_stable_diffusion_upscale_ldm3d.py b/examples/community/pipeline_stable_diffusion_upscale_ldm3d.py index 8a709ab467..6c63f53e81 100644 --- a/examples/community/pipeline_stable_diffusion_upscale_ldm3d.py +++ b/examples/community/pipeline_stable_diffusion_upscale_ldm3d.py @@ -689,7 +689,7 @@ class StableDiffusionUpscaleLDM3DPipeline( f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_image`: {num_channels_image} " - f" = {num_channels_latents+num_channels_image}. Please verify the config of" + f" = {num_channels_latents + num_channels_image}. Please verify the config of" " `pipeline.unet` or your `image` input." ) diff --git a/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py b/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py index 1269a69f0d..8459553f4e 100644 --- a/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py +++ b/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py @@ -1028,7 +1028,7 @@ class StableDiffusionXL_AE_Pipeline( if padding_mask_crop is not None: if not isinstance(image, PIL.Image.Image): raise ValueError( - f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." + f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}." ) if not isinstance(mask_image, PIL.Image.Image): raise ValueError( @@ -1036,7 +1036,7 @@ class StableDiffusionXL_AE_Pipeline( f" {type(mask_image)}." ) if output_type != "pil": - raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") + raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.") if ip_adapter_image is not None and ip_adapter_image_embeds is not None: raise ValueError( @@ -2050,7 +2050,7 @@ class StableDiffusionXL_AE_Pipeline( f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of" " `pipeline.unet` or your `mask_image` or `image` input." ) elif num_channels_unet != 4: diff --git a/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py b/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py index 8480117866..6a0ed3523d 100644 --- a/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py +++ b/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py @@ -1578,7 +1578,7 @@ class StableDiffusionXLControlNetAdapterInpaintPipeline( f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of" " `pipeline.unet` or your `mask_image` or `image` input." ) elif num_channels_unet != 4: diff --git a/examples/community/scheduling_ufogen.py b/examples/community/scheduling_ufogen.py index 4b1b92ff18..0b832394cf 100644 --- a/examples/community/scheduling_ufogen.py +++ b/examples/community/scheduling_ufogen.py @@ -288,8 +288,7 @@ class UFOGenScheduler(SchedulerMixin, ConfigMixin): if timesteps[0] >= self.config.num_train_timesteps: raise ValueError( - f"`timesteps` must start before `self.config.train_timesteps`:" - f" {self.config.num_train_timesteps}." + f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}." ) timesteps = np.array(timesteps, dtype=np.int64) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py index 2045e78093..28fc7c73e6 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py @@ -89,7 +89,7 @@ def get_module_kohya_state_dict(module, prefix: str, dtype: torch.dtype, adapter # Set alpha parameter if "lora_down" in kohya_key: - alpha_key = f'{kohya_key.split(".")[0]}.alpha' + alpha_key = f"{kohya_key.split('.')[0]}.alpha" kohya_ss_state_dict[alpha_key] = torch.tensor(module.peft_config[adapter_name].lora_alpha).to(dtype) return kohya_ss_state_dict diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 38fe94ed3f..61d883fdfb 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -901,7 +901,7 @@ def main(args): unet_ = accelerator.unwrap_model(unet) lora_state_dict, _ = StableDiffusionXLPipeline.lora_state_dict(input_dir) unet_state_dict = { - f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.") + f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.") } unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py index fdb789c216..4324f81b96 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py @@ -95,7 +95,7 @@ def get_module_kohya_state_dict(module, prefix: str, dtype: torch.dtype, adapter # Set alpha parameter if "lora_down" in kohya_key: - alpha_key = f'{kohya_key.split(".")[0]}.alpha' + alpha_key = f"{kohya_key.split('.')[0]}.alpha" kohya_ss_state_dict[alpha_key] = torch.tensor(module.peft_config[adapter_name].lora_alpha).to(dtype) return kohya_ss_state_dict diff --git a/examples/custom_diffusion/retrieve.py b/examples/custom_diffusion/retrieve.py index a28fe344d9..27f4b4e0dc 100644 --- a/examples/custom_diffusion/retrieve.py +++ b/examples/custom_diffusion/retrieve.py @@ -50,9 +50,11 @@ def retrieve(class_prompt, class_data_dir, num_class_images): total = 0 pbar = tqdm(desc="downloading real regularization images", total=num_class_images) - with open(f"{class_data_dir}/caption.txt", "w") as f1, open(f"{class_data_dir}/urls.txt", "w") as f2, open( - f"{class_data_dir}/images.txt", "w" - ) as f3: + with ( + open(f"{class_data_dir}/caption.txt", "w") as f1, + open(f"{class_data_dir}/urls.txt", "w") as f2, + open(f"{class_data_dir}/images.txt", "w") as f3, + ): while total < num_class_images: images = class_images[count] count += 1 diff --git a/examples/custom_diffusion/train_custom_diffusion.py b/examples/custom_diffusion/train_custom_diffusion.py index ea1449f9f3..fa2959cf41 100644 --- a/examples/custom_diffusion/train_custom_diffusion.py +++ b/examples/custom_diffusion/train_custom_diffusion.py @@ -731,18 +731,18 @@ def main(args): if not class_images_dir.exists(): class_images_dir.mkdir(parents=True, exist_ok=True) if args.real_prior: - assert ( - class_images_dir / "images" - ).exists(), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}" - assert ( - len(list((class_images_dir / "images").iterdir())) == args.num_class_images - ), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}" - assert ( - class_images_dir / "caption.txt" - ).exists(), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}" - assert ( - class_images_dir / "images.txt" - ).exists(), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}" + assert (class_images_dir / "images").exists(), ( + f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}' + ) + assert len(list((class_images_dir / "images").iterdir())) == args.num_class_images, ( + f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}' + ) + assert (class_images_dir / "caption.txt").exists(), ( + f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}' + ) + assert (class_images_dir / "images.txt").exists(), ( + f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}' + ) concept["class_prompt"] = os.path.join(class_images_dir, "caption.txt") concept["class_data_dir"] = os.path.join(class_images_dir, "images.txt") args.concepts_list[i] = concept diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index b863f56412..43e680610e 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -1014,7 +1014,7 @@ def main(args): if args.train_text_encoder and unwrap_model(text_encoder).dtype != torch.float32: raise ValueError( - f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}." f" {low_precision_error_string}" + f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}. {low_precision_error_string}" ) # Enable TF32 for faster training on Ampere GPUs, diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 9584e7762d..7f8d06f34a 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -982,7 +982,7 @@ def main(args): lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir) - unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} + unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")} unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index debdafd04b..febf7e51c6 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1294,7 +1294,7 @@ def main(args): lora_state_dict = FluxPipeline.lora_state_dict(input_dir) transformer_state_dict = { - f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") + f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.") } transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") diff --git a/examples/dreambooth/train_dreambooth_lora_lumina2.py b/examples/dreambooth/train_dreambooth_lora_lumina2.py index a8bf4e1cdc..d2cedc2486 100644 --- a/examples/dreambooth/train_dreambooth_lora_lumina2.py +++ b/examples/dreambooth/train_dreambooth_lora_lumina2.py @@ -1053,7 +1053,7 @@ def main(args): lora_state_dict = Lumina2Text2ImgPipeline.lora_state_dict(input_dir) transformer_state_dict = { - f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") + f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.") } transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") diff --git a/examples/dreambooth/train_dreambooth_lora_sana.py b/examples/dreambooth/train_dreambooth_lora_sana.py index 674cb0d1ad..899b1ff679 100644 --- a/examples/dreambooth/train_dreambooth_lora_sana.py +++ b/examples/dreambooth/train_dreambooth_lora_sana.py @@ -1064,7 +1064,7 @@ def main(args): lora_state_dict = SanaPipeline.lora_state_dict(input_dir) transformer_state_dict = { - f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") + f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.") } transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index 4a08daaf61..63cef5d176 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -1355,7 +1355,7 @@ def main(args): lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir) transformer_state_dict = { - f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") + f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.") } transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 735d48b834..37241b8f9e 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -118,7 +118,7 @@ def save_model_card( ) model_description = f""" -# {'SDXL' if 'playground' not in base_model else 'Playground'} LoRA DreamBooth - {repo_id} +# {"SDXL" if "playground" not in base_model else "Playground"} LoRA DreamBooth - {repo_id} @@ -1286,7 +1286,7 @@ def main(args): lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir) - unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} + unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")} unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") if incompatible_keys is not None: diff --git a/examples/flux-control/train_control_lora_flux.py b/examples/flux-control/train_control_lora_flux.py index 56c5f2a89a..2a9bfd949c 100644 --- a/examples/flux-control/train_control_lora_flux.py +++ b/examples/flux-control/train_control_lora_flux.py @@ -91,9 +91,9 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f torch_dtype=weight_dtype, ) pipeline.load_lora_weights(args.output_dir) - assert ( - pipeline.transformer.config.in_channels == initial_channels * 2 - ), f"{pipeline.transformer.config.in_channels=}" + assert pipeline.transformer.config.in_channels == initial_channels * 2, ( + f"{pipeline.transformer.config.in_channels=}" + ) pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) @@ -954,7 +954,7 @@ def main(args): lora_state_dict = FluxControlPipeline.lora_state_dict(input_dir) transformer_lora_state_dict = { - f'{k.replace("transformer.", "")}': v + f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.") and "lora" in k } diff --git a/examples/model_search/pipeline_easy.py b/examples/model_search/pipeline_easy.py index a8add83110..b82e98fb71 100644 --- a/examples/model_search/pipeline_easy.py +++ b/examples/model_search/pipeline_easy.py @@ -1081,9 +1081,9 @@ class AutoConfig: f"textual_inversion_path: {search_word} -> {textual_inversion_path.model_status.site_url}" ) - pretrained_model_name_or_paths[ - pretrained_model_name_or_paths.index(search_word) - ] = textual_inversion_path.model_path + pretrained_model_name_or_paths[pretrained_model_name_or_paths.index(search_word)] = ( + textual_inversion_path.model_path + ) self.load_textual_inversion( pretrained_model_name_or_paths, token=tokens, tokenizer=tokenizer, text_encoder=text_encoder, **kwargs diff --git a/examples/research_projects/anytext/anytext.py b/examples/research_projects/anytext/anytext.py index 5c30b24efe..2e96014c41 100644 --- a/examples/research_projects/anytext/anytext.py +++ b/examples/research_projects/anytext/anytext.py @@ -187,9 +187,9 @@ def get_clip_token_for_string(tokenizer, string): return_tensors="pt", ) tokens = batch_encoding["input_ids"] - assert ( - torch.count_nonzero(tokens - 49407) == 2 - ), f"String '{string}' maps to more than a single token. Please use another string" + assert torch.count_nonzero(tokens - 49407) == 2, ( + f"String '{string}' maps to more than a single token. Please use another string" + ) return tokens[0, 1] diff --git a/examples/research_projects/anytext/ocr_recog/RecSVTR.py b/examples/research_projects/anytext/ocr_recog/RecSVTR.py index 590a96995b..3dc813b84a 100644 --- a/examples/research_projects/anytext/ocr_recog/RecSVTR.py +++ b/examples/research_projects/anytext/ocr_recog/RecSVTR.py @@ -312,9 +312,9 @@ class PatchEmbed(nn.Module): def forward(self, x): B, C, H, W = x.shape - assert ( - H == self.img_size[0] and W == self.img_size[1] - ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + assert H == self.img_size[0] and W == self.img_size[1], ( + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + ) x = self.proj(x).flatten(2).permute(0, 2, 1) return x diff --git a/examples/research_projects/colossalai/train_dreambooth_colossalai.py b/examples/research_projects/colossalai/train_dreambooth_colossalai.py index 10c8e095a6..4e541b8d3a 100644 --- a/examples/research_projects/colossalai/train_dreambooth_colossalai.py +++ b/examples/research_projects/colossalai/train_dreambooth_colossalai.py @@ -619,7 +619,7 @@ def main(args): optimizer.step() lr_scheduler.step() - logger.info(f"max GPU_mem cost is {torch.cuda.max_memory_allocated()/2**20} MB", ranks=[0]) + logger.info(f"max GPU_mem cost is {torch.cuda.max_memory_allocated() / 2**20} MB", ranks=[0]) # Checks if the accelerator has performed an optimization step behind the scenes progress_bar.update(1) global_step += 1 diff --git a/examples/research_projects/controlnet/train_controlnet_webdataset.py b/examples/research_projects/controlnet/train_controlnet_webdataset.py index 829b003115..9744bc7be2 100644 --- a/examples/research_projects/controlnet/train_controlnet_webdataset.py +++ b/examples/research_projects/controlnet/train_controlnet_webdataset.py @@ -803,21 +803,20 @@ def parse_args(input_args=None): "--control_type", type=str, default="canny", - help=("The type of controlnet conditioning image to use. One of `canny`, `depth`" " Defaults to `canny`."), + help=("The type of controlnet conditioning image to use. One of `canny`, `depth` Defaults to `canny`."), ) parser.add_argument( "--transformer_layers_per_block", type=str, default=None, - help=("The number of layers per block in the transformer. If None, defaults to" " `args.transformer_layers`."), + help=("The number of layers per block in the transformer. If None, defaults to `args.transformer_layers`."), ) parser.add_argument( "--old_style_controlnet", action="store_true", default=False, help=( - "Use the old style controlnet, which is a single transformer layer with" - " a single head. Defaults to False." + "Use the old style controlnet, which is a single transformer layer with a single head. Defaults to False." ), ) diff --git a/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py b/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py index ab88d49677..0b9c248ed0 100644 --- a/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py +++ b/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py @@ -86,7 +86,7 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st def log_validation(args, unet, accelerator, weight_dtype, epoch, is_final_validation=False): - logger.info(f"Running validation... \n Generating images with prompts:\n" f" {VALIDATION_PROMPTS}.") + logger.info(f"Running validation... \n Generating images with prompts:\n {VALIDATION_PROMPTS}.") # create pipeline pipeline = DiffusionPipeline.from_pretrained( diff --git a/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py b/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py index 0297a06f5b..f0afa12e9c 100644 --- a/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py +++ b/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py @@ -91,7 +91,7 @@ def import_model_class_from_model_name_or_path( def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False): - logger.info(f"Running validation... \n Generating images with prompts:\n" f" {VALIDATION_PROMPTS}.") + logger.info(f"Running validation... \n Generating images with prompts:\n {VALIDATION_PROMPTS}.") if is_final_validation: if args.mixed_precision == "fp16": diff --git a/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py b/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py index ed245e9cef..12eb67d4a7 100644 --- a/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py +++ b/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py @@ -91,7 +91,7 @@ def import_model_class_from_model_name_or_path( def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False): - logger.info(f"Running validation... \n Generating images with prompts:\n" f" {VALIDATION_PROMPTS}.") + logger.info(f"Running validation... \n Generating images with prompts:\n {VALIDATION_PROMPTS}.") if is_final_validation: if args.mixed_precision == "fp16": @@ -683,7 +683,7 @@ def main(args): lora_state_dict, network_alphas = StableDiffusionXLLoraLoaderMixin.lora_state_dict(input_dir) - unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} + unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")} unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") if incompatible_keys is not None: diff --git a/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py b/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py index 66a7a36529..a5d89f77d6 100644 --- a/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py +++ b/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py @@ -89,7 +89,7 @@ def import_model_class_from_model_name_or_path( def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False): - logger.info(f"Running validation... \n Generating images with prompts:\n" f" {VALIDATION_PROMPTS}.") + logger.info(f"Running validation... \n Generating images with prompts:\n {VALIDATION_PROMPTS}.") if is_final_validation: if args.mixed_precision == "fp16": @@ -790,7 +790,7 @@ def main(args): lora_state_dict, network_alphas = StableDiffusionXLLoraLoaderMixin.lora_state_dict(input_dir) - unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} + unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")} unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") if incompatible_keys is not None: diff --git a/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_miniature.py b/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_miniature.py index ccaf3164a0..cc535bbaaa 100644 --- a/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_miniature.py +++ b/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_miniature.py @@ -783,7 +783,7 @@ def main(args): lora_state_dict = FluxPipeline.lora_state_dict(input_dir) transformer_state_dict = { - f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") + f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.") } transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") diff --git a/examples/research_projects/geodiff/geodiff_molecule_conformation.ipynb b/examples/research_projects/geodiff/geodiff_molecule_conformation.ipynb index bde093802a..aa5951723a 100644 --- a/examples/research_projects/geodiff/geodiff_molecule_conformation.ipynb +++ b/examples/research_projects/geodiff/geodiff_molecule_conformation.ipynb @@ -1,3652 +1,3745 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "F88mignPnalS" - }, - "source": [ - "# Introduction\n", - "\n", - "This colab is design to run the pretrained models from [GeoDiff](https://github.com/MinkaiXu/GeoDiff).\n", - "The visualization code is inspired by this PyMol [colab](https://colab.research.google.com/gist/iwatobipen/2ec7faeafe5974501e69fcc98c122922/pymol.ipynb#scrollTo=Hm4kY7CaZSlw).\n", - "\n", - "The goal is to generate physically accurate molecules. Given the input of a molecule graph (atom and bond structures with their connectivity -- in the form of a 2d graph). What we want to generate is a stable 3d structure of the molecule.\n", - "\n", - "This colab uses GEOM datasets that have multiple 3d targets per configuration, which provide more compelling targets for generative methods.\n", - "\n", - "> Colab made by [natolambert](https://twitter.com/natolambert).\n", - "\n", - "![diffusers_library](https://github.com/huggingface/diffusers/raw/main/docs/source/imgs/diffusers_library.jpg)\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "7cnwXMocnuzB" - }, - "source": [ - "## Installations\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "source": [ - "### Install Conda" - ], - "metadata": { - "id": "ff9SxWnaNId9" - } - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1g_6zOabItDk" - }, - "source": [ - "Here we check the `cuda` version of colab. When this was built, the version was always 11.1, which impacts some installation decisions below." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "K0ofXobG5Y-X", - "outputId": "572c3d25-6f19-4c1e-83f5-a1d084a3207f" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "nvcc: NVIDIA (R) Cuda compiler driver\n", - "Copyright (c) 2005-2021 NVIDIA Corporation\n", - "Built on Sun_Feb_14_21:12:58_PST_2021\n", - "Cuda compilation tools, release 11.2, V11.2.152\n", - "Build cuda_11.2.r11.2/compiler.29618528_0\n" - ] - } - ], - "source": [ - "!nvcc --version" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "VfthW90vI0nw" - }, - "source": [ - "Install Conda for some more complex dependencies for geometric networks." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "2WNFzSnbiE0k", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "690d0d4d-9d0a-4ead-c6dc-086f113f532f" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", - "\u001b[0m" - ] - } - ], - "source": [ - "!pip install -q condacolab" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "NUsbWYCUI7Km" - }, - "source": [ - "Setup Conda" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "FZelreINdmd0", - "outputId": "635f0cb8-0af4-499f-e0a4-b3790cb12e9f" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "✨🍰✨ Everything looks OK!\n" - ] - } - ], - "source": [ - "import condacolab\n", - "condacolab.install()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "JzDHaPU7I9Sn" - }, - "source": [ - "Install pytorch requirements (this takes a few minutes, go grab yourself a coffee 🤗)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "JMxRjHhL7w8V", - "outputId": "6ed511b3-9262-49e8-b340-08e76b05ebd8" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Collecting package metadata (current_repodata.json): - \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\bdone\n", - "Solving environment: \\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n", - "\n", - "## Package Plan ##\n", - "\n", - " environment location: /usr/local\n", - "\n", - " added / updated specs:\n", - " - cudatoolkit=11.1\n", - " - pytorch\n", - " - torchaudio\n", - " - torchvision\n", - "\n", - "\n", - "The following packages will be downloaded:\n", - "\n", - " package | build\n", - " ---------------------------|-----------------\n", - " conda-22.9.0 | py37h89c1867_1 960 KB conda-forge\n", - " ------------------------------------------------------------\n", - " Total: 960 KB\n", - "\n", - "The following packages will be UPDATED:\n", - "\n", - " conda 4.14.0-py37h89c1867_0 --> 22.9.0-py37h89c1867_1\n", - "\n", - "\n", - "\n", - "Downloading and Extracting Packages\n", - "conda-22.9.0 | 960 KB | : 100% 1.0/1 [00:00<00:00, 4.15it/s]\n", - "Preparing transaction: / \b\bdone\n", - "Verifying transaction: \\ \b\bdone\n", - "Executing transaction: / \b\bdone\n", - "Retrieving notices: ...working... done\n" - ] - } - ], - "source": [ - "!conda install pytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch-lts -c nvidia\n", - "# !conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge" - ] - }, - { - "cell_type": "markdown", - "source": [ - "Need to remove a pathspec for colab that specifies the incorrect cuda version." - ], - "metadata": { - "id": "QDS6FPZ0Tu5b" - } - }, - { - "cell_type": "code", - "source": [ - "!rm /usr/local/conda-meta/pinned" - ], - "metadata": { - "id": "dq1lxR10TtrR", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "ed9c5a71-b449-418f-abb7-072b74e7f6c8" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "rm: cannot remove '/usr/local/conda-meta/pinned': No such file or directory\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Z1L3DdZOJB30" - }, - "source": [ - "Install torch geometric (used in the model later)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "D5ukfCOWfjzK", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "8437485a-5aa6-4d53-8f7f-23517ac1ace6" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Collecting package metadata (current_repodata.json): - \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\bdone\n", - "Solving environment: | \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n", - "\n", - "## Package Plan ##\n", - "\n", - " environment location: /usr/local\n", - "\n", - " added / updated specs:\n", - " - pytorch-geometric=1.7.2\n", - "\n", - "\n", - "The following packages will be downloaded:\n", - "\n", - " package | build\n", - " ---------------------------|-----------------\n", - " decorator-4.4.2 | py_0 11 KB conda-forge\n", - " googledrivedownloader-0.4 | pyhd3deb0d_1 7 KB conda-forge\n", - " jinja2-3.1.2 | pyhd8ed1ab_1 99 KB conda-forge\n", - " joblib-1.2.0 | pyhd8ed1ab_0 205 KB conda-forge\n", - " markupsafe-2.1.1 | py37h540881e_1 22 KB conda-forge\n", - " networkx-2.5.1 | pyhd8ed1ab_0 1.2 MB conda-forge\n", - " pandas-1.2.3 | py37hdc94413_0 11.8 MB conda-forge\n", - " pyparsing-3.0.9 | pyhd8ed1ab_0 79 KB conda-forge\n", - " python-dateutil-2.8.2 | pyhd8ed1ab_0 240 KB conda-forge\n", - " python-louvain-0.15 | pyhd8ed1ab_1 13 KB conda-forge\n", - " pytorch-cluster-1.5.9 |py37_torch_1.8.0_cu111 1.2 MB rusty1s\n", - " pytorch-geometric-1.7.2 |py37_torch_1.8.0_cu111 445 KB rusty1s\n", - " pytorch-scatter-2.0.8 |py37_torch_1.8.0_cu111 6.1 MB rusty1s\n", - " pytorch-sparse-0.6.12 |py37_torch_1.8.0_cu111 2.9 MB rusty1s\n", - " pytorch-spline-conv-1.2.1 |py37_torch_1.8.0_cu111 736 KB rusty1s\n", - " pytz-2022.4 | pyhd8ed1ab_0 232 KB conda-forge\n", - " scikit-learn-1.0.2 | py37hf9e9bfc_0 7.8 MB conda-forge\n", - " scipy-1.7.3 | py37hf2a6cf1_0 21.8 MB conda-forge\n", - " setuptools-59.8.0 | py37h89c1867_1 1.0 MB conda-forge\n", - " threadpoolctl-3.1.0 | pyh8a188c0_0 18 KB conda-forge\n", - " ------------------------------------------------------------\n", - " Total: 55.9 MB\n", - "\n", - "The following NEW packages will be INSTALLED:\n", - "\n", - " decorator conda-forge/noarch::decorator-4.4.2-py_0 None\n", - " googledrivedownlo~ conda-forge/noarch::googledrivedownloader-0.4-pyhd3deb0d_1 None\n", - " jinja2 conda-forge/noarch::jinja2-3.1.2-pyhd8ed1ab_1 None\n", - " joblib conda-forge/noarch::joblib-1.2.0-pyhd8ed1ab_0 None\n", - " markupsafe conda-forge/linux-64::markupsafe-2.1.1-py37h540881e_1 None\n", - " networkx conda-forge/noarch::networkx-2.5.1-pyhd8ed1ab_0 None\n", - " pandas conda-forge/linux-64::pandas-1.2.3-py37hdc94413_0 None\n", - " pyparsing conda-forge/noarch::pyparsing-3.0.9-pyhd8ed1ab_0 None\n", - " python-dateutil conda-forge/noarch::python-dateutil-2.8.2-pyhd8ed1ab_0 None\n", - " python-louvain conda-forge/noarch::python-louvain-0.15-pyhd8ed1ab_1 None\n", - " pytorch-cluster rusty1s/linux-64::pytorch-cluster-1.5.9-py37_torch_1.8.0_cu111 None\n", - " pytorch-geometric rusty1s/linux-64::pytorch-geometric-1.7.2-py37_torch_1.8.0_cu111 None\n", - " pytorch-scatter rusty1s/linux-64::pytorch-scatter-2.0.8-py37_torch_1.8.0_cu111 None\n", - " pytorch-sparse rusty1s/linux-64::pytorch-sparse-0.6.12-py37_torch_1.8.0_cu111 None\n", - " pytorch-spline-co~ rusty1s/linux-64::pytorch-spline-conv-1.2.1-py37_torch_1.8.0_cu111 None\n", - " pytz conda-forge/noarch::pytz-2022.4-pyhd8ed1ab_0 None\n", - " scikit-learn conda-forge/linux-64::scikit-learn-1.0.2-py37hf9e9bfc_0 None\n", - " scipy conda-forge/linux-64::scipy-1.7.3-py37hf2a6cf1_0 None\n", - " threadpoolctl conda-forge/noarch::threadpoolctl-3.1.0-pyh8a188c0_0 None\n", - "\n", - "The following packages will be DOWNGRADED:\n", - "\n", - " setuptools 65.3.0-py37h89c1867_0 --> 59.8.0-py37h89c1867_1 None\n", - "\n", - "\n", - "\n", - "Downloading and Extracting Packages\n", - "scikit-learn-1.0.2 | 7.8 MB | : 100% 1.0/1 [00:01<00:00, 1.37s/it] \n", - "pytorch-scatter-2.0. | 6.1 MB | : 100% 1.0/1 [00:06<00:00, 6.18s/it]\n", - "pytorch-geometric-1. | 445 KB | : 100% 1.0/1 [00:02<00:00, 2.53s/it]\n", - "scipy-1.7.3 | 21.8 MB | : 100% 1.0/1 [00:03<00:00, 3.06s/it]\n", - "python-dateutil-2.8. | 240 KB | : 100% 1.0/1 [00:00<00:00, 21.48it/s]\n", - "pytorch-spline-conv- | 736 KB | : 100% 1.0/1 [00:01<00:00, 1.00s/it]\n", - "pytorch-sparse-0.6.1 | 2.9 MB | : 100% 1.0/1 [00:07<00:00, 7.51s/it]\n", - "pyparsing-3.0.9 | 79 KB | : 100% 1.0/1 [00:00<00:00, 26.32it/s]\n", - "pytorch-cluster-1.5. | 1.2 MB | : 100% 1.0/1 [00:02<00:00, 2.78s/it]\n", - "jinja2-3.1.2 | 99 KB | : 100% 1.0/1 [00:00<00:00, 20.28it/s]\n", - "decorator-4.4.2 | 11 KB | : 100% 1.0/1 [00:00<00:00, 21.57it/s]\n", - "joblib-1.2.0 | 205 KB | : 100% 1.0/1 [00:00<00:00, 15.04it/s]\n", - "pytz-2022.4 | 232 KB | : 100% 1.0/1 [00:00<00:00, 10.21it/s]\n", - "python-louvain-0.15 | 13 KB | : 100% 1.0/1 [00:00<00:00, 3.34it/s]\n", - "googledrivedownloade | 7 KB | : 100% 1.0/1 [00:00<00:00, 3.33it/s]\n", - "threadpoolctl-3.1.0 | 18 KB | : 100% 1.0/1 [00:00<00:00, 29.40it/s]\n", - "markupsafe-2.1.1 | 22 KB | : 100% 1.0/1 [00:00<00:00, 28.62it/s]\n", - "pandas-1.2.3 | 11.8 MB | : 100% 1.0/1 [00:02<00:00, 2.08s/it] \n", - "networkx-2.5.1 | 1.2 MB | : 100% 1.0/1 [00:01<00:00, 1.39s/it]\n", - "setuptools-59.8.0 | 1.0 MB | : 100% 1.0/1 [00:00<00:00, 4.25it/s]\n", - "Preparing transaction: / \b\b- \b\b\\ \b\bdone\n", - "Verifying transaction: / \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n", - "Executing transaction: / \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\bdone\n", - "Retrieving notices: ...working... done\n" - ] - } - ], - "source": [ - "!conda install -c rusty1s pytorch-geometric=1.7.2" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ppxv6Mdkalbc" - }, - "source": [ - "### Install Diffusers" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "mgQA_XN-XGY2", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "85392615-b6a4-4052-9d2a-79604be62c94" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "/content\n", - "Cloning into 'diffusers'...\n", - "remote: Enumerating objects: 9298, done.\u001b[K\n", - "remote: Counting objects: 100% (40/40), done.\u001b[K\n", - "remote: Compressing objects: 100% (23/23), done.\u001b[K\n", - "remote: Total 9298 (delta 17), reused 23 (delta 11), pack-reused 9258\u001b[K\n", - "Receiving objects: 100% (9298/9298), 7.38 MiB | 5.28 MiB/s, done.\n", - "Resolving deltas: 100% (6168/6168), done.\n", - " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", - " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", - " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m757.0/757.0 kB\u001b[0m \u001b[31m52.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m163.5/163.5 kB\u001b[0m \u001b[31m21.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.8/40.8 kB\u001b[0m \u001b[31m5.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m596.3/596.3 kB\u001b[0m \u001b[31m51.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25h Building wheel for diffusers (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m432.7/432.7 kB\u001b[0m \u001b[31m36.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.3/5.3 MB\u001b[0m \u001b[31m90.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m35.3/35.3 MB\u001b[0m \u001b[31m39.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m115.1/115.1 kB\u001b[0m \u001b[31m16.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m948.0/948.0 kB\u001b[0m \u001b[31m63.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m212.2/212.2 kB\u001b[0m \u001b[31m21.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m95.8/95.8 kB\u001b[0m \u001b[31m12.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m140.8/140.8 kB\u001b[0m \u001b[31m18.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.6/7.6 MB\u001b[0m \u001b[31m104.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m148.0/148.0 kB\u001b[0m \u001b[31m20.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m231.3/231.3 kB\u001b[0m \u001b[31m30.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m94.8/94.8 kB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.8/58.8 kB\u001b[0m \u001b[31m8.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25h\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", - "\u001b[0m" - ] - } - ], - "source": [ - "%cd /content\n", - "\n", - "# install latest HF diffusers (will update to the release once added)\n", - "!git clone https://github.com/huggingface/diffusers.git\n", - "!pip install -q /content/diffusers\n", - "\n", - "# dependencies for diffusers\n", - "!pip install -q datasets transformers" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "LZO6AJKuJKO8" - }, - "source": [ - "Check that torch is installed correctly and utilizing the GPU in the colab" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "gZt7BNi1e1PA", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 53 - }, - "outputId": "a0e1832c-9c02-49aa-cff8-1339e6cdc889" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "True\n" - ] - }, - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "'1.8.2'" - ], - "application/vnd.google.colaboratory.intrinsic+json": { - "type": "string" - } - }, - "metadata": {}, - "execution_count": 8 - } - ], - "source": [ - "import torch\n", - "print(torch.cuda.is_available())\n", - "torch.__version__" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "KLE7CqlfJNUO" - }, - "source": [ - "### Install Chemistry-specific Dependencies\n", - "\n", - "Install RDKit, a tool for working with and visualizing chemsitry in python (you use this to visualize the generate models later)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "0CPv_NvehRz3", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "6ee0ae4e-4511-4816-de29-22b1c21d49bc" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", - "Collecting rdkit\n", - " Downloading rdkit-2022.3.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (36.8 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m36.8/36.8 MB\u001b[0m \u001b[31m34.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: Pillow in /usr/local/lib/python3.7/site-packages (from rdkit) (9.2.0)\n", - "Requirement already satisfied: numpy in /usr/local/lib/python3.7/site-packages (from rdkit) (1.21.6)\n", - "Installing collected packages: rdkit\n", - "Successfully installed rdkit-2022.3.5\n", - "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", - "\u001b[0m" - ] - } - ], - "source": [ - "!pip install rdkit" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "88GaDbDPxJ5I" - }, - "source": [ - "### Get viewer from nglview\n", - "\n", - "The model you will use outputs a position matrix tensor. This pytorch geometric data object will have many features (positions, known features, edge features -- all tensors).\n", - "The data we give to the model will also have a rdmol object (which can extract features to geometric if needed).\n", - "The rdmol in this object is a source of ground truth for the generated molecules.\n", - "\n", - "You will use one rendering function from nglviewer later!\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "jcl8GCS2mz6t", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 1000 - }, - "outputId": "99b5cc40-67bb-4d8e-faa0-47d7cb33e98f" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", - "Collecting nglview\n", - " Downloading nglview-3.0.3.tar.gz (5.7 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.7/5.7 MB\u001b[0m \u001b[31m91.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", - " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", - " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - "Requirement already satisfied: numpy in /usr/local/lib/python3.7/site-packages (from nglview) (1.21.6)\n", - "Collecting jupyterlab-widgets\n", - " Downloading jupyterlab_widgets-3.0.3-py3-none-any.whl (384 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m384.1/384.1 kB\u001b[0m \u001b[31m40.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting ipywidgets>=7\n", - " Downloading ipywidgets-8.0.2-py3-none-any.whl (134 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.4/134.4 kB\u001b[0m \u001b[31m21.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting widgetsnbextension~=4.0\n", - " Downloading widgetsnbextension-4.0.3-py3-none-any.whl (2.0 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.0/2.0 MB\u001b[0m \u001b[31m84.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting ipython>=6.1.0\n", - " Downloading ipython-7.34.0-py3-none-any.whl (793 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m793.8/793.8 kB\u001b[0m \u001b[31m60.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting ipykernel>=4.5.1\n", - " Downloading ipykernel-6.16.0-py3-none-any.whl (138 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m138.4/138.4 kB\u001b[0m \u001b[31m20.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting traitlets>=4.3.1\n", - " Downloading traitlets-5.4.0-py3-none-any.whl (107 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m107.1/107.1 kB\u001b[0m \u001b[31m17.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.7/site-packages (from ipykernel>=4.5.1->ipywidgets>=7->nglview) (21.3)\n", - "Collecting pyzmq>=17\n", - " Downloading pyzmq-24.0.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.1 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m68.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting matplotlib-inline>=0.1\n", - " Downloading matplotlib_inline-0.1.6-py3-none-any.whl (9.4 kB)\n", - "Collecting tornado>=6.1\n", - " Downloading tornado-6.2-cp37-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (423 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m424.0/424.0 kB\u001b[0m \u001b[31m41.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting nest-asyncio\n", - " Downloading nest_asyncio-1.5.6-py3-none-any.whl (5.2 kB)\n", - "Collecting debugpy>=1.0\n", - " Downloading debugpy-1.6.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.8 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.8/1.8 MB\u001b[0m \u001b[31m83.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting psutil\n", - " Downloading psutil-5.9.2-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (281 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m281.3/281.3 kB\u001b[0m \u001b[31m33.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting jupyter-client>=6.1.12\n", - " Downloading jupyter_client-7.4.2-py3-none-any.whl (132 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m132.2/132.2 kB\u001b[0m \u001b[31m19.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting pickleshare\n", - " Downloading pickleshare-0.7.5-py2.py3-none-any.whl (6.9 kB)\n", - "Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.7/site-packages (from ipython>=6.1.0->ipywidgets>=7->nglview) (59.8.0)\n", - "Collecting backcall\n", - " Downloading backcall-0.2.0-py2.py3-none-any.whl (11 kB)\n", - "Collecting pexpect>4.3\n", - " Downloading pexpect-4.8.0-py2.py3-none-any.whl (59 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m59.0/59.0 kB\u001b[0m \u001b[31m7.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting pygments\n", - " Downloading Pygments-2.13.0-py3-none-any.whl (1.1 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m70.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting jedi>=0.16\n", - " Downloading jedi-0.18.1-py2.py3-none-any.whl (1.6 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m83.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0\n", - " Downloading prompt_toolkit-3.0.31-py3-none-any.whl (382 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m382.3/382.3 kB\u001b[0m \u001b[31m40.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: decorator in /usr/local/lib/python3.7/site-packages (from ipython>=6.1.0->ipywidgets>=7->nglview) (4.4.2)\n", - "Collecting parso<0.9.0,>=0.8.0\n", - " Downloading parso-0.8.3-py2.py3-none-any.whl (100 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m100.8/100.8 kB\u001b[0m \u001b[31m14.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.7/site-packages (from jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets>=7->nglview) (2.8.2)\n", - "Collecting entrypoints\n", - " Downloading entrypoints-0.4-py3-none-any.whl (5.3 kB)\n", - "Collecting jupyter-core>=4.9.2\n", - " Downloading jupyter_core-4.11.1-py3-none-any.whl (88 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m88.4/88.4 kB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting ptyprocess>=0.5\n", - " Downloading ptyprocess-0.7.0-py2.py3-none-any.whl (13 kB)\n", - "Collecting wcwidth\n", - " Downloading wcwidth-0.2.5-py2.py3-none-any.whl (30 kB)\n", - "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/site-packages (from packaging->ipykernel>=4.5.1->ipywidgets>=7->nglview) (3.0.9)\n", - "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/site-packages (from python-dateutil>=2.8.2->jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets>=7->nglview) (1.16.0)\n", - "Building wheels for collected packages: nglview\n", - " Building wheel for nglview (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - " Created wheel for nglview: filename=nglview-3.0.3-py3-none-any.whl size=8057538 sha256=b7e1071bb91822e48515bf27f4e6b197c6e85e06b90912b3439edc8be1e29514\n", - " Stored in directory: /root/.cache/pip/wheels/01/0c/49/c6f79d8edba8fe89752bf20de2d99040bfa57db0548975c5d5\n", - "Successfully built nglview\n", - "Installing collected packages: wcwidth, ptyprocess, pickleshare, backcall, widgetsnbextension, traitlets, tornado, pyzmq, pygments, psutil, prompt-toolkit, pexpect, parso, nest-asyncio, jupyterlab-widgets, entrypoints, debugpy, matplotlib-inline, jupyter-core, jedi, jupyter-client, ipython, ipykernel, ipywidgets, nglview\n", - "Successfully installed backcall-0.2.0 debugpy-1.6.3 entrypoints-0.4 ipykernel-6.16.0 ipython-7.34.0 ipywidgets-8.0.2 jedi-0.18.1 jupyter-client-7.4.2 jupyter-core-4.11.1 jupyterlab-widgets-3.0.3 matplotlib-inline-0.1.6 nest-asyncio-1.5.6 nglview-3.0.3 parso-0.8.3 pexpect-4.8.0 pickleshare-0.7.5 prompt-toolkit-3.0.31 psutil-5.9.2 ptyprocess-0.7.0 pygments-2.13.0 pyzmq-24.0.1 tornado-6.2 traitlets-5.4.0 wcwidth-0.2.5 widgetsnbextension-4.0.3\n", - "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", - "\u001b[0m" - ] - }, - { - "output_type": "display_data", - "data": { - "application/vnd.colab-display-data+json": { - "pip_warning": { - "packages": [ - "pexpect", - "pickleshare", - "wcwidth" - ] - } - } - }, - "metadata": {} - } - ], - "source": [ - "!pip install nglview" - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Create a diffusion model" - ], - "metadata": { - "id": "8t8_e_uVLdKB" - } - }, - { - "cell_type": "markdown", - "source": [ - "### Model class(es)" - ], - "metadata": { - "id": "G0rMncVtNSqU" - } - }, - { - "cell_type": "markdown", - "source": [ - "Imports" - ], - "metadata": { - "id": "L5FEXz5oXkzt" - } - }, - { - "cell_type": "code", - "source": [ - "# Model adapted from GeoDiff https://github.com/MinkaiXu/GeoDiff\n", - "# Model inspired by https://github.com/DeepGraphLearning/torchdrug/tree/master/torchdrug/models\n", - "from dataclasses import dataclass\n", - "from typing import Callable, Tuple, Union\n", - "\n", - "import numpy as np\n", - "import torch\n", - "import torch.nn.functional as F\n", - "from torch import Tensor, nn\n", - "from torch.nn import Embedding, Linear, Module, ModuleList, Sequential\n", - "\n", - "from torch_geometric.nn import MessagePassing, radius, radius_graph\n", - "from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size\n", - "from torch_geometric.utils import dense_to_sparse, to_dense_adj\n", - "from torch_scatter import scatter_add\n", - "from torch_sparse import SparseTensor, coalesce\n", - "\n", - "from diffusers.configuration_utils import ConfigMixin, register_to_config\n", - "from diffusers.modeling_utils import ModelMixin\n", - "from diffusers.utils import BaseOutput\n" - ], - "metadata": { - "id": "-3-P4w5sXkRU" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "Helper classes" - ], - "metadata": { - "id": "EzJQXPN_XrMX" - } - }, - { - "cell_type": "code", - "source": [ - "@dataclass\n", - "class MoleculeGNNOutput(BaseOutput):\n", - " \"\"\"\n", - " Args:\n", - " sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):\n", - " Hidden states output. Output of last layer of model.\n", - " \"\"\"\n", - "\n", - " sample: torch.Tensor\n", - "\n", - "\n", - "class MultiLayerPerceptron(nn.Module):\n", - " \"\"\"\n", - " Multi-layer Perceptron. Note there is no activation or dropout in the last layer.\n", - " Args:\n", - " input_dim (int): input dimension\n", - " hidden_dim (list of int): hidden dimensions\n", - " activation (str or function, optional): activation function\n", - " dropout (float, optional): dropout rate\n", - " \"\"\"\n", - "\n", - " def __init__(self, input_dim, hidden_dims, activation=\"relu\", dropout=0):\n", - " super(MultiLayerPerceptron, self).__init__()\n", - "\n", - " self.dims = [input_dim] + hidden_dims\n", - " if isinstance(activation, str):\n", - " self.activation = getattr(F, activation)\n", - " else:\n", - " print(f\"Warning, activation passed {activation} is not string and ignored\")\n", - " self.activation = None\n", - " if dropout > 0:\n", - " self.dropout = nn.Dropout(dropout)\n", - " else:\n", - " self.dropout = None\n", - "\n", - " self.layers = nn.ModuleList()\n", - " for i in range(len(self.dims) - 1):\n", - " self.layers.append(nn.Linear(self.dims[i], self.dims[i + 1]))\n", - "\n", - " def forward(self, x):\n", - " \"\"\"\"\"\"\n", - " for i, layer in enumerate(self.layers):\n", - " x = layer(x)\n", - " if i < len(self.layers) - 1:\n", - " if self.activation:\n", - " x = self.activation(x)\n", - " if self.dropout:\n", - " x = self.dropout(x)\n", - " return x\n", - "\n", - "\n", - "class ShiftedSoftplus(torch.nn.Module):\n", - " def __init__(self):\n", - " super(ShiftedSoftplus, self).__init__()\n", - " self.shift = torch.log(torch.tensor(2.0)).item()\n", - "\n", - " def forward(self, x):\n", - " return F.softplus(x) - self.shift\n", - "\n", - "\n", - "class CFConv(MessagePassing):\n", - " def __init__(self, in_channels, out_channels, num_filters, mlp, cutoff, smooth):\n", - " super(CFConv, self).__init__(aggr=\"add\")\n", - " self.lin1 = Linear(in_channels, num_filters, bias=False)\n", - " self.lin2 = Linear(num_filters, out_channels)\n", - " self.nn = mlp\n", - " self.cutoff = cutoff\n", - " self.smooth = smooth\n", - "\n", - " self.reset_parameters()\n", - "\n", - " def reset_parameters(self):\n", - " torch.nn.init.xavier_uniform_(self.lin1.weight)\n", - " torch.nn.init.xavier_uniform_(self.lin2.weight)\n", - " self.lin2.bias.data.fill_(0)\n", - "\n", - " def forward(self, x, edge_index, edge_length, edge_attr):\n", - " if self.smooth:\n", - " C = 0.5 * (torch.cos(edge_length * np.pi / self.cutoff) + 1.0)\n", - " C = C * (edge_length <= self.cutoff) * (edge_length >= 0.0) # Modification: cutoff\n", - " else:\n", - " C = (edge_length <= self.cutoff).float()\n", - " W = self.nn(edge_attr) * C.view(-1, 1)\n", - "\n", - " x = self.lin1(x)\n", - " x = self.propagate(edge_index, x=x, W=W)\n", - " x = self.lin2(x)\n", - " return x\n", - "\n", - " def message(self, x_j: torch.Tensor, W) -> torch.Tensor:\n", - " return x_j * W\n", - "\n", - "\n", - "class InteractionBlock(torch.nn.Module):\n", - " def __init__(self, hidden_channels, num_gaussians, num_filters, cutoff, smooth):\n", - " super(InteractionBlock, self).__init__()\n", - " mlp = Sequential(\n", - " Linear(num_gaussians, num_filters),\n", - " ShiftedSoftplus(),\n", - " Linear(num_filters, num_filters),\n", - " )\n", - " self.conv = CFConv(hidden_channels, hidden_channels, num_filters, mlp, cutoff, smooth)\n", - " self.act = ShiftedSoftplus()\n", - " self.lin = Linear(hidden_channels, hidden_channels)\n", - "\n", - " def forward(self, x, edge_index, edge_length, edge_attr):\n", - " x = self.conv(x, edge_index, edge_length, edge_attr)\n", - " x = self.act(x)\n", - " x = self.lin(x)\n", - " return x\n", - "\n", - "\n", - "class SchNetEncoder(Module):\n", - " def __init__(\n", - " self, hidden_channels=128, num_filters=128, num_interactions=6, edge_channels=100, cutoff=10.0, smooth=False\n", - " ):\n", - " super().__init__()\n", - "\n", - " self.hidden_channels = hidden_channels\n", - " self.num_filters = num_filters\n", - " self.num_interactions = num_interactions\n", - " self.cutoff = cutoff\n", - "\n", - " self.embedding = Embedding(100, hidden_channels, max_norm=10.0)\n", - "\n", - " self.interactions = ModuleList()\n", - " for _ in range(num_interactions):\n", - " block = InteractionBlock(hidden_channels, edge_channels, num_filters, cutoff, smooth)\n", - " self.interactions.append(block)\n", - "\n", - " def forward(self, z, edge_index, edge_length, edge_attr, embed_node=True):\n", - " if embed_node:\n", - " assert z.dim() == 1 and z.dtype == torch.long\n", - " h = self.embedding(z)\n", - " else:\n", - " h = z\n", - " for interaction in self.interactions:\n", - " h = h + interaction(h, edge_index, edge_length, edge_attr)\n", - "\n", - " return h\n", - "\n", - "\n", - "class GINEConv(MessagePassing):\n", - " \"\"\"\n", - " Custom class of the graph isomorphism operator from the \"How Powerful are Graph Neural Networks?\n", - " https://arxiv.org/abs/1810.00826 paper. Note that this implementation has the added option of a custom activation.\n", - " \"\"\"\n", - "\n", - " def __init__(self, mlp: Callable, eps: float = 0.0, train_eps: bool = False, activation=\"softplus\", **kwargs):\n", - " super(GINEConv, self).__init__(aggr=\"add\", **kwargs)\n", - " self.nn = mlp\n", - " self.initial_eps = eps\n", - "\n", - " if isinstance(activation, str):\n", - " self.activation = getattr(F, activation)\n", - " else:\n", - " self.activation = None\n", - "\n", - " if train_eps:\n", - " self.eps = torch.nn.Parameter(torch.Tensor([eps]))\n", - " else:\n", - " self.register_buffer(\"eps\", torch.Tensor([eps]))\n", - "\n", - " def forward(\n", - " self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_attr: OptTensor = None, size: Size = None\n", - " ) -> torch.Tensor:\n", - " \"\"\"\"\"\"\n", - " if isinstance(x, torch.Tensor):\n", - " x: OptPairTensor = (x, x)\n", - "\n", - " # Node and edge feature dimensionalites need to match.\n", - " if isinstance(edge_index, torch.Tensor):\n", - " assert edge_attr is not None\n", - " assert x[0].size(-1) == edge_attr.size(-1)\n", - " elif isinstance(edge_index, SparseTensor):\n", - " assert x[0].size(-1) == edge_index.size(-1)\n", - "\n", - " # propagate_type: (x: OptPairTensor, edge_attr: OptTensor)\n", - " out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)\n", - "\n", - " x_r = x[1]\n", - " if x_r is not None:\n", - " out += (1 + self.eps) * x_r\n", - "\n", - " return self.nn(out)\n", - "\n", - " def message(self, x_j: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor:\n", - " if self.activation:\n", - " return self.activation(x_j + edge_attr)\n", - " else:\n", - " return x_j + edge_attr\n", - "\n", - " def __repr__(self):\n", - " return \"{}(nn={})\".format(self.__class__.__name__, self.nn)\n", - "\n", - "\n", - "class GINEncoder(torch.nn.Module):\n", - " def __init__(self, hidden_dim, num_convs=3, activation=\"relu\", short_cut=True, concat_hidden=False):\n", - " super().__init__()\n", - "\n", - " self.hidden_dim = hidden_dim\n", - " self.num_convs = num_convs\n", - " self.short_cut = short_cut\n", - " self.concat_hidden = concat_hidden\n", - " self.node_emb = nn.Embedding(100, hidden_dim)\n", - "\n", - " if isinstance(activation, str):\n", - " self.activation = getattr(F, activation)\n", - " else:\n", - " self.activation = None\n", - "\n", - " self.convs = nn.ModuleList()\n", - " for i in range(self.num_convs):\n", - " self.convs.append(\n", - " GINEConv(\n", - " MultiLayerPerceptron(hidden_dim, [hidden_dim, hidden_dim], activation=activation),\n", - " activation=activation,\n", - " )\n", - " )\n", - "\n", - " def forward(self, z, edge_index, edge_attr):\n", - " \"\"\"\n", - " Input:\n", - " data: (torch_geometric.data.Data): batched graph edge_index: bond indices of the original graph (num_node,\n", - " hidden) edge_attr: edge feature tensor with shape (num_edge, hidden)\n", - " Output:\n", - " node_feature: graph feature\n", - " \"\"\"\n", - "\n", - " node_attr = self.node_emb(z) # (num_node, hidden)\n", - "\n", - " hiddens = []\n", - " conv_input = node_attr # (num_node, hidden)\n", - "\n", - " for conv_idx, conv in enumerate(self.convs):\n", - " hidden = conv(conv_input, edge_index, edge_attr)\n", - " if conv_idx < len(self.convs) - 1 and self.activation is not None:\n", - " hidden = self.activation(hidden)\n", - " assert hidden.shape == conv_input.shape\n", - " if self.short_cut and hidden.shape == conv_input.shape:\n", - " hidden += conv_input\n", - "\n", - " hiddens.append(hidden)\n", - " conv_input = hidden\n", - "\n", - " if self.concat_hidden:\n", - " node_feature = torch.cat(hiddens, dim=-1)\n", - " else:\n", - " node_feature = hiddens[-1]\n", - "\n", - " return node_feature\n", - "\n", - "\n", - "class MLPEdgeEncoder(Module):\n", - " def __init__(self, hidden_dim=100, activation=\"relu\"):\n", - " super().__init__()\n", - " self.hidden_dim = hidden_dim\n", - " self.bond_emb = Embedding(100, embedding_dim=self.hidden_dim)\n", - " self.mlp = MultiLayerPerceptron(1, [self.hidden_dim, self.hidden_dim], activation=activation)\n", - "\n", - " @property\n", - " def out_channels(self):\n", - " return self.hidden_dim\n", - "\n", - " def forward(self, edge_length, edge_type):\n", - " \"\"\"\n", - " Input:\n", - " edge_length: The length of edges, shape=(E, 1). edge_type: The type pf edges, shape=(E,)\n", - " Returns:\n", - " edge_attr: The representation of edges. (E, 2 * num_gaussians)\n", - " \"\"\"\n", - " d_emb = self.mlp(edge_length) # (num_edge, hidden_dim)\n", - " edge_attr = self.bond_emb(edge_type) # (num_edge, hidden_dim)\n", - " return d_emb * edge_attr # (num_edge, hidden)\n", - "\n", - "\n", - "def assemble_atom_pair_feature(node_attr, edge_index, edge_attr):\n", - " h_row, h_col = node_attr[edge_index[0]], node_attr[edge_index[1]]\n", - " h_pair = torch.cat([h_row * h_col, edge_attr], dim=-1) # (E, 2H)\n", - " return h_pair\n", - "\n", - "\n", - "def _extend_graph_order(num_nodes, edge_index, edge_type, order=3):\n", - " \"\"\"\n", - " Args:\n", - " num_nodes: Number of atoms.\n", - " edge_index: Bond indices of the original graph.\n", - " edge_type: Bond types of the original graph.\n", - " order: Extension order.\n", - " Returns:\n", - " new_edge_index: Extended edge indices. new_edge_type: Extended edge types.\n", - " \"\"\"\n", - "\n", - " def binarize(x):\n", - " return torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x))\n", - "\n", - " def get_higher_order_adj_matrix(adj, order):\n", - " \"\"\"\n", - " Args:\n", - " adj: (N, N)\n", - " type_mat: (N, N)\n", - " Returns:\n", - " Following attributes will be updated:\n", - " - edge_index\n", - " - edge_type\n", - " Following attributes will be added to the data object:\n", - " - bond_edge_index: Original edge_index.\n", - " \"\"\"\n", - " adj_mats = [\n", - " torch.eye(adj.size(0), dtype=torch.long, device=adj.device),\n", - " binarize(adj + torch.eye(adj.size(0), dtype=torch.long, device=adj.device)),\n", - " ]\n", - "\n", - " for i in range(2, order + 1):\n", - " adj_mats.append(binarize(adj_mats[i - 1] @ adj_mats[1]))\n", - " order_mat = torch.zeros_like(adj)\n", - "\n", - " for i in range(1, order + 1):\n", - " order_mat += (adj_mats[i] - adj_mats[i - 1]) * i\n", - "\n", - " return order_mat\n", - "\n", - " num_types = 22\n", - " # given from len(BOND_TYPES), where BOND_TYPES = {t: i for i, t in enumerate(BT.names.values())}\n", - " # from rdkit.Chem.rdchem import BondType as BT\n", - " N = num_nodes\n", - " adj = to_dense_adj(edge_index).squeeze(0)\n", - " adj_order = get_higher_order_adj_matrix(adj, order) # (N, N)\n", - "\n", - " type_mat = to_dense_adj(edge_index, edge_attr=edge_type).squeeze(0) # (N, N)\n", - " type_highorder = torch.where(adj_order > 1, num_types + adj_order - 1, torch.zeros_like(adj_order))\n", - " assert (type_mat * type_highorder == 0).all()\n", - " type_new = type_mat + type_highorder\n", - "\n", - " new_edge_index, new_edge_type = dense_to_sparse(type_new)\n", - " _, edge_order = dense_to_sparse(adj_order)\n", - "\n", - " # data.bond_edge_index = data.edge_index # Save original edges\n", - " new_edge_index, new_edge_type = coalesce(new_edge_index, new_edge_type.long(), N, N) # modify data\n", - "\n", - " return new_edge_index, new_edge_type\n", - "\n", - "\n", - "def _extend_to_radius_graph(pos, edge_index, edge_type, cutoff, batch, unspecified_type_number=0, is_sidechain=None):\n", - " assert edge_type.dim() == 1\n", - " N = pos.size(0)\n", - "\n", - " bgraph_adj = torch.sparse.LongTensor(edge_index, edge_type, torch.Size([N, N]))\n", - "\n", - " if is_sidechain is None:\n", - " rgraph_edge_index = radius_graph(pos, r=cutoff, batch=batch) # (2, E_r)\n", - " else:\n", - " # fetch sidechain and its batch index\n", - " is_sidechain = is_sidechain.bool()\n", - " dummy_index = torch.arange(pos.size(0), device=pos.device)\n", - " sidechain_pos = pos[is_sidechain]\n", - " sidechain_index = dummy_index[is_sidechain]\n", - " sidechain_batch = batch[is_sidechain]\n", - "\n", - " assign_index = radius(x=pos, y=sidechain_pos, r=cutoff, batch_x=batch, batch_y=sidechain_batch)\n", - " r_edge_index_x = assign_index[1]\n", - " r_edge_index_y = assign_index[0]\n", - " r_edge_index_y = sidechain_index[r_edge_index_y]\n", - "\n", - " rgraph_edge_index1 = torch.stack((r_edge_index_x, r_edge_index_y)) # (2, E)\n", - " rgraph_edge_index2 = torch.stack((r_edge_index_y, r_edge_index_x)) # (2, E)\n", - " rgraph_edge_index = torch.cat((rgraph_edge_index1, rgraph_edge_index2), dim=-1) # (2, 2E)\n", - " # delete self loop\n", - " rgraph_edge_index = rgraph_edge_index[:, (rgraph_edge_index[0] != rgraph_edge_index[1])]\n", - "\n", - " rgraph_adj = torch.sparse.LongTensor(\n", - " rgraph_edge_index,\n", - " torch.ones(rgraph_edge_index.size(1)).long().to(pos.device) * unspecified_type_number,\n", - " torch.Size([N, N]),\n", - " )\n", - "\n", - " composed_adj = (bgraph_adj + rgraph_adj).coalesce() # Sparse (N, N, T)\n", - "\n", - " new_edge_index = composed_adj.indices()\n", - " new_edge_type = composed_adj.values().long()\n", - "\n", - " return new_edge_index, new_edge_type\n", - "\n", - "\n", - "def extend_graph_order_radius(\n", - " num_nodes,\n", - " pos,\n", - " edge_index,\n", - " edge_type,\n", - " batch,\n", - " order=3,\n", - " cutoff=10.0,\n", - " extend_order=True,\n", - " extend_radius=True,\n", - " is_sidechain=None,\n", - "):\n", - " if extend_order:\n", - " edge_index, edge_type = _extend_graph_order(\n", - " num_nodes=num_nodes, edge_index=edge_index, edge_type=edge_type, order=order\n", - " )\n", - "\n", - " if extend_radius:\n", - " edge_index, edge_type = _extend_to_radius_graph(\n", - " pos=pos, edge_index=edge_index, edge_type=edge_type, cutoff=cutoff, batch=batch, is_sidechain=is_sidechain\n", - " )\n", - "\n", - " return edge_index, edge_type\n", - "\n", - "\n", - "def get_distance(pos, edge_index):\n", - " return (pos[edge_index[0]] - pos[edge_index[1]]).norm(dim=-1)\n", - "\n", - "\n", - "def graph_field_network(score_d, pos, edge_index, edge_length):\n", - " \"\"\"\n", - " Transformation to make the epsilon predicted from the diffusion model roto-translational equivariant. See equations\n", - " 5-7 of the GeoDiff Paper https://arxiv.org/pdf/2203.02923.pdf\n", - " \"\"\"\n", - " N = pos.size(0)\n", - " dd_dr = (1.0 / edge_length) * (pos[edge_index[0]] - pos[edge_index[1]]) # (E, 3)\n", - " score_pos = scatter_add(dd_dr * score_d, edge_index[0], dim=0, dim_size=N) + scatter_add(\n", - " -dd_dr * score_d, edge_index[1], dim=0, dim_size=N\n", - " ) # (N, 3)\n", - " return score_pos\n", - "\n", - "\n", - "def clip_norm(vec, limit, p=2):\n", - " norm = torch.norm(vec, dim=-1, p=2, keepdim=True)\n", - " denom = torch.where(norm > limit, limit / norm, torch.ones_like(norm))\n", - " return vec * denom\n", - "\n", - "\n", - "def is_local_edge(edge_type):\n", - " return edge_type > 0\n" - ], - "metadata": { - "id": "oR1Y56QiLY90" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "Main model class!" - ], - "metadata": { - "id": "QWrHJFcYXyUB" - } - }, - { - "cell_type": "code", - "source": [ - "class MoleculeGNN(ModelMixin, ConfigMixin):\n", - " @register_to_config\n", - " def __init__(\n", - " self,\n", - " hidden_dim=128,\n", - " num_convs=6,\n", - " num_convs_local=4,\n", - " cutoff=10.0,\n", - " mlp_act=\"relu\",\n", - " edge_order=3,\n", - " edge_encoder=\"mlp\",\n", - " smooth_conv=True,\n", - " ):\n", - " super().__init__()\n", - " self.cutoff = cutoff\n", - " self.edge_encoder = edge_encoder\n", - " self.edge_order = edge_order\n", - "\n", - " \"\"\"\n", - " edge_encoder: Takes both edge type and edge length as input and outputs a vector [Note]: node embedding is done\n", - " in SchNetEncoder\n", - " \"\"\"\n", - " self.edge_encoder_global = MLPEdgeEncoder(hidden_dim, mlp_act) # get_edge_encoder(config)\n", - " self.edge_encoder_local = MLPEdgeEncoder(hidden_dim, mlp_act) # get_edge_encoder(config)\n", - "\n", - " \"\"\"\n", - " The graph neural network that extracts node-wise features.\n", - " \"\"\"\n", - " self.encoder_global = SchNetEncoder(\n", - " hidden_channels=hidden_dim,\n", - " num_filters=hidden_dim,\n", - " num_interactions=num_convs,\n", - " edge_channels=self.edge_encoder_global.out_channels,\n", - " cutoff=cutoff,\n", - " smooth=smooth_conv,\n", - " )\n", - " self.encoder_local = GINEncoder(\n", - " hidden_dim=hidden_dim,\n", - " num_convs=num_convs_local,\n", - " )\n", - "\n", - " \"\"\"\n", - " `output_mlp` takes a mixture of two nodewise features and edge features as input and outputs\n", - " gradients w.r.t. edge_length (out_dim = 1).\n", - " \"\"\"\n", - " self.grad_global_dist_mlp = MultiLayerPerceptron(\n", - " 2 * hidden_dim, [hidden_dim, hidden_dim // 2, 1], activation=mlp_act\n", - " )\n", - "\n", - " self.grad_local_dist_mlp = MultiLayerPerceptron(\n", - " 2 * hidden_dim, [hidden_dim, hidden_dim // 2, 1], activation=mlp_act\n", - " )\n", - "\n", - " \"\"\"\n", - " Incorporate parameters together\n", - " \"\"\"\n", - " self.model_global = nn.ModuleList([self.edge_encoder_global, self.encoder_global, self.grad_global_dist_mlp])\n", - " self.model_local = nn.ModuleList([self.edge_encoder_local, self.encoder_local, self.grad_local_dist_mlp])\n", - "\n", - " def _forward(\n", - " self,\n", - " atom_type,\n", - " pos,\n", - " bond_index,\n", - " bond_type,\n", - " batch,\n", - " time_step, # NOTE, model trained without timestep performed best\n", - " edge_index=None,\n", - " edge_type=None,\n", - " edge_length=None,\n", - " return_edges=False,\n", - " extend_order=True,\n", - " extend_radius=True,\n", - " is_sidechain=None,\n", - " ):\n", - " \"\"\"\n", - " Args:\n", - " atom_type: Types of atoms, (N, ).\n", - " bond_index: Indices of bonds (not extended, not radius-graph), (2, E).\n", - " bond_type: Bond types, (E, ).\n", - " batch: Node index to graph index, (N, ).\n", - " \"\"\"\n", - " N = atom_type.size(0)\n", - " if edge_index is None or edge_type is None or edge_length is None:\n", - " edge_index, edge_type = extend_graph_order_radius(\n", - " num_nodes=N,\n", - " pos=pos,\n", - " edge_index=bond_index,\n", - " edge_type=bond_type,\n", - " batch=batch,\n", - " order=self.edge_order,\n", - " cutoff=self.cutoff,\n", - " extend_order=extend_order,\n", - " extend_radius=extend_radius,\n", - " is_sidechain=is_sidechain,\n", - " )\n", - " edge_length = get_distance(pos, edge_index).unsqueeze(-1) # (E, 1)\n", - " local_edge_mask = is_local_edge(edge_type) # (E, )\n", - "\n", - " # with the parameterization of NCSNv2\n", - " # DDPM loss implicit handle the noise variance scale conditioning\n", - " sigma_edge = torch.ones(size=(edge_index.size(1), 1), device=pos.device) # (E, 1)\n", - "\n", - " # Encoding global\n", - " edge_attr_global = self.edge_encoder_global(edge_length=edge_length, edge_type=edge_type) # Embed edges\n", - "\n", - " # Global\n", - " node_attr_global = self.encoder_global(\n", - " z=atom_type,\n", - " edge_index=edge_index,\n", - " edge_length=edge_length,\n", - " edge_attr=edge_attr_global,\n", - " )\n", - " # Assemble pairwise features\n", - " h_pair_global = assemble_atom_pair_feature(\n", - " node_attr=node_attr_global,\n", - " edge_index=edge_index,\n", - " edge_attr=edge_attr_global,\n", - " ) # (E_global, 2H)\n", - " # Invariant features of edges (radius graph, global)\n", - " edge_inv_global = self.grad_global_dist_mlp(h_pair_global) * (1.0 / sigma_edge) # (E_global, 1)\n", - "\n", - " # Encoding local\n", - " edge_attr_local = self.edge_encoder_global(edge_length=edge_length, edge_type=edge_type) # Embed edges\n", - " # edge_attr += temb_edge\n", - "\n", - " # Local\n", - " node_attr_local = self.encoder_local(\n", - " z=atom_type,\n", - " edge_index=edge_index[:, local_edge_mask],\n", - " edge_attr=edge_attr_local[local_edge_mask],\n", - " )\n", - " # Assemble pairwise features\n", - " h_pair_local = assemble_atom_pair_feature(\n", - " node_attr=node_attr_local,\n", - " edge_index=edge_index[:, local_edge_mask],\n", - " edge_attr=edge_attr_local[local_edge_mask],\n", - " ) # (E_local, 2H)\n", - "\n", - " # Invariant features of edges (bond graph, local)\n", - " if isinstance(sigma_edge, torch.Tensor):\n", - " edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (\n", - " 1.0 / sigma_edge[local_edge_mask]\n", - " ) # (E_local, 1)\n", - " else:\n", - " edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (1.0 / sigma_edge) # (E_local, 1)\n", - "\n", - " if return_edges:\n", - " return edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask\n", - " else:\n", - " return edge_inv_global, edge_inv_local\n", - "\n", - " def forward(\n", - " self,\n", - " sample,\n", - " timestep: Union[torch.Tensor, float, int],\n", - " return_dict: bool = True,\n", - " sigma=1.0,\n", - " global_start_sigma=0.5,\n", - " w_global=1.0,\n", - " extend_order=False,\n", - " extend_radius=True,\n", - " clip_local=None,\n", - " clip_global=1000.0,\n", - " ) -> Union[MoleculeGNNOutput, Tuple]:\n", - " r\"\"\"\n", - " Args:\n", - " sample: packed torch geometric object\n", - " timestep (`torch.Tensor` or `float` or `int): TODO verify type and shape (batch) timesteps\n", - " return_dict (`bool`, *optional*, defaults to `True`):\n", - " Whether or not to return a [`~models.molecule_gnn.MoleculeGNNOutput`] instead of a plain tuple.\n", - " Returns:\n", - " [`~models.molecule_gnn.MoleculeGNNOutput`] or `tuple`: [`~models.molecule_gnn.MoleculeGNNOutput`] if\n", - " `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.\n", - " \"\"\"\n", - "\n", - " # unpack sample\n", - " atom_type = sample.atom_type\n", - " bond_index = sample.edge_index\n", - " bond_type = sample.edge_type\n", - " num_graphs = sample.num_graphs\n", - " pos = sample.pos\n", - "\n", - " timesteps = torch.full(size=(num_graphs,), fill_value=timestep, dtype=torch.long, device=pos.device)\n", - "\n", - " edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask = self._forward(\n", - " atom_type=atom_type,\n", - " pos=sample.pos,\n", - " bond_index=bond_index,\n", - " bond_type=bond_type,\n", - " batch=sample.batch,\n", - " time_step=timesteps,\n", - " return_edges=True,\n", - " extend_order=extend_order,\n", - " extend_radius=extend_radius,\n", - " ) # (E_global, 1), (E_local, 1)\n", - "\n", - " # Important equation in the paper for equivariant features - eqns 5-7 of GeoDiff\n", - " node_eq_local = graph_field_network(\n", - " edge_inv_local, pos, edge_index[:, local_edge_mask], edge_length[local_edge_mask]\n", - " )\n", - " if clip_local is not None:\n", - " node_eq_local = clip_norm(node_eq_local, limit=clip_local)\n", - "\n", - " # Global\n", - " if sigma < global_start_sigma:\n", - " edge_inv_global = edge_inv_global * (1 - local_edge_mask.view(-1, 1).float())\n", - " node_eq_global = graph_field_network(edge_inv_global, pos, edge_index, edge_length)\n", - " node_eq_global = clip_norm(node_eq_global, limit=clip_global)\n", - " else:\n", - " node_eq_global = 0\n", - "\n", - " # Sum\n", - " eps_pos = node_eq_local + node_eq_global * w_global\n", - "\n", - " if not return_dict:\n", - " return (-eps_pos,)\n", - "\n", - " return MoleculeGNNOutput(sample=torch.Tensor(-eps_pos).to(pos.device))" - ], - "metadata": { - "id": "MCeZA1qQXzoK" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "CCIrPYSJj9wd" - }, - "source": [ - "### Load pretrained model" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "YdrAr6Ch--Ab" - }, - "source": [ - "#### Load a model\n", - "The model used is a design an\n", - "equivariant convolutional layer, named graph field network (GFN).\n", - "\n", - "The warning about `betas` and `alphas` can be ignored, those were moved to the scheduler." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "DyCo0nsqjbml", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 172, - "referenced_widgets": [ - "d90f304e9560472eacfbdd11e46765eb", - "1c6246f15b654f4daa11c9bcf997b78c", - "c2321b3bff6f490ca12040a20308f555", - "b7feb522161f4cf4b7cc7c1a078ff12d", - "e2d368556e494ae7ae4e2e992af2cd4f", - "bbef741e76ec41b7ab7187b487a383df", - "561f742d418d4721b0670cc8dd62e22c", - "872915dd1bb84f538c44e26badabafdd", - "d022575f1fa2446d891650897f187b4d", - "fdc393f3468c432aa0ada05e238a5436", - "2c9362906e4b40189f16d14aa9a348da", - "6010fc8daa7a44d5aec4b830ec2ebaa1", - "7e0bb1b8d65249d3974200686b193be2", - "ba98aa6d6a884e4ab8bbb5dfb5e4cf7a", - "6526646be5ed415c84d1245b040e629b", - "24d31fc3576e43dd9f8301d2ef3a37ab", - "2918bfaadc8d4b1a9832522c40dfefb8", - "a4bfdca35cc54dae8812720f1b276a08", - "e4901541199b45c6a18824627692fc39", - "f915cf874246446595206221e900b2fe", - "a9e388f22a9742aaaf538e22575c9433", - "42f6c3db29d7484ba6b4f73590abd2f4" - ] - }, - "outputId": "d6bce9d5-c51e-43a4-e680-e1e81bdfaf45" - }, - "outputs": [ - { - "output_type": "display_data", - "data": { - "text/plain": [ - "Downloading: 0%| | 0.00/3.27M [00:00] 124.78K 180KB/s in 0.7s \n", - "\n", - "2022-10-12 18:32:20 (180 KB/s) - ‘molecules.pkl’ saved [127774/127774]\n", - "\n" - ] - } - ], - "source": [ - "import torch\n", - "import numpy as np\n", - "\n", - "!wget https://huggingface.co/datasets/fusing/geodiff-example-data/resolve/main/data/molecules.pkl\n", - "dataset = torch.load('/content/molecules.pkl')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "QZcmy1EvKQRk" - }, - "source": [ - "Print out one entry of the dataset, it contains molecular formulas, atom types, positions, and more." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "JVjz6iH_H6Eh", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "898cb0cf-a0b3-411b-fd4c-bea1fbfd17fe" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "Data(atom_type=[51], bond_edge_index=[2, 108], edge_index=[2, 598], edge_order=[598], edge_type=[598], idx=[1], is_bond=[598], num_nodes_per_graph=[1], num_pos_ref=[1], nx=, pos=[51, 3], pos_ref=[255, 3], rdmol=, smiles=\"CC1CCCN(C(=O)C2CCN(S(=O)(=O)c3cccc4nonc34)CC2)C1\")" - ] - }, - "metadata": {}, - "execution_count": 20 - } - ], - "source": [ - "dataset[0]" - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Run the diffusion process" - ], - "metadata": { - "id": "vHNiZAUxNgoy" - } - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jZ1KZrxKqENg" - }, - "source": [ - "#### Helper Functions" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "s240tYueqKKf" - }, - "outputs": [], - "source": [ - "from torch_geometric.data import Data, Batch\n", - "from torch_scatter import scatter_add, scatter_mean\n", - "from tqdm import tqdm\n", - "import copy\n", - "import os\n", - "\n", - "def repeat_data(data: Data, num_repeat) -> Batch:\n", - " datas = [copy.deepcopy(data) for i in range(num_repeat)]\n", - " return Batch.from_data_list(datas)\n", - "\n", - "def repeat_batch(batch: Batch, num_repeat) -> Batch:\n", - " datas = batch.to_data_list()\n", - " new_data = []\n", - " for i in range(num_repeat):\n", - " new_data += copy.deepcopy(datas)\n", - " return Batch.from_data_list(new_data)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "AMnQTk0eqT7Z" - }, - "source": [ - "#### Constants" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "WYGkzqgzrHmF" - }, - "outputs": [], - "source": [ - "num_samples = 1 # solutions per molecule\n", - "num_molecules = 3\n", - "\n", - "DEVICE = 'cuda'\n", - "sampling_type = 'ddpm_noisy' #'' # paper also uses \"generalize\" and \"ld\"\n", - "# constants for inference\n", - "w_global = 0.5 #0,.3 for qm9\n", - "global_start_sigma = 0.5\n", - "eta = 1.0\n", - "clip_local = None\n", - "clip_pos = None\n", - "\n", - "# constands for data handling\n", - "save_traj = False\n", - "save_data = False\n", - "output_dir = '/content/'" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-xD5bJ3SqM7t" - }, - "source": [ - "#### Generate samples!\n", - "Note that the 3d representation of a molecule is referred to as the **conformation**" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "x9xuLUNg26z1", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "236d2a60-09ed-4c4d-97c1-6e3c0f2d26c4" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stderr", - "text": [ - "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:4: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", - " after removing the cwd from sys.path.\n", - "100%|██████████| 5/5 [00:55<00:00, 11.06s/it]\n" - ] - } - ], - "source": [ - "results = []\n", - "\n", - "# define sigmas\n", - "sigmas = torch.tensor(1.0 - scheduler.alphas_cumprod).sqrt() / torch.tensor(scheduler.alphas_cumprod).sqrt()\n", - "sigmas = sigmas.to(DEVICE)\n", - "\n", - "for count, data in enumerate(tqdm(dataset)):\n", - " num_samples = max(data.pos_ref.size(0) // data.num_nodes, 1)\n", - "\n", - " data_input = data.clone()\n", - " data_input['pos_ref'] = None\n", - " batch = repeat_data(data_input, num_samples).to(DEVICE)\n", - "\n", - " # initial configuration\n", - " pos_init = torch.randn(batch.num_nodes, 3).to(DEVICE)\n", - "\n", - " # for logging animation of denoising\n", - " pos_traj = []\n", - " with torch.no_grad():\n", - "\n", - " # scale initial sample\n", - " pos = pos_init * sigmas[-1]\n", - " for t in scheduler.timesteps:\n", - " batch.pos = pos\n", - "\n", - " # generate geometry with model, then filter it\n", - " epsilon = model.forward(batch, t, sigma=sigmas[t], return_dict=False)[0]\n", - "\n", - " # Update\n", - " reconstructed_pos = scheduler.step(epsilon, t, pos)[\"prev_sample\"].to(DEVICE)\n", - "\n", - " pos = reconstructed_pos\n", - "\n", - " if torch.isnan(pos).any():\n", - " print(\"NaN detected. Please restart.\")\n", - " raise FloatingPointError()\n", - "\n", - " # recenter graph of positions for next iteration\n", - " pos = pos - scatter_mean(pos, batch.batch, dim=0)[batch.batch]\n", - "\n", - " # optional clipping\n", - " if clip_pos is not None:\n", - " pos = torch.clamp(pos, min=-clip_pos, max=clip_pos)\n", - " pos_traj.append(pos.clone().cpu())\n", - "\n", - " pos_gen = pos.cpu()\n", - " if save_traj:\n", - " pos_gen_traj = pos_traj.cpu()\n", - " data.pos_gen = torch.stack(pos_gen_traj)\n", - " else:\n", - " data.pos_gen = pos_gen\n", - " results.append(data)\n", - "\n", - "\n", - "if save_data:\n", - " save_path = os.path.join(output_dir, 'samples_all.pkl')\n", - "\n", - " with open(save_path, 'wb') as f:\n", - " pickle.dump(results, f)" - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Render the results!" - ], - "metadata": { - "id": "fSApwSaZNndW" - } - }, - { - "cell_type": "markdown", - "metadata": { - "id": "d47Zxo2OKdgZ" - }, - "source": [ - "This function allows us to render 3d in colab." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "e9Cd0kCAv9b8" - }, - "outputs": [], - "source": [ - "from google.colab import output\n", - "output.enable_custom_widget_manager()" - ] - }, - { - "cell_type": "markdown", - "source": [ - "### Helper functions" - ], - "metadata": { - "id": "RjaVuR15NqzF" - } - }, - { - "cell_type": "markdown", - "metadata": { - "id": "28rBYa9NKhlz" - }, - "source": [ - "Here is a helper function for copying the generated tensors into a format used by RDKit & NGLViewer." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "LKdKdwxcyTQ6" - }, - "outputs": [], - "source": [ - "from copy import deepcopy\n", - "def set_rdmol_positions(rdkit_mol, pos):\n", - " \"\"\"\n", - " Args:\n", - " rdkit_mol: An `rdkit.Chem.rdchem.Mol` object.\n", - " pos: (N_atoms, 3)\n", - " \"\"\"\n", - " mol = deepcopy(rdkit_mol)\n", - " set_rdmol_positions_(mol, pos)\n", - " return mol\n", - "\n", - "def set_rdmol_positions_(mol, pos):\n", - " \"\"\"\n", - " Args:\n", - " rdkit_mol: An `rdkit.Chem.rdchem.Mol` object.\n", - " pos: (N_atoms, 3)\n", - " \"\"\"\n", - " for i in range(pos.shape[0]):\n", - " mol.GetConformer(0).SetAtomPosition(i, pos[i].tolist())\n", - " return mol\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "NuE10hcpKmzK" - }, - "source": [ - "Process the generated data to make it easy to view." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "KieVE1vc0_Vs", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "6faa185d-b1bc-47e8-be18-30d1e557e7c8" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "collect 5 generated molecules in `mols`\n" - ] - } - ], - "source": [ - "# the model can generate multiple conformations per 2d geometry\n", - "num_gen = results[0]['pos_gen'].shape[0]\n", - "\n", - "# init storage objects\n", - "mols_gen = []\n", - "mols_orig = []\n", - "for to_process in results:\n", - "\n", - " # store the reference 3d position\n", - " to_process['pos_ref'] = to_process['pos_ref'].reshape(-1, to_process['rdmol'].GetNumAtoms(), 3)\n", - "\n", - " # store the generated 3d position\n", - " to_process['pos_gen'] = to_process['pos_gen'].reshape(-1, to_process['rdmol'].GetNumAtoms(), 3)\n", - "\n", - " # copy data to new object\n", - " new_mol = set_rdmol_positions(to_process.rdmol, to_process['pos_gen'][0])\n", - "\n", - " # append results\n", - " mols_gen.append(new_mol)\n", - " mols_orig.append(to_process.rdmol)\n", - "\n", - "print(f\"collect {len(mols_gen)} generated molecules in `mols`\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tin89JwMKp4v" - }, - "source": [ - "Import tools to visualize the 2d chemical diagram of the molecule." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "yqV6gllSZn38" - }, - "outputs": [], - "source": [ - "from rdkit.Chem import AllChem\n", - "from rdkit import Chem\n", - "from rdkit.Chem.Draw import rdMolDraw2D as MD2\n", - "from IPython.display import SVG, display" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "TFNKmGddVoOk" - }, - "source": [ - "Select molecule to visualize" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "KzuwLlrrVaGc" - }, - "outputs": [], - "source": [ - "idx = 0\n", - "assert idx < len(results), \"selected molecule that was not generated\"" - ] - }, - { - "cell_type": "markdown", - "source": [ - "### Viewing" - ], - "metadata": { - "id": "hkb8w0_SNtU8" - } - }, - { - "cell_type": "markdown", - "metadata": { - "id": "I3R4QBQeKttN" - }, - "source": [ - "This 2D rendering is the equivalent of the **input to the model**!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "gkQRWjraaKex", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 321 - }, - "outputId": "9c3d1a91-a51d-475d-9e34-2be2459abc47" - }, - "outputs": [ - { - "output_type": "display_data", - "data": { - "text/plain": [ - "" - ], - "image/svg+xml": "\n\n \n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n" - }, - "metadata": {} - } - ], - "source": [ - "mc = Chem.MolFromSmiles(dataset[0]['smiles'])\n", - "molSize=(450,300)\n", - "drawer = MD2.MolDraw2DSVG(molSize[0],molSize[1])\n", - "drawer.DrawMolecule(mc)\n", - "drawer.FinishDrawing()\n", - "svg = drawer.GetDrawingText()\n", - "display(SVG(svg.replace('svg:','')))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "z4FDMYMxKw2I" - }, - "source": [ - "Generate the 3d molecule!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "aT1Bkb8YxJfV", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 17, - "referenced_widgets": [ - "695ab5bbf30a4ab19df1f9f33469f314", - "eac6a8dcdc9d4335a2e51031793ead29" - ] - }, - "outputId": "b98870ae-049d-4386-b676-166e9526bda2" - }, - "outputs": [ - { - "output_type": "display_data", - "data": { - "text/plain": [], - "application/vnd.jupyter.widget-view+json": { - "version_major": 2, - "version_minor": 0, - "model_id": "695ab5bbf30a4ab19df1f9f33469f314" - } - }, - "metadata": { - "application/vnd.jupyter.widget-view+json": { - "colab": { - "custom_widget_manager": { - "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/d2e234f7cc04bf79/manager.min.js" - } - } - } - } - } - ], - "source": [ - "from nglview import show_rdkit as show" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "pxtq8I-I18C-", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 337, - "referenced_widgets": [ - "be446195da2b4ff2aec21ec5ff963a54", - "c6596896148b4a8a9c57963b67c7782f", - "2489b5e5648541fbbdceadb05632a050", - "01e0ba4e5da04914b4652b8d58565d7b", - "c30e6c2f3e2a44dbbb3d63bd519acaa4", - "f31c6e40e9b2466a9064a2669933ecd5", - "19308ccac642498ab8b58462e3f1b0bb", - "4a081cdc2ec3421ca79dd933b7e2b0c4", - "e5c0d75eb5e1447abd560c8f2c6017e1", - "5146907ef6764654ad7d598baebc8b58", - "144ec959b7604a2cabb5ca46ae5e5379", - "abce2a80e6304df3899109c6d6cac199", - "65195cb7a4134f4887e9dd19f3676462" - ] - }, - "outputId": "72ed63ac-d2ec-4f5c-a0b1-4e7c1840a4e7" - }, - "outputs": [ - { - "output_type": "display_data", - "data": { - "text/plain": [ - "NGLWidget()" - ], - "application/vnd.jupyter.widget-view+json": { - "version_major": 2, - "version_minor": 0, - "model_id": "be446195da2b4ff2aec21ec5ff963a54" - } - }, - "metadata": { - "application/vnd.jupyter.widget-view+json": { - "colab": { - "custom_widget_manager": { - "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/d2e234f7cc04bf79/manager.min.js" - } - } - } - } - } - ], - "source": [ - "# new molecule\n", - "show(mols_gen[idx])" - ] - }, - { - "cell_type": "code", - "source": [], - "metadata": { - "id": "KJr4h2mwXeTo" - }, - "execution_count": null, - "outputs": [] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "provenance": [] - }, - "gpuClass": "standard", - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" - }, - "widgets": { - "application/vnd.jupyter.widget-state+json": { - "d90f304e9560472eacfbdd11e46765eb": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HBoxModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_1c6246f15b654f4daa11c9bcf997b78c", - "IPY_MODEL_c2321b3bff6f490ca12040a20308f555", - "IPY_MODEL_b7feb522161f4cf4b7cc7c1a078ff12d" - ], - "layout": "IPY_MODEL_e2d368556e494ae7ae4e2e992af2cd4f" - } - }, - "1c6246f15b654f4daa11c9bcf997b78c": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_bbef741e76ec41b7ab7187b487a383df", - "placeholder": "​", - "style": "IPY_MODEL_561f742d418d4721b0670cc8dd62e22c", - "value": "Downloading: 100%" - } - }, - "c2321b3bff6f490ca12040a20308f555": { - "model_module": "@jupyter-widgets/controls", - "model_name": "FloatProgressModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_872915dd1bb84f538c44e26badabafdd", - "max": 3271865, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_d022575f1fa2446d891650897f187b4d", - "value": 3271865 - } - }, - "b7feb522161f4cf4b7cc7c1a078ff12d": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_fdc393f3468c432aa0ada05e238a5436", - "placeholder": "​", - "style": "IPY_MODEL_2c9362906e4b40189f16d14aa9a348da", - "value": " 3.27M/3.27M [00:01<00:00, 3.25MB/s]" - } - }, - "e2d368556e494ae7ae4e2e992af2cd4f": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "bbef741e76ec41b7ab7187b487a383df": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "561f742d418d4721b0670cc8dd62e22c": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "872915dd1bb84f538c44e26badabafdd": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "d022575f1fa2446d891650897f187b4d": { - "model_module": "@jupyter-widgets/controls", - "model_name": "ProgressStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "fdc393f3468c432aa0ada05e238a5436": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "2c9362906e4b40189f16d14aa9a348da": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "6010fc8daa7a44d5aec4b830ec2ebaa1": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HBoxModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_7e0bb1b8d65249d3974200686b193be2", - "IPY_MODEL_ba98aa6d6a884e4ab8bbb5dfb5e4cf7a", - "IPY_MODEL_6526646be5ed415c84d1245b040e629b" - ], - "layout": "IPY_MODEL_24d31fc3576e43dd9f8301d2ef3a37ab" - } - }, - "7e0bb1b8d65249d3974200686b193be2": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_2918bfaadc8d4b1a9832522c40dfefb8", - "placeholder": "​", - "style": "IPY_MODEL_a4bfdca35cc54dae8812720f1b276a08", - "value": "Downloading: 100%" - } - }, - "ba98aa6d6a884e4ab8bbb5dfb5e4cf7a": { - "model_module": "@jupyter-widgets/controls", - "model_name": "FloatProgressModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_e4901541199b45c6a18824627692fc39", - "max": 401, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_f915cf874246446595206221e900b2fe", - "value": 401 - } - }, - "6526646be5ed415c84d1245b040e629b": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_a9e388f22a9742aaaf538e22575c9433", - "placeholder": "​", - "style": "IPY_MODEL_42f6c3db29d7484ba6b4f73590abd2f4", - "value": " 401/401 [00:00<00:00, 13.5kB/s]" - } - }, - "24d31fc3576e43dd9f8301d2ef3a37ab": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "2918bfaadc8d4b1a9832522c40dfefb8": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "a4bfdca35cc54dae8812720f1b276a08": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "e4901541199b45c6a18824627692fc39": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "f915cf874246446595206221e900b2fe": { - "model_module": "@jupyter-widgets/controls", - "model_name": "ProgressStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "a9e388f22a9742aaaf538e22575c9433": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "42f6c3db29d7484ba6b4f73590abd2f4": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "695ab5bbf30a4ab19df1f9f33469f314": { - "model_module": "nglview-js-widgets", - "model_name": "ColormakerRegistryModel", - "model_module_version": "3.0.1", - "state": { - "_dom_classes": [], - "_model_module": "nglview-js-widgets", - "_model_module_version": "3.0.1", - "_model_name": "ColormakerRegistryModel", - "_msg_ar": [], - "_msg_q": [], - "_ready": false, - "_view_count": null, - "_view_module": "nglview-js-widgets", - "_view_module_version": "3.0.1", - "_view_name": "ColormakerRegistryView", - "layout": "IPY_MODEL_eac6a8dcdc9d4335a2e51031793ead29" - } - }, - "eac6a8dcdc9d4335a2e51031793ead29": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "be446195da2b4ff2aec21ec5ff963a54": { - "model_module": "nglview-js-widgets", - "model_name": "NGLModel", - "model_module_version": "3.0.1", - "state": { - "_camera_orientation": [ - -15.519693580202304, - -14.065056548036177, - -23.53197484807691, - 0, - -23.357853515109753, - 20.94055073042662, - 2.888695042134944, - 0, - 14.352363398292777, - 18.870825741878015, - -20.744689572909344, - 0, - 0.2724999189376831, - 0.6940000057220459, - -0.3734999895095825, - 1 - ], - "_camera_str": "orthographic", - "_dom_classes": [], - "_gui_theme": null, - "_ibtn_fullscreen": "IPY_MODEL_2489b5e5648541fbbdceadb05632a050", - "_igui": null, - "_iplayer": "IPY_MODEL_01e0ba4e5da04914b4652b8d58565d7b", - "_model_module": "nglview-js-widgets", - "_model_module_version": "3.0.1", - "_model_name": "NGLModel", - "_ngl_color_dict": {}, - "_ngl_coordinate_resource": {}, - "_ngl_full_stage_parameters": { - "impostor": true, - "quality": "medium", - "workerDefault": true, - "sampleLevel": 0, - "backgroundColor": "white", - "rotateSpeed": 2, - "zoomSpeed": 1.2, - "panSpeed": 1, - "clipNear": 0, - "clipFar": 100, - "clipDist": 10, - "fogNear": 50, - "fogFar": 100, - "cameraFov": 40, - "cameraEyeSep": 0.3, - "cameraType": "perspective", - "lightColor": 14540253, - "lightIntensity": 1, - "ambientColor": 14540253, - "ambientIntensity": 0.2, - "hoverTimeout": 0, - "tooltip": true, - "mousePreset": "default" - }, - "_ngl_msg_archive": [ - { - "target": "Stage", - "type": "call_method", - "methodName": "loadFile", - "reconstruc_color_scheme": false, - "args": [ - { - "type": "blob", - "data": "HETATM 1 C1 UNL 1 -0.025 3.128 2.316 1.00 0.00 C \nHETATM 2 H1 UNL 1 0.183 3.657 2.823 1.00 0.00 H \nHETATM 3 C2 UNL 1 0.590 3.559 0.963 1.00 0.00 C \nHETATM 4 C3 UNL 1 0.056 4.479 0.406 1.00 0.00 C \nHETATM 5 C4 UNL 1 -0.219 4.802 -1.065 1.00 0.00 C \nHETATM 6 H2 UNL 1 0.686 4.431 -1.575 1.00 0.00 H \nHETATM 7 H3 UNL 1 -0.524 5.217 -1.274 1.00 0.00 H \nHETATM 8 C5 UNL 1 -1.284 3.766 -1.342 1.00 0.00 C \nHETATM 9 N1 UNL 1 -1.073 2.494 -0.580 1.00 0.00 N \nHETATM 10 C6 UNL 1 -1.909 1.494 -0.964 1.00 0.00 C \nHETATM 11 O1 UNL 1 -2.487 1.531 -2.092 1.00 0.00 O \nHETATM 12 C7 UNL 1 -2.232 0.242 -0.130 1.00 0.00 C \nHETATM 13 C8 UNL 1 -2.161 -1.057 -1.037 1.00 0.00 C \nHETATM 14 C9 UNL 1 -0.744 -1.111 -1.610 1.00 0.00 C \nHETATM 15 N2 UNL 1 0.290 -0.917 -0.628 1.00 0.00 N \nHETATM 16 S1 UNL 1 1.717 -1.597 -0.914 1.00 0.00 S \nHETATM 17 O2 UNL 1 1.960 -1.671 -2.338 1.00 0.00 O \nHETATM 18 O3 UNL 1 2.713 -0.968 -0.082 1.00 0.00 O \nHETATM 19 C10 UNL 1 1.425 -3.170 -0.345 1.00 0.00 C \nHETATM 20 C11 UNL 1 1.225 -4.400 -1.271 1.00 0.00 C \nHETATM 21 C12 UNL 1 1.314 -5.913 -0.895 1.00 0.00 C \nHETATM 22 C13 UNL 1 1.823 -6.229 0.386 1.00 0.00 C \nHETATM 23 C14 UNL 1 2.031 -5.110 1.365 1.00 0.00 C \nHETATM 24 N3 UNL 1 1.850 -5.267 2.712 1.00 0.00 N \nHETATM 25 O4 UNL 1 1.382 -4.029 3.126 1.00 0.00 O \nHETATM 26 N4 UNL 1 1.300 -3.023 2.154 1.00 0.00 N \nHETATM 27 C15 UNL 1 1.731 -3.672 1.032 1.00 0.00 C \nHETATM 28 H4 UNL 1 2.380 -6.874 0.436 1.00 0.00 H \nHETATM 29 H5 UNL 1 0.704 -6.526 -1.420 1.00 0.00 H \nHETATM 30 H6 UNL 1 1.144 -4.035 -2.291 1.00 0.00 H \nHETATM 31 C16 UNL 1 0.044 -0.371 0.685 1.00 0.00 C \nHETATM 32 C17 UNL 1 -1.352 -0.045 1.077 1.00 0.00 C \nHETATM 33 H7 UNL 1 -1.395 0.770 1.768 1.00 0.00 H \nHETATM 34 H8 UNL 1 -1.792 -0.941 1.582 1.00 0.00 H \nHETATM 35 H9 UNL 1 0.583 -1.035 1.393 1.00 0.00 H \nHETATM 36 H10 UNL 1 0.664 0.613 0.663 1.00 0.00 H \nHETATM 37 H11 UNL 1 -0.631 -0.267 -2.335 1.00 0.00 H \nHETATM 38 H12 UNL 1 -0.571 -2.046 -2.098 1.00 0.00 H \nHETATM 39 H13 UNL 1 -2.872 -0.992 -1.826 1.00 0.00 H \nHETATM 40 H14 UNL 1 -2.370 -1.924 -0.444 1.00 0.00 H \nHETATM 41 H15 UNL 1 -3.258 0.364 0.197 1.00 0.00 H \nHETATM 42 C18 UNL 1 0.276 2.337 -0.078 1.00 0.00 C \nHETATM 43 H16 UNL 1 0.514 1.371 0.252 1.00 0.00 H \nHETATM 44 H17 UNL 1 0.988 2.413 -0.949 1.00 0.00 H \nHETATM 45 H18 UNL 1 -1.349 3.451 -2.379 1.00 0.00 H \nHETATM 46 H19 UNL 1 -2.224 4.055 -0.958 1.00 0.00 H \nHETATM 47 H20 UNL 1 0.793 5.486 0.669 1.00 0.00 H \nHETATM 48 H21 UNL 1 -0.849 4.974 0.937 1.00 0.00 H \nHETATM 49 H22 UNL 1 1.667 3.431 1.070 1.00 0.00 H \nHETATM 50 H23 UNL 1 0.379 2.143 2.689 1.00 0.00 H \nHETATM 51 H24 UNL 1 -1.094 2.983 2.223 1.00 0.00 H \nCONECT 1 2 3 50 51\nCONECT 3 4 42 49\nCONECT 4 5 47 48\nCONECT 5 6 7 8\nCONECT 8 9 45 46\nCONECT 9 10 42\nCONECT 10 11 11 12\nCONECT 12 13 32 41\nCONECT 13 14 39 40\nCONECT 14 15 37 38\nCONECT 15 16 31\nCONECT 16 17 17 18 18\nCONECT 16 19\nCONECT 19 20 20 27\nCONECT 20 21 30\nCONECT 21 22 22 29\nCONECT 22 23 28\nCONECT 23 24 24 27\nCONECT 24 25\nCONECT 25 26\nCONECT 26 27 27\nCONECT 31 32 35 36\nCONECT 32 33 34\nCONECT 42 43 44\nEND\n", - "binary": false - } - ], - "kwargs": { - "defaultRepresentation": true, - "ext": "pdb" - } - } - ], - "_ngl_original_stage_parameters": { - "impostor": true, - "quality": "medium", - "workerDefault": true, - "sampleLevel": 0, - "backgroundColor": "white", - "rotateSpeed": 2, - "zoomSpeed": 1.2, - "panSpeed": 1, - "clipNear": 0, - "clipFar": 100, - "clipDist": 10, - "fogNear": 50, - "fogFar": 100, - "cameraFov": 40, - "cameraEyeSep": 0.3, - "cameraType": "perspective", - "lightColor": 14540253, - "lightIntensity": 1, - "ambientColor": 14540253, - "ambientIntensity": 0.2, - "hoverTimeout": 0, - "tooltip": true, - "mousePreset": "default" - }, - "_ngl_repr_dict": { - "0": { - "0": { - "type": "ball+stick", - "params": { - "lazy": false, - "visible": true, - "quality": "high", - "sphereDetail": 2, - "radialSegments": 20, - "openEnded": true, - "disableImpostor": false, - "aspectRatio": 1.5, - "lineOnly": false, - "cylinderOnly": false, - "multipleBond": "off", - "bondScale": 0.3, - "bondSpacing": 0.75, - "linewidth": 2, - "radiusType": "size", - "radiusData": {}, - "radiusSize": 0.15, - "radiusScale": 2, - "assembly": "default", - "defaultAssembly": "", - "clipNear": 0, - "clipRadius": 0, - "clipCenter": { - "x": 0, - "y": 0, - "z": 0 - }, - "flatShaded": false, - "opacity": 1, - "depthWrite": true, - "side": "double", - "wireframe": false, - "colorScheme": "element", - "colorScale": "", - "colorReverse": false, - "colorValue": 9474192, - "colorMode": "hcl", - "roughness": 0.4, - "metalness": 0, - "diffuse": 16777215, - "diffuseInterior": false, - "useInteriorColor": true, - "interiorColor": 2236962, - "interiorDarkening": 0, - "matrix": { - "elements": [ - 1, - 0, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 0, - 1 - ] - }, - "disablePicking": false, - "sele": "" - } - } - }, - "1": { - "0": { - "type": "ball+stick", - "params": { - "lazy": false, - "visible": true, - "quality": "high", - "sphereDetail": 2, - "radialSegments": 20, - "openEnded": true, - "disableImpostor": false, - "aspectRatio": 1.5, - "lineOnly": false, - "cylinderOnly": false, - "multipleBond": "off", - "bondScale": 0.3, - "bondSpacing": 0.75, - "linewidth": 2, - "radiusType": "size", - "radiusData": {}, - "radiusSize": 0.15, - "radiusScale": 2, - "assembly": "default", - "defaultAssembly": "", - "clipNear": 0, - "clipRadius": 0, - "clipCenter": { - "x": 0, - "y": 0, - "z": 0 - }, - "flatShaded": false, - "opacity": 1, - "depthWrite": true, - "side": "double", - "wireframe": false, - "colorScheme": "element", - "colorScale": "", - "colorReverse": false, - "colorValue": 9474192, - "colorMode": "hcl", - "roughness": 0.4, - "metalness": 0, - "diffuse": 16777215, - "diffuseInterior": false, - "useInteriorColor": true, - "interiorColor": 2236962, - "interiorDarkening": 0, - "matrix": { - "elements": [ - 1, - 0, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 0, - 1 - ] - }, - "disablePicking": false, - "sele": "" - } - } - } - }, - "_ngl_serialize": false, - "_ngl_version": "", - "_ngl_view_id": [ - "FB989FD1-5B9C-446B-8914-6B58AF85446D" - ], - "_player_dict": {}, - "_scene_position": {}, - "_scene_rotation": {}, - "_synced_model_ids": [], - "_synced_repr_model_ids": [], - "_view_count": null, - "_view_height": "", - "_view_module": "nglview-js-widgets", - "_view_module_version": "3.0.1", - "_view_name": "NGLView", - "_view_width": "", - "background": "white", - "frame": 0, - "gui_style": null, - "layout": "IPY_MODEL_c6596896148b4a8a9c57963b67c7782f", - "max_frame": 0, - "n_components": 2, - "picked": {} - } - }, - "c6596896148b4a8a9c57963b67c7782f": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "2489b5e5648541fbbdceadb05632a050": { - "model_module": "@jupyter-widgets/controls", - "model_name": "ButtonModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ButtonModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ButtonView", - "button_style": "", - "description": "", - "disabled": false, - "icon": "compress", - "layout": "IPY_MODEL_abce2a80e6304df3899109c6d6cac199", - "style": "IPY_MODEL_65195cb7a4134f4887e9dd19f3676462", - "tooltip": "" - } - }, - "01e0ba4e5da04914b4652b8d58565d7b": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HBoxModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_e5c0d75eb5e1447abd560c8f2c6017e1", - "IPY_MODEL_5146907ef6764654ad7d598baebc8b58" - ], - "layout": "IPY_MODEL_144ec959b7604a2cabb5ca46ae5e5379" - } - }, - "c30e6c2f3e2a44dbbb3d63bd519acaa4": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "f31c6e40e9b2466a9064a2669933ecd5": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "19308ccac642498ab8b58462e3f1b0bb": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "4a081cdc2ec3421ca79dd933b7e2b0c4": { - "model_module": "@jupyter-widgets/controls", - "model_name": "SliderStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "SliderStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "", - "handle_color": null - } - }, - "e5c0d75eb5e1447abd560c8f2c6017e1": { - "model_module": "@jupyter-widgets/controls", - "model_name": "PlayModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "PlayModel", - "_playing": false, - "_repeat": false, - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "PlayView", - "description": "", - "description_tooltip": null, - "disabled": false, - "interval": 100, - "layout": "IPY_MODEL_c30e6c2f3e2a44dbbb3d63bd519acaa4", - "max": 0, - "min": 0, - "show_repeat": true, - "step": 1, - "style": "IPY_MODEL_f31c6e40e9b2466a9064a2669933ecd5", - "value": 0 - } - }, - "5146907ef6764654ad7d598baebc8b58": { - "model_module": "@jupyter-widgets/controls", - "model_name": "IntSliderModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "IntSliderModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "IntSliderView", - "continuous_update": true, - "description": "", - "description_tooltip": null, - "disabled": false, - "layout": "IPY_MODEL_19308ccac642498ab8b58462e3f1b0bb", - "max": 0, - "min": 0, - "orientation": "horizontal", - "readout": true, - "readout_format": "d", - "step": 1, - "style": "IPY_MODEL_4a081cdc2ec3421ca79dd933b7e2b0c4", - "value": 0 - } - }, - "144ec959b7604a2cabb5ca46ae5e5379": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "abce2a80e6304df3899109c6d6cac199": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": "34px" - } - }, - "65195cb7a4134f4887e9dd19f3676462": { - "model_module": "@jupyter-widgets/controls", - "model_name": "ButtonStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ButtonStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "button_color": null, - "font_weight": "" - } - } - } - } + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "F88mignPnalS" + }, + "source": [ + "# Introduction\n", + "\n", + "This colab is design to run the pretrained models from [GeoDiff](https://github.com/MinkaiXu/GeoDiff).\n", + "The visualization code is inspired by this PyMol [colab](https://colab.research.google.com/gist/iwatobipen/2ec7faeafe5974501e69fcc98c122922/pymol.ipynb#scrollTo=Hm4kY7CaZSlw).\n", + "\n", + "The goal is to generate physically accurate molecules. Given the input of a molecule graph (atom and bond structures with their connectivity -- in the form of a 2d graph). What we want to generate is a stable 3d structure of the molecule.\n", + "\n", + "This colab uses GEOM datasets that have multiple 3d targets per configuration, which provide more compelling targets for generative methods.\n", + "\n", + "> Colab made by [natolambert](https://twitter.com/natolambert).\n", + "\n", + "![diffusers_library](https://github.com/huggingface/diffusers/raw/main/docs/source/imgs/diffusers_library.jpg)\n" + ] }, - "nbformat": 4, - "nbformat_minor": 0 -} \ No newline at end of file + { + "cell_type": "markdown", + "metadata": { + "id": "7cnwXMocnuzB" + }, + "source": [ + "## Installations\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ff9SxWnaNId9" + }, + "source": [ + "### Install Conda" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1g_6zOabItDk" + }, + "source": [ + "Here we check the `cuda` version of colab. When this was built, the version was always 11.1, which impacts some installation decisions below." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "K0ofXobG5Y-X", + "outputId": "572c3d25-6f19-4c1e-83f5-a1d084a3207f" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "nvcc: NVIDIA (R) Cuda compiler driver\n", + "Copyright (c) 2005-2021 NVIDIA Corporation\n", + "Built on Sun_Feb_14_21:12:58_PST_2021\n", + "Cuda compilation tools, release 11.2, V11.2.152\n", + "Build cuda_11.2.r11.2/compiler.29618528_0\n" + ] + } + ], + "source": [ + "!nvcc --version" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VfthW90vI0nw" + }, + "source": [ + "Install Conda for some more complex dependencies for geometric networks." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "2WNFzSnbiE0k", + "outputId": "690d0d4d-9d0a-4ead-c6dc-086f113f532f" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0m" + ] + } + ], + "source": [ + "!pip install -q condacolab" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NUsbWYCUI7Km" + }, + "source": [ + "Setup Conda" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "FZelreINdmd0", + "outputId": "635f0cb8-0af4-499f-e0a4-b3790cb12e9f" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "✨🍰✨ Everything looks OK!\n" + ] + } + ], + "source": [ + "import condacolab\n", + "\n", + "\n", + "condacolab.install()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "JzDHaPU7I9Sn" + }, + "source": [ + "Install pytorch requirements (this takes a few minutes, go grab yourself a coffee 🤗)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "JMxRjHhL7w8V", + "outputId": "6ed511b3-9262-49e8-b340-08e76b05ebd8" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting package metadata (current_repodata.json): - \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\bdone\n", + "Solving environment: \\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n", + "\n", + "## Package Plan ##\n", + "\n", + " environment location: /usr/local\n", + "\n", + " added / updated specs:\n", + " - cudatoolkit=11.1\n", + " - pytorch\n", + " - torchaudio\n", + " - torchvision\n", + "\n", + "\n", + "The following packages will be downloaded:\n", + "\n", + " package | build\n", + " ---------------------------|-----------------\n", + " conda-22.9.0 | py37h89c1867_1 960 KB conda-forge\n", + " ------------------------------------------------------------\n", + " Total: 960 KB\n", + "\n", + "The following packages will be UPDATED:\n", + "\n", + " conda 4.14.0-py37h89c1867_0 --> 22.9.0-py37h89c1867_1\n", + "\n", + "\n", + "\n", + "Downloading and Extracting Packages\n", + "conda-22.9.0 | 960 KB | : 100% 1.0/1 [00:00<00:00, 4.15it/s]\n", + "Preparing transaction: / \b\bdone\n", + "Verifying transaction: \\ \b\bdone\n", + "Executing transaction: / \b\bdone\n", + "Retrieving notices: ...working... done\n" + ] + } + ], + "source": [ + "!conda install pytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch-lts -c nvidia\n", + "# !conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QDS6FPZ0Tu5b" + }, + "source": [ + "Need to remove a pathspec for colab that specifies the incorrect cuda version." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "dq1lxR10TtrR", + "outputId": "ed9c5a71-b449-418f-abb7-072b74e7f6c8" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "rm: cannot remove '/usr/local/conda-meta/pinned': No such file or directory\n" + ] + } + ], + "source": [ + "!rm /usr/local/conda-meta/pinned" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Z1L3DdZOJB30" + }, + "source": [ + "Install torch geometric (used in the model later)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "D5ukfCOWfjzK", + "outputId": "8437485a-5aa6-4d53-8f7f-23517ac1ace6" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting package metadata (current_repodata.json): - \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\bdone\n", + "Solving environment: | \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n", + "\n", + "## Package Plan ##\n", + "\n", + " environment location: /usr/local\n", + "\n", + " added / updated specs:\n", + " - pytorch-geometric=1.7.2\n", + "\n", + "\n", + "The following packages will be downloaded:\n", + "\n", + " package | build\n", + " ---------------------------|-----------------\n", + " decorator-4.4.2 | py_0 11 KB conda-forge\n", + " googledrivedownloader-0.4 | pyhd3deb0d_1 7 KB conda-forge\n", + " jinja2-3.1.2 | pyhd8ed1ab_1 99 KB conda-forge\n", + " joblib-1.2.0 | pyhd8ed1ab_0 205 KB conda-forge\n", + " markupsafe-2.1.1 | py37h540881e_1 22 KB conda-forge\n", + " networkx-2.5.1 | pyhd8ed1ab_0 1.2 MB conda-forge\n", + " pandas-1.2.3 | py37hdc94413_0 11.8 MB conda-forge\n", + " pyparsing-3.0.9 | pyhd8ed1ab_0 79 KB conda-forge\n", + " python-dateutil-2.8.2 | pyhd8ed1ab_0 240 KB conda-forge\n", + " python-louvain-0.15 | pyhd8ed1ab_1 13 KB conda-forge\n", + " pytorch-cluster-1.5.9 |py37_torch_1.8.0_cu111 1.2 MB rusty1s\n", + " pytorch-geometric-1.7.2 |py37_torch_1.8.0_cu111 445 KB rusty1s\n", + " pytorch-scatter-2.0.8 |py37_torch_1.8.0_cu111 6.1 MB rusty1s\n", + " pytorch-sparse-0.6.12 |py37_torch_1.8.0_cu111 2.9 MB rusty1s\n", + " pytorch-spline-conv-1.2.1 |py37_torch_1.8.0_cu111 736 KB rusty1s\n", + " pytz-2022.4 | pyhd8ed1ab_0 232 KB conda-forge\n", + " scikit-learn-1.0.2 | py37hf9e9bfc_0 7.8 MB conda-forge\n", + " scipy-1.7.3 | py37hf2a6cf1_0 21.8 MB conda-forge\n", + " setuptools-59.8.0 | py37h89c1867_1 1.0 MB conda-forge\n", + " threadpoolctl-3.1.0 | pyh8a188c0_0 18 KB conda-forge\n", + " ------------------------------------------------------------\n", + " Total: 55.9 MB\n", + "\n", + "The following NEW packages will be INSTALLED:\n", + "\n", + " decorator conda-forge/noarch::decorator-4.4.2-py_0 None\n", + " googledrivedownlo~ conda-forge/noarch::googledrivedownloader-0.4-pyhd3deb0d_1 None\n", + " jinja2 conda-forge/noarch::jinja2-3.1.2-pyhd8ed1ab_1 None\n", + " joblib conda-forge/noarch::joblib-1.2.0-pyhd8ed1ab_0 None\n", + " markupsafe conda-forge/linux-64::markupsafe-2.1.1-py37h540881e_1 None\n", + " networkx conda-forge/noarch::networkx-2.5.1-pyhd8ed1ab_0 None\n", + " pandas conda-forge/linux-64::pandas-1.2.3-py37hdc94413_0 None\n", + " pyparsing conda-forge/noarch::pyparsing-3.0.9-pyhd8ed1ab_0 None\n", + " python-dateutil conda-forge/noarch::python-dateutil-2.8.2-pyhd8ed1ab_0 None\n", + " python-louvain conda-forge/noarch::python-louvain-0.15-pyhd8ed1ab_1 None\n", + " pytorch-cluster rusty1s/linux-64::pytorch-cluster-1.5.9-py37_torch_1.8.0_cu111 None\n", + " pytorch-geometric rusty1s/linux-64::pytorch-geometric-1.7.2-py37_torch_1.8.0_cu111 None\n", + " pytorch-scatter rusty1s/linux-64::pytorch-scatter-2.0.8-py37_torch_1.8.0_cu111 None\n", + " pytorch-sparse rusty1s/linux-64::pytorch-sparse-0.6.12-py37_torch_1.8.0_cu111 None\n", + " pytorch-spline-co~ rusty1s/linux-64::pytorch-spline-conv-1.2.1-py37_torch_1.8.0_cu111 None\n", + " pytz conda-forge/noarch::pytz-2022.4-pyhd8ed1ab_0 None\n", + " scikit-learn conda-forge/linux-64::scikit-learn-1.0.2-py37hf9e9bfc_0 None\n", + " scipy conda-forge/linux-64::scipy-1.7.3-py37hf2a6cf1_0 None\n", + " threadpoolctl conda-forge/noarch::threadpoolctl-3.1.0-pyh8a188c0_0 None\n", + "\n", + "The following packages will be DOWNGRADED:\n", + "\n", + " setuptools 65.3.0-py37h89c1867_0 --> 59.8.0-py37h89c1867_1 None\n", + "\n", + "\n", + "\n", + "Downloading and Extracting Packages\n", + "scikit-learn-1.0.2 | 7.8 MB | : 100% 1.0/1 [00:01<00:00, 1.37s/it] \n", + "pytorch-scatter-2.0. | 6.1 MB | : 100% 1.0/1 [00:06<00:00, 6.18s/it]\n", + "pytorch-geometric-1. | 445 KB | : 100% 1.0/1 [00:02<00:00, 2.53s/it]\n", + "scipy-1.7.3 | 21.8 MB | : 100% 1.0/1 [00:03<00:00, 3.06s/it]\n", + "python-dateutil-2.8. | 240 KB | : 100% 1.0/1 [00:00<00:00, 21.48it/s]\n", + "pytorch-spline-conv- | 736 KB | : 100% 1.0/1 [00:01<00:00, 1.00s/it]\n", + "pytorch-sparse-0.6.1 | 2.9 MB | : 100% 1.0/1 [00:07<00:00, 7.51s/it]\n", + "pyparsing-3.0.9 | 79 KB | : 100% 1.0/1 [00:00<00:00, 26.32it/s]\n", + "pytorch-cluster-1.5. | 1.2 MB | : 100% 1.0/1 [00:02<00:00, 2.78s/it]\n", + "jinja2-3.1.2 | 99 KB | : 100% 1.0/1 [00:00<00:00, 20.28it/s]\n", + "decorator-4.4.2 | 11 KB | : 100% 1.0/1 [00:00<00:00, 21.57it/s]\n", + "joblib-1.2.0 | 205 KB | : 100% 1.0/1 [00:00<00:00, 15.04it/s]\n", + "pytz-2022.4 | 232 KB | : 100% 1.0/1 [00:00<00:00, 10.21it/s]\n", + "python-louvain-0.15 | 13 KB | : 100% 1.0/1 [00:00<00:00, 3.34it/s]\n", + "googledrivedownloade | 7 KB | : 100% 1.0/1 [00:00<00:00, 3.33it/s]\n", + "threadpoolctl-3.1.0 | 18 KB | : 100% 1.0/1 [00:00<00:00, 29.40it/s]\n", + "markupsafe-2.1.1 | 22 KB | : 100% 1.0/1 [00:00<00:00, 28.62it/s]\n", + "pandas-1.2.3 | 11.8 MB | : 100% 1.0/1 [00:02<00:00, 2.08s/it] \n", + "networkx-2.5.1 | 1.2 MB | : 100% 1.0/1 [00:01<00:00, 1.39s/it]\n", + "setuptools-59.8.0 | 1.0 MB | : 100% 1.0/1 [00:00<00:00, 4.25it/s]\n", + "Preparing transaction: / \b\b- \b\b\\ \b\bdone\n", + "Verifying transaction: / \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n", + "Executing transaction: / \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\bdone\n", + "Retrieving notices: ...working... done\n" + ] + } + ], + "source": [ + "!conda install -c rusty1s pytorch-geometric=1.7.2" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ppxv6Mdkalbc" + }, + "source": [ + "### Install Diffusers" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "mgQA_XN-XGY2", + "outputId": "85392615-b6a4-4052-9d2a-79604be62c94" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/content\n", + "Cloning into 'diffusers'...\n", + "remote: Enumerating objects: 9298, done.\u001b[K\n", + "remote: Counting objects: 100% (40/40), done.\u001b[K\n", + "remote: Compressing objects: 100% (23/23), done.\u001b[K\n", + "remote: Total 9298 (delta 17), reused 23 (delta 11), pack-reused 9258\u001b[K\n", + "Receiving objects: 100% (9298/9298), 7.38 MiB | 5.28 MiB/s, done.\n", + "Resolving deltas: 100% (6168/6168), done.\n", + " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", + " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m757.0/757.0 kB\u001b[0m \u001b[31m52.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m163.5/163.5 kB\u001b[0m \u001b[31m21.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.8/40.8 kB\u001b[0m \u001b[31m5.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m596.3/596.3 kB\u001b[0m \u001b[31m51.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Building wheel for diffusers (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m432.7/432.7 kB\u001b[0m \u001b[31m36.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.3/5.3 MB\u001b[0m \u001b[31m90.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m35.3/35.3 MB\u001b[0m \u001b[31m39.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m115.1/115.1 kB\u001b[0m \u001b[31m16.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m948.0/948.0 kB\u001b[0m \u001b[31m63.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m212.2/212.2 kB\u001b[0m \u001b[31m21.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m95.8/95.8 kB\u001b[0m \u001b[31m12.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m140.8/140.8 kB\u001b[0m \u001b[31m18.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.6/7.6 MB\u001b[0m \u001b[31m104.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m148.0/148.0 kB\u001b[0m \u001b[31m20.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m231.3/231.3 kB\u001b[0m \u001b[31m30.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m94.8/94.8 kB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.8/58.8 kB\u001b[0m \u001b[31m8.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0m" + ] + } + ], + "source": [ + "%cd /content\n", + "\n", + "# install latest HF diffusers (will update to the release once added)\n", + "!git clone https://github.com/huggingface/diffusers.git\n", + "!pip install -q /content/diffusers\n", + "\n", + "# dependencies for diffusers\n", + "!pip install -q datasets transformers" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "LZO6AJKuJKO8" + }, + "source": [ + "Check that torch is installed correctly and utilizing the GPU in the colab" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 53 + }, + "id": "gZt7BNi1e1PA", + "outputId": "a0e1832c-9c02-49aa-cff8-1339e6cdc889" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "True\n" + ] + }, + { + "data": { + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "string" + }, + "text/plain": [ + "'1.8.2'" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import torch\n", + "\n", + "\n", + "print(torch.cuda.is_available())\n", + "torch.__version__" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KLE7CqlfJNUO" + }, + "source": [ + "### Install Chemistry-specific Dependencies\n", + "\n", + "Install RDKit, a tool for working with and visualizing chemsitry in python (you use this to visualize the generate models later)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "0CPv_NvehRz3", + "outputId": "6ee0ae4e-4511-4816-de29-22b1c21d49bc" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", + "Collecting rdkit\n", + " Downloading rdkit-2022.3.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (36.8 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m36.8/36.8 MB\u001b[0m \u001b[31m34.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: Pillow in /usr/local/lib/python3.7/site-packages (from rdkit) (9.2.0)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.7/site-packages (from rdkit) (1.21.6)\n", + "Installing collected packages: rdkit\n", + "Successfully installed rdkit-2022.3.5\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0m" + ] + } + ], + "source": [ + "!pip install rdkit" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "88GaDbDPxJ5I" + }, + "source": [ + "### Get viewer from nglview\n", + "\n", + "The model you will use outputs a position matrix tensor. This pytorch geometric data object will have many features (positions, known features, edge features -- all tensors).\n", + "The data we give to the model will also have a rdmol object (which can extract features to geometric if needed).\n", + "The rdmol in this object is a source of ground truth for the generated molecules.\n", + "\n", + "You will use one rendering function from nglviewer later!\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "id": "jcl8GCS2mz6t", + "outputId": "99b5cc40-67bb-4d8e-faa0-47d7cb33e98f" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", + "Collecting nglview\n", + " Downloading nglview-3.0.3.tar.gz (5.7 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.7/5.7 MB\u001b[0m \u001b[31m91.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", + " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.7/site-packages (from nglview) (1.21.6)\n", + "Collecting jupyterlab-widgets\n", + " Downloading jupyterlab_widgets-3.0.3-py3-none-any.whl (384 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m384.1/384.1 kB\u001b[0m \u001b[31m40.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting ipywidgets>=7\n", + " Downloading ipywidgets-8.0.2-py3-none-any.whl (134 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.4/134.4 kB\u001b[0m \u001b[31m21.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting widgetsnbextension~=4.0\n", + " Downloading widgetsnbextension-4.0.3-py3-none-any.whl (2.0 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.0/2.0 MB\u001b[0m \u001b[31m84.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting ipython>=6.1.0\n", + " Downloading ipython-7.34.0-py3-none-any.whl (793 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m793.8/793.8 kB\u001b[0m \u001b[31m60.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting ipykernel>=4.5.1\n", + " Downloading ipykernel-6.16.0-py3-none-any.whl (138 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m138.4/138.4 kB\u001b[0m \u001b[31m20.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting traitlets>=4.3.1\n", + " Downloading traitlets-5.4.0-py3-none-any.whl (107 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m107.1/107.1 kB\u001b[0m \u001b[31m17.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.7/site-packages (from ipykernel>=4.5.1->ipywidgets>=7->nglview) (21.3)\n", + "Collecting pyzmq>=17\n", + " Downloading pyzmq-24.0.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.1 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m68.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting matplotlib-inline>=0.1\n", + " Downloading matplotlib_inline-0.1.6-py3-none-any.whl (9.4 kB)\n", + "Collecting tornado>=6.1\n", + " Downloading tornado-6.2-cp37-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (423 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m424.0/424.0 kB\u001b[0m \u001b[31m41.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting nest-asyncio\n", + " Downloading nest_asyncio-1.5.6-py3-none-any.whl (5.2 kB)\n", + "Collecting debugpy>=1.0\n", + " Downloading debugpy-1.6.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.8 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.8/1.8 MB\u001b[0m \u001b[31m83.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting psutil\n", + " Downloading psutil-5.9.2-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (281 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m281.3/281.3 kB\u001b[0m \u001b[31m33.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting jupyter-client>=6.1.12\n", + " Downloading jupyter_client-7.4.2-py3-none-any.whl (132 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m132.2/132.2 kB\u001b[0m \u001b[31m19.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting pickleshare\n", + " Downloading pickleshare-0.7.5-py2.py3-none-any.whl (6.9 kB)\n", + "Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.7/site-packages (from ipython>=6.1.0->ipywidgets>=7->nglview) (59.8.0)\n", + "Collecting backcall\n", + " Downloading backcall-0.2.0-py2.py3-none-any.whl (11 kB)\n", + "Collecting pexpect>4.3\n", + " Downloading pexpect-4.8.0-py2.py3-none-any.whl (59 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m59.0/59.0 kB\u001b[0m \u001b[31m7.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting pygments\n", + " Downloading Pygments-2.13.0-py3-none-any.whl (1.1 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m70.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting jedi>=0.16\n", + " Downloading jedi-0.18.1-py2.py3-none-any.whl (1.6 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m83.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0\n", + " Downloading prompt_toolkit-3.0.31-py3-none-any.whl (382 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m382.3/382.3 kB\u001b[0m \u001b[31m40.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: decorator in /usr/local/lib/python3.7/site-packages (from ipython>=6.1.0->ipywidgets>=7->nglview) (4.4.2)\n", + "Collecting parso<0.9.0,>=0.8.0\n", + " Downloading parso-0.8.3-py2.py3-none-any.whl (100 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m100.8/100.8 kB\u001b[0m \u001b[31m14.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.7/site-packages (from jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets>=7->nglview) (2.8.2)\n", + "Collecting entrypoints\n", + " Downloading entrypoints-0.4-py3-none-any.whl (5.3 kB)\n", + "Collecting jupyter-core>=4.9.2\n", + " Downloading jupyter_core-4.11.1-py3-none-any.whl (88 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m88.4/88.4 kB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting ptyprocess>=0.5\n", + " Downloading ptyprocess-0.7.0-py2.py3-none-any.whl (13 kB)\n", + "Collecting wcwidth\n", + " Downloading wcwidth-0.2.5-py2.py3-none-any.whl (30 kB)\n", + "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/site-packages (from packaging->ipykernel>=4.5.1->ipywidgets>=7->nglview) (3.0.9)\n", + "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/site-packages (from python-dateutil>=2.8.2->jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets>=7->nglview) (1.16.0)\n", + "Building wheels for collected packages: nglview\n", + " Building wheel for nglview (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for nglview: filename=nglview-3.0.3-py3-none-any.whl size=8057538 sha256=b7e1071bb91822e48515bf27f4e6b197c6e85e06b90912b3439edc8be1e29514\n", + " Stored in directory: /root/.cache/pip/wheels/01/0c/49/c6f79d8edba8fe89752bf20de2d99040bfa57db0548975c5d5\n", + "Successfully built nglview\n", + "Installing collected packages: wcwidth, ptyprocess, pickleshare, backcall, widgetsnbextension, traitlets, tornado, pyzmq, pygments, psutil, prompt-toolkit, pexpect, parso, nest-asyncio, jupyterlab-widgets, entrypoints, debugpy, matplotlib-inline, jupyter-core, jedi, jupyter-client, ipython, ipykernel, ipywidgets, nglview\n", + "Successfully installed backcall-0.2.0 debugpy-1.6.3 entrypoints-0.4 ipykernel-6.16.0 ipython-7.34.0 ipywidgets-8.0.2 jedi-0.18.1 jupyter-client-7.4.2 jupyter-core-4.11.1 jupyterlab-widgets-3.0.3 matplotlib-inline-0.1.6 nest-asyncio-1.5.6 nglview-3.0.3 parso-0.8.3 pexpect-4.8.0 pickleshare-0.7.5 prompt-toolkit-3.0.31 psutil-5.9.2 ptyprocess-0.7.0 pygments-2.13.0 pyzmq-24.0.1 tornado-6.2 traitlets-5.4.0 wcwidth-0.2.5 widgetsnbextension-4.0.3\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0m" + ] + }, + { + "data": { + "application/vnd.colab-display-data+json": { + "pip_warning": { + "packages": [ + "pexpect", + "pickleshare", + "wcwidth" + ] + } + } + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "!pip install nglview" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8t8_e_uVLdKB" + }, + "source": [ + "## Create a diffusion model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "G0rMncVtNSqU" + }, + "source": [ + "### Model class(es)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "L5FEXz5oXkzt" + }, + "source": [ + "Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-3-P4w5sXkRU" + }, + "outputs": [], + "source": [ + "# Model adapted from GeoDiff https://github.com/MinkaiXu/GeoDiff\n", + "# Model inspired by https://github.com/DeepGraphLearning/torchdrug/tree/master/torchdrug/models\n", + "from dataclasses import dataclass\n", + "from typing import Callable, Tuple, Union\n", + "\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn.functional as F\n", + "from torch import Tensor, nn\n", + "from torch.nn import Embedding, Linear, Module, ModuleList, Sequential\n", + "from torch_geometric.nn import MessagePassing, radius, radius_graph\n", + "from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size\n", + "from torch_geometric.utils import dense_to_sparse, to_dense_adj\n", + "from torch_scatter import scatter_add\n", + "from torch_sparse import SparseTensor, coalesce\n", + "\n", + "from diffusers.configuration_utils import ConfigMixin, register_to_config\n", + "from diffusers.modeling_utils import ModelMixin\n", + "from diffusers.utils import BaseOutput" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "EzJQXPN_XrMX" + }, + "source": [ + "Helper classes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "oR1Y56QiLY90" + }, + "outputs": [], + "source": [ + "@dataclass\n", + "class MoleculeGNNOutput(BaseOutput):\n", + " \"\"\"\n", + " Args:\n", + " sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):\n", + " Hidden states output. Output of last layer of model.\n", + " \"\"\"\n", + "\n", + " sample: torch.Tensor\n", + "\n", + "\n", + "class MultiLayerPerceptron(nn.Module):\n", + " \"\"\"\n", + " Multi-layer Perceptron. Note there is no activation or dropout in the last layer.\n", + " Args:\n", + " input_dim (int): input dimension\n", + " hidden_dim (list of int): hidden dimensions\n", + " activation (str or function, optional): activation function\n", + " dropout (float, optional): dropout rate\n", + " \"\"\"\n", + "\n", + " def __init__(self, input_dim, hidden_dims, activation=\"relu\", dropout=0):\n", + " super(MultiLayerPerceptron, self).__init__()\n", + "\n", + " self.dims = [input_dim] + hidden_dims\n", + " if isinstance(activation, str):\n", + " self.activation = getattr(F, activation)\n", + " else:\n", + " print(f\"Warning, activation passed {activation} is not string and ignored\")\n", + " self.activation = None\n", + " if dropout > 0:\n", + " self.dropout = nn.Dropout(dropout)\n", + " else:\n", + " self.dropout = None\n", + "\n", + " self.layers = nn.ModuleList()\n", + " for i in range(len(self.dims) - 1):\n", + " self.layers.append(nn.Linear(self.dims[i], self.dims[i + 1]))\n", + "\n", + " def forward(self, x):\n", + " \"\"\"\"\"\"\n", + " for i, layer in enumerate(self.layers):\n", + " x = layer(x)\n", + " if i < len(self.layers) - 1:\n", + " if self.activation:\n", + " x = self.activation(x)\n", + " if self.dropout:\n", + " x = self.dropout(x)\n", + " return x\n", + "\n", + "\n", + "class ShiftedSoftplus(torch.nn.Module):\n", + " def __init__(self):\n", + " super(ShiftedSoftplus, self).__init__()\n", + " self.shift = torch.log(torch.tensor(2.0)).item()\n", + "\n", + " def forward(self, x):\n", + " return F.softplus(x) - self.shift\n", + "\n", + "\n", + "class CFConv(MessagePassing):\n", + " def __init__(self, in_channels, out_channels, num_filters, mlp, cutoff, smooth):\n", + " super(CFConv, self).__init__(aggr=\"add\")\n", + " self.lin1 = Linear(in_channels, num_filters, bias=False)\n", + " self.lin2 = Linear(num_filters, out_channels)\n", + " self.nn = mlp\n", + " self.cutoff = cutoff\n", + " self.smooth = smooth\n", + "\n", + " self.reset_parameters()\n", + "\n", + " def reset_parameters(self):\n", + " torch.nn.init.xavier_uniform_(self.lin1.weight)\n", + " torch.nn.init.xavier_uniform_(self.lin2.weight)\n", + " self.lin2.bias.data.fill_(0)\n", + "\n", + " def forward(self, x, edge_index, edge_length, edge_attr):\n", + " if self.smooth:\n", + " C = 0.5 * (torch.cos(edge_length * np.pi / self.cutoff) + 1.0)\n", + " C = C * (edge_length <= self.cutoff) * (edge_length >= 0.0) # Modification: cutoff\n", + " else:\n", + " C = (edge_length <= self.cutoff).float()\n", + " W = self.nn(edge_attr) * C.view(-1, 1)\n", + "\n", + " x = self.lin1(x)\n", + " x = self.propagate(edge_index, x=x, W=W)\n", + " x = self.lin2(x)\n", + " return x\n", + "\n", + " def message(self, x_j: torch.Tensor, W) -> torch.Tensor:\n", + " return x_j * W\n", + "\n", + "\n", + "class InteractionBlock(torch.nn.Module):\n", + " def __init__(self, hidden_channels, num_gaussians, num_filters, cutoff, smooth):\n", + " super(InteractionBlock, self).__init__()\n", + " mlp = Sequential(\n", + " Linear(num_gaussians, num_filters),\n", + " ShiftedSoftplus(),\n", + " Linear(num_filters, num_filters),\n", + " )\n", + " self.conv = CFConv(hidden_channels, hidden_channels, num_filters, mlp, cutoff, smooth)\n", + " self.act = ShiftedSoftplus()\n", + " self.lin = Linear(hidden_channels, hidden_channels)\n", + "\n", + " def forward(self, x, edge_index, edge_length, edge_attr):\n", + " x = self.conv(x, edge_index, edge_length, edge_attr)\n", + " x = self.act(x)\n", + " x = self.lin(x)\n", + " return x\n", + "\n", + "\n", + "class SchNetEncoder(Module):\n", + " def __init__(\n", + " self, hidden_channels=128, num_filters=128, num_interactions=6, edge_channels=100, cutoff=10.0, smooth=False\n", + " ):\n", + " super().__init__()\n", + "\n", + " self.hidden_channels = hidden_channels\n", + " self.num_filters = num_filters\n", + " self.num_interactions = num_interactions\n", + " self.cutoff = cutoff\n", + "\n", + " self.embedding = Embedding(100, hidden_channels, max_norm=10.0)\n", + "\n", + " self.interactions = ModuleList()\n", + " for _ in range(num_interactions):\n", + " block = InteractionBlock(hidden_channels, edge_channels, num_filters, cutoff, smooth)\n", + " self.interactions.append(block)\n", + "\n", + " def forward(self, z, edge_index, edge_length, edge_attr, embed_node=True):\n", + " if embed_node:\n", + " assert z.dim() == 1 and z.dtype == torch.long\n", + " h = self.embedding(z)\n", + " else:\n", + " h = z\n", + " for interaction in self.interactions:\n", + " h = h + interaction(h, edge_index, edge_length, edge_attr)\n", + "\n", + " return h\n", + "\n", + "\n", + "class GINEConv(MessagePassing):\n", + " \"\"\"\n", + " Custom class of the graph isomorphism operator from the \"How Powerful are Graph Neural Networks?\n", + " https://arxiv.org/abs/1810.00826 paper. Note that this implementation has the added option of a custom activation.\n", + " \"\"\"\n", + "\n", + " def __init__(self, mlp: Callable, eps: float = 0.0, train_eps: bool = False, activation=\"softplus\", **kwargs):\n", + " super(GINEConv, self).__init__(aggr=\"add\", **kwargs)\n", + " self.nn = mlp\n", + " self.initial_eps = eps\n", + "\n", + " if isinstance(activation, str):\n", + " self.activation = getattr(F, activation)\n", + " else:\n", + " self.activation = None\n", + "\n", + " if train_eps:\n", + " self.eps = torch.nn.Parameter(torch.Tensor([eps]))\n", + " else:\n", + " self.register_buffer(\"eps\", torch.Tensor([eps]))\n", + "\n", + " def forward(\n", + " self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_attr: OptTensor = None, size: Size = None\n", + " ) -> torch.Tensor:\n", + " \"\"\"\"\"\"\n", + " if isinstance(x, torch.Tensor):\n", + " x: OptPairTensor = (x, x)\n", + "\n", + " # Node and edge feature dimensionalites need to match.\n", + " if isinstance(edge_index, torch.Tensor):\n", + " assert edge_attr is not None\n", + " assert x[0].size(-1) == edge_attr.size(-1)\n", + " elif isinstance(edge_index, SparseTensor):\n", + " assert x[0].size(-1) == edge_index.size(-1)\n", + "\n", + " # propagate_type: (x: OptPairTensor, edge_attr: OptTensor)\n", + " out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)\n", + "\n", + " x_r = x[1]\n", + " if x_r is not None:\n", + " out += (1 + self.eps) * x_r\n", + "\n", + " return self.nn(out)\n", + "\n", + " def message(self, x_j: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor:\n", + " if self.activation:\n", + " return self.activation(x_j + edge_attr)\n", + " else:\n", + " return x_j + edge_attr\n", + "\n", + " def __repr__(self):\n", + " return \"{}(nn={})\".format(self.__class__.__name__, self.nn)\n", + "\n", + "\n", + "class GINEncoder(torch.nn.Module):\n", + " def __init__(self, hidden_dim, num_convs=3, activation=\"relu\", short_cut=True, concat_hidden=False):\n", + " super().__init__()\n", + "\n", + " self.hidden_dim = hidden_dim\n", + " self.num_convs = num_convs\n", + " self.short_cut = short_cut\n", + " self.concat_hidden = concat_hidden\n", + " self.node_emb = nn.Embedding(100, hidden_dim)\n", + "\n", + " if isinstance(activation, str):\n", + " self.activation = getattr(F, activation)\n", + " else:\n", + " self.activation = None\n", + "\n", + " self.convs = nn.ModuleList()\n", + " for i in range(self.num_convs):\n", + " self.convs.append(\n", + " GINEConv(\n", + " MultiLayerPerceptron(hidden_dim, [hidden_dim, hidden_dim], activation=activation),\n", + " activation=activation,\n", + " )\n", + " )\n", + "\n", + " def forward(self, z, edge_index, edge_attr):\n", + " \"\"\"\n", + " Input:\n", + " data: (torch_geometric.data.Data): batched graph edge_index: bond indices of the original graph (num_node,\n", + " hidden) edge_attr: edge feature tensor with shape (num_edge, hidden)\n", + " Output:\n", + " node_feature: graph feature\n", + " \"\"\"\n", + "\n", + " node_attr = self.node_emb(z) # (num_node, hidden)\n", + "\n", + " hiddens = []\n", + " conv_input = node_attr # (num_node, hidden)\n", + "\n", + " for conv_idx, conv in enumerate(self.convs):\n", + " hidden = conv(conv_input, edge_index, edge_attr)\n", + " if conv_idx < len(self.convs) - 1 and self.activation is not None:\n", + " hidden = self.activation(hidden)\n", + " assert hidden.shape == conv_input.shape\n", + " if self.short_cut and hidden.shape == conv_input.shape:\n", + " hidden += conv_input\n", + "\n", + " hiddens.append(hidden)\n", + " conv_input = hidden\n", + "\n", + " if self.concat_hidden:\n", + " node_feature = torch.cat(hiddens, dim=-1)\n", + " else:\n", + " node_feature = hiddens[-1]\n", + "\n", + " return node_feature\n", + "\n", + "\n", + "class MLPEdgeEncoder(Module):\n", + " def __init__(self, hidden_dim=100, activation=\"relu\"):\n", + " super().__init__()\n", + " self.hidden_dim = hidden_dim\n", + " self.bond_emb = Embedding(100, embedding_dim=self.hidden_dim)\n", + " self.mlp = MultiLayerPerceptron(1, [self.hidden_dim, self.hidden_dim], activation=activation)\n", + "\n", + " @property\n", + " def out_channels(self):\n", + " return self.hidden_dim\n", + "\n", + " def forward(self, edge_length, edge_type):\n", + " \"\"\"\n", + " Input:\n", + " edge_length: The length of edges, shape=(E, 1). edge_type: The type pf edges, shape=(E,)\n", + " Returns:\n", + " edge_attr: The representation of edges. (E, 2 * num_gaussians)\n", + " \"\"\"\n", + " d_emb = self.mlp(edge_length) # (num_edge, hidden_dim)\n", + " edge_attr = self.bond_emb(edge_type) # (num_edge, hidden_dim)\n", + " return d_emb * edge_attr # (num_edge, hidden)\n", + "\n", + "\n", + "def assemble_atom_pair_feature(node_attr, edge_index, edge_attr):\n", + " h_row, h_col = node_attr[edge_index[0]], node_attr[edge_index[1]]\n", + " h_pair = torch.cat([h_row * h_col, edge_attr], dim=-1) # (E, 2H)\n", + " return h_pair\n", + "\n", + "\n", + "def _extend_graph_order(num_nodes, edge_index, edge_type, order=3):\n", + " \"\"\"\n", + " Args:\n", + " num_nodes: Number of atoms.\n", + " edge_index: Bond indices of the original graph.\n", + " edge_type: Bond types of the original graph.\n", + " order: Extension order.\n", + " Returns:\n", + " new_edge_index: Extended edge indices. new_edge_type: Extended edge types.\n", + " \"\"\"\n", + "\n", + " def binarize(x):\n", + " return torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x))\n", + "\n", + " def get_higher_order_adj_matrix(adj, order):\n", + " \"\"\"\n", + " Args:\n", + " adj: (N, N)\n", + " type_mat: (N, N)\n", + " Returns:\n", + " Following attributes will be updated:\n", + " - edge_index\n", + " - edge_type\n", + " Following attributes will be added to the data object:\n", + " - bond_edge_index: Original edge_index.\n", + " \"\"\"\n", + " adj_mats = [\n", + " torch.eye(adj.size(0), dtype=torch.long, device=adj.device),\n", + " binarize(adj + torch.eye(adj.size(0), dtype=torch.long, device=adj.device)),\n", + " ]\n", + "\n", + " for i in range(2, order + 1):\n", + " adj_mats.append(binarize(adj_mats[i - 1] @ adj_mats[1]))\n", + " order_mat = torch.zeros_like(adj)\n", + "\n", + " for i in range(1, order + 1):\n", + " order_mat += (adj_mats[i] - adj_mats[i - 1]) * i\n", + "\n", + " return order_mat\n", + "\n", + " num_types = 22\n", + " # given from len(BOND_TYPES), where BOND_TYPES = {t: i for i, t in enumerate(BT.names.values())}\n", + " # from rdkit.Chem.rdchem import BondType as BT\n", + " N = num_nodes\n", + " adj = to_dense_adj(edge_index).squeeze(0)\n", + " adj_order = get_higher_order_adj_matrix(adj, order) # (N, N)\n", + "\n", + " type_mat = to_dense_adj(edge_index, edge_attr=edge_type).squeeze(0) # (N, N)\n", + " type_highorder = torch.where(adj_order > 1, num_types + adj_order - 1, torch.zeros_like(adj_order))\n", + " assert (type_mat * type_highorder == 0).all()\n", + " type_new = type_mat + type_highorder\n", + "\n", + " new_edge_index, new_edge_type = dense_to_sparse(type_new)\n", + " _, edge_order = dense_to_sparse(adj_order)\n", + "\n", + " # data.bond_edge_index = data.edge_index # Save original edges\n", + " new_edge_index, new_edge_type = coalesce(new_edge_index, new_edge_type.long(), N, N) # modify data\n", + "\n", + " return new_edge_index, new_edge_type\n", + "\n", + "\n", + "def _extend_to_radius_graph(pos, edge_index, edge_type, cutoff, batch, unspecified_type_number=0, is_sidechain=None):\n", + " assert edge_type.dim() == 1\n", + " N = pos.size(0)\n", + "\n", + " bgraph_adj = torch.sparse.LongTensor(edge_index, edge_type, torch.Size([N, N]))\n", + "\n", + " if is_sidechain is None:\n", + " rgraph_edge_index = radius_graph(pos, r=cutoff, batch=batch) # (2, E_r)\n", + " else:\n", + " # fetch sidechain and its batch index\n", + " is_sidechain = is_sidechain.bool()\n", + " dummy_index = torch.arange(pos.size(0), device=pos.device)\n", + " sidechain_pos = pos[is_sidechain]\n", + " sidechain_index = dummy_index[is_sidechain]\n", + " sidechain_batch = batch[is_sidechain]\n", + "\n", + " assign_index = radius(x=pos, y=sidechain_pos, r=cutoff, batch_x=batch, batch_y=sidechain_batch)\n", + " r_edge_index_x = assign_index[1]\n", + " r_edge_index_y = assign_index[0]\n", + " r_edge_index_y = sidechain_index[r_edge_index_y]\n", + "\n", + " rgraph_edge_index1 = torch.stack((r_edge_index_x, r_edge_index_y)) # (2, E)\n", + " rgraph_edge_index2 = torch.stack((r_edge_index_y, r_edge_index_x)) # (2, E)\n", + " rgraph_edge_index = torch.cat((rgraph_edge_index1, rgraph_edge_index2), dim=-1) # (2, 2E)\n", + " # delete self loop\n", + " rgraph_edge_index = rgraph_edge_index[:, (rgraph_edge_index[0] != rgraph_edge_index[1])]\n", + "\n", + " rgraph_adj = torch.sparse.LongTensor(\n", + " rgraph_edge_index,\n", + " torch.ones(rgraph_edge_index.size(1)).long().to(pos.device) * unspecified_type_number,\n", + " torch.Size([N, N]),\n", + " )\n", + "\n", + " composed_adj = (bgraph_adj + rgraph_adj).coalesce() # Sparse (N, N, T)\n", + "\n", + " new_edge_index = composed_adj.indices()\n", + " new_edge_type = composed_adj.values().long()\n", + "\n", + " return new_edge_index, new_edge_type\n", + "\n", + "\n", + "def extend_graph_order_radius(\n", + " num_nodes,\n", + " pos,\n", + " edge_index,\n", + " edge_type,\n", + " batch,\n", + " order=3,\n", + " cutoff=10.0,\n", + " extend_order=True,\n", + " extend_radius=True,\n", + " is_sidechain=None,\n", + "):\n", + " if extend_order:\n", + " edge_index, edge_type = _extend_graph_order(\n", + " num_nodes=num_nodes, edge_index=edge_index, edge_type=edge_type, order=order\n", + " )\n", + "\n", + " if extend_radius:\n", + " edge_index, edge_type = _extend_to_radius_graph(\n", + " pos=pos, edge_index=edge_index, edge_type=edge_type, cutoff=cutoff, batch=batch, is_sidechain=is_sidechain\n", + " )\n", + "\n", + " return edge_index, edge_type\n", + "\n", + "\n", + "def get_distance(pos, edge_index):\n", + " return (pos[edge_index[0]] - pos[edge_index[1]]).norm(dim=-1)\n", + "\n", + "\n", + "def graph_field_network(score_d, pos, edge_index, edge_length):\n", + " \"\"\"\n", + " Transformation to make the epsilon predicted from the diffusion model roto-translational equivariant. See equations\n", + " 5-7 of the GeoDiff Paper https://arxiv.org/pdf/2203.02923.pdf\n", + " \"\"\"\n", + " N = pos.size(0)\n", + " dd_dr = (1.0 / edge_length) * (pos[edge_index[0]] - pos[edge_index[1]]) # (E, 3)\n", + " score_pos = scatter_add(dd_dr * score_d, edge_index[0], dim=0, dim_size=N) + scatter_add(\n", + " -dd_dr * score_d, edge_index[1], dim=0, dim_size=N\n", + " ) # (N, 3)\n", + " return score_pos\n", + "\n", + "\n", + "def clip_norm(vec, limit, p=2):\n", + " norm = torch.norm(vec, dim=-1, p=2, keepdim=True)\n", + " denom = torch.where(norm > limit, limit / norm, torch.ones_like(norm))\n", + " return vec * denom\n", + "\n", + "\n", + "def is_local_edge(edge_type):\n", + " return edge_type > 0" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QWrHJFcYXyUB" + }, + "source": [ + "Main model class!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "MCeZA1qQXzoK" + }, + "outputs": [], + "source": [ + "class MoleculeGNN(ModelMixin, ConfigMixin):\n", + " @register_to_config\n", + " def __init__(\n", + " self,\n", + " hidden_dim=128,\n", + " num_convs=6,\n", + " num_convs_local=4,\n", + " cutoff=10.0,\n", + " mlp_act=\"relu\",\n", + " edge_order=3,\n", + " edge_encoder=\"mlp\",\n", + " smooth_conv=True,\n", + " ):\n", + " super().__init__()\n", + " self.cutoff = cutoff\n", + " self.edge_encoder = edge_encoder\n", + " self.edge_order = edge_order\n", + "\n", + " \"\"\"\n", + " edge_encoder: Takes both edge type and edge length as input and outputs a vector [Note]: node embedding is done\n", + " in SchNetEncoder\n", + " \"\"\"\n", + " self.edge_encoder_global = MLPEdgeEncoder(hidden_dim, mlp_act) # get_edge_encoder(config)\n", + " self.edge_encoder_local = MLPEdgeEncoder(hidden_dim, mlp_act) # get_edge_encoder(config)\n", + "\n", + " \"\"\"\n", + " The graph neural network that extracts node-wise features.\n", + " \"\"\"\n", + " self.encoder_global = SchNetEncoder(\n", + " hidden_channels=hidden_dim,\n", + " num_filters=hidden_dim,\n", + " num_interactions=num_convs,\n", + " edge_channels=self.edge_encoder_global.out_channels,\n", + " cutoff=cutoff,\n", + " smooth=smooth_conv,\n", + " )\n", + " self.encoder_local = GINEncoder(\n", + " hidden_dim=hidden_dim,\n", + " num_convs=num_convs_local,\n", + " )\n", + "\n", + " \"\"\"\n", + " `output_mlp` takes a mixture of two nodewise features and edge features as input and outputs\n", + " gradients w.r.t. edge_length (out_dim = 1).\n", + " \"\"\"\n", + " self.grad_global_dist_mlp = MultiLayerPerceptron(\n", + " 2 * hidden_dim, [hidden_dim, hidden_dim // 2, 1], activation=mlp_act\n", + " )\n", + "\n", + " self.grad_local_dist_mlp = MultiLayerPerceptron(\n", + " 2 * hidden_dim, [hidden_dim, hidden_dim // 2, 1], activation=mlp_act\n", + " )\n", + "\n", + " \"\"\"\n", + " Incorporate parameters together\n", + " \"\"\"\n", + " self.model_global = nn.ModuleList([self.edge_encoder_global, self.encoder_global, self.grad_global_dist_mlp])\n", + " self.model_local = nn.ModuleList([self.edge_encoder_local, self.encoder_local, self.grad_local_dist_mlp])\n", + "\n", + " def _forward(\n", + " self,\n", + " atom_type,\n", + " pos,\n", + " bond_index,\n", + " bond_type,\n", + " batch,\n", + " time_step, # NOTE, model trained without timestep performed best\n", + " edge_index=None,\n", + " edge_type=None,\n", + " edge_length=None,\n", + " return_edges=False,\n", + " extend_order=True,\n", + " extend_radius=True,\n", + " is_sidechain=None,\n", + " ):\n", + " \"\"\"\n", + " Args:\n", + " atom_type: Types of atoms, (N, ).\n", + " bond_index: Indices of bonds (not extended, not radius-graph), (2, E).\n", + " bond_type: Bond types, (E, ).\n", + " batch: Node index to graph index, (N, ).\n", + " \"\"\"\n", + " N = atom_type.size(0)\n", + " if edge_index is None or edge_type is None or edge_length is None:\n", + " edge_index, edge_type = extend_graph_order_radius(\n", + " num_nodes=N,\n", + " pos=pos,\n", + " edge_index=bond_index,\n", + " edge_type=bond_type,\n", + " batch=batch,\n", + " order=self.edge_order,\n", + " cutoff=self.cutoff,\n", + " extend_order=extend_order,\n", + " extend_radius=extend_radius,\n", + " is_sidechain=is_sidechain,\n", + " )\n", + " edge_length = get_distance(pos, edge_index).unsqueeze(-1) # (E, 1)\n", + " local_edge_mask = is_local_edge(edge_type) # (E, )\n", + "\n", + " # with the parameterization of NCSNv2\n", + " # DDPM loss implicit handle the noise variance scale conditioning\n", + " sigma_edge = torch.ones(size=(edge_index.size(1), 1), device=pos.device) # (E, 1)\n", + "\n", + " # Encoding global\n", + " edge_attr_global = self.edge_encoder_global(edge_length=edge_length, edge_type=edge_type) # Embed edges\n", + "\n", + " # Global\n", + " node_attr_global = self.encoder_global(\n", + " z=atom_type,\n", + " edge_index=edge_index,\n", + " edge_length=edge_length,\n", + " edge_attr=edge_attr_global,\n", + " )\n", + " # Assemble pairwise features\n", + " h_pair_global = assemble_atom_pair_feature(\n", + " node_attr=node_attr_global,\n", + " edge_index=edge_index,\n", + " edge_attr=edge_attr_global,\n", + " ) # (E_global, 2H)\n", + " # Invariant features of edges (radius graph, global)\n", + " edge_inv_global = self.grad_global_dist_mlp(h_pair_global) * (1.0 / sigma_edge) # (E_global, 1)\n", + "\n", + " # Encoding local\n", + " edge_attr_local = self.edge_encoder_global(edge_length=edge_length, edge_type=edge_type) # Embed edges\n", + " # edge_attr += temb_edge\n", + "\n", + " # Local\n", + " node_attr_local = self.encoder_local(\n", + " z=atom_type,\n", + " edge_index=edge_index[:, local_edge_mask],\n", + " edge_attr=edge_attr_local[local_edge_mask],\n", + " )\n", + " # Assemble pairwise features\n", + " h_pair_local = assemble_atom_pair_feature(\n", + " node_attr=node_attr_local,\n", + " edge_index=edge_index[:, local_edge_mask],\n", + " edge_attr=edge_attr_local[local_edge_mask],\n", + " ) # (E_local, 2H)\n", + "\n", + " # Invariant features of edges (bond graph, local)\n", + " if isinstance(sigma_edge, torch.Tensor):\n", + " edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (\n", + " 1.0 / sigma_edge[local_edge_mask]\n", + " ) # (E_local, 1)\n", + " else:\n", + " edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (1.0 / sigma_edge) # (E_local, 1)\n", + "\n", + " if return_edges:\n", + " return edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask\n", + " else:\n", + " return edge_inv_global, edge_inv_local\n", + "\n", + " def forward(\n", + " self,\n", + " sample,\n", + " timestep: Union[torch.Tensor, float, int],\n", + " return_dict: bool = True,\n", + " sigma=1.0,\n", + " global_start_sigma=0.5,\n", + " w_global=1.0,\n", + " extend_order=False,\n", + " extend_radius=True,\n", + " clip_local=None,\n", + " clip_global=1000.0,\n", + " ) -> Union[MoleculeGNNOutput, Tuple]:\n", + " r\"\"\"\n", + " Args:\n", + " sample: packed torch geometric object\n", + " timestep (`torch.Tensor` or `float` or `int): TODO verify type and shape (batch) timesteps\n", + " return_dict (`bool`, *optional*, defaults to `True`):\n", + " Whether or not to return a [`~models.molecule_gnn.MoleculeGNNOutput`] instead of a plain tuple.\n", + " Returns:\n", + " [`~models.molecule_gnn.MoleculeGNNOutput`] or `tuple`: [`~models.molecule_gnn.MoleculeGNNOutput`] if\n", + " `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.\n", + " \"\"\"\n", + "\n", + " # unpack sample\n", + " atom_type = sample.atom_type\n", + " bond_index = sample.edge_index\n", + " bond_type = sample.edge_type\n", + " num_graphs = sample.num_graphs\n", + " pos = sample.pos\n", + "\n", + " timesteps = torch.full(size=(num_graphs,), fill_value=timestep, dtype=torch.long, device=pos.device)\n", + "\n", + " edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask = self._forward(\n", + " atom_type=atom_type,\n", + " pos=sample.pos,\n", + " bond_index=bond_index,\n", + " bond_type=bond_type,\n", + " batch=sample.batch,\n", + " time_step=timesteps,\n", + " return_edges=True,\n", + " extend_order=extend_order,\n", + " extend_radius=extend_radius,\n", + " ) # (E_global, 1), (E_local, 1)\n", + "\n", + " # Important equation in the paper for equivariant features - eqns 5-7 of GeoDiff\n", + " node_eq_local = graph_field_network(\n", + " edge_inv_local, pos, edge_index[:, local_edge_mask], edge_length[local_edge_mask]\n", + " )\n", + " if clip_local is not None:\n", + " node_eq_local = clip_norm(node_eq_local, limit=clip_local)\n", + "\n", + " # Global\n", + " if sigma < global_start_sigma:\n", + " edge_inv_global = edge_inv_global * (1 - local_edge_mask.view(-1, 1).float())\n", + " node_eq_global = graph_field_network(edge_inv_global, pos, edge_index, edge_length)\n", + " node_eq_global = clip_norm(node_eq_global, limit=clip_global)\n", + " else:\n", + " node_eq_global = 0\n", + "\n", + " # Sum\n", + " eps_pos = node_eq_local + node_eq_global * w_global\n", + "\n", + " if not return_dict:\n", + " return (-eps_pos,)\n", + "\n", + " return MoleculeGNNOutput(sample=torch.Tensor(-eps_pos).to(pos.device))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "CCIrPYSJj9wd" + }, + "source": [ + "### Load pretrained model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "YdrAr6Ch--Ab" + }, + "source": [ + "#### Load a model\n", + "The model used is a design an\n", + "equivariant convolutional layer, named graph field network (GFN).\n", + "\n", + "The warning about `betas` and `alphas` can be ignored, those were moved to the scheduler." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 172, + "referenced_widgets": [ + "d90f304e9560472eacfbdd11e46765eb", + "1c6246f15b654f4daa11c9bcf997b78c", + "c2321b3bff6f490ca12040a20308f555", + "b7feb522161f4cf4b7cc7c1a078ff12d", + "e2d368556e494ae7ae4e2e992af2cd4f", + "bbef741e76ec41b7ab7187b487a383df", + "561f742d418d4721b0670cc8dd62e22c", + "872915dd1bb84f538c44e26badabafdd", + "d022575f1fa2446d891650897f187b4d", + "fdc393f3468c432aa0ada05e238a5436", + "2c9362906e4b40189f16d14aa9a348da", + "6010fc8daa7a44d5aec4b830ec2ebaa1", + "7e0bb1b8d65249d3974200686b193be2", + "ba98aa6d6a884e4ab8bbb5dfb5e4cf7a", + "6526646be5ed415c84d1245b040e629b", + "24d31fc3576e43dd9f8301d2ef3a37ab", + "2918bfaadc8d4b1a9832522c40dfefb8", + "a4bfdca35cc54dae8812720f1b276a08", + "e4901541199b45c6a18824627692fc39", + "f915cf874246446595206221e900b2fe", + "a9e388f22a9742aaaf538e22575c9433", + "42f6c3db29d7484ba6b4f73590abd2f4" + ] + }, + "id": "DyCo0nsqjbml", + "outputId": "d6bce9d5-c51e-43a4-e680-e1e81bdfaf45" + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "d90f304e9560472eacfbdd11e46765eb", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Downloading: 0%| | 0.00/3.27M [00:00] 124.78K 180KB/s in 0.7s \n", + "\n", + "2022-10-12 18:32:20 (180 KB/s) - ‘molecules.pkl’ saved [127774/127774]\n", + "\n" + ] + } + ], + "source": [ + "import torch\n", + "\n", + "\n", + "!wget https://huggingface.co/datasets/fusing/geodiff-example-data/resolve/main/data/molecules.pkl\n", + "dataset = torch.load(\"/content/molecules.pkl\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QZcmy1EvKQRk" + }, + "source": [ + "Print out one entry of the dataset, it contains molecular formulas, atom types, positions, and more." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "JVjz6iH_H6Eh", + "outputId": "898cb0cf-a0b3-411b-fd4c-bea1fbfd17fe" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Data(atom_type=[51], bond_edge_index=[2, 108], edge_index=[2, 598], edge_order=[598], edge_type=[598], idx=[1], is_bond=[598], num_nodes_per_graph=[1], num_pos_ref=[1], nx=, pos=[51, 3], pos_ref=[255, 3], rdmol=, smiles=\"CC1CCCN(C(=O)C2CCN(S(=O)(=O)c3cccc4nonc34)CC2)C1\")" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset[0]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vHNiZAUxNgoy" + }, + "source": [ + "## Run the diffusion process" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jZ1KZrxKqENg" + }, + "source": [ + "#### Helper Functions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "s240tYueqKKf" + }, + "outputs": [], + "source": [ + "import copy\n", + "import os\n", + "\n", + "from torch_geometric.data import Batch, Data\n", + "from torch_scatter import scatter_mean\n", + "from tqdm import tqdm\n", + "\n", + "\n", + "def repeat_data(data: Data, num_repeat) -> Batch:\n", + " datas = [copy.deepcopy(data) for i in range(num_repeat)]\n", + " return Batch.from_data_list(datas)\n", + "\n", + "\n", + "def repeat_batch(batch: Batch, num_repeat) -> Batch:\n", + " datas = batch.to_data_list()\n", + " new_data = []\n", + " for i in range(num_repeat):\n", + " new_data += copy.deepcopy(datas)\n", + " return Batch.from_data_list(new_data)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AMnQTk0eqT7Z" + }, + "source": [ + "#### Constants" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "WYGkzqgzrHmF" + }, + "outputs": [], + "source": [ + "num_samples = 1 # solutions per molecule\n", + "num_molecules = 3\n", + "\n", + "DEVICE = \"cuda\"\n", + "sampling_type = \"ddpm_noisy\" #'' # paper also uses \"generalize\" and \"ld\"\n", + "# constants for inference\n", + "w_global = 0.5 # 0,.3 for qm9\n", + "global_start_sigma = 0.5\n", + "eta = 1.0\n", + "clip_local = None\n", + "clip_pos = None\n", + "\n", + "# constands for data handling\n", + "save_traj = False\n", + "save_data = False\n", + "output_dir = \"/content/\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-xD5bJ3SqM7t" + }, + "source": [ + "#### Generate samples!\n", + "Note that the 3d representation of a molecule is referred to as the **conformation**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "x9xuLUNg26z1", + "outputId": "236d2a60-09ed-4c4d-97c1-6e3c0f2d26c4" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:4: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", + " after removing the cwd from sys.path.\n", + "100%|██████████| 5/5 [00:55<00:00, 11.06s/it]\n" + ] + } + ], + "source": [ + "import pickle\n", + "\n", + "\n", + "results = []\n", + "\n", + "# define sigmas\n", + "sigmas = torch.tensor(1.0 - scheduler.alphas_cumprod).sqrt() / torch.tensor(scheduler.alphas_cumprod).sqrt()\n", + "sigmas = sigmas.to(DEVICE)\n", + "\n", + "for count, data in enumerate(tqdm(dataset)):\n", + " num_samples = max(data.pos_ref.size(0) // data.num_nodes, 1)\n", + "\n", + " data_input = data.clone()\n", + " data_input[\"pos_ref\"] = None\n", + " batch = repeat_data(data_input, num_samples).to(DEVICE)\n", + "\n", + " # initial configuration\n", + " pos_init = torch.randn(batch.num_nodes, 3).to(DEVICE)\n", + "\n", + " # for logging animation of denoising\n", + " pos_traj = []\n", + " with torch.no_grad():\n", + " # scale initial sample\n", + " pos = pos_init * sigmas[-1]\n", + " for t in scheduler.timesteps:\n", + " batch.pos = pos\n", + "\n", + " # generate geometry with model, then filter it\n", + " epsilon = model.forward(batch, t, sigma=sigmas[t], return_dict=False)[0]\n", + "\n", + " # Update\n", + " reconstructed_pos = scheduler.step(epsilon, t, pos)[\"prev_sample\"].to(DEVICE)\n", + "\n", + " pos = reconstructed_pos\n", + "\n", + " if torch.isnan(pos).any():\n", + " print(\"NaN detected. Please restart.\")\n", + " raise FloatingPointError()\n", + "\n", + " # recenter graph of positions for next iteration\n", + " pos = pos - scatter_mean(pos, batch.batch, dim=0)[batch.batch]\n", + "\n", + " # optional clipping\n", + " if clip_pos is not None:\n", + " pos = torch.clamp(pos, min=-clip_pos, max=clip_pos)\n", + " pos_traj.append(pos.clone().cpu())\n", + "\n", + " pos_gen = pos.cpu()\n", + " if save_traj:\n", + " pos_gen_traj = pos_traj.cpu()\n", + " data.pos_gen = torch.stack(pos_gen_traj)\n", + " else:\n", + " data.pos_gen = pos_gen\n", + " results.append(data)\n", + "\n", + "\n", + "if save_data:\n", + " save_path = os.path.join(output_dir, \"samples_all.pkl\")\n", + "\n", + " with open(save_path, \"wb\") as f:\n", + " pickle.dump(results, f)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "fSApwSaZNndW" + }, + "source": [ + "## Render the results!" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "d47Zxo2OKdgZ" + }, + "source": [ + "This function allows us to render 3d in colab." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "e9Cd0kCAv9b8" + }, + "outputs": [], + "source": [ + "from google.colab import output\n", + "\n", + "\n", + "output.enable_custom_widget_manager()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RjaVuR15NqzF" + }, + "source": [ + "### Helper functions" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "28rBYa9NKhlz" + }, + "source": [ + "Here is a helper function for copying the generated tensors into a format used by RDKit & NGLViewer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "LKdKdwxcyTQ6" + }, + "outputs": [], + "source": [ + "from copy import deepcopy\n", + "\n", + "\n", + "def set_rdmol_positions(rdkit_mol, pos):\n", + " \"\"\"\n", + " Args:\n", + " rdkit_mol: An `rdkit.Chem.rdchem.Mol` object.\n", + " pos: (N_atoms, 3)\n", + " \"\"\"\n", + " mol = deepcopy(rdkit_mol)\n", + " set_rdmol_positions_(mol, pos)\n", + " return mol\n", + "\n", + "\n", + "def set_rdmol_positions_(mol, pos):\n", + " \"\"\"\n", + " Args:\n", + " rdkit_mol: An `rdkit.Chem.rdchem.Mol` object.\n", + " pos: (N_atoms, 3)\n", + " \"\"\"\n", + " for i in range(pos.shape[0]):\n", + " mol.GetConformer(0).SetAtomPosition(i, pos[i].tolist())\n", + " return mol" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NuE10hcpKmzK" + }, + "source": [ + "Process the generated data to make it easy to view." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "KieVE1vc0_Vs", + "outputId": "6faa185d-b1bc-47e8-be18-30d1e557e7c8" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "collect 5 generated molecules in `mols`\n" + ] + } + ], + "source": [ + "# the model can generate multiple conformations per 2d geometry\n", + "num_gen = results[0][\"pos_gen\"].shape[0]\n", + "\n", + "# init storage objects\n", + "mols_gen = []\n", + "mols_orig = []\n", + "for to_process in results:\n", + " # store the reference 3d position\n", + " to_process[\"pos_ref\"] = to_process[\"pos_ref\"].reshape(-1, to_process[\"rdmol\"].GetNumAtoms(), 3)\n", + "\n", + " # store the generated 3d position\n", + " to_process[\"pos_gen\"] = to_process[\"pos_gen\"].reshape(-1, to_process[\"rdmol\"].GetNumAtoms(), 3)\n", + "\n", + " # copy data to new object\n", + " new_mol = set_rdmol_positions(to_process.rdmol, to_process[\"pos_gen\"][0])\n", + "\n", + " # append results\n", + " mols_gen.append(new_mol)\n", + " mols_orig.append(to_process.rdmol)\n", + "\n", + "print(f\"collect {len(mols_gen)} generated molecules in `mols`\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "tin89JwMKp4v" + }, + "source": [ + "Import tools to visualize the 2d chemical diagram of the molecule." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "yqV6gllSZn38" + }, + "outputs": [], + "source": [ + "from IPython.display import SVG, display\n", + "from rdkit import Chem\n", + "from rdkit.Chem.Draw import rdMolDraw2D as MD2" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TFNKmGddVoOk" + }, + "source": [ + "Select molecule to visualize" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "KzuwLlrrVaGc" + }, + "outputs": [], + "source": [ + "idx = 0\n", + "assert idx < len(results), \"selected molecule that was not generated\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hkb8w0_SNtU8" + }, + "source": [ + "### Viewing" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "I3R4QBQeKttN" + }, + "source": [ + "This 2D rendering is the equivalent of the **input to the model**!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 321 + }, + "id": "gkQRWjraaKex", + "outputId": "9c3d1a91-a51d-475d-9e34-2be2459abc47" + }, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + " \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "mc = Chem.MolFromSmiles(dataset[0][\"smiles\"])\n", + "molSize = (450, 300)\n", + "drawer = MD2.MolDraw2DSVG(molSize[0], molSize[1])\n", + "drawer.DrawMolecule(mc)\n", + "drawer.FinishDrawing()\n", + "svg = drawer.GetDrawingText()\n", + "display(SVG(svg.replace(\"svg:\", \"\")))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "z4FDMYMxKw2I" + }, + "source": [ + "Generate the 3d molecule!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 17, + "referenced_widgets": [ + "695ab5bbf30a4ab19df1f9f33469f314", + "eac6a8dcdc9d4335a2e51031793ead29" + ] + }, + "id": "aT1Bkb8YxJfV", + "outputId": "b98870ae-049d-4386-b676-166e9526bda2" + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "695ab5bbf30a4ab19df1f9f33469f314", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [] + }, + "metadata": { + "application/vnd.jupyter.widget-view+json": { + "colab": { + "custom_widget_manager": { + "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/d2e234f7cc04bf79/manager.min.js" + } + } + } + }, + "output_type": "display_data" + } + ], + "source": [ + "from nglview import show_rdkit as show" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 337, + "referenced_widgets": [ + "be446195da2b4ff2aec21ec5ff963a54", + "c6596896148b4a8a9c57963b67c7782f", + "2489b5e5648541fbbdceadb05632a050", + "01e0ba4e5da04914b4652b8d58565d7b", + "c30e6c2f3e2a44dbbb3d63bd519acaa4", + "f31c6e40e9b2466a9064a2669933ecd5", + "19308ccac642498ab8b58462e3f1b0bb", + "4a081cdc2ec3421ca79dd933b7e2b0c4", + "e5c0d75eb5e1447abd560c8f2c6017e1", + "5146907ef6764654ad7d598baebc8b58", + "144ec959b7604a2cabb5ca46ae5e5379", + "abce2a80e6304df3899109c6d6cac199", + "65195cb7a4134f4887e9dd19f3676462" + ] + }, + "id": "pxtq8I-I18C-", + "outputId": "72ed63ac-d2ec-4f5c-a0b1-4e7c1840a4e7" + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "be446195da2b4ff2aec21ec5ff963a54", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "NGLWidget()" + ] + }, + "metadata": { + "application/vnd.jupyter.widget-view+json": { + "colab": { + "custom_widget_manager": { + "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/d2e234f7cc04bf79/manager.min.js" + } + } + } + }, + "output_type": "display_data" + } + ], + "source": [ + "# new molecule\n", + "show(mols_gen[idx])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "KJr4h2mwXeTo" + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "provenance": [] + }, + "gpuClass": "standard", + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "01e0ba4e5da04914b4652b8d58565d7b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_e5c0d75eb5e1447abd560c8f2c6017e1", + "IPY_MODEL_5146907ef6764654ad7d598baebc8b58" + ], + "layout": "IPY_MODEL_144ec959b7604a2cabb5ca46ae5e5379" + } + }, + "144ec959b7604a2cabb5ca46ae5e5379": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "19308ccac642498ab8b58462e3f1b0bb": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "1c6246f15b654f4daa11c9bcf997b78c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_bbef741e76ec41b7ab7187b487a383df", + "placeholder": "​", + "style": "IPY_MODEL_561f742d418d4721b0670cc8dd62e22c", + "value": "Downloading: 100%" + } + }, + "2489b5e5648541fbbdceadb05632a050": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ButtonModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ButtonModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ButtonView", + "button_style": "", + "description": "", + "disabled": false, + "icon": "compress", + "layout": "IPY_MODEL_abce2a80e6304df3899109c6d6cac199", + "style": "IPY_MODEL_65195cb7a4134f4887e9dd19f3676462", + "tooltip": "" + } + }, + "24d31fc3576e43dd9f8301d2ef3a37ab": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "2918bfaadc8d4b1a9832522c40dfefb8": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "2c9362906e4b40189f16d14aa9a348da": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "42f6c3db29d7484ba6b4f73590abd2f4": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "4a081cdc2ec3421ca79dd933b7e2b0c4": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "SliderStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "SliderStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "", + "handle_color": null + } + }, + "5146907ef6764654ad7d598baebc8b58": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "IntSliderModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "IntSliderModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "IntSliderView", + "continuous_update": true, + "description": "", + "description_tooltip": null, + "disabled": false, + "layout": "IPY_MODEL_19308ccac642498ab8b58462e3f1b0bb", + "max": 0, + "min": 0, + "orientation": "horizontal", + "readout": true, + "readout_format": "d", + "step": 1, + "style": "IPY_MODEL_4a081cdc2ec3421ca79dd933b7e2b0c4", + "value": 0 + } + }, + "561f742d418d4721b0670cc8dd62e22c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "6010fc8daa7a44d5aec4b830ec2ebaa1": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_7e0bb1b8d65249d3974200686b193be2", + "IPY_MODEL_ba98aa6d6a884e4ab8bbb5dfb5e4cf7a", + "IPY_MODEL_6526646be5ed415c84d1245b040e629b" + ], + "layout": "IPY_MODEL_24d31fc3576e43dd9f8301d2ef3a37ab" + } + }, + "65195cb7a4134f4887e9dd19f3676462": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ButtonStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ButtonStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "button_color": null, + "font_weight": "" + } + }, + "6526646be5ed415c84d1245b040e629b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_a9e388f22a9742aaaf538e22575c9433", + "placeholder": "​", + "style": "IPY_MODEL_42f6c3db29d7484ba6b4f73590abd2f4", + "value": " 401/401 [00:00<00:00, 13.5kB/s]" + } + }, + "695ab5bbf30a4ab19df1f9f33469f314": { + "model_module": "nglview-js-widgets", + "model_module_version": "3.0.1", + "model_name": "ColormakerRegistryModel", + "state": { + "_dom_classes": [], + "_model_module": "nglview-js-widgets", + "_model_module_version": "3.0.1", + "_model_name": "ColormakerRegistryModel", + "_msg_ar": [], + "_msg_q": [], + "_ready": false, + "_view_count": null, + "_view_module": "nglview-js-widgets", + "_view_module_version": "3.0.1", + "_view_name": "ColormakerRegistryView", + "layout": "IPY_MODEL_eac6a8dcdc9d4335a2e51031793ead29" + } + }, + "7e0bb1b8d65249d3974200686b193be2": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_2918bfaadc8d4b1a9832522c40dfefb8", + "placeholder": "​", + "style": "IPY_MODEL_a4bfdca35cc54dae8812720f1b276a08", + "value": "Downloading: 100%" + } + }, + "872915dd1bb84f538c44e26badabafdd": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "a4bfdca35cc54dae8812720f1b276a08": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "a9e388f22a9742aaaf538e22575c9433": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "abce2a80e6304df3899109c6d6cac199": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": "34px" + } + }, + "b7feb522161f4cf4b7cc7c1a078ff12d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_fdc393f3468c432aa0ada05e238a5436", + "placeholder": "​", + "style": "IPY_MODEL_2c9362906e4b40189f16d14aa9a348da", + "value": " 3.27M/3.27M [00:01<00:00, 3.25MB/s]" + } + }, + "ba98aa6d6a884e4ab8bbb5dfb5e4cf7a": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_e4901541199b45c6a18824627692fc39", + "max": 401, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_f915cf874246446595206221e900b2fe", + "value": 401 + } + }, + "bbef741e76ec41b7ab7187b487a383df": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "be446195da2b4ff2aec21ec5ff963a54": { + "model_module": "nglview-js-widgets", + "model_module_version": "3.0.1", + "model_name": "NGLModel", + "state": { + "_camera_orientation": [ + -15.519693580202304, + -14.065056548036177, + -23.53197484807691, + 0, + -23.357853515109753, + 20.94055073042662, + 2.888695042134944, + 0, + 14.352363398292775, + 18.870825741878015, + -20.744689572909344, + 0, + 0.2724999189376831, + 0.6940000057220459, + -0.3734999895095825, + 1 + ], + "_camera_str": "orthographic", + "_dom_classes": [], + "_gui_theme": null, + "_ibtn_fullscreen": "IPY_MODEL_2489b5e5648541fbbdceadb05632a050", + "_igui": null, + "_iplayer": "IPY_MODEL_01e0ba4e5da04914b4652b8d58565d7b", + "_model_module": "nglview-js-widgets", + "_model_module_version": "3.0.1", + "_model_name": "NGLModel", + "_ngl_color_dict": {}, + "_ngl_coordinate_resource": {}, + "_ngl_full_stage_parameters": { + "ambientColor": 14540253, + "ambientIntensity": 0.2, + "backgroundColor": "white", + "cameraEyeSep": 0.3, + "cameraFov": 40, + "cameraType": "perspective", + "clipDist": 10, + "clipFar": 100, + "clipNear": 0, + "fogFar": 100, + "fogNear": 50, + "hoverTimeout": 0, + "impostor": true, + "lightColor": 14540253, + "lightIntensity": 1, + "mousePreset": "default", + "panSpeed": 1, + "quality": "medium", + "rotateSpeed": 2, + "sampleLevel": 0, + "tooltip": true, + "workerDefault": true, + "zoomSpeed": 1.2 + }, + "_ngl_msg_archive": [ + { + "args": [ + { + "binary": false, + "data": "HETATM 1 C1 UNL 1 -0.025 3.128 2.316 1.00 0.00 C \nHETATM 2 H1 UNL 1 0.183 3.657 2.823 1.00 0.00 H \nHETATM 3 C2 UNL 1 0.590 3.559 0.963 1.00 0.00 C \nHETATM 4 C3 UNL 1 0.056 4.479 0.406 1.00 0.00 C \nHETATM 5 C4 UNL 1 -0.219 4.802 -1.065 1.00 0.00 C \nHETATM 6 H2 UNL 1 0.686 4.431 -1.575 1.00 0.00 H \nHETATM 7 H3 UNL 1 -0.524 5.217 -1.274 1.00 0.00 H \nHETATM 8 C5 UNL 1 -1.284 3.766 -1.342 1.00 0.00 C \nHETATM 9 N1 UNL 1 -1.073 2.494 -0.580 1.00 0.00 N \nHETATM 10 C6 UNL 1 -1.909 1.494 -0.964 1.00 0.00 C \nHETATM 11 O1 UNL 1 -2.487 1.531 -2.092 1.00 0.00 O \nHETATM 12 C7 UNL 1 -2.232 0.242 -0.130 1.00 0.00 C \nHETATM 13 C8 UNL 1 -2.161 -1.057 -1.037 1.00 0.00 C \nHETATM 14 C9 UNL 1 -0.744 -1.111 -1.610 1.00 0.00 C \nHETATM 15 N2 UNL 1 0.290 -0.917 -0.628 1.00 0.00 N \nHETATM 16 S1 UNL 1 1.717 -1.597 -0.914 1.00 0.00 S \nHETATM 17 O2 UNL 1 1.960 -1.671 -2.338 1.00 0.00 O \nHETATM 18 O3 UNL 1 2.713 -0.968 -0.082 1.00 0.00 O \nHETATM 19 C10 UNL 1 1.425 -3.170 -0.345 1.00 0.00 C \nHETATM 20 C11 UNL 1 1.225 -4.400 -1.271 1.00 0.00 C \nHETATM 21 C12 UNL 1 1.314 -5.913 -0.895 1.00 0.00 C \nHETATM 22 C13 UNL 1 1.823 -6.229 0.386 1.00 0.00 C \nHETATM 23 C14 UNL 1 2.031 -5.110 1.365 1.00 0.00 C \nHETATM 24 N3 UNL 1 1.850 -5.267 2.712 1.00 0.00 N \nHETATM 25 O4 UNL 1 1.382 -4.029 3.126 1.00 0.00 O \nHETATM 26 N4 UNL 1 1.300 -3.023 2.154 1.00 0.00 N \nHETATM 27 C15 UNL 1 1.731 -3.672 1.032 1.00 0.00 C \nHETATM 28 H4 UNL 1 2.380 -6.874 0.436 1.00 0.00 H \nHETATM 29 H5 UNL 1 0.704 -6.526 -1.420 1.00 0.00 H \nHETATM 30 H6 UNL 1 1.144 -4.035 -2.291 1.00 0.00 H \nHETATM 31 C16 UNL 1 0.044 -0.371 0.685 1.00 0.00 C \nHETATM 32 C17 UNL 1 -1.352 -0.045 1.077 1.00 0.00 C \nHETATM 33 H7 UNL 1 -1.395 0.770 1.768 1.00 0.00 H \nHETATM 34 H8 UNL 1 -1.792 -0.941 1.582 1.00 0.00 H \nHETATM 35 H9 UNL 1 0.583 -1.035 1.393 1.00 0.00 H \nHETATM 36 H10 UNL 1 0.664 0.613 0.663 1.00 0.00 H \nHETATM 37 H11 UNL 1 -0.631 -0.267 -2.335 1.00 0.00 H \nHETATM 38 H12 UNL 1 -0.571 -2.046 -2.098 1.00 0.00 H \nHETATM 39 H13 UNL 1 -2.872 -0.992 -1.826 1.00 0.00 H \nHETATM 40 H14 UNL 1 -2.370 -1.924 -0.444 1.00 0.00 H \nHETATM 41 H15 UNL 1 -3.258 0.364 0.197 1.00 0.00 H \nHETATM 42 C18 UNL 1 0.276 2.337 -0.078 1.00 0.00 C \nHETATM 43 H16 UNL 1 0.514 1.371 0.252 1.00 0.00 H \nHETATM 44 H17 UNL 1 0.988 2.413 -0.949 1.00 0.00 H \nHETATM 45 H18 UNL 1 -1.349 3.451 -2.379 1.00 0.00 H \nHETATM 46 H19 UNL 1 -2.224 4.055 -0.958 1.00 0.00 H \nHETATM 47 H20 UNL 1 0.793 5.486 0.669 1.00 0.00 H \nHETATM 48 H21 UNL 1 -0.849 4.974 0.937 1.00 0.00 H \nHETATM 49 H22 UNL 1 1.667 3.431 1.070 1.00 0.00 H \nHETATM 50 H23 UNL 1 0.379 2.143 2.689 1.00 0.00 H \nHETATM 51 H24 UNL 1 -1.094 2.983 2.223 1.00 0.00 H \nCONECT 1 2 3 50 51\nCONECT 3 4 42 49\nCONECT 4 5 47 48\nCONECT 5 6 7 8\nCONECT 8 9 45 46\nCONECT 9 10 42\nCONECT 10 11 11 12\nCONECT 12 13 32 41\nCONECT 13 14 39 40\nCONECT 14 15 37 38\nCONECT 15 16 31\nCONECT 16 17 17 18 18\nCONECT 16 19\nCONECT 19 20 20 27\nCONECT 20 21 30\nCONECT 21 22 22 29\nCONECT 22 23 28\nCONECT 23 24 24 27\nCONECT 24 25\nCONECT 25 26\nCONECT 26 27 27\nCONECT 31 32 35 36\nCONECT 32 33 34\nCONECT 42 43 44\nEND\n", + "type": "blob" + } + ], + "kwargs": { + "defaultRepresentation": true, + "ext": "pdb" + }, + "methodName": "loadFile", + "reconstruc_color_scheme": false, + "target": "Stage", + "type": "call_method" + } + ], + "_ngl_original_stage_parameters": { + "ambientColor": 14540253, + "ambientIntensity": 0.2, + "backgroundColor": "white", + "cameraEyeSep": 0.3, + "cameraFov": 40, + "cameraType": "perspective", + "clipDist": 10, + "clipFar": 100, + "clipNear": 0, + "fogFar": 100, + "fogNear": 50, + "hoverTimeout": 0, + "impostor": true, + "lightColor": 14540253, + "lightIntensity": 1, + "mousePreset": "default", + "panSpeed": 1, + "quality": "medium", + "rotateSpeed": 2, + "sampleLevel": 0, + "tooltip": true, + "workerDefault": true, + "zoomSpeed": 1.2 + }, + "_ngl_repr_dict": { + "0": { + "0": { + "params": { + "aspectRatio": 1.5, + "assembly": "default", + "bondScale": 0.3, + "bondSpacing": 0.75, + "clipCenter": { + "x": 0, + "y": 0, + "z": 0 + }, + "clipNear": 0, + "clipRadius": 0, + "colorMode": "hcl", + "colorReverse": false, + "colorScale": "", + "colorScheme": "element", + "colorValue": 9474192, + "cylinderOnly": false, + "defaultAssembly": "", + "depthWrite": true, + "diffuse": 16777215, + "diffuseInterior": false, + "disableImpostor": false, + "disablePicking": false, + "flatShaded": false, + "interiorColor": 2236962, + "interiorDarkening": 0, + "lazy": false, + "lineOnly": false, + "linewidth": 2, + "matrix": { + "elements": [ + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1 + ] + }, + "metalness": 0, + "multipleBond": "off", + "opacity": 1, + "openEnded": true, + "quality": "high", + "radialSegments": 20, + "radiusData": {}, + "radiusScale": 2, + "radiusSize": 0.15, + "radiusType": "size", + "roughness": 0.4, + "sele": "", + "side": "double", + "sphereDetail": 2, + "useInteriorColor": true, + "visible": true, + "wireframe": false + }, + "type": "ball+stick" + } + }, + "1": { + "0": { + "params": { + "aspectRatio": 1.5, + "assembly": "default", + "bondScale": 0.3, + "bondSpacing": 0.75, + "clipCenter": { + "x": 0, + "y": 0, + "z": 0 + }, + "clipNear": 0, + "clipRadius": 0, + "colorMode": "hcl", + "colorReverse": false, + "colorScale": "", + "colorScheme": "element", + "colorValue": 9474192, + "cylinderOnly": false, + "defaultAssembly": "", + "depthWrite": true, + "diffuse": 16777215, + "diffuseInterior": false, + "disableImpostor": false, + "disablePicking": false, + "flatShaded": false, + "interiorColor": 2236962, + "interiorDarkening": 0, + "lazy": false, + "lineOnly": false, + "linewidth": 2, + "matrix": { + "elements": [ + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1 + ] + }, + "metalness": 0, + "multipleBond": "off", + "opacity": 1, + "openEnded": true, + "quality": "high", + "radialSegments": 20, + "radiusData": {}, + "radiusScale": 2, + "radiusSize": 0.15, + "radiusType": "size", + "roughness": 0.4, + "sele": "", + "side": "double", + "sphereDetail": 2, + "useInteriorColor": true, + "visible": true, + "wireframe": false + }, + "type": "ball+stick" + } + } + }, + "_ngl_serialize": false, + "_ngl_version": "", + "_ngl_view_id": [ + "FB989FD1-5B9C-446B-8914-6B58AF85446D" + ], + "_player_dict": {}, + "_scene_position": {}, + "_scene_rotation": {}, + "_synced_model_ids": [], + "_synced_repr_model_ids": [], + "_view_count": null, + "_view_height": "", + "_view_module": "nglview-js-widgets", + "_view_module_version": "3.0.1", + "_view_name": "NGLView", + "_view_width": "", + "background": "white", + "frame": 0, + "gui_style": null, + "layout": "IPY_MODEL_c6596896148b4a8a9c57963b67c7782f", + "max_frame": 0, + "n_components": 2, + "picked": {} + } + }, + "c2321b3bff6f490ca12040a20308f555": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_872915dd1bb84f538c44e26badabafdd", + "max": 3271865, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_d022575f1fa2446d891650897f187b4d", + "value": 3271865 + } + }, + "c30e6c2f3e2a44dbbb3d63bd519acaa4": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "c6596896148b4a8a9c57963b67c7782f": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "d022575f1fa2446d891650897f187b4d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "d90f304e9560472eacfbdd11e46765eb": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_1c6246f15b654f4daa11c9bcf997b78c", + "IPY_MODEL_c2321b3bff6f490ca12040a20308f555", + "IPY_MODEL_b7feb522161f4cf4b7cc7c1a078ff12d" + ], + "layout": "IPY_MODEL_e2d368556e494ae7ae4e2e992af2cd4f" + } + }, + "e2d368556e494ae7ae4e2e992af2cd4f": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "e4901541199b45c6a18824627692fc39": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "e5c0d75eb5e1447abd560c8f2c6017e1": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "PlayModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "PlayModel", + "_playing": false, + "_repeat": false, + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "PlayView", + "description": "", + "description_tooltip": null, + "disabled": false, + "interval": 100, + "layout": "IPY_MODEL_c30e6c2f3e2a44dbbb3d63bd519acaa4", + "max": 0, + "min": 0, + "show_repeat": true, + "step": 1, + "style": "IPY_MODEL_f31c6e40e9b2466a9064a2669933ecd5", + "value": 0 + } + }, + "eac6a8dcdc9d4335a2e51031793ead29": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "f31c6e40e9b2466a9064a2669933ecd5": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "f915cf874246446595206221e900b2fe": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "fdc393f3468c432aa0ada05e238a5436": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + } + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/examples/research_projects/gligen/demo.ipynb b/examples/research_projects/gligen/demo.ipynb index 571f1a0323..315aee7105 100644 --- a/examples/research_projects/gligen/demo.ipynb +++ b/examples/research_projects/gligen/demo.ipynb @@ -26,8 +26,7 @@ "%load_ext autoreload\n", "%autoreload 2\n", "\n", - "import torch\n", - "from diffusers import StableDiffusionGLIGENTextImagePipeline, StableDiffusionGLIGENPipeline" + "from diffusers import StableDiffusionGLIGENPipeline" ] }, { @@ -36,28 +35,25 @@ "metadata": {}, "outputs": [], "source": [ - "import os\n", + "from transformers import CLIPTextModel, CLIPTokenizer\n", + "\n", "import diffusers\n", "from diffusers import (\n", " AutoencoderKL,\n", " DDPMScheduler,\n", - " UNet2DConditionModel,\n", - " UniPCMultistepScheduler,\n", " EulerDiscreteScheduler,\n", + " UNet2DConditionModel,\n", ")\n", - "from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer\n", + "\n", + "\n", "# pretrained_model_name_or_path = 'masterful/gligen-1-4-generation-text-box'\n", "\n", - "pretrained_model_name_or_path = '/root/data/zhizhonghuang/checkpoints/models--masterful--gligen-1-4-generation-text-box/snapshots/d2820dc1e9ba6ca082051ce79cfd3eb468ae2c83'\n", + "pretrained_model_name_or_path = \"/root/data/zhizhonghuang/checkpoints/models--masterful--gligen-1-4-generation-text-box/snapshots/d2820dc1e9ba6ca082051ce79cfd3eb468ae2c83\"\n", "\n", "tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder=\"tokenizer\")\n", "noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder=\"scheduler\")\n", - "text_encoder = CLIPTextModel.from_pretrained(\n", - " pretrained_model_name_or_path, subfolder=\"text_encoder\"\n", - ")\n", - "vae = AutoencoderKL.from_pretrained(\n", - " pretrained_model_name_or_path, subfolder=\"vae\"\n", - ")\n", + "text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder=\"text_encoder\")\n", + "vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder=\"vae\")\n", "# unet = UNet2DConditionModel.from_pretrained(\n", "# pretrained_model_name_or_path, subfolder=\"unet\"\n", "# )\n", @@ -71,9 +67,7 @@ "metadata": {}, "outputs": [], "source": [ - "unet = UNet2DConditionModel.from_pretrained(\n", - " '/root/data/zhizhonghuang/ckpt/GLIGEN_Text_Retrain_COCO'\n", - ")" + "unet = UNet2DConditionModel.from_pretrained(\"/root/data/zhizhonghuang/ckpt/GLIGEN_Text_Retrain_COCO\")" ] }, { @@ -108,6 +102,9 @@ "metadata": {}, "outputs": [], "source": [ + "import numpy as np\n", + "\n", + "\n", "# prompt = 'A realistic image of landscape scene depicting a green car parking on the left of a blue truck, with a red air balloon and a bird in the sky'\n", "# gen_boxes = [('a green car', [21, 281, 211, 159]), ('a blue truck', [269, 283, 209, 160]), ('a red air balloon', [66, 8, 145, 135]), ('a bird', [296, 42, 143, 100])]\n", "\n", @@ -117,10 +114,8 @@ "# prompt = 'A realistic scene of three skiers standing in a line on the snow near a palm tree'\n", "# gen_boxes = [('a skier', [5, 152, 139, 168]), ('a skier', [278, 192, 121, 158]), ('a skier', [148, 173, 124, 155]), ('a palm tree', [404, 105, 103, 251])]\n", "\n", - "prompt = 'An oil painting of a pink dolphin jumping on the left of a steam boat on the sea'\n", - "gen_boxes = [('a steam boat', [232, 225, 257, 149]), ('a jumping pink dolphin', [21, 249, 189, 123])]\n", - "\n", - "import numpy as np\n", + "prompt = \"An oil painting of a pink dolphin jumping on the left of a steam boat on the sea\"\n", + "gen_boxes = [(\"a steam boat\", [232, 225, 257, 149]), (\"a jumping pink dolphin\", [21, 249, 189, 123])]\n", "\n", "boxes = np.array([x[1] for x in gen_boxes])\n", "boxes = boxes / 512\n", @@ -166,7 +161,7 @@ "metadata": {}, "outputs": [], "source": [ - "diffusers.utils.make_image_grid(images, 4, len(images)//4)" + "diffusers.utils.make_image_grid(images, 4, len(images) // 4)" ] }, { @@ -179,7 +174,7 @@ ], "metadata": { "kernelspec": { - "display_name": "densecaption", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -197,5 +192,5 @@ } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py b/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py index 1d9203be7e..f94b1dd6b5 100644 --- a/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py +++ b/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py @@ -15,8 +15,8 @@ # limitations under the License. """ - Script to fine-tune Stable Diffusion for LORA InstructPix2Pix. - Base code referred from: https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py +Script to fine-tune Stable Diffusion for LORA InstructPix2Pix. +Base code referred from: https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py """ import argparse diff --git a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py index 0f507b26d6..57c555e43f 100644 --- a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py +++ b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py @@ -763,9 +763,9 @@ def main(args): # Parse instance and class inputs, and double check that lengths match instance_data_dir = args.instance_data_dir.split(",") instance_prompt = args.instance_prompt.split(",") - assert all( - x == len(instance_data_dir) for x in [len(instance_data_dir), len(instance_prompt)] - ), "Instance data dir and prompt inputs are not of the same length." + assert all(x == len(instance_data_dir) for x in [len(instance_data_dir), len(instance_prompt)]), ( + "Instance data dir and prompt inputs are not of the same length." + ) if args.with_prior_preservation: class_data_dir = args.class_data_dir.split(",") @@ -788,9 +788,9 @@ def main(args): negative_validation_prompts.append(None) args.validation_negative_prompt = negative_validation_prompts - assert num_of_validation_prompts == len( - negative_validation_prompts - ), "The length of negative prompts for validation is greater than the number of validation prompts." + assert num_of_validation_prompts == len(negative_validation_prompts), ( + "The length of negative prompts for validation is greater than the number of validation prompts." + ) args.validation_inference_steps = [args.validation_inference_steps] * num_of_validation_prompts args.validation_guidance_scale = [args.validation_guidance_scale] * num_of_validation_prompts diff --git a/examples/research_projects/multi_token_textual_inversion/textual_inversion.py b/examples/research_projects/multi_token_textual_inversion/textual_inversion.py index 19432142f5..75dcfccbd5 100644 --- a/examples/research_projects/multi_token_textual_inversion/textual_inversion.py +++ b/examples/research_projects/multi_token_textual_inversion/textual_inversion.py @@ -830,9 +830,9 @@ def main(): # Let's make sure we don't update any embedding weights besides the newly added token index_no_updates = get_mask(tokenizer, accelerator) with torch.no_grad(): - accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[ - index_no_updates - ] = orig_embeds_params[index_no_updates] + accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = ( + orig_embeds_params[index_no_updates] + ) # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: diff --git a/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py b/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py index 7f5dc8ece9..a881b06a94 100644 --- a/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py +++ b/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py @@ -886,9 +886,9 @@ def main(): index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False with torch.no_grad(): - accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[ - index_no_updates - ] = orig_embeds_params[index_no_updates] + accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = ( + orig_embeds_params[index_no_updates] + ) # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: diff --git a/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py b/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py index 19c1f30d82..51668a61cd 100644 --- a/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py +++ b/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py @@ -663,8 +663,7 @@ class PromptDiffusionPipeline( self.check_image(image, prompt, prompt_embeds) else: raise ValueError( - f"You have passed a list of images of length {len(image_pair)}." - f"Make sure the list size equals to two." + f"You have passed a list of images of length {len(image_pair)}.Make sure the list size equals to two." ) # Check `controlnet_conditioning_scale` diff --git a/examples/research_projects/pytorch_xla/training/text_to_image/train_text_to_image_xla.py b/examples/research_projects/pytorch_xla/training/text_to_image/train_text_to_image_xla.py index 9719585d3d..6ae1a9a6c6 100644 --- a/examples/research_projects/pytorch_xla/training/text_to_image/train_text_to_image_xla.py +++ b/examples/research_projects/pytorch_xla/training/text_to_image/train_text_to_image_xla.py @@ -173,7 +173,7 @@ class TrainSD: if not dataloader_exception: xm.wait_device_ops() total_time = time.time() - last_time - print(f"Average step time: {total_time/(self.args.max_train_steps-measure_start_step)}") + print(f"Average step time: {total_time / (self.args.max_train_steps - measure_start_step)}") else: print("dataloader exception happen, skip result") return @@ -622,7 +622,7 @@ def main(args): num_devices_per_host = num_devices // num_hosts if xm.is_master_ordinal(): print("***** Running training *****") - print(f"Instantaneous batch size per device = {args.train_batch_size // num_devices_per_host }") + print(f"Instantaneous batch size per device = {args.train_batch_size // num_devices_per_host}") print( f"Total train batch size (w. parallel, distributed & accumulation) = {args.train_batch_size * num_hosts}" ) diff --git a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth.py b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth.py index 26caba5a42..043f913893 100644 --- a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth.py +++ b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth.py @@ -1057,7 +1057,7 @@ def main(args): if args.train_text_encoder and unwrap_model(text_encoder).dtype != torch.float32: raise ValueError( - f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}." f" {low_precision_error_string}" + f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}. {low_precision_error_string}" ) # Enable TF32 for faster training on Ampere GPUs, diff --git a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora.py b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora.py index 410cd74a5b..393f991387 100644 --- a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora.py +++ b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora.py @@ -1021,7 +1021,7 @@ def main(args): lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir) - unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} + unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")} unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") diff --git a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py index c02a59a007..01ef67a55d 100644 --- a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py @@ -118,7 +118,7 @@ def save_model_card( ) model_description = f""" -# {'SDXL' if 'playground' not in base_model else 'Playground'} LoRA DreamBooth - {repo_id} +# {"SDXL" if "playground" not in base_model else "Playground"} LoRA DreamBooth - {repo_id} @@ -1336,7 +1336,7 @@ def main(args): lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir) - unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} + unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")} unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") if incompatible_keys is not None: diff --git a/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora_sdxl.py b/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora_sdxl.py index abc4399126..c87f50e272 100644 --- a/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora_sdxl.py @@ -750,7 +750,7 @@ def main(args): raise ValueError(f"unexpected save model: {model.__class__}") lora_state_dict, _ = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir) - unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} + unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")} unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") if incompatible_keys is not None: diff --git a/examples/research_projects/sd3_lora_colab/train_dreambooth_lora_sd3_miniature.py b/examples/research_projects/sd3_lora_colab/train_dreambooth_lora_sd3_miniature.py index f5bee58d45..ebb9b129db 100644 --- a/examples/research_projects/sd3_lora_colab/train_dreambooth_lora_sd3_miniature.py +++ b/examples/research_projects/sd3_lora_colab/train_dreambooth_lora_sd3_miniature.py @@ -765,7 +765,7 @@ def main(args): lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir) transformer_state_dict = { - f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") + f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.") } transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index 2061f0c677..539d4a6575 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -767,7 +767,7 @@ def main(args): raise ValueError(f"unexpected save model: {model.__class__}") lora_state_dict, _ = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir) - unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} + unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")} unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") if incompatible_keys is not None: diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 757a12045f..51e220828c 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -910,9 +910,9 @@ def main(): index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False with torch.no_grad(): - accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[ - index_no_updates - ] = orig_embeds_params[index_no_updates] + accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = ( + orig_embeds_params[index_no_updates] + ) # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: diff --git a/examples/textual_inversion/textual_inversion_sdxl.py b/examples/textual_inversion/textual_inversion_sdxl.py index 11463943c4..f32c729195 100644 --- a/examples/textual_inversion/textual_inversion_sdxl.py +++ b/examples/textual_inversion/textual_inversion_sdxl.py @@ -965,12 +965,12 @@ def main(): index_no_updates_2[min(placeholder_token_ids_2) : max(placeholder_token_ids_2) + 1] = False with torch.no_grad(): - accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight[ - index_no_updates - ] = orig_embeds_params[index_no_updates] - accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight[ - index_no_updates_2 - ] = orig_embeds_params_2[index_no_updates_2] + accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight[index_no_updates] = ( + orig_embeds_params[index_no_updates] + ) + accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight[index_no_updates_2] = ( + orig_embeds_params_2[index_no_updates_2] + ) # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: diff --git a/examples/vqgan/test_vqgan.py b/examples/vqgan/test_vqgan.py index aa5d4c67b6..d13e102e78 100644 --- a/examples/vqgan/test_vqgan.py +++ b/examples/vqgan/test_vqgan.py @@ -177,7 +177,7 @@ class TextToImage(ExamplesTestsAccelerate): --model_config_name_or_path {vqmodel_config_path} --discriminator_config_name_or_path {discriminator_config_path} --checkpointing_steps=1 - --resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')} + --resume_from_checkpoint={os.path.join(tmpdir, "checkpoint-4")} --output_dir {tmpdir} --seed=0 """.split() @@ -262,7 +262,7 @@ class TextToImage(ExamplesTestsAccelerate): --model_config_name_or_path {vqmodel_config_path} --discriminator_config_name_or_path {discriminator_config_path} --checkpointing_steps=1 - --resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')} + --resume_from_checkpoint={os.path.join(tmpdir, "checkpoint-4")} --output_dir {tmpdir} --use_ema --seed=0 @@ -377,7 +377,7 @@ class TextToImage(ExamplesTestsAccelerate): --discriminator_config_name_or_path {discriminator_config_path} --output_dir {tmpdir} --checkpointing_steps=2 - --resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')} + --resume_from_checkpoint={os.path.join(tmpdir, "checkpoint-4")} --checkpoints_total_limit=2 --seed=0 """.split() diff --git a/examples/vqgan/train_vqgan.py b/examples/vqgan/train_vqgan.py index 992722fa7a..33d234da52 100644 --- a/examples/vqgan/train_vqgan.py +++ b/examples/vqgan/train_vqgan.py @@ -653,15 +653,15 @@ def main(): try: # Gets the resolution of the timm transformation after centercrop timm_centercrop_transform = timm_transform.transforms[1] - assert isinstance( - timm_centercrop_transform, transforms.CenterCrop - ), f"Timm model {timm_model} is currently incompatible with this script. Try vgg19." + assert isinstance(timm_centercrop_transform, transforms.CenterCrop), ( + f"Timm model {timm_model} is currently incompatible with this script. Try vgg19." + ) timm_model_resolution = timm_centercrop_transform.size[0] # Gets final normalization timm_model_normalization = timm_transform.transforms[-1] - assert isinstance( - timm_model_normalization, transforms.Normalize - ), f"Timm model {timm_model} is currently incompatible with this script. Try vgg19." + assert isinstance(timm_model_normalization, transforms.Normalize), ( + f"Timm model {timm_model} is currently incompatible with this script. Try vgg19." + ) except AssertionError as e: raise NotImplementedError(e) # Enable flash attention if asked diff --git a/pyproject.toml b/pyproject.toml index 299865a122..a864ea34b8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ line-length = 119 [tool.ruff.lint] # Never enforce `E501` (line length violations). -ignore = ["C901", "E501", "E741", "F402", "F823"] +ignore = ["C901", "E501", "E721", "E741", "F402", "F823"] select = ["C", "E", "F", "I", "W"] # Ignore import violations in all `__init__.py` files. diff --git a/scripts/convert_amused.py b/scripts/convert_amused.py index 21be29dfdb..ddd1bf508b 100644 --- a/scripts/convert_amused.py +++ b/scripts/convert_amused.py @@ -468,7 +468,7 @@ def make_vqvae(old_vae): # assert (old_output == new_output).all() print("skipping full vae equivalence check") - print(f"vae full diff { (old_output - new_output).float().abs().sum()}") + print(f"vae full diff {(old_output - new_output).float().abs().sum()}") return new_vae diff --git a/scripts/convert_consistency_to_diffusers.py b/scripts/convert_consistency_to_diffusers.py index 0f8b4ddca8..2b918280ca 100644 --- a/scripts/convert_consistency_to_diffusers.py +++ b/scripts/convert_consistency_to_diffusers.py @@ -239,7 +239,7 @@ def con_pt_to_diffuser(checkpoint_path: str, unet_config): if i != len(up_block_types) - 1: new_prefix = f"up_blocks.{i}.upsamplers.0" - old_prefix = f"output_blocks.{current_layer-1}.1" + old_prefix = f"output_blocks.{current_layer - 1}.1" new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix) elif layer_type == "AttnUpBlock2D": for j in range(layers_per_block + 1): @@ -255,7 +255,7 @@ def con_pt_to_diffuser(checkpoint_path: str, unet_config): if i != len(up_block_types) - 1: new_prefix = f"up_blocks.{i}.upsamplers.0" - old_prefix = f"output_blocks.{current_layer-1}.2" + old_prefix = f"output_blocks.{current_layer - 1}.2" new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix) new_checkpoint["conv_norm_out.weight"] = checkpoint["out.0.weight"] diff --git a/scripts/convert_dance_diffusion_to_diffusers.py b/scripts/convert_dance_diffusion_to_diffusers.py index f9caa50dfc..e269a49070 100755 --- a/scripts/convert_dance_diffusion_to_diffusers.py +++ b/scripts/convert_dance_diffusion_to_diffusers.py @@ -261,9 +261,9 @@ def main(args): model_name = args.model_path.split("/")[-1].split(".")[0] if not os.path.isfile(args.model_path): - assert ( - model_name == args.model_path - ), f"Make sure to provide one of the official model names {MODELS_MAP.keys()}" + assert model_name == args.model_path, ( + f"Make sure to provide one of the official model names {MODELS_MAP.keys()}" + ) args.model_path = download(model_name) sample_rate = MODELS_MAP[model_name]["sample_rate"] @@ -290,9 +290,9 @@ def main(args): assert all(k.endswith("kernel") for k in list(diffusers_minus_renamed)), f"Problem with {diffusers_minus_renamed}" for key, value in renamed_state_dict.items(): - assert ( - diffusers_state_dict[key].squeeze().shape == value.squeeze().shape - ), f"Shape for {key} doesn't match. Diffusers: {diffusers_state_dict[key].shape} vs. {value.shape}" + assert diffusers_state_dict[key].squeeze().shape == value.squeeze().shape, ( + f"Shape for {key} doesn't match. Diffusers: {diffusers_state_dict[key].shape} vs. {value.shape}" + ) if key == "time_proj.weight": value = value.squeeze() diff --git a/scripts/convert_diffusers_to_original_sdxl.py b/scripts/convert_diffusers_to_original_sdxl.py index 648d0376f7..1aa792b3f0 100644 --- a/scripts/convert_diffusers_to_original_sdxl.py +++ b/scripts/convert_diffusers_to_original_sdxl.py @@ -52,18 +52,18 @@ for i in range(3): for j in range(2): # loop over resnets/attentions for downblocks hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." - sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." + sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0." unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) if i > 0: hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." - sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." + sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1." unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) for j in range(4): # loop over resnets/attentions for upblocks hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." - sd_up_res_prefix = f"output_blocks.{3*i + j}.0." + sd_up_res_prefix = f"output_blocks.{3 * i + j}.0." unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) if i < 2: @@ -75,12 +75,12 @@ for i in range(3): if i < 3: # no downsample in down_blocks.3 hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." - sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." + sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op." unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) # no upsample in up_blocks.3 hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." - sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." + sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}." unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) unet_conversion_map_layer.append(("output_blocks.2.2.conv.", "output_blocks.2.1.conv.")) @@ -89,7 +89,7 @@ sd_mid_atn_prefix = "middle_block.1." unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) for j in range(2): hf_mid_res_prefix = f"mid_block.resnets.{j}." - sd_mid_res_prefix = f"middle_block.{2*j}." + sd_mid_res_prefix = f"middle_block.{2 * j}." unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) @@ -137,20 +137,20 @@ for i in range(4): vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." - sd_upsample_prefix = f"up.{3-i}.upsample." + sd_upsample_prefix = f"up.{3 - i}.upsample." vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) # up_blocks have three resnets # also, up blocks in hf are numbered in reverse from sd for j in range(3): hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." - sd_up_prefix = f"decoder.up.{3-i}.block.{j}." + sd_up_prefix = f"decoder.up.{3 - i}.block.{j}." vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) # this part accounts for mid blocks in both the encoder and the decoder for i in range(2): hf_mid_res_prefix = f"mid_block.resnets.{i}." - sd_mid_res_prefix = f"mid.block_{i+1}." + sd_mid_res_prefix = f"mid.block_{i + 1}." vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) diff --git a/scripts/convert_diffusers_to_original_stable_diffusion.py b/scripts/convert_diffusers_to_original_stable_diffusion.py index d1b7df070c..049dda7d42 100644 --- a/scripts/convert_diffusers_to_original_stable_diffusion.py +++ b/scripts/convert_diffusers_to_original_stable_diffusion.py @@ -47,36 +47,36 @@ for i in range(4): for j in range(2): # loop over resnets/attentions for downblocks hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." - sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." + sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0." unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) if i < 3: # no attention layers in down_blocks.3 hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." - sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." + sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1." unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) for j in range(3): # loop over resnets/attentions for upblocks hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." - sd_up_res_prefix = f"output_blocks.{3*i + j}.0." + sd_up_res_prefix = f"output_blocks.{3 * i + j}.0." unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) if i > 0: # no attention layers in up_blocks.0 hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." - sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." + sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1." unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) if i < 3: # no downsample in down_blocks.3 hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." - sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." + sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op." unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) # no upsample in up_blocks.3 hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." - sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." + sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}." unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) hf_mid_atn_prefix = "mid_block.attentions.0." @@ -85,7 +85,7 @@ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) for j in range(2): hf_mid_res_prefix = f"mid_block.resnets.{j}." - sd_mid_res_prefix = f"middle_block.{2*j}." + sd_mid_res_prefix = f"middle_block.{2 * j}." unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) @@ -133,20 +133,20 @@ for i in range(4): vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." - sd_upsample_prefix = f"up.{3-i}.upsample." + sd_upsample_prefix = f"up.{3 - i}.upsample." vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) # up_blocks have three resnets # also, up blocks in hf are numbered in reverse from sd for j in range(3): hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." - sd_up_prefix = f"decoder.up.{3-i}.block.{j}." + sd_up_prefix = f"decoder.up.{3 - i}.block.{j}." vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) # this part accounts for mid blocks in both the encoder and the decoder for i in range(2): hf_mid_res_prefix = f"mid_block.resnets.{i}." - sd_mid_res_prefix = f"mid.block_{i+1}." + sd_mid_res_prefix = f"mid.block_{i + 1}." vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) diff --git a/scripts/convert_hunyuandit_controlnet_to_diffusers.py b/scripts/convert_hunyuandit_controlnet_to_diffusers.py index 1c83836908..5cef46c989 100644 --- a/scripts/convert_hunyuandit_controlnet_to_diffusers.py +++ b/scripts/convert_hunyuandit_controlnet_to_diffusers.py @@ -21,9 +21,9 @@ def main(args): model_config = HunyuanDiT2DControlNetModel.load_config( "Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", subfolder="transformer" ) - model_config[ - "use_style_cond_and_image_meta_size" - ] = args.use_style_cond_and_image_meta_size ### version <= v1.1: True; version >= v1.2: False + model_config["use_style_cond_and_image_meta_size"] = ( + args.use_style_cond_and_image_meta_size + ) ### version <= v1.1: True; version >= v1.2: False print(model_config) for key in state_dict: diff --git a/scripts/convert_hunyuandit_to_diffusers.py b/scripts/convert_hunyuandit_to_diffusers.py index da3af8333e..65fcccb22a 100644 --- a/scripts/convert_hunyuandit_to_diffusers.py +++ b/scripts/convert_hunyuandit_to_diffusers.py @@ -13,15 +13,14 @@ def main(args): state_dict = state_dict[args.load_key] except KeyError: raise KeyError( - f"{args.load_key} not found in the checkpoint." - f"Please load from the following keys:{state_dict.keys()}" + f"{args.load_key} not found in the checkpoint.Please load from the following keys:{state_dict.keys()}" ) device = "cuda" model_config = HunyuanDiT2DModel.load_config("Tencent-Hunyuan/HunyuanDiT-Diffusers", subfolder="transformer") - model_config[ - "use_style_cond_and_image_meta_size" - ] = args.use_style_cond_and_image_meta_size ### version <= v1.1: True; version >= v1.2: False + model_config["use_style_cond_and_image_meta_size"] = ( + args.use_style_cond_and_image_meta_size + ) ### version <= v1.1: True; version >= v1.2: False # input_size -> sample_size, text_dim -> cross_attention_dim for key in state_dict: diff --git a/scripts/convert_k_upscaler_to_diffusers.py b/scripts/convert_k_upscaler_to_diffusers.py index 62abedd737..cff845ef80 100644 --- a/scripts/convert_k_upscaler_to_diffusers.py +++ b/scripts/convert_k_upscaler_to_diffusers.py @@ -142,14 +142,14 @@ def block_to_diffusers_checkpoint(block, checkpoint, block_idx, block_type): diffusers_attention_prefix = f"{block_type}_blocks.{block_idx}.attentions.{attention_idx}" idx = n * attention_idx + 1 if block_type == "up" else n * attention_idx + 2 self_attention_prefix = f"{block_prefix}.{idx}" - cross_attention_prefix = f"{block_prefix}.{idx }" + cross_attention_prefix = f"{block_prefix}.{idx}" cross_attention_index = 1 if not attention.add_self_attention else 2 idx = ( n * attention_idx + cross_attention_index if block_type == "up" else n * attention_idx + cross_attention_index + 1 ) - cross_attention_prefix = f"{block_prefix}.{idx }" + cross_attention_prefix = f"{block_prefix}.{idx}" diffusers_checkpoint.update( cross_attn_to_diffusers_checkpoint( @@ -220,9 +220,9 @@ def unet_model_from_original_config(original_config): block_out_channels = original_config["channels"] - assert ( - len(set(original_config["depths"])) == 1 - ), "UNet2DConditionModel currently do not support blocks with different number of layers" + assert len(set(original_config["depths"])) == 1, ( + "UNet2DConditionModel currently do not support blocks with different number of layers" + ) layers_per_block = original_config["depths"][0] class_labels_dim = original_config["mapping_cond_dim"] diff --git a/scripts/convert_mochi_to_diffusers.py b/scripts/convert_mochi_to_diffusers.py index 9727deeb6b..64e4f69eac 100644 --- a/scripts/convert_mochi_to_diffusers.py +++ b/scripts/convert_mochi_to_diffusers.py @@ -168,28 +168,28 @@ def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_pa # Convert block_in (MochiMidBlock3D) for i in range(3): # layers_per_block[-1] = 3 new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight"] = decoder_state_dict.pop( - f"blocks.0.{i+1}.stack.0.weight" + f"blocks.0.{i + 1}.stack.0.weight" ) new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias"] = decoder_state_dict.pop( - f"blocks.0.{i+1}.stack.0.bias" + f"blocks.0.{i + 1}.stack.0.bias" ) new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.weight"] = decoder_state_dict.pop( - f"blocks.0.{i+1}.stack.2.weight" + f"blocks.0.{i + 1}.stack.2.weight" ) new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.bias"] = decoder_state_dict.pop( - f"blocks.0.{i+1}.stack.2.bias" + f"blocks.0.{i + 1}.stack.2.bias" ) new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight"] = decoder_state_dict.pop( - f"blocks.0.{i+1}.stack.3.weight" + f"blocks.0.{i + 1}.stack.3.weight" ) new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias"] = decoder_state_dict.pop( - f"blocks.0.{i+1}.stack.3.bias" + f"blocks.0.{i + 1}.stack.3.bias" ) new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.weight"] = decoder_state_dict.pop( - f"blocks.0.{i+1}.stack.5.weight" + f"blocks.0.{i + 1}.stack.5.weight" ) new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.bias"] = decoder_state_dict.pop( - f"blocks.0.{i+1}.stack.5.bias" + f"blocks.0.{i + 1}.stack.5.bias" ) # Convert up_blocks (MochiUpBlock3D) @@ -197,33 +197,35 @@ def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_pa for block in range(3): for i in range(down_block_layers[block]): new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"] = decoder_state_dict.pop( - f"blocks.{block+1}.blocks.{i}.stack.0.weight" + f"blocks.{block + 1}.blocks.{i}.stack.0.weight" ) new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.bias"] = decoder_state_dict.pop( - f"blocks.{block+1}.blocks.{i}.stack.0.bias" + f"blocks.{block + 1}.blocks.{i}.stack.0.bias" ) new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.weight"] = decoder_state_dict.pop( - f"blocks.{block+1}.blocks.{i}.stack.2.weight" + f"blocks.{block + 1}.blocks.{i}.stack.2.weight" ) new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.bias"] = decoder_state_dict.pop( - f"blocks.{block+1}.blocks.{i}.stack.2.bias" + f"blocks.{block + 1}.blocks.{i}.stack.2.bias" ) new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"] = decoder_state_dict.pop( - f"blocks.{block+1}.blocks.{i}.stack.3.weight" + f"blocks.{block + 1}.blocks.{i}.stack.3.weight" ) new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.bias"] = decoder_state_dict.pop( - f"blocks.{block+1}.blocks.{i}.stack.3.bias" + f"blocks.{block + 1}.blocks.{i}.stack.3.bias" ) new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.weight"] = decoder_state_dict.pop( - f"blocks.{block+1}.blocks.{i}.stack.5.weight" + f"blocks.{block + 1}.blocks.{i}.stack.5.weight" ) new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.bias"] = decoder_state_dict.pop( - f"blocks.{block+1}.blocks.{i}.stack.5.bias" + f"blocks.{block + 1}.blocks.{i}.stack.5.bias" ) new_state_dict[f"{prefix}up_blocks.{block}.proj.weight"] = decoder_state_dict.pop( - f"blocks.{block+1}.proj.weight" + f"blocks.{block + 1}.proj.weight" + ) + new_state_dict[f"{prefix}up_blocks.{block}.proj.bias"] = decoder_state_dict.pop( + f"blocks.{block + 1}.proj.bias" ) - new_state_dict[f"{prefix}up_blocks.{block}.proj.bias"] = decoder_state_dict.pop(f"blocks.{block+1}.proj.bias") # Convert block_out (MochiMidBlock3D) for i in range(3): # layers_per_block[0] = 3 @@ -267,133 +269,133 @@ def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_pa # Convert block_in (MochiMidBlock3D) for i in range(3): # layers_per_block[0] = 3 new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight"] = encoder_state_dict.pop( - f"layers.{i+1}.stack.0.weight" + f"layers.{i + 1}.stack.0.weight" ) new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop( - f"layers.{i+1}.stack.0.bias" + f"layers.{i + 1}.stack.0.bias" ) new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop( - f"layers.{i+1}.stack.2.weight" + f"layers.{i + 1}.stack.2.weight" ) new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop( - f"layers.{i+1}.stack.2.bias" + f"layers.{i + 1}.stack.2.bias" ) new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight"] = encoder_state_dict.pop( - f"layers.{i+1}.stack.3.weight" + f"layers.{i + 1}.stack.3.weight" ) new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop( - f"layers.{i+1}.stack.3.bias" + f"layers.{i + 1}.stack.3.bias" ) new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop( - f"layers.{i+1}.stack.5.weight" + f"layers.{i + 1}.stack.5.weight" ) new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop( - f"layers.{i+1}.stack.5.bias" + f"layers.{i + 1}.stack.5.bias" ) # Convert down_blocks (MochiDownBlock3D) down_block_layers = [3, 4, 6] # layers_per_block[1], layers_per_block[2], layers_per_block[3] for block in range(3): new_state_dict[f"{prefix}down_blocks.{block}.conv_in.conv.weight"] = encoder_state_dict.pop( - f"layers.{block+4}.layers.0.weight" + f"layers.{block + 4}.layers.0.weight" ) new_state_dict[f"{prefix}down_blocks.{block}.conv_in.conv.bias"] = encoder_state_dict.pop( - f"layers.{block+4}.layers.0.bias" + f"layers.{block + 4}.layers.0.bias" ) for i in range(down_block_layers[block]): # Convert resnets - new_state_dict[ - f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.weight" - ] = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.stack.0.weight") + new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"] = ( + encoder_state_dict.pop(f"layers.{block + 4}.layers.{i + 1}.stack.0.weight") + ) new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop( - f"layers.{block+4}.layers.{i+1}.stack.0.bias" + f"layers.{block + 4}.layers.{i + 1}.stack.0.bias" ) new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop( - f"layers.{block+4}.layers.{i+1}.stack.2.weight" + f"layers.{block + 4}.layers.{i + 1}.stack.2.weight" ) new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop( - f"layers.{block+4}.layers.{i+1}.stack.2.bias" + f"layers.{block + 4}.layers.{i + 1}.stack.2.bias" + ) + new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"] = ( + encoder_state_dict.pop(f"layers.{block + 4}.layers.{i + 1}.stack.3.weight") ) - new_state_dict[ - f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.weight" - ] = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.stack.3.weight") new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop( - f"layers.{block+4}.layers.{i+1}.stack.3.bias" + f"layers.{block + 4}.layers.{i + 1}.stack.3.bias" ) new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop( - f"layers.{block+4}.layers.{i+1}.stack.5.weight" + f"layers.{block + 4}.layers.{i + 1}.stack.5.weight" ) new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop( - f"layers.{block+4}.layers.{i+1}.stack.5.bias" + f"layers.{block + 4}.layers.{i + 1}.stack.5.bias" ) # Convert attentions - qkv_weight = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.attn_block.attn.qkv.weight") + qkv_weight = encoder_state_dict.pop(f"layers.{block + 4}.layers.{i + 1}.attn_block.attn.qkv.weight") q, k, v = qkv_weight.chunk(3, dim=0) new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_q.weight"] = q new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_k.weight"] = k new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_v.weight"] = v new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_out.0.weight"] = encoder_state_dict.pop( - f"layers.{block+4}.layers.{i+1}.attn_block.attn.out.weight" + f"layers.{block + 4}.layers.{i + 1}.attn_block.attn.out.weight" ) new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_out.0.bias"] = encoder_state_dict.pop( - f"layers.{block+4}.layers.{i+1}.attn_block.attn.out.bias" + f"layers.{block + 4}.layers.{i + 1}.attn_block.attn.out.bias" ) new_state_dict[f"{prefix}down_blocks.{block}.norms.{i}.norm_layer.weight"] = encoder_state_dict.pop( - f"layers.{block+4}.layers.{i+1}.attn_block.norm.weight" + f"layers.{block + 4}.layers.{i + 1}.attn_block.norm.weight" ) new_state_dict[f"{prefix}down_blocks.{block}.norms.{i}.norm_layer.bias"] = encoder_state_dict.pop( - f"layers.{block+4}.layers.{i+1}.attn_block.norm.bias" + f"layers.{block + 4}.layers.{i + 1}.attn_block.norm.bias" ) # Convert block_out (MochiMidBlock3D) for i in range(3): # layers_per_block[-1] = 3 # Convert resnets new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.weight"] = encoder_state_dict.pop( - f"layers.{i+7}.stack.0.weight" + f"layers.{i + 7}.stack.0.weight" ) new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop( - f"layers.{i+7}.stack.0.bias" + f"layers.{i + 7}.stack.0.bias" ) new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop( - f"layers.{i+7}.stack.2.weight" + f"layers.{i + 7}.stack.2.weight" ) new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop( - f"layers.{i+7}.stack.2.bias" + f"layers.{i + 7}.stack.2.bias" ) new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.weight"] = encoder_state_dict.pop( - f"layers.{i+7}.stack.3.weight" + f"layers.{i + 7}.stack.3.weight" ) new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop( - f"layers.{i+7}.stack.3.bias" + f"layers.{i + 7}.stack.3.bias" ) new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop( - f"layers.{i+7}.stack.5.weight" + f"layers.{i + 7}.stack.5.weight" ) new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop( - f"layers.{i+7}.stack.5.bias" + f"layers.{i + 7}.stack.5.bias" ) # Convert attentions - qkv_weight = encoder_state_dict.pop(f"layers.{i+7}.attn_block.attn.qkv.weight") + qkv_weight = encoder_state_dict.pop(f"layers.{i + 7}.attn_block.attn.qkv.weight") q, k, v = qkv_weight.chunk(3, dim=0) new_state_dict[f"{prefix}block_out.attentions.{i}.to_q.weight"] = q new_state_dict[f"{prefix}block_out.attentions.{i}.to_k.weight"] = k new_state_dict[f"{prefix}block_out.attentions.{i}.to_v.weight"] = v new_state_dict[f"{prefix}block_out.attentions.{i}.to_out.0.weight"] = encoder_state_dict.pop( - f"layers.{i+7}.attn_block.attn.out.weight" + f"layers.{i + 7}.attn_block.attn.out.weight" ) new_state_dict[f"{prefix}block_out.attentions.{i}.to_out.0.bias"] = encoder_state_dict.pop( - f"layers.{i+7}.attn_block.attn.out.bias" + f"layers.{i + 7}.attn_block.attn.out.bias" ) new_state_dict[f"{prefix}block_out.norms.{i}.norm_layer.weight"] = encoder_state_dict.pop( - f"layers.{i+7}.attn_block.norm.weight" + f"layers.{i + 7}.attn_block.norm.weight" ) new_state_dict[f"{prefix}block_out.norms.{i}.norm_layer.bias"] = encoder_state_dict.pop( - f"layers.{i+7}.attn_block.norm.bias" + f"layers.{i + 7}.attn_block.norm.bias" ) # Convert output layers diff --git a/scripts/convert_original_audioldm2_to_diffusers.py b/scripts/convert_original_audioldm2_to_diffusers.py index 1dc7d739ea..2c0695ce55 100644 --- a/scripts/convert_original_audioldm2_to_diffusers.py +++ b/scripts/convert_original_audioldm2_to_diffusers.py @@ -662,7 +662,7 @@ def convert_open_clap_checkpoint(checkpoint): # replace sequential layers with list sequential_layer = re.match(sequential_layers_pattern, key).group(1) - key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer)//3}.linear.") + key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer) // 3}.linear.") elif re.match(text_projection_pattern, key): projecton_layer = int(re.match(text_projection_pattern, key).group(1)) diff --git a/scripts/convert_original_audioldm_to_diffusers.py b/scripts/convert_original_audioldm_to_diffusers.py index 4f8e4f8f9f..44183f1aea 100644 --- a/scripts/convert_original_audioldm_to_diffusers.py +++ b/scripts/convert_original_audioldm_to_diffusers.py @@ -636,7 +636,7 @@ def convert_open_clap_checkpoint(checkpoint): # replace sequential layers with list sequential_layer = re.match(sequential_layers_pattern, key).group(1) - key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer)//3}.linear.") + key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer) // 3}.linear.") elif re.match(text_projection_pattern, key): projecton_layer = int(re.match(text_projection_pattern, key).group(1)) diff --git a/scripts/convert_original_musicldm_to_diffusers.py b/scripts/convert_original_musicldm_to_diffusers.py index 61e5d16eea..00836fde25 100644 --- a/scripts/convert_original_musicldm_to_diffusers.py +++ b/scripts/convert_original_musicldm_to_diffusers.py @@ -642,7 +642,7 @@ def convert_open_clap_checkpoint(checkpoint): # replace sequential layers with list sequential_layer = re.match(sequential_layers_pattern, key).group(1) - key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer)//3}.linear.") + key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer) // 3}.linear.") elif re.match(text_projection_pattern, key): projecton_layer = int(re.match(text_projection_pattern, key).group(1)) diff --git a/scripts/convert_stable_audio.py b/scripts/convert_stable_audio.py index a0f9d0f87d..b33c8b0608 100644 --- a/scripts/convert_stable_audio.py +++ b/scripts/convert_stable_audio.py @@ -95,18 +95,18 @@ def convert_stable_audio_state_dict_to_diffusers(state_dict, num_autoencoder_lay # get idx of the layer idx = int(new_key.split("coder.layers.")[1].split(".")[0]) - new_key = new_key.replace(f"coder.layers.{idx}", f"coder.block.{idx-1}") + new_key = new_key.replace(f"coder.layers.{idx}", f"coder.block.{idx - 1}") if "encoder" in new_key: for i in range(3): - new_key = new_key.replace(f"block.{idx-1}.layers.{i}", f"block.{idx-1}.res_unit{i+1}") - new_key = new_key.replace(f"block.{idx-1}.layers.3", f"block.{idx-1}.snake1") - new_key = new_key.replace(f"block.{idx-1}.layers.4", f"block.{idx-1}.conv1") + new_key = new_key.replace(f"block.{idx - 1}.layers.{i}", f"block.{idx - 1}.res_unit{i + 1}") + new_key = new_key.replace(f"block.{idx - 1}.layers.3", f"block.{idx - 1}.snake1") + new_key = new_key.replace(f"block.{idx - 1}.layers.4", f"block.{idx - 1}.conv1") else: for i in range(2, 5): - new_key = new_key.replace(f"block.{idx-1}.layers.{i}", f"block.{idx-1}.res_unit{i-1}") - new_key = new_key.replace(f"block.{idx-1}.layers.0", f"block.{idx-1}.snake1") - new_key = new_key.replace(f"block.{idx-1}.layers.1", f"block.{idx-1}.conv_t1") + new_key = new_key.replace(f"block.{idx - 1}.layers.{i}", f"block.{idx - 1}.res_unit{i - 1}") + new_key = new_key.replace(f"block.{idx - 1}.layers.0", f"block.{idx - 1}.snake1") + new_key = new_key.replace(f"block.{idx - 1}.layers.1", f"block.{idx - 1}.conv_t1") new_key = new_key.replace("layers.0.beta", "snake1.beta") new_key = new_key.replace("layers.0.alpha", "snake1.alpha") @@ -118,9 +118,9 @@ def convert_stable_audio_state_dict_to_diffusers(state_dict, num_autoencoder_lay new_key = new_key.replace("layers.3.weight_", "conv2.weight_") if idx == num_autoencoder_layers + 1: - new_key = new_key.replace(f"block.{idx-1}", "snake1") + new_key = new_key.replace(f"block.{idx - 1}", "snake1") elif idx == num_autoencoder_layers + 2: - new_key = new_key.replace(f"block.{idx-1}", "conv2") + new_key = new_key.replace(f"block.{idx - 1}", "conv2") else: new_key = new_key diff --git a/scripts/convert_svd_to_diffusers.py b/scripts/convert_svd_to_diffusers.py index 3243ce294b..e46410ccb3 100644 --- a/scripts/convert_svd_to_diffusers.py +++ b/scripts/convert_svd_to_diffusers.py @@ -381,9 +381,9 @@ def convert_ldm_unet_checkpoint( # TODO resnet time_mixer.mix_factor if f"input_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict: - new_checkpoint[ - f"down_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor" - ] = unet_state_dict[f"input_blocks.{i}.0.time_mixer.mix_factor"] + new_checkpoint[f"down_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"] = ( + unet_state_dict[f"input_blocks.{i}.0.time_mixer.mix_factor"] + ) if len(attentions): paths = renew_attention_paths(attentions) @@ -478,9 +478,9 @@ def convert_ldm_unet_checkpoint( ) if f"output_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict: - new_checkpoint[ - f"up_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor" - ] = unet_state_dict[f"output_blocks.{i}.0.time_mixer.mix_factor"] + new_checkpoint[f"up_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"] = ( + unet_state_dict[f"output_blocks.{i}.0.time_mixer.mix_factor"] + ) output_block_list = {k: sorted(v) for k, v in output_block_list.items()} if ["conv.bias", "conv.weight"] in output_block_list.values(): diff --git a/scripts/convert_vq_diffusion_to_diffusers.py b/scripts/convert_vq_diffusion_to_diffusers.py index 7da6b40949..fe62d18faf 100644 --- a/scripts/convert_vq_diffusion_to_diffusers.py +++ b/scripts/convert_vq_diffusion_to_diffusers.py @@ -51,9 +51,9 @@ PORTED_VQVAES = ["image_synthesis.modeling.codecs.image_codec.patch_vqgan.PatchV def vqvae_model_from_original_config(original_config): - assert ( - original_config["target"] in PORTED_VQVAES - ), f"{original_config['target']} has not yet been ported to diffusers." + assert original_config["target"] in PORTED_VQVAES, ( + f"{original_config['target']} has not yet been ported to diffusers." + ) original_config = original_config["params"] @@ -464,15 +464,15 @@ PORTED_CONTENT_EMBEDDINGS = ["image_synthesis.modeling.embeddings.dalle_mask_ima def transformer_model_from_original_config( original_diffusion_config, original_transformer_config, original_content_embedding_config ): - assert ( - original_diffusion_config["target"] in PORTED_DIFFUSIONS - ), f"{original_diffusion_config['target']} has not yet been ported to diffusers." - assert ( - original_transformer_config["target"] in PORTED_TRANSFORMERS - ), f"{original_transformer_config['target']} has not yet been ported to diffusers." - assert ( - original_content_embedding_config["target"] in PORTED_CONTENT_EMBEDDINGS - ), f"{original_content_embedding_config['target']} has not yet been ported to diffusers." + assert original_diffusion_config["target"] in PORTED_DIFFUSIONS, ( + f"{original_diffusion_config['target']} has not yet been ported to diffusers." + ) + assert original_transformer_config["target"] in PORTED_TRANSFORMERS, ( + f"{original_transformer_config['target']} has not yet been ported to diffusers." + ) + assert original_content_embedding_config["target"] in PORTED_CONTENT_EMBEDDINGS, ( + f"{original_content_embedding_config['target']} has not yet been ported to diffusers." + ) original_diffusion_config = original_diffusion_config["params"] original_transformer_config = original_transformer_config["params"] diff --git a/setup.py b/setup.py index fdc166a81e..7c15d650c7 100644 --- a/setup.py +++ b/setup.py @@ -122,7 +122,7 @@ _deps = [ "pytest-timeout", "pytest-xdist", "python>=3.8.0", - "ruff==0.1.5", + "ruff==0.9.10", "safetensors>=0.3.1", "sentencepiece>=0.1.91,!=0.1.92", "GitPython<3.1.19", diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index 8ec95ed6fc..520815d122 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -29,7 +29,7 @@ deps = { "pytest-timeout": "pytest-timeout", "pytest-xdist": "pytest-xdist", "python": "python>=3.8.0", - "ruff": "ruff==0.1.5", + "ruff": "ruff==0.9.10", "safetensors": "safetensors>=0.3.1", "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92", "GitPython": "GitPython<3.1.19", diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index 21a1a70ff7..025f525214 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -295,8 +295,7 @@ class IPAdapterMixin: ): if len(scale_configs) != len(attn_processor.scale): raise ValueError( - f"Cannot assign {len(scale_configs)} scale_configs to " - f"{len(attn_processor.scale)} IP-Adapter." + f"Cannot assign {len(scale_configs)} scale_configs to {len(attn_processor.scale)} IP-Adapter." ) elif len(scale_configs) == 1: scale_configs = scale_configs * len(attn_processor.scale) diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index 5ec16ff299..791b7ae9b1 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -184,9 +184,9 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_ # Store DoRA scale if present. if dora_present_in_unet: dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down." - unet_state_dict[ - diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.") - ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) + unet_state_dict[diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")] = ( + state_dict.pop(key.replace("lora_down.weight", "dora_scale")) + ) # Handle text encoder LoRAs. elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")): @@ -206,13 +206,13 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_ "_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer." ) if lora_name.startswith(("lora_te_", "lora_te1_")): - te_state_dict[ - diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.") - ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) + te_state_dict[diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")] = ( + state_dict.pop(key.replace("lora_down.weight", "dora_scale")) + ) elif lora_name.startswith("lora_te2_"): - te2_state_dict[ - diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.") - ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) + te2_state_dict[diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")] = ( + state_dict.pop(key.replace("lora_down.weight", "dora_scale")) + ) # Store alpha if present. if lora_name_alpha in state_dict: @@ -1020,21 +1020,21 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict): for lora_key in ["lora_A", "lora_B"]: ## time_text_embed.timestep_embedder <- time_in - converted_state_dict[ - f"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight" - ] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.weight") + converted_state_dict[f"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight"] = ( + original_state_dict.pop(f"time_in.in_layer.{lora_key}.weight") + ) if f"time_in.in_layer.{lora_key}.bias" in original_state_dict_keys: - converted_state_dict[ - f"time_text_embed.timestep_embedder.linear_1.{lora_key}.bias" - ] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.bias") + converted_state_dict[f"time_text_embed.timestep_embedder.linear_1.{lora_key}.bias"] = ( + original_state_dict.pop(f"time_in.in_layer.{lora_key}.bias") + ) - converted_state_dict[ - f"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight" - ] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.weight") + converted_state_dict[f"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight"] = ( + original_state_dict.pop(f"time_in.out_layer.{lora_key}.weight") + ) if f"time_in.out_layer.{lora_key}.bias" in original_state_dict_keys: - converted_state_dict[ - f"time_text_embed.timestep_embedder.linear_2.{lora_key}.bias" - ] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.bias") + converted_state_dict[f"time_text_embed.timestep_embedder.linear_2.{lora_key}.bias"] = ( + original_state_dict.pop(f"time_in.out_layer.{lora_key}.bias") + ) ## time_text_embed.text_embedder <- vector_in converted_state_dict[f"time_text_embed.text_embedder.linear_1.{lora_key}.weight"] = original_state_dict.pop( @@ -1056,21 +1056,21 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict): # guidance has_guidance = any("guidance" in k for k in original_state_dict) if has_guidance: - converted_state_dict[ - f"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight" - ] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.weight") + converted_state_dict[f"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight"] = ( + original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.weight") + ) if f"guidance_in.in_layer.{lora_key}.bias" in original_state_dict_keys: - converted_state_dict[ - f"time_text_embed.guidance_embedder.linear_1.{lora_key}.bias" - ] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.bias") + converted_state_dict[f"time_text_embed.guidance_embedder.linear_1.{lora_key}.bias"] = ( + original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.bias") + ) - converted_state_dict[ - f"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight" - ] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.weight") + converted_state_dict[f"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight"] = ( + original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.weight") + ) if f"guidance_in.out_layer.{lora_key}.bias" in original_state_dict_keys: - converted_state_dict[ - f"time_text_embed.guidance_embedder.linear_2.{lora_key}.bias" - ] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.bias") + converted_state_dict[f"time_text_embed.guidance_embedder.linear_2.{lora_key}.bias"] = ( + original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.bias") + ) # context_embedder converted_state_dict[f"context_embedder.{lora_key}.weight"] = original_state_dict.pop( diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 99a2f871c8..218394af28 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -26,6 +26,7 @@ _import_structure = {} if is_torch_available(): _import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"] + _import_structure["auto_model"] = ["AutoModel"] _import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"] _import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"] _import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"] @@ -41,7 +42,6 @@ if is_torch_available(): _import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"] _import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"] _import_structure["autoencoders.vq_model"] = ["VQModel"] - _import_structure["auto_model"] = ["AutoModel"] _import_structure["cache_utils"] = ["CacheMixin"] _import_structure["controlnets.controlnet"] = ["ControlNetModel"] _import_structure["controlnets.controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"] diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 741f7075d7..ebc7d79aeb 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -205,7 +205,7 @@ def load_state_dict( ) from e except (UnicodeDecodeError, ValueError): raise OSError( - f"Unable to load weights from checkpoint file for '{checkpoint_file}' " f"at '{checkpoint_file}'. " + f"Unable to load weights from checkpoint file for '{checkpoint_file}' at '{checkpoint_file}'. " ) diff --git a/src/diffusers/models/transformers/transformer_2d.py b/src/diffusers/models/transformers/transformer_2d.py index a88ee6c9c9..5515a78850 100644 --- a/src/diffusers/models/transformers/transformer_2d.py +++ b/src/diffusers/models/transformers/transformer_2d.py @@ -211,9 +211,9 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin): def _init_vectorized_inputs(self, norm_type): assert self.config.sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" - assert ( - self.config.num_vector_embeds is not None - ), "Transformer2DModel over discrete input must provide num_embed" + assert self.config.num_vector_embeds is not None, ( + "Transformer2DModel over discrete input must provide num_embed" + ) self.height = self.config.sample_size self.width = self.config.sample_size diff --git a/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py b/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py index 1616d94ff1..f80771381b 100644 --- a/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py +++ b/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py @@ -791,7 +791,7 @@ class AudioLDM2Pipeline(DiffusionPipeline): if transcription is None: if self.text_encoder_2.config.model_type == "vits": - raise ValueError("Cannot forward without transcription. Please make sure to" " have transcription") + raise ValueError("Cannot forward without transcription. Please make sure to have transcription") elif transcription is not None and ( not isinstance(transcription, str) and not isinstance(transcription, list) ): diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index 16d3529ed3..2c63aedd96 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -657,7 +657,7 @@ class StableDiffusionControlNetInpaintPipeline( if padding_mask_crop is not None: if not isinstance(image, PIL.Image.Image): raise ValueError( - f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." + f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}." ) if not isinstance(mask_image, PIL.Image.Image): raise ValueError( @@ -665,7 +665,7 @@ class StableDiffusionControlNetInpaintPipeline( f" {type(mask_image)}." ) if output_type != "pil": - raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") + raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.") # `prompt` needs more sophisticated handling when there are multiple # conditionings. diff --git a/src/diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py b/src/diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py index 15745ecca3..aaec454cc7 100755 --- a/src/diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +++ b/src/diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py @@ -1130,7 +1130,7 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline): f"Incorrect configuration settings! The config of `pipeline.transformer`: {self.transformer.config} expects" f" {self.transformer.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of" " `pipeline.transformer` or your `mask_image` or `image` input." ) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index bff625367b..64cc8e13f3 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -507,7 +507,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From if padding_mask_crop is not None: if not isinstance(image, PIL.Image.Image): raise ValueError( - f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." + f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}." ) if not isinstance(mask_image, PIL.Image.Image): raise ValueError( @@ -515,7 +515,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From f" {type(mask_image)}." ) if output_type != "pil": - raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") + raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.") if max_sequence_length is not None and max_sequence_length > 512: raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") diff --git a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py index 27b9e0cd45..ecd5a8967b 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py @@ -574,7 +574,7 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FluxIPAdapterM if padding_mask_crop is not None: if not isinstance(image, PIL.Image.Image): raise ValueError( - f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." + f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}." ) if not isinstance(mask_image, PIL.Image.Image): raise ValueError( @@ -582,7 +582,7 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FluxIPAdapterM f" {type(mask_image)}." ) if output_type != "pil": - raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") + raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.") if max_sequence_length is not None and max_sequence_length > 512: raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") diff --git a/src/diffusers/pipelines/free_noise_utils.py b/src/diffusers/pipelines/free_noise_utils.py index dc0071a494..8ea5eb7dd5 100644 --- a/src/diffusers/pipelines/free_noise_utils.py +++ b/src/diffusers/pipelines/free_noise_utils.py @@ -341,9 +341,9 @@ class AnimateDiffFreeNoiseMixin: start_tensor = negative_prompt_embeds[i].unsqueeze(0) end_tensor = negative_prompt_embeds[i + 1].unsqueeze(0) - negative_prompt_interpolation_embeds[ - start_frame : end_frame + 1 - ] = self._free_noise_prompt_interpolation_callback(start_frame, end_frame, start_tensor, end_tensor) + negative_prompt_interpolation_embeds[start_frame : end_frame + 1] = ( + self._free_noise_prompt_interpolation_callback(start_frame, end_frame, start_tensor, end_tensor) + ) prompt_embeds = prompt_interpolation_embeds negative_prompt_embeds = negative_prompt_interpolation_embeds diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py index e653b8266f..5f8db26eef 100644 --- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py @@ -360,7 +360,7 @@ class KandinskyImg2ImgCombinedPipeline(DiffusionPipeline): """ _load_connected_pipes = True - model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->prior_prior->" "text_encoder->unet->movq" + model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->prior_prior->text_encoder->unet->movq" _exclude_from_cpu_offload = ["prior_prior"] def __init__( diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py index cce5f0b3d5..769c834ec3 100644 --- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py @@ -579,7 +579,7 @@ class KandinskyInpaintPipeline(DiffusionPipeline): f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of" " `pipeline.unet` or your `mask_image` or `image` input." ) diff --git a/src/diffusers/pipelines/omnigen/processor_omnigen.py b/src/diffusers/pipelines/omnigen/processor_omnigen.py index 75d272ac51..40fac01f8f 100644 --- a/src/diffusers/pipelines/omnigen/processor_omnigen.py +++ b/src/diffusers/pipelines/omnigen/processor_omnigen.py @@ -95,13 +95,13 @@ class OmniGenMultiModalProcessor: image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags] unique_image_ids = sorted(set(image_ids)) - assert unique_image_ids == list( - range(1, len(unique_image_ids) + 1) - ), f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}" + assert unique_image_ids == list(range(1, len(unique_image_ids) + 1)), ( + f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}" + ) # total images must be the same as the number of image tags - assert ( - len(unique_image_ids) == len(input_images) - ), f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(input_images)} images" + assert len(unique_image_ids) == len(input_images), ( + f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(input_images)} images" + ) input_images = [input_images[x - 1] for x in image_ids] diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py index bc7a4b57af..6d89f16765 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py @@ -604,7 +604,7 @@ class StableDiffusionControlNetPAGInpaintPipeline( if padding_mask_crop is not None: if not isinstance(image, PIL.Image.Image): raise ValueError( - f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." + f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}." ) if not isinstance(mask_image, PIL.Image.Image): raise ValueError( @@ -612,7 +612,7 @@ class StableDiffusionControlNetPAGInpaintPipeline( f" {type(mask_image)}." ) if output_type != "pil": - raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") + raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.") # `prompt` needs more sophisticated handling when there are multiple # conditionings. @@ -1340,7 +1340,7 @@ class StableDiffusionControlNetPAGInpaintPipeline( f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of" " `pipeline.unet` or your `mask_image` or `image` input." ) elif num_channels_unet != 4: diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py index 33abfb0be8..db652989cf 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py @@ -683,7 +683,7 @@ class StableDiffusionPAGInpaintPipeline( if padding_mask_crop is not None: if not isinstance(image, PIL.Image.Image): raise ValueError( - f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." + f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}." ) if not isinstance(mask_image, PIL.Image.Image): raise ValueError( @@ -691,7 +691,7 @@ class StableDiffusionPAGInpaintPipeline( f" {type(mask_image)}." ) if output_type != "pil": - raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") + raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.") if ip_adapter_image is not None and ip_adapter_image_embeds is not None: raise ValueError( @@ -1191,7 +1191,7 @@ class StableDiffusionPAGInpaintPipeline( f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of" " `pipeline.unet` or your `mask_image` or `image` input." ) elif num_channels_unet != 4: diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py index fdf3df2f4d..8b06bdc9c9 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py @@ -737,7 +737,7 @@ class StableDiffusionXLPAGInpaintPipeline( if padding_mask_crop is not None: if not isinstance(image, PIL.Image.Image): raise ValueError( - f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." + f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}." ) if not isinstance(mask_image, PIL.Image.Image): raise ValueError( @@ -745,7 +745,7 @@ class StableDiffusionXLPAGInpaintPipeline( f" {type(mask_image)}." ) if output_type != "pil": - raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") + raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.") if ip_adapter_image is not None and ip_adapter_image_embeds is not None: raise ValueError( @@ -1509,7 +1509,7 @@ class StableDiffusionXLPAGInpaintPipeline( f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of" " `pipeline.unet` or your `mask_image` or `image` input." ) elif num_channels_unet != 4: diff --git a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py index 55a9f47145..288f269a65 100644 --- a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +++ b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py @@ -575,7 +575,7 @@ class PaintByExamplePipeline(DiffusionPipeline, StableDiffusionMixin): f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of" " `pipeline.unet` or your `mask_image` or `image` input." ) diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index f5b430564c..89a403df8d 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -323,9 +323,7 @@ def maybe_raise_or_warn( model_cls = unwrapped_sub_model.__class__ if not issubclass(model_cls, expected_class_obj): - raise ValueError( - f"{passed_class_obj[name]} is of type: {model_cls}, but should be" f" {expected_class_obj}" - ) + raise ValueError(f"{passed_class_obj[name]} is of type: {model_cls}, but should be {expected_class_obj}") else: logger.warning( f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it" diff --git a/src/diffusers/pipelines/shap_e/renderer.py b/src/diffusers/pipelines/shap_e/renderer.py index 9d9f9d9b2a..dd25945590 100644 --- a/src/diffusers/pipelines/shap_e/renderer.py +++ b/src/diffusers/pipelines/shap_e/renderer.py @@ -983,9 +983,9 @@ class ShapERenderer(ModelMixin, ConfigMixin): fields = torch.cat(fields, dim=1) fields = fields.float() - assert ( - len(fields.shape) == 3 and fields.shape[-1] == 1 - ), f"expected [meta_batch x inner_batch] SDF results, but got {fields.shape}" + assert len(fields.shape) == 3 and fields.shape[-1] == 1, ( + f"expected [meta_batch x inner_batch] SDF results, but got {fields.shape}" + ) fields = fields.reshape(1, *([grid_size] * 3)) @@ -1039,9 +1039,9 @@ class ShapERenderer(ModelMixin, ConfigMixin): textures = textures.float() # 3.3 augument the mesh with texture data - assert len(textures.shape) == 3 and textures.shape[-1] == len( - texture_channels - ), f"expected [meta_batch x inner_batch x texture_channels] field results, but got {textures.shape}" + assert len(textures.shape) == 3 and textures.shape[-1] == len(texture_channels), ( + f"expected [meta_batch x inner_batch x texture_channels] field results, but got {textures.shape}" + ) for m, texture in zip(raw_meshes, textures): texture = texture[: len(m.verts)] diff --git a/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py b/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py index 5d773b614a..1b87c02df0 100644 --- a/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py +++ b/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py @@ -584,7 +584,7 @@ class StableAudioPipeline(DiffusionPipeline): if audio_end_in_s - audio_start_in_s > max_audio_length_in_s: raise ValueError( - f"The total audio length requested ({audio_end_in_s-audio_start_in_s}s) is longer than the model maximum possible length ({max_audio_length_in_s}). Make sure that 'audio_end_in_s-audio_start_in_s<={max_audio_length_in_s}'." + f"The total audio length requested ({audio_end_in_s - audio_start_in_s}s) is longer than the model maximum possible length ({max_audio_length_in_s}). Make sure that 'audio_end_in_s-audio_start_in_s<={max_audio_length_in_s}'." ) waveform_start = int(audio_start_in_s * self.vae.config.sampling_rate) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py index abcba92616..dd659306e0 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py @@ -335,7 +335,7 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline): f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of" " `pipeline.unet` or your `mask_image` or `image` input." ) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py index ddd2e27ded..f2e1d87be8 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py @@ -475,7 +475,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): "Incorrect configuration settings! The config of `pipeline.unet` expects" f" {unet_input_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of" " `pipeline.unet` or your `mask_image` or `image` input." ) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 6f4e7f3589..0f7be1a1bb 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -660,7 +660,7 @@ class StableDiffusionInpaintPipeline( if padding_mask_crop is not None: if not isinstance(image, PIL.Image.Image): raise ValueError( - f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." + f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}." ) if not isinstance(mask_image, PIL.Image.Image): raise ValueError( @@ -668,7 +668,7 @@ class StableDiffusionInpaintPipeline( f" {type(mask_image)}." ) if output_type != "pil": - raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") + raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.") if ip_adapter_image is not None and ip_adapter_image_embeds is not None: raise ValueError( @@ -1226,7 +1226,7 @@ class StableDiffusionInpaintPipeline( f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of" " `pipeline.unet` or your `mask_image` or `image` input." ) elif num_channels_unet != 4: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py index 7857bc58a8..e0748943ff 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py @@ -401,7 +401,7 @@ class StableDiffusionInstructPix2PixPipeline( f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_image`: {num_channels_image} " - f" = {num_channels_latents+num_channels_image}. Please verify the config of" + f" = {num_channels_latents + num_channels_image}. Please verify the config of" " `pipeline.unet` or your `image` input." ) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py index c6967bc393..42db88b030 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py @@ -600,7 +600,7 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_image`: {num_channels_image} " - f" = {num_channels_latents+num_channels_image}. Please verify the config of" + f" = {num_channels_latents + num_channels_image}. Please verify the config of" " `pipeline.unet` or your `image` input." ) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index dae4540ebe..f9b6dcbf5a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -740,7 +740,7 @@ class StableDiffusionUpscalePipeline( f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_image`: {num_channels_image} " - f" = {num_channels_latents+num_channels_image}. Please verify the config of" + f" = {num_channels_latents + num_channels_image}. Please verify the config of" " `pipeline.unet` or your `image` input." ) diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py index c69fb90a4c..cac305a87f 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py @@ -1258,7 +1258,7 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro f"Incorrect configuration settings! The config of `pipeline.transformer`: {self.transformer.config} expects" f" {self.transformer.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of" " `pipeline.transformer` or your `mask_image` or `image` input." ) elif num_channels_transformer != 16: diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 920caf4d24..835c0af800 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -741,7 +741,7 @@ class StableDiffusionXLInpaintPipeline( if padding_mask_crop is not None: if not isinstance(image, PIL.Image.Image): raise ValueError( - f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." + f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}." ) if not isinstance(mask_image, PIL.Image.Image): raise ValueError( @@ -749,7 +749,7 @@ class StableDiffusionXLInpaintPipeline( f" {type(mask_image)}." ) if output_type != "pil": - raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") + raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.") if ip_adapter_image is not None and ip_adapter_image_embeds is not None: raise ValueError( @@ -1509,7 +1509,7 @@ class StableDiffusionXLInpaintPipeline( f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of" " `pipeline.unet` or your `mask_image` or `image` input." ) elif num_channels_unet != 4: diff --git a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py index 487ad2d80a..775d86dcfc 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py @@ -334,7 +334,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): "Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined." ) if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image): - raise ValueError("`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is" f" {type(image)}") + raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}") if height % 16 != 0 or width % 16 != 0: raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") diff --git a/src/diffusers/quantizers/base.py b/src/diffusers/quantizers/base.py index 1c75b5bef9..fa9ba98e6d 100644 --- a/src/diffusers/quantizers/base.py +++ b/src/diffusers/quantizers/base.py @@ -215,19 +215,15 @@ class DiffusersQuantizer(ABC): ) @abstractmethod - def _process_model_before_weight_loading(self, model, **kwargs): - ... + def _process_model_before_weight_loading(self, model, **kwargs): ... @abstractmethod - def _process_model_after_weight_loading(self, model, **kwargs): - ... + def _process_model_after_weight_loading(self, model, **kwargs): ... @property @abstractmethod - def is_serializable(self): - ... + def is_serializable(self): ... @property @abstractmethod - def is_trainable(self): - ... + def is_trainable(self): ... diff --git a/src/diffusers/schedulers/scheduling_consistency_models.py b/src/diffusers/schedulers/scheduling_consistency_models.py index 653171638c..c946fa1681 100644 --- a/src/diffusers/schedulers/scheduling_consistency_models.py +++ b/src/diffusers/schedulers/scheduling_consistency_models.py @@ -203,8 +203,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin): if timesteps[0] >= self.config.num_train_timesteps: raise ValueError( - f"`timesteps` must start before `self.config.train_timesteps`:" - f" {self.config.num_train_timesteps}." + f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}." ) timesteps = np.array(timesteps, dtype=np.int64) diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 624d5a5cd4..f9eb9c365a 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -279,8 +279,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): if timesteps[0] >= self.config.num_train_timesteps: raise ValueError( - f"`timesteps` must start before `self.config.train_timesteps`:" - f" {self.config.num_train_timesteps}." + f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}." ) timesteps = np.array(timesteps, dtype=np.int64) diff --git a/src/diffusers/schedulers/scheduling_ddpm_parallel.py b/src/diffusers/schedulers/scheduling_ddpm_parallel.py index 20ad7a4c92..64195be141 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_parallel.py +++ b/src/diffusers/schedulers/scheduling_ddpm_parallel.py @@ -289,8 +289,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): if timesteps[0] >= self.config.num_train_timesteps: raise ValueError( - f"`timesteps` must start before `self.config.train_timesteps`:" - f" {self.config.num_train_timesteps}." + f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}." ) timesteps = np.array(timesteps, dtype=np.int64) diff --git a/src/diffusers/schedulers/scheduling_lcm.py b/src/diffusers/schedulers/scheduling_lcm.py index 686b686f68..2a0cce7bf1 100644 --- a/src/diffusers/schedulers/scheduling_lcm.py +++ b/src/diffusers/schedulers/scheduling_lcm.py @@ -413,8 +413,7 @@ class LCMScheduler(SchedulerMixin, ConfigMixin): if timesteps[0] >= self.config.num_train_timesteps: raise ValueError( - f"`timesteps` must start before `self.config.train_timesteps`:" - f" {self.config.num_train_timesteps}." + f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}." ) # Raise warning if timestep schedule does not start with self.config.num_train_timesteps - 1 diff --git a/src/diffusers/schedulers/scheduling_tcd.py b/src/diffusers/schedulers/scheduling_tcd.py index 5d60383142..77770ab206 100644 --- a/src/diffusers/schedulers/scheduling_tcd.py +++ b/src/diffusers/schedulers/scheduling_tcd.py @@ -431,8 +431,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin): if timesteps[0] >= self.config.num_train_timesteps: raise ValueError( - f"`timesteps` must start before `self.config.train_timesteps`:" - f" {self.config.num_train_timesteps}." + f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}." ) # Raise warning if timestep schedule does not start with self.config.num_train_timesteps - 1 diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index c570bac733..b98c4e33f8 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -241,7 +241,7 @@ def _set_state_dict_into_text_encoder( """ text_encoder_state_dict = { - f'{k.replace(prefix, "")}': v for k, v in lora_state_dict.items() if k.startswith(prefix) + f"{k.replace(prefix, '')}": v for k, v in lora_state_dict.items() if k.startswith(prefix) } text_encoder_state_dict = convert_state_dict_to_peft(convert_state_dict_to_diffusers(text_encoder_state_dict)) set_peft_model_state_dict(text_encoder, text_encoder_state_dict, adapter_name="default") @@ -583,7 +583,7 @@ class EMAModel: """ if self.temp_stored_params is None: - raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`") + raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights to `restore()`") if self.foreach: torch._foreach_copy_( [param.data for param in parameters], [c_param.data for c_param in self.temp_stored_params] diff --git a/src/diffusers/utils/deprecation_utils.py b/src/diffusers/utils/deprecation_utils.py index f482deddd2..4f001b3047 100644 --- a/src/diffusers/utils/deprecation_utils.py +++ b/src/diffusers/utils/deprecation_utils.py @@ -40,7 +40,7 @@ def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn line_number = call_frame.lineno function = call_frame.function key, value = next(iter(deprecated_kwargs.items())) - raise TypeError(f"{function} in {filename} line {line_number-1} got an unexpected keyword argument `{key}`") + raise TypeError(f"{function} in {filename} line {line_number - 1} got an unexpected keyword argument `{key}`") if len(values) == 0: return diff --git a/src/diffusers/utils/logging.py b/src/diffusers/utils/logging.py index 6f93450c41..b96e0e222c 100644 --- a/src/diffusers/utils/logging.py +++ b/src/diffusers/utils/logging.py @@ -60,8 +60,7 @@ def _get_default_logging_level() -> int: return log_levels[env_level_str] else: logging.getLogger().warning( - f"Unknown option DIFFUSERS_VERBOSITY={env_level_str}, " - f"has to be one of: { ', '.join(log_levels.keys()) }" + f"Unknown option DIFFUSERS_VERBOSITY={env_level_str}, has to be one of: {', '.join(log_levels.keys())}" ) return _default_log_level diff --git a/src/diffusers/utils/state_dict_utils.py b/src/diffusers/utils/state_dict_utils.py index f23fddd286..3682c5bfac 100644 --- a/src/diffusers/utils/state_dict_utils.py +++ b/src/diffusers/utils/state_dict_utils.py @@ -334,7 +334,7 @@ def convert_state_dict_to_kohya(state_dict, original_type=None, **kwargs): kohya_key = kohya_key.replace(peft_adapter_name, "") # Kohya doesn't take names kohya_ss_state_dict[kohya_key] = weight if "lora_down" in kohya_key: - alpha_key = f'{kohya_key.split(".")[0]}.alpha' + alpha_key = f"{kohya_key.split('.')[0]}.alpha" kohya_ss_state_dict[alpha_key] = torch.tensor(len(weight)) return kohya_ss_state_dict diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 51e7e640fb..4ba6f7c25e 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -882,7 +882,7 @@ def pytest_terminal_summary_main(tr, id): f.write("slowest durations\n") for i, rep in enumerate(dlist): if rep.duration < durations_min: - f.write(f"{len(dlist)-i} durations < {durations_min} secs were omitted") + f.write(f"{len(dlist) - i} durations < {durations_min} secs were omitted") break f.write(f"{rep.duration:02.2f}s {rep.when:<8} {rep.nodeid}\n") @@ -1027,7 +1027,7 @@ def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None): process.join(timeout=timeout) if results["error"] is not None: - test_case.fail(f'{results["error"]}') + test_case.fail(f"{results['error']}") class CaptureLogger: diff --git a/tests/hooks/test_hooks.py b/tests/hooks/test_hooks.py index 74bd43c523..53aafc9d2e 100644 --- a/tests/hooks/test_hooks.py +++ b/tests/hooks/test_hooks.py @@ -168,9 +168,7 @@ class HookTests(unittest.TestCase): registry.register_hook(MultiplyHook(2), "multiply_hook") registry_repr = repr(registry) - expected_repr = ( - "HookRegistry(\n" " (0) add_hook - AddHook\n" " (1) multiply_hook - MultiplyHook(value=2)\n" ")" - ) + expected_repr = "HookRegistry(\n (0) add_hook - AddHook\n (1) multiply_hook - MultiplyHook(value=2)\n)" self.assertEqual(len(registry.hooks), 2) self.assertEqual(registry._hook_order, ["add_hook", "multiply_hook"]) @@ -285,12 +283,7 @@ class HookTests(unittest.TestCase): self.model(input) output = cap_logger.out.replace(" ", "").replace("\n", "") expected_invocation_order_log = ( - ( - "MultiplyHook pre_forward\n" - "AddHook pre_forward\n" - "AddHook post_forward\n" - "MultiplyHook post_forward\n" - ) + ("MultiplyHook pre_forward\nAddHook pre_forward\nAddHook post_forward\nMultiplyHook post_forward\n") .replace(" ", "") .replace("\n", "") ) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 6155ac2e39..f82a2407f3 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -299,9 +299,9 @@ class ModelUtilsTest(unittest.TestCase): ) download_requests = [r.method for r in m.request_history] - assert ( - download_requests.count("HEAD") == 3 - ), "3 HEAD requests one for config, one for model, and one for shard index file." + assert download_requests.count("HEAD") == 3, ( + "3 HEAD requests one for config, one for model, and one for shard index file." + ) assert download_requests.count("GET") == 2, "2 GET requests one for config, one for model" with requests_mock.mock(real_http=True) as m: @@ -313,9 +313,9 @@ class ModelUtilsTest(unittest.TestCase): ) cache_requests = [r.method for r in m.request_history] - assert ( - "HEAD" == cache_requests[0] and len(cache_requests) == 2 - ), "We should call only `model_info` to check for commit hash and knowing if shard index is present." + assert "HEAD" == cache_requests[0] and len(cache_requests) == 2, ( + "We should call only `model_info` to check for commit hash and knowing if shard index is present." + ) def test_weight_overwrite(self): with tempfile.TemporaryDirectory() as tmpdirname, self.assertRaises(ValueError) as error_context: diff --git a/tests/models/transformers/test_models_transformer_sd3.py b/tests/models/transformers/test_models_transformer_sd3.py index 659d9a82fd..bfef1fc4f0 100644 --- a/tests/models/transformers/test_models_transformer_sd3.py +++ b/tests/models/transformers/test_models_transformer_sd3.py @@ -92,9 +92,9 @@ class SD3TransformerTests(ModelTesterMixin, unittest.TestCase): model.enable_xformers_memory_efficient_attention() - assert ( - model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor" - ), "xformers is not enabled" + assert model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor", ( + "xformers is not enabled" + ) @unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply") def test_set_attn_processor_for_determinism(self): @@ -167,9 +167,9 @@ class SD35TransformerTests(ModelTesterMixin, unittest.TestCase): model.enable_xformers_memory_efficient_attention() - assert ( - model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor" - ), "xformers is not enabled" + assert model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor", ( + "xformers is not enabled" + ) @unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply") def test_set_attn_processor_for_determinism(self): diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index 8e1187f114..d01a0b4935 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -654,22 +654,22 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test keepall_mask = torch.ones(*cond.shape[:-1], device=cond.device, dtype=mask_dtype) full_cond_keepallmask_out = model(**{**inputs_dict, "encoder_attention_mask": keepall_mask}).sample - assert full_cond_keepallmask_out.allclose( - full_cond_out, rtol=1e-05, atol=1e-05 - ), "a 'keep all' mask should give the same result as no mask" + assert full_cond_keepallmask_out.allclose(full_cond_out, rtol=1e-05, atol=1e-05), ( + "a 'keep all' mask should give the same result as no mask" + ) trunc_cond = cond[:, :-1, :] trunc_cond_out = model(**{**inputs_dict, "encoder_hidden_states": trunc_cond}).sample - assert not trunc_cond_out.allclose( - full_cond_out, rtol=1e-05, atol=1e-05 - ), "discarding the last token from our cond should change the result" + assert not trunc_cond_out.allclose(full_cond_out, rtol=1e-05, atol=1e-05), ( + "discarding the last token from our cond should change the result" + ) batch, tokens, _ = cond.shape mask_last = (torch.arange(tokens) < tokens - 1).expand(batch, -1).to(cond.device, mask_dtype) masked_cond_out = model(**{**inputs_dict, "encoder_attention_mask": mask_last}).sample - assert masked_cond_out.allclose( - trunc_cond_out, rtol=1e-05, atol=1e-05 - ), "masking the last token from our cond should be equivalent to truncating that token out of the condition" + assert masked_cond_out.allclose(trunc_cond_out, rtol=1e-05, atol=1e-05), ( + "masking the last token from our cond should be equivalent to truncating that token out of the condition" + ) # see diffusers.models.attention_processor::Attention#prepare_attention_mask # note: we may not need to fix mask padding to work for stable-diffusion cross-attn masks. @@ -697,9 +697,9 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test trunc_mask = torch.zeros(batch, tokens - 1, device=cond.device, dtype=torch.bool) trunc_mask_out = model(**{**inputs_dict, "encoder_attention_mask": trunc_mask}).sample - assert trunc_mask_out.allclose( - keeplast_out - ), "a mask with fewer tokens than condition, will be padded with 'keep' tokens. a 'discard-all' mask missing the final token is thus equivalent to a 'keep last' mask." + assert trunc_mask_out.allclose(keeplast_out), ( + "a mask with fewer tokens than condition, will be padded with 'keep' tokens. a 'discard-all' mask missing the final token is thus equivalent to a 'keep last' mask." + ) def test_custom_diffusion_processors(self): # enable deterministic behavior for gradient checkpointing @@ -1114,12 +1114,12 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test with torch.no_grad(): lora_sample_2 = model(**inputs_dict).sample - assert not torch.allclose( - non_lora_sample, lora_sample_1, atol=1e-4, rtol=1e-4 - ), "LoRA injected UNet should produce different results." - assert torch.allclose( - lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4 - ), "Loading from a saved checkpoint should produce identical results." + assert not torch.allclose(non_lora_sample, lora_sample_1, atol=1e-4, rtol=1e-4), ( + "LoRA injected UNet should produce different results." + ) + assert torch.allclose(lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4), ( + "Loading from a saved checkpoint should produce identical results." + ) @require_peft_backend def test_save_attn_procs_raise_warning(self): diff --git a/tests/others/test_image_processor.py b/tests/others/test_image_processor.py index 3397ca9e39..071194c59e 100644 --- a/tests/others/test_image_processor.py +++ b/tests/others/test_image_processor.py @@ -65,9 +65,9 @@ class ImageProcessorTest(unittest.TestCase): ) out_np = self.to_np(out) in_np = (input_np * 255).round() if output_type == "pil" else input_np - assert ( - np.abs(in_np - out_np).max() < 1e-6 - ), f"decoded output does not match input for output_type {output_type}" + assert np.abs(in_np - out_np).max() < 1e-6, ( + f"decoded output does not match input for output_type {output_type}" + ) def test_vae_image_processor_np(self): image_processor = VaeImageProcessor(do_resize=False, do_normalize=True) @@ -78,9 +78,9 @@ class ImageProcessorTest(unittest.TestCase): out_np = self.to_np(out) in_np = (input_np * 255).round() if output_type == "pil" else input_np - assert ( - np.abs(in_np - out_np).max() < 1e-6 - ), f"decoded output does not match input for output_type {output_type}" + assert np.abs(in_np - out_np).max() < 1e-6, ( + f"decoded output does not match input for output_type {output_type}" + ) def test_vae_image_processor_pil(self): image_processor = VaeImageProcessor(do_resize=False, do_normalize=True) @@ -93,9 +93,9 @@ class ImageProcessorTest(unittest.TestCase): for i, o in zip(input_pil, out): in_np = np.array(i) out_np = self.to_np(out) if output_type == "pil" else (self.to_np(out) * 255).round() - assert ( - np.abs(in_np - out_np).max() < 1e-6 - ), f"decoded output does not match input for output_type {output_type}" + assert np.abs(in_np - out_np).max() < 1e-6, ( + f"decoded output does not match input for output_type {output_type}" + ) def test_preprocess_input_3d(self): image_processor = VaeImageProcessor(do_resize=False, do_normalize=False) @@ -293,9 +293,9 @@ class ImageProcessorTest(unittest.TestCase): scale = 2 out_pt = image_processor.resize(image=input_pt, height=h // scale, width=w // scale) exp_pt_shape = (b, c, h // scale, w // scale) - assert ( - out_pt.shape == exp_pt_shape - ), f"resized image output shape '{out_pt.shape}' didn't match expected shape '{exp_pt_shape}'." + assert out_pt.shape == exp_pt_shape, ( + f"resized image output shape '{out_pt.shape}' didn't match expected shape '{exp_pt_shape}'." + ) def test_vae_image_processor_resize_np(self): image_processor = VaeImageProcessor(do_resize=True, vae_scale_factor=1) @@ -305,6 +305,6 @@ class ImageProcessorTest(unittest.TestCase): input_np = self.to_np(input_pt) out_np = image_processor.resize(image=input_np, height=h // scale, width=w // scale) exp_np_shape = (b, h // scale, w // scale, c) - assert ( - out_np.shape == exp_np_shape - ), f"resized image output shape '{out_np.shape}' didn't match expected shape '{exp_np_shape}'." + assert out_np.shape == exp_np_shape, ( + f"resized image output shape '{out_np.shape}' didn't match expected shape '{exp_np_shape}'." + ) diff --git a/tests/pipelines/amused/test_amused.py b/tests/pipelines/amused/test_amused.py index a0fbc5df1c..ac579bbf2b 100644 --- a/tests/pipelines/amused/test_amused.py +++ b/tests/pipelines/amused/test_amused.py @@ -126,8 +126,7 @@ class AmusedPipelineFastTests(PipelineTesterMixin, unittest.TestCase): self._test_inference_batch_consistent(batch_sizes=batch_sizes, batch_generator=False) @unittest.skip("aMUSEd does not support lists of generators") - def test_inference_batch_single_identical(self): - ... + def test_inference_batch_single_identical(self): ... @slow diff --git a/tests/pipelines/amused/test_amused_img2img.py b/tests/pipelines/amused/test_amused_img2img.py index 2699bbe7f5..942735f157 100644 --- a/tests/pipelines/amused/test_amused_img2img.py +++ b/tests/pipelines/amused/test_amused_img2img.py @@ -126,8 +126,7 @@ class AmusedImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase): self._test_inference_batch_consistent(batch_sizes=batch_sizes, batch_generator=False) @unittest.skip("aMUSEd does not support lists of generators") - def test_inference_batch_single_identical(self): - ... + def test_inference_batch_single_identical(self): ... @slow diff --git a/tests/pipelines/amused/test_amused_inpaint.py b/tests/pipelines/amused/test_amused_inpaint.py index 645379a7ea..541b988f17 100644 --- a/tests/pipelines/amused/test_amused_inpaint.py +++ b/tests/pipelines/amused/test_amused_inpaint.py @@ -130,8 +130,7 @@ class AmusedInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase): self._test_inference_batch_consistent(batch_sizes=batch_sizes, batch_generator=False) @unittest.skip("aMUSEd does not support lists of generators") - def test_inference_batch_single_identical(self): - ... + def test_inference_batch_single_identical(self): ... @slow diff --git a/tests/pipelines/aura_flow/test_pipeline_aura_flow.py b/tests/pipelines/aura_flow/test_pipeline_aura_flow.py index c56aeb905a..1eb9d1035c 100644 --- a/tests/pipelines/aura_flow/test_pipeline_aura_flow.py +++ b/tests/pipelines/aura_flow/test_pipeline_aura_flow.py @@ -106,9 +106,9 @@ class AuraFlowPipelineFastTests(unittest.TestCase, PipelineTesterMixin): # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added # to the pipeline level. pipe.transformer.fuse_qkv_projections() - assert check_qkv_fusion_processors_exist( - pipe.transformer - ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + assert check_qkv_fusion_processors_exist(pipe.transformer), ( + "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + ) assert check_qkv_fusion_matches_attn_procs_length( pipe.transformer, pipe.transformer.original_attn_processors ), "Something wrong with the attention processors concerning the fused QKV projections." @@ -122,15 +122,15 @@ class AuraFlowPipelineFastTests(unittest.TestCase, PipelineTesterMixin): image = pipe(**inputs).images image_slice_disabled = image[0, -3:, -3:, -1] - assert np.allclose( - original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3 - ), "Fusion of QKV projections shouldn't affect the outputs." - assert np.allclose( - image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3 - ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." - assert np.allclose( - original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 - ), "Original outputs should match when fused QKV projections are disabled." + assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), ( + "Fusion of QKV projections shouldn't affect the outputs." + ) + assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), ( + "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." + ) + assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), ( + "Original outputs should match when fused QKV projections are disabled." + ) @unittest.skip("xformers attention processor does not exist for AuraFlow") def test_xformers_attention_forwardGenerator_pass(self): diff --git a/tests/pipelines/blipdiffusion/test_blipdiffusion.py b/tests/pipelines/blipdiffusion/test_blipdiffusion.py index e073f55aec..db8d36b23a 100644 --- a/tests/pipelines/blipdiffusion/test_blipdiffusion.py +++ b/tests/pipelines/blipdiffusion/test_blipdiffusion.py @@ -195,9 +195,9 @@ class BlipDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): [0.5329548, 0.8372512, 0.33269387, 0.82096875, 0.43657133, 0.3783, 0.5953028, 0.51934963, 0.42142007] ) - assert ( - np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - ), f" expected_slice {image_slice.flatten()}, but got {image_slice.flatten()}" + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, ( + f" expected_slice {image_slice.flatten()}, but got {image_slice.flatten()}" + ) @unittest.skip("Test not supported because of complexities in deriving query_embeds.") def test_encode_prompt_works_in_isolation(self): diff --git a/tests/pipelines/cogvideo/test_cogvideox.py b/tests/pipelines/cogvideo/test_cogvideox.py index 388dc9ef7e..a9de0ff05f 100644 --- a/tests/pipelines/cogvideo/test_cogvideox.py +++ b/tests/pipelines/cogvideo/test_cogvideox.py @@ -299,9 +299,9 @@ class CogVideoXPipelineFastTests( original_image_slice = frames[0, -2:, -1, -3:, -3:] pipe.fuse_qkv_projections() - assert check_qkv_fusion_processors_exist( - pipe.transformer - ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + assert check_qkv_fusion_processors_exist(pipe.transformer), ( + "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + ) assert check_qkv_fusion_matches_attn_procs_length( pipe.transformer, pipe.transformer.original_attn_processors ), "Something wrong with the attention processors concerning the fused QKV projections." @@ -315,15 +315,15 @@ class CogVideoXPipelineFastTests( frames = pipe(**inputs).frames image_slice_disabled = frames[0, -2:, -1, -3:, -3:] - assert np.allclose( - original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3 - ), "Fusion of QKV projections shouldn't affect the outputs." - assert np.allclose( - image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3 - ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." - assert np.allclose( - original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 - ), "Original outputs should match when fused QKV projections are disabled." + assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), ( + "Fusion of QKV projections shouldn't affect the outputs." + ) + assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), ( + "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." + ) + assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), ( + "Original outputs should match when fused QKV projections are disabled." + ) @slow diff --git a/tests/pipelines/cogvideo/test_cogvideox_fun_control.py b/tests/pipelines/cogvideo/test_cogvideox_fun_control.py index 2e962bd247..4f32da7ac4 100644 --- a/tests/pipelines/cogvideo/test_cogvideox_fun_control.py +++ b/tests/pipelines/cogvideo/test_cogvideox_fun_control.py @@ -299,9 +299,9 @@ class CogVideoXFunControlPipelineFastTests(PipelineTesterMixin, unittest.TestCas original_image_slice = frames[0, -2:, -1, -3:, -3:] pipe.fuse_qkv_projections() - assert check_qkv_fusion_processors_exist( - pipe.transformer - ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + assert check_qkv_fusion_processors_exist(pipe.transformer), ( + "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + ) assert check_qkv_fusion_matches_attn_procs_length( pipe.transformer, pipe.transformer.original_attn_processors ), "Something wrong with the attention processors concerning the fused QKV projections." @@ -315,12 +315,12 @@ class CogVideoXFunControlPipelineFastTests(PipelineTesterMixin, unittest.TestCas frames = pipe(**inputs).frames image_slice_disabled = frames[0, -2:, -1, -3:, -3:] - assert np.allclose( - original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3 - ), "Fusion of QKV projections shouldn't affect the outputs." - assert np.allclose( - image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3 - ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." - assert np.allclose( - original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 - ), "Original outputs should match when fused QKV projections are disabled." + assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), ( + "Fusion of QKV projections shouldn't affect the outputs." + ) + assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), ( + "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." + ) + assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), ( + "Original outputs should match when fused QKV projections are disabled." + ) diff --git a/tests/pipelines/cogvideo/test_cogvideox_image2video.py b/tests/pipelines/cogvideo/test_cogvideox_image2video.py index cac47f1a83..ec4e51bd1b 100644 --- a/tests/pipelines/cogvideo/test_cogvideox_image2video.py +++ b/tests/pipelines/cogvideo/test_cogvideox_image2video.py @@ -317,9 +317,9 @@ class CogVideoXImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestC original_image_slice = frames[0, -2:, -1, -3:, -3:] pipe.fuse_qkv_projections() - assert check_qkv_fusion_processors_exist( - pipe.transformer - ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + assert check_qkv_fusion_processors_exist(pipe.transformer), ( + "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + ) assert check_qkv_fusion_matches_attn_procs_length( pipe.transformer, pipe.transformer.original_attn_processors ), "Something wrong with the attention processors concerning the fused QKV projections." @@ -333,15 +333,15 @@ class CogVideoXImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestC frames = pipe(**inputs).frames image_slice_disabled = frames[0, -2:, -1, -3:, -3:] - assert np.allclose( - original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3 - ), "Fusion of QKV projections shouldn't affect the outputs." - assert np.allclose( - image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3 - ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." - assert np.allclose( - original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 - ), "Original outputs should match when fused QKV projections are disabled." + assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), ( + "Fusion of QKV projections shouldn't affect the outputs." + ) + assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), ( + "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." + ) + assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), ( + "Original outputs should match when fused QKV projections are disabled." + ) @slow diff --git a/tests/pipelines/cogvideo/test_cogvideox_video2video.py b/tests/pipelines/cogvideo/test_cogvideox_video2video.py index 4d836cb5e2..b1ac8cbd90 100644 --- a/tests/pipelines/cogvideo/test_cogvideox_video2video.py +++ b/tests/pipelines/cogvideo/test_cogvideox_video2video.py @@ -298,9 +298,9 @@ class CogVideoXVideoToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestC original_image_slice = frames[0, -2:, -1, -3:, -3:] pipe.fuse_qkv_projections() - assert check_qkv_fusion_processors_exist( - pipe.transformer - ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + assert check_qkv_fusion_processors_exist(pipe.transformer), ( + "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + ) assert check_qkv_fusion_matches_attn_procs_length( pipe.transformer, pipe.transformer.original_attn_processors ), "Something wrong with the attention processors concerning the fused QKV projections." @@ -314,12 +314,12 @@ class CogVideoXVideoToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestC frames = pipe(**inputs).frames image_slice_disabled = frames[0, -2:, -1, -3:, -3:] - assert np.allclose( - original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3 - ), "Fusion of QKV projections shouldn't affect the outputs." - assert np.allclose( - image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3 - ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." - assert np.allclose( - original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 - ), "Original outputs should match when fused QKV projections are disabled." + assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), ( + "Fusion of QKV projections shouldn't affect the outputs." + ) + assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), ( + "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." + ) + assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), ( + "Original outputs should match when fused QKV projections are disabled." + ) diff --git a/tests/pipelines/controlnet/test_controlnet_blip_diffusion.py b/tests/pipelines/controlnet/test_controlnet_blip_diffusion.py index eedda4e217..a5768cb51f 100644 --- a/tests/pipelines/controlnet/test_controlnet_blip_diffusion.py +++ b/tests/pipelines/controlnet/test_controlnet_blip_diffusion.py @@ -219,9 +219,9 @@ class BlipDiffusionControlNetPipelineFastTests(PipelineTesterMixin, unittest.Tes assert image.shape == (1, 16, 16, 4) expected_slice = np.array([0.7953, 0.7136, 0.6597, 0.4779, 0.7389, 0.4111, 0.5826, 0.4150, 0.8422]) - assert ( - np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}" + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, ( + f" expected_slice {expected_slice}, but got {image_slice.flatten()}" + ) @unittest.skip("Test not supported because of complexities in deriving query_embeds.") def test_encode_prompt_works_in_isolation(self): diff --git a/tests/pipelines/controlnet_flux/test_controlnet_flux.py b/tests/pipelines/controlnet_flux/test_controlnet_flux.py index 9a270c2bbf..9ce62cde9f 100644 --- a/tests/pipelines/controlnet_flux/test_controlnet_flux.py +++ b/tests/pipelines/controlnet_flux/test_controlnet_flux.py @@ -178,9 +178,9 @@ class FluxControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMixin, Fl [0.47387695, 0.63134766, 0.5605469, 0.61621094, 0.7207031, 0.7089844, 0.70410156, 0.6113281, 0.64160156] ) - assert ( - np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - ), f"Expected: {expected_slice}, got: {image_slice.flatten()}" + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, ( + f"Expected: {expected_slice}, got: {image_slice.flatten()}" + ) @unittest.skip("xFormersAttnProcessor does not work with SD3 Joint Attention") def test_xformers_attention_forwardGenerator_pass(self): diff --git a/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py b/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py index 59ccb92378..8d63619c40 100644 --- a/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py +++ b/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py @@ -170,9 +170,9 @@ class FluxControlNetImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMi original_image_slice = image[0, -3:, -3:, -1] pipe.transformer.fuse_qkv_projections() - assert check_qkv_fusion_processors_exist( - pipe.transformer - ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + assert check_qkv_fusion_processors_exist(pipe.transformer), ( + "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + ) assert check_qkv_fusion_matches_attn_procs_length( pipe.transformer, pipe.transformer.original_attn_processors ), "Something wrong with the attention processors concerning the fused QKV projections." @@ -186,15 +186,15 @@ class FluxControlNetImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMi image = pipe(**inputs).images image_slice_disabled = image[0, -3:, -3:, -1] - assert np.allclose( - original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3 - ), "Fusion of QKV projections shouldn't affect the outputs." - assert np.allclose( - image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3 - ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." - assert np.allclose( - original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 - ), "Original outputs should match when fused QKV projections are disabled." + assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), ( + "Fusion of QKV projections shouldn't affect the outputs." + ) + assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), ( + "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." + ) + assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), ( + "Original outputs should match when fused QKV projections are disabled." + ) def test_flux_image_output_shape(self): pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) diff --git a/tests/pipelines/controlnet_hunyuandit/test_controlnet_hunyuandit.py b/tests/pipelines/controlnet_hunyuandit/test_controlnet_hunyuandit.py index f7b3db05c8..4bd7f59dc0 100644 --- a/tests/pipelines/controlnet_hunyuandit/test_controlnet_hunyuandit.py +++ b/tests/pipelines/controlnet_hunyuandit/test_controlnet_hunyuandit.py @@ -162,9 +162,9 @@ class HunyuanDiTControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMix [0.6953125, 0.89208984, 0.59375, 0.5078125, 0.5786133, 0.6035156, 0.5839844, 0.53564453, 0.52246094] ) - assert ( - np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - ), f"Expected: {expected_slice}, got: {image_slice.flatten()}" + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, ( + f"Expected: {expected_slice}, got: {image_slice.flatten()}" + ) def test_inference_batch_single_identical(self): self._test_inference_batch_single_identical( diff --git a/tests/pipelines/controlnet_sd3/test_controlnet_inpaint_sd3.py b/tests/pipelines/controlnet_sd3/test_controlnet_inpaint_sd3.py index 2cd57ce56d..d9f5dcad7d 100644 --- a/tests/pipelines/controlnet_sd3/test_controlnet_inpaint_sd3.py +++ b/tests/pipelines/controlnet_sd3/test_controlnet_inpaint_sd3.py @@ -194,9 +194,9 @@ class StableDiffusion3ControlInpaintNetPipelineFastTests(unittest.TestCase, Pipe [0.51708984, 0.7421875, 0.4580078, 0.6435547, 0.65625, 0.43603516, 0.5151367, 0.65722656, 0.60839844] ) - assert ( - np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - ), f"Expected: {expected_slice}, got: {image_slice.flatten()}" + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, ( + f"Expected: {expected_slice}, got: {image_slice.flatten()}" + ) @unittest.skip("xFormersAttnProcessor does not work with SD3 Joint Attention") def test_xformers_attention_forwardGenerator_pass(self): diff --git a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py index 84ce09acbe..1be15645ef 100644 --- a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py +++ b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py @@ -202,9 +202,9 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes else: expected_slice = np.array([1.0000, 0.9072, 0.4209, 0.2744, 0.5737, 0.3840, 0.6113, 0.6250, 0.6328]) - assert ( - np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - ), f"Expected: {expected_slice}, got: {image_slice.flatten()}" + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, ( + f"Expected: {expected_slice}, got: {image_slice.flatten()}" + ) def test_controlnet_sd3(self): components = self.get_dummy_components() diff --git a/tests/pipelines/dit/test_dit.py b/tests/pipelines/dit/test_dit.py index 30883ac4a6..18732c0058 100644 --- a/tests/pipelines/dit/test_dit.py +++ b/tests/pipelines/dit/test_dit.py @@ -149,8 +149,7 @@ class DiTPipelineIntegrationTests(unittest.TestCase): for word, image in zip(words, images): expected_image = load_numpy( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - f"/dit/{word}_512.npy" + f"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/dit/{word}_512.npy" ) assert np.abs((expected_image - image).max()) < 1e-1 diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py index 6a560367a5..646ad928ec 100644 --- a/tests/pipelines/flux/test_pipeline_flux.py +++ b/tests/pipelines/flux/test_pipeline_flux.py @@ -170,9 +170,9 @@ class FluxPipelineFastTests( # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added # to the pipeline level. pipe.transformer.fuse_qkv_projections() - assert check_qkv_fusion_processors_exist( - pipe.transformer - ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + assert check_qkv_fusion_processors_exist(pipe.transformer), ( + "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + ) assert check_qkv_fusion_matches_attn_procs_length( pipe.transformer, pipe.transformer.original_attn_processors ), "Something wrong with the attention processors concerning the fused QKV projections." @@ -186,15 +186,15 @@ class FluxPipelineFastTests( image = pipe(**inputs).images image_slice_disabled = image[0, -3:, -3:, -1] - assert np.allclose( - original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3 - ), "Fusion of QKV projections shouldn't affect the outputs." - assert np.allclose( - image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3 - ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." - assert np.allclose( - original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 - ), "Original outputs should match when fused QKV projections are disabled." + assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), ( + "Fusion of QKV projections shouldn't affect the outputs." + ) + assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), ( + "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." + ) + assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), ( + "Original outputs should match when fused QKV projections are disabled." + ) def test_flux_image_output_shape(self): pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) diff --git a/tests/pipelines/flux/test_pipeline_flux_control.py b/tests/pipelines/flux/test_pipeline_flux_control.py index d8293952ad..d8d0774e1e 100644 --- a/tests/pipelines/flux/test_pipeline_flux_control.py +++ b/tests/pipelines/flux/test_pipeline_flux_control.py @@ -140,9 +140,9 @@ class FluxControlPipelineFastTests(unittest.TestCase, PipelineTesterMixin): # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added # to the pipeline level. pipe.transformer.fuse_qkv_projections() - assert check_qkv_fusion_processors_exist( - pipe.transformer - ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + assert check_qkv_fusion_processors_exist(pipe.transformer), ( + "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + ) assert check_qkv_fusion_matches_attn_procs_length( pipe.transformer, pipe.transformer.original_attn_processors ), "Something wrong with the attention processors concerning the fused QKV projections." @@ -156,15 +156,15 @@ class FluxControlPipelineFastTests(unittest.TestCase, PipelineTesterMixin): image = pipe(**inputs).images image_slice_disabled = image[0, -3:, -3:, -1] - assert np.allclose( - original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3 - ), "Fusion of QKV projections shouldn't affect the outputs." - assert np.allclose( - image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3 - ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." - assert np.allclose( - original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 - ), "Original outputs should match when fused QKV projections are disabled." + assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), ( + "Fusion of QKV projections shouldn't affect the outputs." + ) + assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), ( + "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." + ) + assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), ( + "Original outputs should match when fused QKV projections are disabled." + ) def test_flux_image_output_shape(self): pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) diff --git a/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py b/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py index 44ce2a4ded..a2f7c91710 100644 --- a/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py +++ b/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py @@ -134,9 +134,9 @@ class FluxControlInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added # to the pipeline level. pipe.transformer.fuse_qkv_projections() - assert check_qkv_fusion_processors_exist( - pipe.transformer - ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + assert check_qkv_fusion_processors_exist(pipe.transformer), ( + "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + ) assert check_qkv_fusion_matches_attn_procs_length( pipe.transformer, pipe.transformer.original_attn_processors ), "Something wrong with the attention processors concerning the fused QKV projections." @@ -150,15 +150,15 @@ class FluxControlInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin image = pipe(**inputs).images image_slice_disabled = image[0, -3:, -3:, -1] - assert np.allclose( - original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3 - ), "Fusion of QKV projections shouldn't affect the outputs." - assert np.allclose( - image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3 - ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." - assert np.allclose( - original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 - ), "Original outputs should match when fused QKV projections are disabled." + assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), ( + "Fusion of QKV projections shouldn't affect the outputs." + ) + assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), ( + "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." + ) + assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), ( + "Original outputs should match when fused QKV projections are disabled." + ) def test_flux_image_output_shape(self): pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) diff --git a/tests/pipelines/hunyuandit/test_hunyuan_dit.py b/tests/pipelines/hunyuandit/test_hunyuan_dit.py index 5b1a82eda2..66453b73b0 100644 --- a/tests/pipelines/hunyuandit/test_hunyuan_dit.py +++ b/tests/pipelines/hunyuandit/test_hunyuan_dit.py @@ -21,12 +21,7 @@ import numpy as np import torch from transformers import AutoTokenizer, BertModel, T5EncoderModel -from diffusers import ( - AutoencoderKL, - DDPMScheduler, - HunyuanDiT2DModel, - HunyuanDiTPipeline, -) +from diffusers import AutoencoderKL, DDPMScheduler, HunyuanDiT2DModel, HunyuanDiTPipeline from diffusers.utils.testing_utils import ( enable_full_determinism, numpy_cosine_similarity_distance, @@ -179,9 +174,9 @@ class HunyuanDiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase): # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added # to the pipeline level. pipe.transformer.fuse_qkv_projections() - assert check_qkv_fusion_processors_exist( - pipe.transformer - ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + assert check_qkv_fusion_processors_exist(pipe.transformer), ( + "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + ) assert check_qkv_fusion_matches_attn_procs_length( pipe.transformer, pipe.transformer.original_attn_processors ), "Something wrong with the attention processors concerning the fused QKV projections." @@ -197,15 +192,15 @@ class HunyuanDiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase): image_disabled = pipe(**inputs)[0] image_slice_disabled = image_disabled[0, -3:, -3:, -1] - assert np.allclose( - original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2 - ), "Fusion of QKV projections shouldn't affect the outputs." - assert np.allclose( - image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2 - ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." - assert np.allclose( - original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 - ), "Original outputs should match when fused QKV projections are disabled." + assert np.allclose(original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2), ( + "Fusion of QKV projections shouldn't affect the outputs." + ) + assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2), ( + "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." + ) + assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), ( + "Original outputs should match when fused QKV projections are disabled." + ) @unittest.skip( "Test not supported as `encode_prompt` is called two times separately which deivates from about 99% of the pipelines we have." diff --git a/tests/pipelines/kandinsky/test_kandinsky.py b/tests/pipelines/kandinsky/test_kandinsky.py index 30144e37a9..f4de6f3a53 100644 --- a/tests/pipelines/kandinsky/test_kandinsky.py +++ b/tests/pipelines/kandinsky/test_kandinsky.py @@ -240,12 +240,12 @@ class KandinskyPipelineFastTests(PipelineTesterMixin, unittest.TestCase): expected_slice = np.array([1.0000, 1.0000, 0.2766, 1.0000, 0.5447, 0.1737, 1.0000, 0.4316, 0.9024]) - assert ( - np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}" - assert ( - np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 - ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}" + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, ( + f" expected_slice {expected_slice}, but got {image_slice.flatten()}" + ) + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, ( + f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}" + ) @require_torch_accelerator def test_offloads(self): diff --git a/tests/pipelines/kandinsky/test_kandinsky_combined.py b/tests/pipelines/kandinsky/test_kandinsky_combined.py index c5f27a9cc9..f14a741d7d 100644 --- a/tests/pipelines/kandinsky/test_kandinsky_combined.py +++ b/tests/pipelines/kandinsky/test_kandinsky_combined.py @@ -98,12 +98,12 @@ class KandinskyPipelineCombinedFastTests(PipelineTesterMixin, unittest.TestCase) expected_slice = np.array([0.2893, 0.1464, 0.4603, 0.3529, 0.4612, 0.7701, 0.4027, 0.3051, 0.5155]) - assert ( - np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}" - assert ( - np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 - ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}" + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, ( + f" expected_slice {expected_slice}, but got {image_slice.flatten()}" + ) + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, ( + f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}" + ) @require_torch_accelerator def test_offloads(self): @@ -206,12 +206,12 @@ class KandinskyPipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest.Te expected_slice = np.array([0.4852, 0.4136, 0.4539, 0.4781, 0.4680, 0.5217, 0.4973, 0.4089, 0.4977]) - assert ( - np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}" - assert ( - np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 - ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}" + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, ( + f" expected_slice {expected_slice}, but got {image_slice.flatten()}" + ) + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, ( + f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}" + ) @require_torch_accelerator def test_offloads(self): @@ -318,12 +318,12 @@ class KandinskyPipelineInpaintCombinedFastTests(PipelineTesterMixin, unittest.Te expected_slice = np.array([0.0320, 0.0860, 0.4013, 0.0518, 0.2484, 0.5847, 0.4411, 0.2321, 0.4593]) - assert ( - np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}" - assert ( - np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 - ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}" + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, ( + f" expected_slice {expected_slice}, but got {image_slice.flatten()}" + ) + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, ( + f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}" + ) @require_torch_accelerator def test_offloads(self): diff --git a/tests/pipelines/kandinsky/test_kandinsky_img2img.py b/tests/pipelines/kandinsky/test_kandinsky_img2img.py index 26361ce18b..1697099780 100644 --- a/tests/pipelines/kandinsky/test_kandinsky_img2img.py +++ b/tests/pipelines/kandinsky/test_kandinsky_img2img.py @@ -261,12 +261,12 @@ class KandinskyImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase): assert image.shape == (1, 64, 64, 3) expected_slice = np.array([0.5816, 0.5872, 0.4634, 0.5982, 0.4767, 0.4710, 0.4669, 0.4717, 0.4966]) - assert ( - np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}" - assert ( - np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 - ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}" + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, ( + f" expected_slice {expected_slice}, but got {image_slice.flatten()}" + ) + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, ( + f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}" + ) @require_torch_accelerator def test_offloads(self): @@ -321,7 +321,7 @@ class KandinskyImg2ImgPipelineIntegrationTests(unittest.TestCase): ) init_image = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/cat.png" + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png" ) prompt = "A red cartoon frog, 4k" @@ -387,7 +387,7 @@ class KandinskyImg2ImgPipelineNightlyTests(unittest.TestCase): ) init_image = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/frog.png" + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/frog.png" ) prompt = "A red cartoon frog, 4k" diff --git a/tests/pipelines/kandinsky/test_kandinsky_inpaint.py b/tests/pipelines/kandinsky/test_kandinsky_inpaint.py index e30c601b60..d4d5c4e48f 100644 --- a/tests/pipelines/kandinsky/test_kandinsky_inpaint.py +++ b/tests/pipelines/kandinsky/test_kandinsky_inpaint.py @@ -256,12 +256,12 @@ class KandinskyInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase): expected_slice = np.array([0.8222, 0.8896, 0.4373, 0.8088, 0.4905, 0.2609, 0.6816, 0.4291, 0.5129]) - assert ( - np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}" - assert ( - np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 - ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}" + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, ( + f" expected_slice {expected_slice}, but got {image_slice.flatten()}" + ) + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, ( + f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}" + ) def test_inference_batch_single_identical(self): super().test_inference_batch_single_identical(expected_max_diff=3e-3) @@ -319,7 +319,7 @@ class KandinskyInpaintPipelineIntegrationTests(unittest.TestCase): ) init_image = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/cat.png" + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png" ) mask = np.zeros((768, 768), dtype=np.float32) mask[:250, 250:-250] = 1 diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky.py b/tests/pipelines/kandinsky2_2/test_kandinsky.py index fea49d47b7..aa17f6fc5d 100644 --- a/tests/pipelines/kandinsky2_2/test_kandinsky.py +++ b/tests/pipelines/kandinsky2_2/test_kandinsky.py @@ -210,13 +210,13 @@ class KandinskyV22PipelineFastTests(PipelineTesterMixin, unittest.TestCase): expected_slice = np.array([0.3420, 0.9505, 0.3919, 1.0000, 0.5188, 0.3109, 0.6139, 0.5624, 0.6811]) - assert ( - np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}" + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, ( + f" expected_slice {expected_slice}, but got {image_slice.flatten()}" + ) - assert ( - np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 - ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}" + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, ( + f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}" + ) def test_float16_inference(self): super().test_float16_inference(expected_max_diff=1e-1) diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py b/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py index 90f8b20341..17ef3dc260 100644 --- a/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py +++ b/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py @@ -103,12 +103,12 @@ class KandinskyV22PipelineCombinedFastTests(PipelineTesterMixin, unittest.TestCa expected_slice = np.array([0.3076, 0.2729, 0.5668, 0.0522, 0.3384, 0.7028, 0.4908, 0.3659, 0.6243]) - assert ( - np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}" - assert ( - np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 - ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}" + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, ( + f" expected_slice {expected_slice}, but got {image_slice.flatten()}" + ) + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, ( + f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}" + ) @require_torch_accelerator def test_offloads(self): @@ -227,12 +227,12 @@ class KandinskyV22PipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest expected_slice = np.array([0.4445, 0.4287, 0.4596, 0.3919, 0.3730, 0.5039, 0.4834, 0.4269, 0.5521]) - assert ( - np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}" - assert ( - np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 - ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}" + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, ( + f" expected_slice {expected_slice}, but got {image_slice.flatten()}" + ) + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, ( + f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}" + ) @require_torch_accelerator def test_offloads(self): @@ -350,12 +350,12 @@ class KandinskyV22PipelineInpaintCombinedFastTests(PipelineTesterMixin, unittest expected_slice = np.array([0.5039, 0.4926, 0.4898, 0.4978, 0.4838, 0.4942, 0.4738, 0.4702, 0.4816]) - assert ( - np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}" - assert ( - np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 - ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}" + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, ( + f" expected_slice {expected_slice}, but got {image_slice.flatten()}" + ) + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, ( + f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}" + ) @require_torch_accelerator def test_offloads(self): diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet.py b/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet.py index 1f3219e0d6..10a95d6177 100644 --- a/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet.py +++ b/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet.py @@ -210,13 +210,13 @@ class KandinskyV22ControlnetPipelineFastTests(PipelineTesterMixin, unittest.Test [0.6959826, 0.868279, 0.7558092, 0.68769467, 0.85805804, 0.65977496, 0.44885302, 0.5959111, 0.4251595] ) - assert ( - np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}" + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, ( + f" expected_slice {expected_slice}, but got {image_slice.flatten()}" + ) - assert ( - np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 - ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}" + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, ( + f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}" + ) def test_float16_inference(self): super().test_float16_inference(expected_max_diff=1e-1) diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet_img2img.py b/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet_img2img.py index 20944aa3d6..58fbbecc05 100644 --- a/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet_img2img.py +++ b/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet_img2img.py @@ -218,12 +218,12 @@ class KandinskyV22ControlnetImg2ImgPipelineFastTests(PipelineTesterMixin, unitte expected_slice = np.array( [0.54985034, 0.55509365, 0.52561504, 0.5570494, 0.5593818, 0.5263979, 0.50285643, 0.5069846, 0.51196736] ) - assert ( - np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}" - assert ( - np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 - ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}" + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, ( + f" expected_slice {expected_slice}, but got {image_slice.flatten()}" + ) + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, ( + f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}" + ) def test_inference_batch_single_identical(self): super().test_inference_batch_single_identical(expected_max_diff=1.75e-3) @@ -254,7 +254,7 @@ class KandinskyV22ControlnetImg2ImgPipelineIntegrationTests(unittest.TestCase): ) init_image = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/cat.png" + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png" ) init_image = init_image.resize((512, 512)) diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_img2img.py b/tests/pipelines/kandinsky2_2/test_kandinsky_img2img.py index 4702f473a9..aa7589a212 100644 --- a/tests/pipelines/kandinsky2_2/test_kandinsky_img2img.py +++ b/tests/pipelines/kandinsky2_2/test_kandinsky_img2img.py @@ -228,12 +228,12 @@ class KandinskyV22Img2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCas assert image.shape == (1, 64, 64, 3) expected_slice = np.array([0.5712, 0.5443, 0.4725, 0.6195, 0.5184, 0.4651, 0.4473, 0.4590, 0.5016]) - assert ( - np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}" - assert ( - np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 - ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}" + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, ( + f" expected_slice {expected_slice}, but got {image_slice.flatten()}" + ) + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, ( + f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}" + ) def test_float16_inference(self): super().test_float16_inference(expected_max_diff=2e-1) @@ -261,7 +261,7 @@ class KandinskyV22Img2ImgPipelineIntegrationTests(unittest.TestCase): ) init_image = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/cat.png" + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png" ) prompt = "A red cartoon frog, 4k" diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py b/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py index 9a7f659e53..d7ac698207 100644 --- a/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py +++ b/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py @@ -234,12 +234,12 @@ class KandinskyV22InpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCas [0.50775903, 0.49527195, 0.48824543, 0.50192237, 0.48644906, 0.49373814, 0.4780598, 0.47234827, 0.48327848] ) - assert ( - np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}" - assert ( - np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 - ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}" + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, ( + f" expected_slice {expected_slice}, but got {image_slice.flatten()}" + ) + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, ( + f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}" + ) def test_inference_batch_single_identical(self): super().test_inference_batch_single_identical(expected_max_diff=3e-3) @@ -314,7 +314,7 @@ class KandinskyV22InpaintPipelineIntegrationTests(unittest.TestCase): ) init_image = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/cat.png" + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png" ) mask = np.zeros((768, 768), dtype=np.float32) mask[:250, 250:-250] = 1 diff --git a/tests/pipelines/kandinsky3/test_kandinsky3.py b/tests/pipelines/kandinsky3/test_kandinsky3.py index af1d45ff89..c54b91f024 100644 --- a/tests/pipelines/kandinsky3/test_kandinsky3.py +++ b/tests/pipelines/kandinsky3/test_kandinsky3.py @@ -157,9 +157,9 @@ class Kandinsky3PipelineFastTests(PipelineTesterMixin, unittest.TestCase): expected_slice = np.array([0.3768, 0.4373, 0.4865, 0.4890, 0.4299, 0.5122, 0.4921, 0.4924, 0.5599]) - assert ( - np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}" + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, ( + f" expected_slice {expected_slice}, but got {image_slice.flatten()}" + ) def test_float16_inference(self): super().test_float16_inference(expected_max_diff=1e-1) diff --git a/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py b/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py index e00948621a..088c32e286 100644 --- a/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py +++ b/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py @@ -181,9 +181,9 @@ class Kandinsky3Img2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase) [0.576259, 0.6132097, 0.41703486, 0.603196, 0.62062526, 0.4655338, 0.5434324, 0.5660727, 0.65433365] ) - assert ( - np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}" + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, ( + f" expected_slice {expected_slice}, but got {image_slice.flatten()}" + ) def test_float16_inference(self): super().test_float16_inference(expected_max_diff=1e-1) diff --git a/tests/pipelines/pag/test_pag_animatediff.py b/tests/pipelines/pag/test_pag_animatediff.py index 6fa9627540..b9ce29c70b 100644 --- a/tests/pipelines/pag/test_pag_animatediff.py +++ b/tests/pipelines/pag/test_pag_animatediff.py @@ -450,9 +450,9 @@ class AnimateDiffPAGPipelineFastTests( inputs = self.get_dummy_inputs(device) del inputs["pag_scale"] - assert ( - "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters - ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}." + assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, ( + f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}." + ) out = pipe_sd(**inputs).frames[0, -3:, -3:, -1] components = self.get_dummy_components() diff --git a/tests/pipelines/pag/test_pag_controlnet_sd.py b/tests/pipelines/pag/test_pag_controlnet_sd.py index ee97b0507a..02232c7379 100644 --- a/tests/pipelines/pag/test_pag_controlnet_sd.py +++ b/tests/pipelines/pag/test_pag_controlnet_sd.py @@ -169,9 +169,9 @@ class StableDiffusionControlNetPAGPipelineFastTests( inputs = self.get_dummy_inputs(device) del inputs["pag_scale"] - assert ( - "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters - ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}." + assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, ( + f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}." + ) out = pipe_sd(**inputs).images[0, -3:, -3:, -1] # pag disabled with pag_scale=0.0 diff --git a/tests/pipelines/pag/test_pag_controlnet_sd_inpaint.py b/tests/pipelines/pag/test_pag_controlnet_sd_inpaint.py index 25ef5d253d..cfc0b218d2 100644 --- a/tests/pipelines/pag/test_pag_controlnet_sd_inpaint.py +++ b/tests/pipelines/pag/test_pag_controlnet_sd_inpaint.py @@ -165,9 +165,9 @@ class StableDiffusionControlNetPAGInpaintPipelineFastTests( inputs = self.get_dummy_inputs(device) del inputs["pag_scale"] - assert ( - "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters - ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__calss__.__name__}." + assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, ( + f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__calss__.__name__}." + ) out = pipe_sd(**inputs).images[0, -3:, -3:, -1] # pag disabled with pag_scale=0.0 diff --git a/tests/pipelines/pag/test_pag_controlnet_sdxl.py b/tests/pipelines/pag/test_pag_controlnet_sdxl.py index 0588e26286..10adff7fe0 100644 --- a/tests/pipelines/pag/test_pag_controlnet_sdxl.py +++ b/tests/pipelines/pag/test_pag_controlnet_sdxl.py @@ -187,9 +187,9 @@ class StableDiffusionXLControlNetPAGPipelineFastTests( inputs = self.get_dummy_inputs(device) del inputs["pag_scale"] - assert ( - "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters - ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}." + assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, ( + f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}." + ) out = pipe_sd(**inputs).images[0, -3:, -3:, -1] # pag disabled with pag_scale=0.0 diff --git a/tests/pipelines/pag/test_pag_controlnet_sdxl_img2img.py b/tests/pipelines/pag/test_pag_controlnet_sdxl_img2img.py index 63c7d9fbee..fe4b615f64 100644 --- a/tests/pipelines/pag/test_pag_controlnet_sdxl_img2img.py +++ b/tests/pipelines/pag/test_pag_controlnet_sdxl_img2img.py @@ -189,9 +189,9 @@ class StableDiffusionXLControlNetPAGImg2ImgPipelineFastTests( inputs = self.get_dummy_inputs(device) del inputs["pag_scale"] - assert ( - "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters - ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}." + assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, ( + f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}." + ) out = pipe_sd(**inputs).images[0, -3:, -3:, -1] # pag disabled with pag_scale=0.0 diff --git a/tests/pipelines/pag/test_pag_hunyuan_dit.py b/tests/pipelines/pag/test_pag_hunyuan_dit.py index 31cd9aa666..d6cfbbed9e 100644 --- a/tests/pipelines/pag/test_pag_hunyuan_dit.py +++ b/tests/pipelines/pag/test_pag_hunyuan_dit.py @@ -177,15 +177,15 @@ class HunyuanDiTPAGPipelineFastTests(PipelineTesterMixin, unittest.TestCase): image_disabled = pipe(**inputs)[0] image_slice_disabled = image_disabled[0, -3:, -3:, -1] - assert np.allclose( - original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2 - ), "Fusion of QKV projections shouldn't affect the outputs." - assert np.allclose( - image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2 - ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." - assert np.allclose( - original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 - ), "Original outputs should match when fused QKV projections are disabled." + assert np.allclose(original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2), ( + "Fusion of QKV projections shouldn't affect the outputs." + ) + assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2), ( + "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." + ) + assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), ( + "Original outputs should match when fused QKV projections are disabled." + ) def test_pag_disable_enable(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator @@ -198,9 +198,9 @@ class HunyuanDiTPAGPipelineFastTests(PipelineTesterMixin, unittest.TestCase): inputs = self.get_dummy_inputs(device) del inputs["pag_scale"] - assert ( - "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters - ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}." + assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, ( + f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}." + ) out = pipe_sd(**inputs).images[0, -3:, -3:, -1] components = self.get_dummy_components() diff --git a/tests/pipelines/pag/test_pag_kolors.py b/tests/pipelines/pag/test_pag_kolors.py index 9a4f1daa2c..c9f197b703 100644 --- a/tests/pipelines/pag/test_pag_kolors.py +++ b/tests/pipelines/pag/test_pag_kolors.py @@ -140,9 +140,9 @@ class KolorsPAGPipelineFastTests( inputs = self.get_dummy_inputs(device) del inputs["pag_scale"] - assert ( - "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters - ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}." + assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, ( + f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}." + ) out = pipe_sd(**inputs).images[0, -3:, -3:, -1] # pag disabled with pag_scale=0.0 diff --git a/tests/pipelines/pag/test_pag_pixart_sigma.py b/tests/pipelines/pag/test_pag_pixart_sigma.py index 63f42416db..624b578443 100644 --- a/tests/pipelines/pag/test_pag_pixart_sigma.py +++ b/tests/pipelines/pag/test_pag_pixart_sigma.py @@ -120,9 +120,9 @@ class PixArtSigmaPAGPipelineFastTests(PipelineTesterMixin, unittest.TestCase): inputs = self.get_dummy_inputs(device) del inputs["pag_scale"] - assert ( - "pag_scale" not in inspect.signature(pipe.__call__).parameters - ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe.__class__.__name__}." + assert "pag_scale" not in inspect.signature(pipe.__call__).parameters, ( + f"`pag_scale` should not be a call parameter of the base pipeline {pipe.__class__.__name__}." + ) out = pipe(**inputs).images[0, -3:, -3:, -1] # pag disabled with pag_scale=0.0 diff --git a/tests/pipelines/pag/test_pag_sana.py b/tests/pipelines/pag/test_pag_sana.py index a2c6572978..ee1e359383 100644 --- a/tests/pipelines/pag/test_pag_sana.py +++ b/tests/pipelines/pag/test_pag_sana.py @@ -268,9 +268,9 @@ class SanaPAGPipelineFastTests(PipelineTesterMixin, unittest.TestCase): inputs = self.get_dummy_inputs(device) del inputs["pag_scale"] - assert ( - "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters - ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}." + assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, ( + f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}." + ) out = pipe_sd(**inputs).images[0, -3:, -3:, -1] components = self.get_dummy_components() diff --git a/tests/pipelines/pag/test_pag_sd.py b/tests/pipelines/pag/test_pag_sd.py index d4cf00b034..bc20226873 100644 --- a/tests/pipelines/pag/test_pag_sd.py +++ b/tests/pipelines/pag/test_pag_sd.py @@ -154,9 +154,9 @@ class StableDiffusionPAGPipelineFastTests( inputs = self.get_dummy_inputs(device) del inputs["pag_scale"] - assert ( - "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters - ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}." + assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, ( + f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}." + ) out = pipe_sd(**inputs).images[0, -3:, -3:, -1] # pag disabled with pag_scale=0.0 @@ -328,9 +328,9 @@ class StableDiffusionPAGPipelineIntegrationTests(unittest.TestCase): expected_slice = np.array( [0.58251953, 0.5722656, 0.5683594, 0.55029297, 0.52001953, 0.52001953, 0.49951172, 0.45410156, 0.50146484] ) - assert ( - np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 - ), f"output is different from expected, {image_slice.flatten()}" + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, ( + f"output is different from expected, {image_slice.flatten()}" + ) def test_pag_uncond(self): pipeline = AutoPipelineForText2Image.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16) @@ -345,6 +345,6 @@ class StableDiffusionPAGPipelineIntegrationTests(unittest.TestCase): expected_slice = np.array( [0.5986328, 0.52441406, 0.3972168, 0.4741211, 0.34985352, 0.22705078, 0.4128418, 0.2866211, 0.31713867] ) - assert ( - np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 - ), f"output is different from expected, {image_slice.flatten()}" + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, ( + f"output is different from expected, {image_slice.flatten()}" + ) diff --git a/tests/pipelines/pag/test_pag_sd3.py b/tests/pipelines/pag/test_pag_sd3.py index 41ff0c3c09..737e238e5f 100644 --- a/tests/pipelines/pag/test_pag_sd3.py +++ b/tests/pipelines/pag/test_pag_sd3.py @@ -170,9 +170,9 @@ class StableDiffusion3PAGPipelineFastTests(unittest.TestCase, PipelineTesterMixi # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added # to the pipeline level. pipe.transformer.fuse_qkv_projections() - assert check_qkv_fusion_processors_exist( - pipe.transformer - ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + assert check_qkv_fusion_processors_exist(pipe.transformer), ( + "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + ) assert check_qkv_fusion_matches_attn_procs_length( pipe.transformer, pipe.transformer.original_attn_processors ), "Something wrong with the attention processors concerning the fused QKV projections." @@ -186,15 +186,15 @@ class StableDiffusion3PAGPipelineFastTests(unittest.TestCase, PipelineTesterMixi image = pipe(**inputs).images image_slice_disabled = image[0, -3:, -3:, -1] - assert np.allclose( - original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3 - ), "Fusion of QKV projections shouldn't affect the outputs." - assert np.allclose( - image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3 - ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." - assert np.allclose( - original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 - ), "Original outputs should match when fused QKV projections are disabled." + assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), ( + "Fusion of QKV projections shouldn't affect the outputs." + ) + assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), ( + "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." + ) + assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), ( + "Original outputs should match when fused QKV projections are disabled." + ) def test_pag_disable_enable(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator @@ -207,9 +207,9 @@ class StableDiffusion3PAGPipelineFastTests(unittest.TestCase, PipelineTesterMixi inputs = self.get_dummy_inputs(device) del inputs["pag_scale"] - assert ( - "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters - ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}." + assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, ( + f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}." + ) out = pipe_sd(**inputs).images[0, -3:, -3:, -1] components = self.get_dummy_components() diff --git a/tests/pipelines/pag/test_pag_sd3_img2img.py b/tests/pipelines/pag/test_pag_sd3_img2img.py index 2fe9889291..fe593d47dc 100644 --- a/tests/pipelines/pag/test_pag_sd3_img2img.py +++ b/tests/pipelines/pag/test_pag_sd3_img2img.py @@ -149,9 +149,9 @@ class StableDiffusion3PAGImg2ImgPipelineFastTests(unittest.TestCase, PipelineTes inputs = self.get_dummy_inputs(device) del inputs["pag_scale"] - assert ( - "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters - ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}." + assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, ( + f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}." + ) out = pipe_sd(**inputs).images[0, -3:, -3:, -1] components = self.get_dummy_components() @@ -254,9 +254,9 @@ class StableDiffusion3PAGImg2ImgPipelineIntegrationTests(unittest.TestCase): 0.17822266, ] ) - assert ( - np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 - ), f"output is different from expected, {image_slice.flatten()}" + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, ( + f"output is different from expected, {image_slice.flatten()}" + ) def test_pag_uncond(self): pipeline = AutoPipelineForImage2Image.from_pretrained( @@ -272,6 +272,6 @@ class StableDiffusion3PAGImg2ImgPipelineIntegrationTests(unittest.TestCase): expected_slice = np.array( [0.1508789, 0.16210938, 0.17138672, 0.16210938, 0.17089844, 0.16137695, 0.16235352, 0.16430664, 0.16455078] ) - assert ( - np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 - ), f"output is different from expected, {image_slice.flatten()}" + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, ( + f"output is different from expected, {image_slice.flatten()}" + ) diff --git a/tests/pipelines/pag/test_pag_sd_img2img.py b/tests/pipelines/pag/test_pag_sd_img2img.py index d000493d6b..ef70985571 100644 --- a/tests/pipelines/pag/test_pag_sd_img2img.py +++ b/tests/pipelines/pag/test_pag_sd_img2img.py @@ -161,9 +161,9 @@ class StableDiffusionPAGImg2ImgPipelineFastTests( inputs = self.get_dummy_inputs(device) del inputs["pag_scale"] - assert ( - "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters - ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}." + assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, ( + f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}." + ) out = pipe_sd(**inputs).images[0, -3:, -3:, -1] # pag disabled with pag_scale=0.0 @@ -267,9 +267,9 @@ class StableDiffusionPAGImg2ImgPipelineIntegrationTests(unittest.TestCase): expected_slice = np.array( [0.58251953, 0.5722656, 0.5683594, 0.55029297, 0.52001953, 0.52001953, 0.49951172, 0.45410156, 0.50146484] ) - assert ( - np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 - ), f"output is different from expected, {image_slice.flatten()}" + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, ( + f"output is different from expected, {image_slice.flatten()}" + ) def test_pag_uncond(self): pipeline = AutoPipelineForImage2Image.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16) @@ -285,6 +285,6 @@ class StableDiffusionPAGImg2ImgPipelineIntegrationTests(unittest.TestCase): [0.5986328, 0.52441406, 0.3972168, 0.4741211, 0.34985352, 0.22705078, 0.4128418, 0.2866211, 0.31713867] ) - assert ( - np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 - ), f"output is different from expected, {image_slice.flatten()}" + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, ( + f"output is different from expected, {image_slice.flatten()}" + ) diff --git a/tests/pipelines/pag/test_pag_sd_inpaint.py b/tests/pipelines/pag/test_pag_sd_inpaint.py index 06682c111d..04ec8b2165 100644 --- a/tests/pipelines/pag/test_pag_sd_inpaint.py +++ b/tests/pipelines/pag/test_pag_sd_inpaint.py @@ -302,9 +302,9 @@ class StableDiffusionPAGPipelineIntegrationTests(unittest.TestCase): expected_slice = np.array( [0.38793945, 0.4111328, 0.47924805, 0.39208984, 0.4165039, 0.41674805, 0.37060547, 0.36791992, 0.40625] ) - assert ( - np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 - ), f"output is different from expected, {image_slice.flatten()}" + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, ( + f"output is different from expected, {image_slice.flatten()}" + ) def test_pag_uncond(self): pipeline = AutoPipelineForInpainting.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16) @@ -319,6 +319,6 @@ class StableDiffusionPAGPipelineIntegrationTests(unittest.TestCase): expected_slice = np.array( [0.3876953, 0.40356445, 0.4934082, 0.39697266, 0.41674805, 0.41015625, 0.375, 0.36914062, 0.40649414] ) - assert ( - np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 - ), f"output is different from expected, {image_slice.flatten()}" + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, ( + f"output is different from expected, {image_slice.flatten()}" + ) diff --git a/tests/pipelines/pag/test_pag_sdxl.py b/tests/pipelines/pag/test_pag_sdxl.py index b35b2b1d2f..fc4ce1067f 100644 --- a/tests/pipelines/pag/test_pag_sdxl.py +++ b/tests/pipelines/pag/test_pag_sdxl.py @@ -167,9 +167,9 @@ class StableDiffusionXLPAGPipelineFastTests( inputs = self.get_dummy_inputs(device) del inputs["pag_scale"] - assert ( - "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters - ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}." + assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, ( + f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}." + ) out = pipe_sd(**inputs).images[0, -3:, -3:, -1] # pag disabled with pag_scale=0.0 @@ -331,9 +331,9 @@ class StableDiffusionXLPAGPipelineIntegrationTests(unittest.TestCase): expected_slice = np.array( [0.3123679, 0.31725878, 0.32026544, 0.327533, 0.3266391, 0.3303998, 0.33544615, 0.34181812, 0.34102726] ) - assert ( - np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 - ), f"output is different from expected, {image_slice.flatten()}" + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, ( + f"output is different from expected, {image_slice.flatten()}" + ) def test_pag_uncond(self): pipeline = AutoPipelineForText2Image.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16) @@ -348,6 +348,6 @@ class StableDiffusionXLPAGPipelineIntegrationTests(unittest.TestCase): expected_slice = np.array( [0.47400922, 0.48650584, 0.4839625, 0.4724013, 0.4890427, 0.49544555, 0.51707107, 0.54299414, 0.5224372] ) - assert ( - np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 - ), f"output is different from expected, {image_slice.flatten()}" + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, ( + f"output is different from expected, {image_slice.flatten()}" + ) diff --git a/tests/pipelines/pag/test_pag_sdxl_img2img.py b/tests/pipelines/pag/test_pag_sdxl_img2img.py index c94a6836de..0e5c2cc7f9 100644 --- a/tests/pipelines/pag/test_pag_sdxl_img2img.py +++ b/tests/pipelines/pag/test_pag_sdxl_img2img.py @@ -215,9 +215,9 @@ class StableDiffusionXLPAGImg2ImgPipelineFastTests( inputs = self.get_dummy_inputs(device) del inputs["pag_scale"] - assert ( - "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters - ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}." + assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, ( + f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}." + ) out = pipe_sd(**inputs).images[0, -3:, -3:, -1] # pag disabled with pag_scale=0.0 @@ -316,9 +316,9 @@ class StableDiffusionXLPAGImg2ImgPipelineIntegrationTests(unittest.TestCase): expected_slice = np.array( [0.20301354, 0.21078318, 0.2021082, 0.20277798, 0.20681083, 0.19562206, 0.20121682, 0.21562952, 0.21277016] ) - assert ( - np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 - ), f"output is different from expected, {image_slice.flatten()}" + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, ( + f"output is different from expected, {image_slice.flatten()}" + ) def test_pag_uncond(self): pipeline = AutoPipelineForImage2Image.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16) @@ -333,6 +333,6 @@ class StableDiffusionXLPAGImg2ImgPipelineIntegrationTests(unittest.TestCase): expected_slice = np.array( [0.21303111, 0.22188407, 0.2124992, 0.21365267, 0.18823743, 0.17569828, 0.21113116, 0.19419771, 0.18919235] ) - assert ( - np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 - ), f"output is different from expected, {image_slice.flatten()}" + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, ( + f"output is different from expected, {image_slice.flatten()}" + ) diff --git a/tests/pipelines/pag/test_pag_sdxl_inpaint.py b/tests/pipelines/pag/test_pag_sdxl_inpaint.py index cca5292288..854c65cbc7 100644 --- a/tests/pipelines/pag/test_pag_sdxl_inpaint.py +++ b/tests/pipelines/pag/test_pag_sdxl_inpaint.py @@ -220,9 +220,9 @@ class StableDiffusionXLPAGInpaintPipelineFastTests( inputs = self.get_dummy_inputs(device) del inputs["pag_scale"] - assert ( - "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters - ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}." + assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, ( + f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}." + ) out = pipe_sd(**inputs).images[0, -3:, -3:, -1] # pag disabled with pag_scale=0.0 @@ -322,9 +322,9 @@ class StableDiffusionXLPAGInpaintPipelineIntegrationTests(unittest.TestCase): expected_slice = np.array( [0.41385046, 0.39608297, 0.4360491, 0.26872507, 0.32187328, 0.4242474, 0.2603805, 0.34167895, 0.46561807] ) - assert ( - np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 - ), f"output is different from expected, {image_slice.flatten()}" + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, ( + f"output is different from expected, {image_slice.flatten()}" + ) def test_pag_uncond(self): pipeline = AutoPipelineForInpainting.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16) @@ -339,6 +339,6 @@ class StableDiffusionXLPAGInpaintPipelineIntegrationTests(unittest.TestCase): expected_slice = np.array( [0.41597816, 0.39302617, 0.44287828, 0.2687074, 0.28315824, 0.40582314, 0.20877528, 0.2380802, 0.39447647] ) - assert ( - np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 - ), f"output is different from expected, {image_slice.flatten()}" + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, ( + f"output is different from expected, {image_slice.flatten()}" + ) diff --git a/tests/pipelines/pixart_sigma/test_pixart.py b/tests/pipelines/pixart_sigma/test_pixart.py index b220afcfc2..7084fc9bce 100644 --- a/tests/pipelines/pixart_sigma/test_pixart.py +++ b/tests/pipelines/pixart_sigma/test_pixart.py @@ -260,9 +260,9 @@ class PixArtSigmaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added # to the pipeline level. pipe.transformer.fuse_qkv_projections() - assert check_qkv_fusion_processors_exist( - pipe.transformer - ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + assert check_qkv_fusion_processors_exist(pipe.transformer), ( + "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + ) assert check_qkv_fusion_matches_attn_procs_length( pipe.transformer, pipe.transformer.original_attn_processors ), "Something wrong with the attention processors concerning the fused QKV projections." @@ -276,15 +276,15 @@ class PixArtSigmaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): image = pipe(**inputs).images image_slice_disabled = image[0, -3:, -3:, -1] - assert np.allclose( - original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3 - ), "Fusion of QKV projections shouldn't affect the outputs." - assert np.allclose( - image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3 - ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." - assert np.allclose( - original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 - ), "Original outputs should match when fused QKV projections are disabled." + assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), ( + "Fusion of QKV projections shouldn't affect the outputs." + ) + assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), ( + "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." + ) + assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), ( + "Original outputs should match when fused QKV projections are disabled." + ) @slow diff --git a/tests/pipelines/shap_e/test_shap_e_img2img.py b/tests/pipelines/shap_e/test_shap_e_img2img.py index ac7096874b..72eee3e35e 100644 --- a/tests/pipelines/shap_e/test_shap_e_img2img.py +++ b/tests/pipelines/shap_e/test_shap_e_img2img.py @@ -266,7 +266,7 @@ class ShapEImg2ImgPipelineIntegrationTests(unittest.TestCase): def test_shap_e_img2img(self): input_image = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/shap_e/corgi.png" + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/shap_e/corgi.png" ) expected_image = load_numpy( "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" diff --git a/tests/pipelines/stable_cascade/test_stable_cascade_combined.py b/tests/pipelines/stable_cascade/test_stable_cascade_combined.py index 1765f3a022..d433a461bd 100644 --- a/tests/pipelines/stable_cascade/test_stable_cascade_combined.py +++ b/tests/pipelines/stable_cascade/test_stable_cascade_combined.py @@ -198,12 +198,12 @@ class StableCascadeCombinedPipelineFastTests(PipelineTesterMixin, unittest.TestC assert image.shape == (1, 128, 128, 3) expected_slice = np.array([0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0]) - assert ( - np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}" - assert ( - np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 - ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}" + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, ( + f" expected_slice {expected_slice}, but got {image_slice.flatten()}" + ) + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, ( + f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}" + ) @require_torch_accelerator def test_offloads(self): diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index 6e17b86639..3b5c7a24b4 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -293,15 +293,15 @@ class StableDiffusionPipelineFastTests( inputs["sigmas"] = sigma_schedule output_sigmas = sd_pipe(**inputs).images - assert ( - np.abs(output_sigmas.flatten() - output_ts.flatten()).max() < 1e-3 - ), "ays timesteps and ays sigmas should have the same outputs" - assert ( - np.abs(output.flatten() - output_ts.flatten()).max() > 1e-3 - ), "use ays timesteps should have different outputs" - assert ( - np.abs(output.flatten() - output_sigmas.flatten()).max() > 1e-3 - ), "use ays sigmas should have different outputs" + assert np.abs(output_sigmas.flatten() - output_ts.flatten()).max() < 1e-3, ( + "ays timesteps and ays sigmas should have the same outputs" + ) + assert np.abs(output.flatten() - output_ts.flatten()).max() > 1e-3, ( + "use ays timesteps should have different outputs" + ) + assert np.abs(output.flatten() - output_sigmas.flatten()).max() > 1e-3, ( + "use ays sigmas should have different outputs" + ) def test_stable_diffusion_prompt_embeds(self): components = self.get_dummy_components() @@ -656,9 +656,9 @@ class StableDiffusionPipelineFastTests( sd_pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4) output_freeu = sd_pipe(prompt, num_inference_steps=1, output_type="np", generator=torch.manual_seed(0)).images - assert not np.allclose( - output[0, -3:, -3:, -1], output_freeu[0, -3:, -3:, -1] - ), "Enabling of FreeU should lead to different results." + assert not np.allclose(output[0, -3:, -3:, -1], output_freeu[0, -3:, -3:, -1]), ( + "Enabling of FreeU should lead to different results." + ) def test_freeu_disabled(self): components = self.get_dummy_components() @@ -681,9 +681,9 @@ class StableDiffusionPipelineFastTests( prompt, num_inference_steps=1, output_type="np", generator=torch.manual_seed(0) ).images - assert np.allclose( - output[0, -3:, -3:, -1], output_no_freeu[0, -3:, -3:, -1] - ), "Disabling of FreeU should lead to results similar to the default pipeline results." + assert np.allclose(output[0, -3:, -3:, -1], output_no_freeu[0, -3:, -3:, -1]), ( + "Disabling of FreeU should lead to results similar to the default pipeline results." + ) def test_fused_qkv_projections(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator @@ -706,15 +706,15 @@ class StableDiffusionPipelineFastTests( image = sd_pipe(**inputs).images image_slice_disabled = image[0, -3:, -3:, -1] - assert np.allclose( - original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2 - ), "Fusion of QKV projections shouldn't affect the outputs." - assert np.allclose( - image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2 - ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." - assert np.allclose( - original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 - ), "Original outputs should match when fused QKV projections are disabled." + assert np.allclose(original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2), ( + "Fusion of QKV projections shouldn't affect the outputs." + ) + assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2), ( + "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." + ) + assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), ( + "Original outputs should match when fused QKV projections are disabled." + ) def test_pipeline_interrupt(self): components = self.get_dummy_components() diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py index 38ef6143f4..8e2fa77fc0 100644 --- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py +++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py @@ -171,9 +171,9 @@ class StableDiffusion3PipelineFastTests(unittest.TestCase, PipelineTesterMixin): # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added # to the pipeline level. pipe.transformer.fuse_qkv_projections() - assert check_qkv_fusion_processors_exist( - pipe.transformer - ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + assert check_qkv_fusion_processors_exist(pipe.transformer), ( + "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + ) assert check_qkv_fusion_matches_attn_procs_length( pipe.transformer, pipe.transformer.original_attn_processors ), "Something wrong with the attention processors concerning the fused QKV projections." @@ -187,15 +187,15 @@ class StableDiffusion3PipelineFastTests(unittest.TestCase, PipelineTesterMixin): image = pipe(**inputs).images image_slice_disabled = image[0, -3:, -3:, -1] - assert np.allclose( - original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3 - ), "Fusion of QKV projections shouldn't affect the outputs." - assert np.allclose( - image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3 - ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." - assert np.allclose( - original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 - ), "Original outputs should match when fused QKV projections are disabled." + assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), ( + "Fusion of QKV projections shouldn't affect the outputs." + ) + assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), ( + "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." + ) + assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), ( + "Original outputs should match when fused QKV projections are disabled." + ) def test_skip_guidance_layers(self): components = self.get_dummy_components() diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py index c68cdf6703..a41e7dc7f3 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py @@ -242,15 +242,15 @@ class StableDiffusionXLPipelineFastTests( inputs["sigmas"] = sigma_schedule output_sigmas = sd_pipe(**inputs).images - assert ( - np.abs(output_sigmas.flatten() - output_ts.flatten()).max() < 1e-3 - ), "ays timesteps and ays sigmas should have the same outputs" - assert ( - np.abs(output.flatten() - output_ts.flatten()).max() > 1e-3 - ), "use ays timesteps should have different outputs" - assert ( - np.abs(output.flatten() - output_sigmas.flatten()).max() > 1e-3 - ), "use ays sigmas should have different outputs" + assert np.abs(output_sigmas.flatten() - output_ts.flatten()).max() < 1e-3, ( + "ays timesteps and ays sigmas should have the same outputs" + ) + assert np.abs(output.flatten() - output_ts.flatten()).max() > 1e-3, ( + "use ays timesteps should have different outputs" + ) + assert np.abs(output.flatten() - output_sigmas.flatten()).max() > 1e-3, ( + "use ays sigmas should have different outputs" + ) def test_ip_adapter(self): expected_pipe_slice = None @@ -742,9 +742,9 @@ class StableDiffusionXLPipelineFastTests( inputs_1 = {**inputs, **{"denoising_end": split_1, "output_type": "latent"}} latents = pipe_1(**inputs_1).images[0] - assert ( - expected_steps_1 == done_steps - ), f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}" + assert expected_steps_1 == done_steps, ( + f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}" + ) with self.assertRaises(ValueError) as cm: inputs_2 = { @@ -771,9 +771,9 @@ class StableDiffusionXLPipelineFastTests( pipe_3(**inputs_3).images[0] assert expected_steps_3 == done_steps[len(expected_steps_1) + len(expected_steps_2) :] - assert ( - expected_steps == done_steps - ), f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}" + assert expected_steps == done_steps, ( + f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}" + ) for steps in [7, 11, 20]: for split_1, split_2 in zip([0.19, 0.32], [0.81, 0.68]): diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py index 66ae581a05..729c6981d2 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py @@ -585,9 +585,9 @@ class StableDiffusionXLInpaintPipelineFastTests( inputs_1 = {**inputs, **{"denoising_end": split_1, "output_type": "latent"}} latents = pipe_1(**inputs_1).images[0] - assert ( - expected_steps_1 == done_steps - ), f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}" + assert expected_steps_1 == done_steps, ( + f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}" + ) inputs_2 = { **inputs, @@ -601,9 +601,9 @@ class StableDiffusionXLInpaintPipelineFastTests( pipe_3(**inputs_3).images[0] assert expected_steps_3 == done_steps[len(expected_steps_1) + len(expected_steps_2) :] - assert ( - expected_steps == done_steps - ), f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}" + assert expected_steps == done_steps, ( + f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}" + ) for steps in [7, 11, 20]: for split_1, split_2 in zip([0.19, 0.32], [0.81, 0.68]): diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index ae5a12e04b..00c7636ed9 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -167,9 +167,9 @@ class DownloadTests(unittest.TestCase): download_requests = [r.method for r in m.request_history] assert download_requests.count("HEAD") == 15, "15 calls to files" assert download_requests.count("GET") == 17, "15 calls to files + model_info + model_index.json" - assert ( - len(download_requests) == 32 - ), "2 calls per file (15 files) + send_telemetry, model_info and model_index.json" + assert len(download_requests) == 32, ( + "2 calls per file (15 files) + send_telemetry, model_info and model_index.json" + ) with requests_mock.mock(real_http=True) as m: DiffusionPipeline.download( @@ -179,9 +179,9 @@ class DownloadTests(unittest.TestCase): cache_requests = [r.method for r in m.request_history] assert cache_requests.count("HEAD") == 1, "model_index.json is only HEAD" assert cache_requests.count("GET") == 1, "model info is only GET" - assert ( - len(cache_requests) == 2 - ), "We should call only `model_info` to check for _commit hash and `send_telemetry`" + assert len(cache_requests) == 2, ( + "We should call only `model_info` to check for _commit hash and `send_telemetry`" + ) def test_less_downloads_passed_object(self): with tempfile.TemporaryDirectory() as tmpdirname: @@ -217,9 +217,9 @@ class DownloadTests(unittest.TestCase): assert download_requests.count("HEAD") == 13, "13 calls to files" # 17 - 2 because no call to config or model file for `safety_checker` assert download_requests.count("GET") == 15, "13 calls to files + model_info + model_index.json" - assert ( - len(download_requests) == 28 - ), "2 calls per file (13 files) + send_telemetry, model_info and model_index.json" + assert len(download_requests) == 28, ( + "2 calls per file (13 files) + send_telemetry, model_info and model_index.json" + ) with requests_mock.mock(real_http=True) as m: DiffusionPipeline.download( @@ -229,9 +229,9 @@ class DownloadTests(unittest.TestCase): cache_requests = [r.method for r in m.request_history] assert cache_requests.count("HEAD") == 1, "model_index.json is only HEAD" assert cache_requests.count("GET") == 1, "model info is only GET" - assert ( - len(cache_requests) == 2 - ), "We should call only `model_info` to check for _commit hash and `send_telemetry`" + assert len(cache_requests) == 2, ( + "We should call only `model_info` to check for _commit hash and `send_telemetry`" + ) def test_download_only_pytorch(self): with tempfile.TemporaryDirectory() as tmpdirname: diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index b69669464d..be5245796b 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -191,12 +191,12 @@ class SDFunctionTesterMixin: inputs["output_type"] = "np" output_no_freeu = pipe(**inputs)[0] - assert not np.allclose( - output[0, -3:, -3:, -1], output_freeu[0, -3:, -3:, -1] - ), "Enabling of FreeU should lead to different results." - assert np.allclose( - output, output_no_freeu, atol=1e-2 - ), f"Disabling of FreeU should lead to results similar to the default pipeline results but Max Abs Error={np.abs(output_no_freeu - output).max()}." + assert not np.allclose(output[0, -3:, -3:, -1], output_freeu[0, -3:, -3:, -1]), ( + "Enabling of FreeU should lead to different results." + ) + assert np.allclose(output, output_no_freeu, atol=1e-2), ( + f"Disabling of FreeU should lead to results similar to the default pipeline results but Max Abs Error={np.abs(output_no_freeu - output).max()}." + ) def test_fused_qkv_projections(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator @@ -217,12 +217,12 @@ class SDFunctionTesterMixin: and hasattr(component, "original_attn_processors") and component.original_attn_processors is not None ): - assert check_qkv_fusion_processors_exist( - component - ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused." - assert check_qkv_fusion_matches_attn_procs_length( - component, component.original_attn_processors - ), "Something wrong with the attention processors concerning the fused QKV projections." + assert check_qkv_fusion_processors_exist(component), ( + "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + ) + assert check_qkv_fusion_matches_attn_procs_length(component, component.original_attn_processors), ( + "Something wrong with the attention processors concerning the fused QKV projections." + ) inputs = self.get_dummy_inputs(device) inputs["return_dict"] = False @@ -235,15 +235,15 @@ class SDFunctionTesterMixin: image_disabled = pipe(**inputs)[0] image_slice_disabled = image_disabled[0, -3:, -3:, -1] - assert np.allclose( - original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2 - ), "Fusion of QKV projections shouldn't affect the outputs." - assert np.allclose( - image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2 - ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." - assert np.allclose( - original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 - ), "Original outputs should match when fused QKV projections are disabled." + assert np.allclose(original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2), ( + "Fusion of QKV projections shouldn't affect the outputs." + ) + assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2), ( + "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." + ) + assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), ( + "Original outputs should match when fused QKV projections are disabled." + ) class IPAdapterTesterMixin: @@ -909,9 +909,9 @@ class PipelineFromPipeTesterMixin: for component in pipe_original.components.values(): if hasattr(component, "attn_processors"): - assert all( - type(proc) == AttnProcessor for proc in component.attn_processors.values() - ), "`from_pipe` changed the attention processor in original pipeline." + assert all(type(proc) == AttnProcessor for proc in component.attn_processors.values()), ( + "`from_pipe` changed the attention processor in original pipeline." + ) @require_accelerator @require_accelerate_version_greater("0.14.0") @@ -2569,12 +2569,12 @@ class PyramidAttentionBroadcastTesterMixin: image_slice_pab_disabled = output.flatten() image_slice_pab_disabled = np.concatenate((image_slice_pab_disabled[:8], image_slice_pab_disabled[-8:])) - assert np.allclose( - original_image_slice, image_slice_pab_enabled, atol=expected_atol - ), "PAB outputs should not differ much in specified timestep range." - assert np.allclose( - original_image_slice, image_slice_pab_disabled, atol=1e-4 - ), "Outputs from normal inference and after disabling cache should not differ." + assert np.allclose(original_image_slice, image_slice_pab_enabled, atol=expected_atol), ( + "PAB outputs should not differ much in specified timestep range." + ) + assert np.allclose(original_image_slice, image_slice_pab_disabled, atol=1e-4), ( + "Outputs from normal inference and after disabling cache should not differ." + ) class FasterCacheTesterMixin: @@ -2639,12 +2639,12 @@ class FasterCacheTesterMixin: output = run_forward(pipe).flatten() image_slice_faster_cache_disabled = np.concatenate((output[:8], output[-8:])) - assert np.allclose( - original_image_slice, image_slice_faster_cache_enabled, atol=expected_atol - ), "FasterCache outputs should not differ much in specified timestep range." - assert np.allclose( - original_image_slice, image_slice_faster_cache_disabled, atol=1e-4 - ), "Outputs from normal inference and after disabling cache should not differ." + assert np.allclose(original_image_slice, image_slice_faster_cache_enabled, atol=expected_atol), ( + "FasterCache outputs should not differ much in specified timestep range." + ) + assert np.allclose(original_image_slice, image_slice_faster_cache_disabled, atol=1e-4), ( + "Outputs from normal inference and after disabling cache should not differ." + ) def test_faster_cache_state(self): from diffusers.hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_combined.py b/tests/pipelines/wuerstchen/test_wuerstchen_combined.py index 084d62a8c6..fa544c91f2 100644 --- a/tests/pipelines/wuerstchen/test_wuerstchen_combined.py +++ b/tests/pipelines/wuerstchen/test_wuerstchen_combined.py @@ -191,12 +191,12 @@ class WuerstchenCombinedPipelineFastTests(PipelineTesterMixin, unittest.TestCase expected_slice = np.array([0.7616304, 0.0, 1.0, 0.0, 1.0, 0.0, 0.05925313, 0.0, 0.951898]) - assert ( - np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}" - assert ( - np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 - ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}" + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, ( + f" expected_slice {expected_slice}, but got {image_slice.flatten()}" + ) + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, ( + f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}" + ) @require_torch_accelerator def test_offloads(self): diff --git a/tests/schedulers/test_scheduler_dpm_multi.py b/tests/schedulers/test_scheduler_dpm_multi.py index 55b3202ad0..28c354709d 100644 --- a/tests/schedulers/test_scheduler_dpm_multi.py +++ b/tests/schedulers/test_scheduler_dpm_multi.py @@ -357,9 +357,9 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): prediction_type=prediction_type, final_sigmas_type=final_sigmas_type, ) - assert ( - torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5 - ), f"Scheduler outputs are not identical for algorithm_type: {algorithm_type}, prediction_type: {prediction_type} and final_sigmas_type: {final_sigmas_type}" + assert torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5, ( + f"Scheduler outputs are not identical for algorithm_type: {algorithm_type}, prediction_type: {prediction_type} and final_sigmas_type: {final_sigmas_type}" + ) def test_beta_sigmas(self): self.check_over_configs(use_beta_sigmas=True) diff --git a/tests/schedulers/test_scheduler_dpm_single.py b/tests/schedulers/test_scheduler_dpm_single.py index 7cbaa5cc5e..0756a5ed71 100644 --- a/tests/schedulers/test_scheduler_dpm_single.py +++ b/tests/schedulers/test_scheduler_dpm_single.py @@ -345,9 +345,9 @@ class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest): lower_order_final=lower_order_final, final_sigmas_type=final_sigmas_type, ) - assert ( - torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5 - ), f"Scheduler outputs are not identical for prediction_type: {prediction_type}, lower_order_final: {lower_order_final} and final_sigmas_type: {final_sigmas_type}" + assert torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5, ( + f"Scheduler outputs are not identical for prediction_type: {prediction_type}, lower_order_final: {lower_order_final} and final_sigmas_type: {final_sigmas_type}" + ) def test_beta_sigmas(self): self.check_over_configs(use_beta_sigmas=True) diff --git a/tests/schedulers/test_scheduler_edm_dpmsolver_multistep.py b/tests/schedulers/test_scheduler_edm_dpmsolver_multistep.py index e97d64ec5f..8525ce61c4 100644 --- a/tests/schedulers/test_scheduler_edm_dpmsolver_multistep.py +++ b/tests/schedulers/test_scheduler_edm_dpmsolver_multistep.py @@ -188,9 +188,9 @@ class EDMDPMSolverMultistepSchedulerTest(SchedulerCommonTest): prediction_type=prediction_type, algorithm_type=algorithm_type, ) - assert ( - not torch.isnan(sample).any() - ), f"Samples have nan numbers, {order}, {solver_type}, {prediction_type}, {algorithm_type}" + assert not torch.isnan(sample).any(), ( + f"Samples have nan numbers, {order}, {solver_type}, {prediction_type}, {algorithm_type}" + ) def test_lower_order_final(self): self.check_over_configs(lower_order_final=True) diff --git a/tests/schedulers/test_scheduler_euler.py b/tests/schedulers/test_scheduler_euler.py index 4c7e02442c..01e173a631 100644 --- a/tests/schedulers/test_scheduler_euler.py +++ b/tests/schedulers/test_scheduler_euler.py @@ -245,9 +245,9 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest): interpolation_type=interpolation_type, final_sigmas_type=final_sigmas_type, ) - assert ( - torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5 - ), f"Scheduler outputs are not identical for prediction_type: {prediction_type}, interpolation_type: {interpolation_type} and final_sigmas_type: {final_sigmas_type}" + assert torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5, ( + f"Scheduler outputs are not identical for prediction_type: {prediction_type}, interpolation_type: {interpolation_type} and final_sigmas_type: {final_sigmas_type}" + ) def test_custom_sigmas(self): for prediction_type in ["epsilon", "sample", "v_prediction"]: @@ -260,9 +260,9 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest): prediction_type=prediction_type, final_sigmas_type=final_sigmas_type, ) - assert ( - torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5 - ), f"Scheduler outputs are not identical for prediction_type: {prediction_type} and final_sigmas_type: {final_sigmas_type}" + assert torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5, ( + f"Scheduler outputs are not identical for prediction_type: {prediction_type} and final_sigmas_type: {final_sigmas_type}" + ) def test_beta_sigmas(self): self.check_over_configs(use_beta_sigmas=True) diff --git a/tests/schedulers/test_scheduler_heun.py b/tests/schedulers/test_scheduler_heun.py index 9e060c6d47..90012f5525 100644 --- a/tests/schedulers/test_scheduler_heun.py +++ b/tests/schedulers/test_scheduler_heun.py @@ -216,9 +216,9 @@ class HeunDiscreteSchedulerTest(SchedulerCommonTest): prediction_type=prediction_type, timestep_spacing=timestep_spacing, ) - assert ( - torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5 - ), f"Scheduler outputs are not identical for prediction_type: {prediction_type}, timestep_spacing: {timestep_spacing}" + assert torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5, ( + f"Scheduler outputs are not identical for prediction_type: {prediction_type}, timestep_spacing: {timestep_spacing}" + ) def test_beta_sigmas(self): self.check_over_configs(use_beta_sigmas=True) diff --git a/tests/single_file/single_file_testing_utils.py b/tests/single_file/single_file_testing_utils.py index 4e7bc0af68..4e1713c9ce 100644 --- a/tests/single_file/single_file_testing_utils.py +++ b/tests/single_file/single_file_testing_utils.py @@ -72,9 +72,9 @@ class SDSingleFileTesterMixin: continue assert component_name in pipe.components, f"single file {component_name} not found in pretrained pipeline" - assert isinstance( - component, pipe.components[component_name].__class__ - ), f"single file {component.__class__.__name__} and pretrained {pipe.components[component_name].__class__.__name__} are not the same" + assert isinstance(component, pipe.components[component_name].__class__), ( + f"single file {component.__class__.__name__} and pretrained {pipe.components[component_name].__class__.__name__} are not the same" + ) for param_name, param_value in component.config.items(): if param_name in PARAMS_TO_IGNORE: @@ -85,9 +85,9 @@ class SDSingleFileTesterMixin: if param_name == "upcast_attention" and pipe.components[component_name].config[param_name] is None: pipe.components[component_name].config[param_name] = param_value - assert ( - pipe.components[component_name].config[param_name] == param_value - ), f"single file {param_name}: {param_value} differs from pretrained {pipe.components[component_name].config[param_name]}" + assert pipe.components[component_name].config[param_name] == param_value, ( + f"single file {param_name}: {param_value} differs from pretrained {pipe.components[component_name].config[param_name]}" + ) def test_single_file_components(self, pipe=None, single_file_pipe=None): single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( @@ -253,9 +253,9 @@ class SDXLSingleFileTesterMixin: continue assert component_name in pipe.components, f"single file {component_name} not found in pretrained pipeline" - assert isinstance( - component, pipe.components[component_name].__class__ - ), f"single file {component.__class__.__name__} and pretrained {pipe.components[component_name].__class__.__name__} are not the same" + assert isinstance(component, pipe.components[component_name].__class__), ( + f"single file {component.__class__.__name__} and pretrained {pipe.components[component_name].__class__.__name__} are not the same" + ) for param_name, param_value in component.config.items(): if param_name in PARAMS_TO_IGNORE: @@ -266,9 +266,9 @@ class SDXLSingleFileTesterMixin: if param_name == "upcast_attention" and pipe.components[component_name].config[param_name] is None: pipe.components[component_name].config[param_name] = param_value - assert ( - pipe.components[component_name].config[param_name] == param_value - ), f"single file {param_name}: {param_value} differs from pretrained {pipe.components[component_name].config[param_name]}" + assert pipe.components[component_name].config[param_name] == param_value, ( + f"single file {param_name}: {param_value} differs from pretrained {pipe.components[component_name].config[param_name]}" + ) def test_single_file_components(self, pipe=None, single_file_pipe=None): single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( diff --git a/tests/single_file/test_lumina2_transformer.py b/tests/single_file/test_lumina2_transformer.py index 78e68c4c2d..d3ffd4fc3a 100644 --- a/tests/single_file/test_lumina2_transformer.py +++ b/tests/single_file/test_lumina2_transformer.py @@ -60,9 +60,9 @@ class Lumina2Transformer2DModelSingleFileTests(unittest.TestCase): for param_name, param_value in model_single_file.config.items(): if param_name in PARAMS_TO_IGNORE: continue - assert ( - model.config[param_name] == param_value - ), f"{param_name} differs between single file loading and pretrained loading" + assert model.config[param_name] == param_value, ( + f"{param_name} differs between single file loading and pretrained loading" + ) def test_checkpoint_loading(self): for ckpt_path in self.alternate_keys_ckpt_paths: diff --git a/tests/single_file/test_model_autoencoder_dc_single_file.py b/tests/single_file/test_model_autoencoder_dc_single_file.py index b1faeb7877..31b2eb6e36 100644 --- a/tests/single_file/test_model_autoencoder_dc_single_file.py +++ b/tests/single_file/test_model_autoencoder_dc_single_file.py @@ -87,9 +87,9 @@ class AutoencoderDCSingleFileTests(unittest.TestCase): for param_name, param_value in model_single_file.config.items(): if param_name in PARAMS_TO_IGNORE: continue - assert ( - model.config[param_name] == param_value - ), f"{param_name} differs between pretrained loading and single file loading" + assert model.config[param_name] == param_value, ( + f"{param_name} differs between pretrained loading and single file loading" + ) def test_single_file_in_type_variant_components(self): # `in` variant checkpoints require passing in a `config` parameter @@ -106,9 +106,9 @@ class AutoencoderDCSingleFileTests(unittest.TestCase): for param_name, param_value in model_single_file.config.items(): if param_name in PARAMS_TO_IGNORE: continue - assert ( - model.config[param_name] == param_value - ), f"{param_name} differs between pretrained loading and single file loading" + assert model.config[param_name] == param_value, ( + f"{param_name} differs between pretrained loading and single file loading" + ) def test_single_file_mix_type_variant_components(self): repo_id = "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers" @@ -121,6 +121,6 @@ class AutoencoderDCSingleFileTests(unittest.TestCase): for param_name, param_value in model_single_file.config.items(): if param_name in PARAMS_TO_IGNORE: continue - assert ( - model.config[param_name] == param_value - ), f"{param_name} differs between pretrained loading and single file loading" + assert model.config[param_name] == param_value, ( + f"{param_name} differs between pretrained loading and single file loading" + ) diff --git a/tests/single_file/test_model_controlnet_single_file.py b/tests/single_file/test_model_controlnet_single_file.py index bfcb802380..3580d73531 100644 --- a/tests/single_file/test_model_controlnet_single_file.py +++ b/tests/single_file/test_model_controlnet_single_file.py @@ -58,9 +58,9 @@ class ControlNetModelSingleFileTests(unittest.TestCase): for param_name, param_value in model_single_file.config.items(): if param_name in PARAMS_TO_IGNORE: continue - assert ( - model.config[param_name] == param_value - ), f"{param_name} differs between single file loading and pretrained loading" + assert model.config[param_name] == param_value, ( + f"{param_name} differs between single file loading and pretrained loading" + ) def test_single_file_arguments(self): model_default = self.model_class.from_single_file(self.ckpt_path) diff --git a/tests/single_file/test_model_flux_transformer_single_file.py b/tests/single_file/test_model_flux_transformer_single_file.py index 0ec97db26a..bf11faaa9c 100644 --- a/tests/single_file/test_model_flux_transformer_single_file.py +++ b/tests/single_file/test_model_flux_transformer_single_file.py @@ -58,9 +58,9 @@ class FluxTransformer2DModelSingleFileTests(unittest.TestCase): for param_name, param_value in model_single_file.config.items(): if param_name in PARAMS_TO_IGNORE: continue - assert ( - model.config[param_name] == param_value - ), f"{param_name} differs between single file loading and pretrained loading" + assert model.config[param_name] == param_value, ( + f"{param_name} differs between single file loading and pretrained loading" + ) def test_checkpoint_loading(self): for ckpt_path in self.alternate_keys_ckpt_paths: diff --git a/tests/single_file/test_model_motion_adapter_single_file.py b/tests/single_file/test_model_motion_adapter_single_file.py index b195f25d09..a747f16dc1 100644 --- a/tests/single_file/test_model_motion_adapter_single_file.py +++ b/tests/single_file/test_model_motion_adapter_single_file.py @@ -40,9 +40,9 @@ class MotionAdapterSingleFileTests(unittest.TestCase): for param_name, param_value in model_single_file.config.items(): if param_name in PARAMS_TO_IGNORE: continue - assert ( - model.config[param_name] == param_value - ), f"{param_name} differs between pretrained loading and single file loading" + assert model.config[param_name] == param_value, ( + f"{param_name} differs between pretrained loading and single file loading" + ) def test_single_file_components_version_v1_5_2(self): ckpt_path = "https://huggingface.co/guoyww/animatediff/blob/main/mm_sd_v15_v2.ckpt" @@ -55,9 +55,9 @@ class MotionAdapterSingleFileTests(unittest.TestCase): for param_name, param_value in model_single_file.config.items(): if param_name in PARAMS_TO_IGNORE: continue - assert ( - model.config[param_name] == param_value - ), f"{param_name} differs between pretrained loading and single file loading" + assert model.config[param_name] == param_value, ( + f"{param_name} differs between pretrained loading and single file loading" + ) def test_single_file_components_version_v1_5_3(self): ckpt_path = "https://huggingface.co/guoyww/animatediff/blob/main/v3_sd15_mm.ckpt" @@ -70,9 +70,9 @@ class MotionAdapterSingleFileTests(unittest.TestCase): for param_name, param_value in model_single_file.config.items(): if param_name in PARAMS_TO_IGNORE: continue - assert ( - model.config[param_name] == param_value - ), f"{param_name} differs between pretrained loading and single file loading" + assert model.config[param_name] == param_value, ( + f"{param_name} differs between pretrained loading and single file loading" + ) def test_single_file_components_version_sdxl_beta(self): ckpt_path = "https://huggingface.co/guoyww/animatediff/blob/main/mm_sdxl_v10_beta.ckpt" @@ -85,6 +85,6 @@ class MotionAdapterSingleFileTests(unittest.TestCase): for param_name, param_value in model_single_file.config.items(): if param_name in PARAMS_TO_IGNORE: continue - assert ( - model.config[param_name] == param_value - ), f"{param_name} differs between pretrained loading and single file loading" + assert model.config[param_name] == param_value, ( + f"{param_name} differs between pretrained loading and single file loading" + ) diff --git a/tests/single_file/test_model_sd_cascade_unet_single_file.py b/tests/single_file/test_model_sd_cascade_unet_single_file.py index 08b04e3cd7..92b371c3fb 100644 --- a/tests/single_file/test_model_sd_cascade_unet_single_file.py +++ b/tests/single_file/test_model_sd_cascade_unet_single_file.py @@ -60,9 +60,9 @@ class StableCascadeUNetSingleFileTest(unittest.TestCase): for param_name, param_value in model_single_file.config.items(): if param_name in PARAMS_TO_IGNORE: continue - assert ( - model.config[param_name] == param_value - ), f"{param_name} differs between single file loading and pretrained loading" + assert model.config[param_name] == param_value, ( + f"{param_name} differs between single file loading and pretrained loading" + ) def test_single_file_components_stage_b_lite(self): model_single_file = StableCascadeUNet.from_single_file( @@ -77,9 +77,9 @@ class StableCascadeUNetSingleFileTest(unittest.TestCase): for param_name, param_value in model_single_file.config.items(): if param_name in PARAMS_TO_IGNORE: continue - assert ( - model.config[param_name] == param_value - ), f"{param_name} differs between single file loading and pretrained loading" + assert model.config[param_name] == param_value, ( + f"{param_name} differs between single file loading and pretrained loading" + ) def test_single_file_components_stage_c(self): model_single_file = StableCascadeUNet.from_single_file( @@ -94,9 +94,9 @@ class StableCascadeUNetSingleFileTest(unittest.TestCase): for param_name, param_value in model_single_file.config.items(): if param_name in PARAMS_TO_IGNORE: continue - assert ( - model.config[param_name] == param_value - ), f"{param_name} differs between single file loading and pretrained loading" + assert model.config[param_name] == param_value, ( + f"{param_name} differs between single file loading and pretrained loading" + ) def test_single_file_components_stage_c_lite(self): model_single_file = StableCascadeUNet.from_single_file( @@ -111,6 +111,6 @@ class StableCascadeUNetSingleFileTest(unittest.TestCase): for param_name, param_value in model_single_file.config.items(): if param_name in PARAMS_TO_IGNORE: continue - assert ( - model.config[param_name] == param_value - ), f"{param_name} differs between single file loading and pretrained loading" + assert model.config[param_name] == param_value, ( + f"{param_name} differs between single file loading and pretrained loading" + ) diff --git a/tests/single_file/test_model_vae_single_file.py b/tests/single_file/test_model_vae_single_file.py index 9db4cddb3c..bba1726ae3 100644 --- a/tests/single_file/test_model_vae_single_file.py +++ b/tests/single_file/test_model_vae_single_file.py @@ -91,9 +91,9 @@ class AutoencoderKLSingleFileTests(unittest.TestCase): for param_name, param_value in model_single_file.config.items(): if param_name in PARAMS_TO_IGNORE: continue - assert ( - model.config[param_name] == param_value - ), f"{param_name} differs between pretrained loading and single file loading" + assert model.config[param_name] == param_value, ( + f"{param_name} differs between pretrained loading and single file loading" + ) def test_single_file_arguments(self): model_default = self.model_class.from_single_file(self.ckpt_path, config=self.repo_id) diff --git a/tests/single_file/test_model_wan_autoencoder_single_file.py b/tests/single_file/test_model_wan_autoencoder_single_file.py index f5720ddd39..7f0e1c1a4b 100644 --- a/tests/single_file/test_model_wan_autoencoder_single_file.py +++ b/tests/single_file/test_model_wan_autoencoder_single_file.py @@ -56,6 +56,6 @@ class AutoencoderKLWanSingleFileTests(unittest.TestCase): for param_name, param_value in model_single_file.config.items(): if param_name in PARAMS_TO_IGNORE: continue - assert ( - model.config[param_name] == param_value - ), f"{param_name} differs between single file loading and pretrained loading" + assert model.config[param_name] == param_value, ( + f"{param_name} differs between single file loading and pretrained loading" + ) diff --git a/tests/single_file/test_model_wan_transformer3d_single_file.py b/tests/single_file/test_model_wan_transformer3d_single_file.py index 9b938aa175..36f0919cac 100644 --- a/tests/single_file/test_model_wan_transformer3d_single_file.py +++ b/tests/single_file/test_model_wan_transformer3d_single_file.py @@ -57,9 +57,9 @@ class WanTransformer3DModelText2VideoSingleFileTest(unittest.TestCase): for param_name, param_value in model_single_file.config.items(): if param_name in PARAMS_TO_IGNORE: continue - assert ( - model.config[param_name] == param_value - ), f"{param_name} differs between single file loading and pretrained loading" + assert model.config[param_name] == param_value, ( + f"{param_name} differs between single file loading and pretrained loading" + ) @require_big_gpu_with_torch_cuda @@ -88,6 +88,6 @@ class WanTransformer3DModelImage2VideoSingleFileTest(unittest.TestCase): for param_name, param_value in model_single_file.config.items(): if param_name in PARAMS_TO_IGNORE: continue - assert ( - model.config[param_name] == param_value - ), f"{param_name} differs between single file loading and pretrained loading" + assert model.config[param_name] == param_value, ( + f"{param_name} differs between single file loading and pretrained loading" + ) diff --git a/tests/single_file/test_sana_transformer.py b/tests/single_file/test_sana_transformer.py index 7695e15777..802ca37abf 100644 --- a/tests/single_file/test_sana_transformer.py +++ b/tests/single_file/test_sana_transformer.py @@ -47,9 +47,9 @@ class SanaTransformer2DModelSingleFileTests(unittest.TestCase): for param_name, param_value in model_single_file.config.items(): if param_name in PARAMS_TO_IGNORE: continue - assert ( - model.config[param_name] == param_value - ), f"{param_name} differs between single file loading and pretrained loading" + assert model.config[param_name] == param_value, ( + f"{param_name} differs between single file loading and pretrained loading" + ) def test_checkpoint_loading(self): for ckpt_path in self.alternate_keys_ckpt_paths: diff --git a/utils/log_reports.py b/utils/log_reports.py index dd1b258519..5575c9ba84 100644 --- a/utils/log_reports.py +++ b/utils/log_reports.py @@ -35,7 +35,7 @@ def main(slack_channel_name=None): if line.get("nodeid", "") != "": test = line["nodeid"] if line.get("duration", None) is not None: - duration = f'{line["duration"]:.4f}' + duration = f"{line['duration']:.4f}" if line.get("outcome", "") == "failed": section_num_failed += 1 failed.append([test, duration, log.name.split("_")[0]]) diff --git a/utils/update_metadata.py b/utils/update_metadata.py index a97e65801c..4fde581d41 100644 --- a/utils/update_metadata.py +++ b/utils/update_metadata.py @@ -104,8 +104,7 @@ def update_metadata(commit_sha: str): if commit_sha is not None: commit_message = ( - f"Update with commit {commit_sha}\n\nSee: " - f"https://github.com/huggingface/diffusers/commit/{commit_sha}" + f"Update with commit {commit_sha}\n\nSee: https://github.com/huggingface/diffusers/commit/{commit_sha}" ) else: commit_message = "Update"