From 56e772ab7ea797ff0fa220b644ebe3622167257d Mon Sep 17 00:00:00 2001 From: Lucain Date: Sat, 20 Jul 2024 16:31:21 +0200 Subject: [PATCH] Use model_info.id instead of model_info.modelId (#8912) Mention model_info.id instead of model_info.modelId --- scripts/generate_logits.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/scripts/generate_logits.py b/scripts/generate_logits.py index 89dce0e78d..99d46d6628 100644 --- a/scripts/generate_logits.py +++ b/scripts/generate_logits.py @@ -103,12 +103,12 @@ results["google_ddpm_ema_cat_256"] = torch.tensor([ models = api.list_models(filter="diffusers") for mod in models: - if "google" in mod.author or mod.modelId == "CompVis/ldm-celebahq-256": - local_checkpoint = "/home/patrick/google_checkpoints/" + mod.modelId.split("/")[-1] + if "google" in mod.author or mod.id == "CompVis/ldm-celebahq-256": + local_checkpoint = "/home/patrick/google_checkpoints/" + mod.id.split("/")[-1] - print(f"Started running {mod.modelId}!!!") + print(f"Started running {mod.id}!!!") - if mod.modelId.startswith("CompVis"): + if mod.id.startswith("CompVis"): model = UNet2DModel.from_pretrained(local_checkpoint, subfolder="unet") else: model = UNet2DModel.from_pretrained(local_checkpoint) @@ -122,6 +122,6 @@ for mod in models: logits = model(noise, time_step).sample assert torch.allclose( - logits[0, 0, 0, :30], results["_".join("_".join(mod.modelId.split("/")).split("-"))], atol=1e-3 + logits[0, 0, 0, :30], results["_".join("_".join(mod.id.split("/")).split("-"))], atol=1e-3 ) - print(f"{mod.modelId} has passed successfully!!!") + print(f"{mod.id} has passed successfully!!!")