From 1221b28eac1801dd759e8d1df9fc9a2a998b41ed Mon Sep 17 00:00:00 2001 From: Alphin Jain <72972178+jainalphin@users.noreply.github.com> Date: Thu, 16 May 2024 15:49:54 +0530 Subject: [PATCH] Fix AttributeError in train_lcm_distill_lora_sdxl_wds.py (#7923) Fix conditional teacher model check in train_lcm_distill_lora_sdxl_wds.py Co-authored-by: Sayak Paul --- .../consistency_distillation/train_lcm_distill_lora_sdxl_wds.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py index 08d6b23d6d..ce3e7f6248 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py @@ -1358,7 +1358,7 @@ def main(args): # estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE # solver timestep. with torch.no_grad(): - if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path: + if torch.backends.mps.is_available() or "playground" in args.pretrained_teacher_model: autocast_ctx = nullcontext() else: autocast_ctx = torch.autocast(accelerator.device.type)