From 326de4191578dfb55cb968880d40d703075e331e Mon Sep 17 00:00:00 2001 From: Ben Sherman Date: Wed, 7 Dec 2022 12:39:48 -0800 Subject: [PATCH] Trivial fix for undefined symbol in train_dreambooth.py (#1598) easy fix for undefined name in train_dreambooth.py import_model_class_from_model_name_or_path loads a pretrained model and refers to args.revision in a context where args is undefined. I modified the function to take revision as an argument and modified the invocation of the function to pass in the revision from args. Seems like this was caused by a cut and paste. --- examples/dreambooth/train_dreambooth.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 4424ddc0c5..b904920f1c 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -30,11 +30,11 @@ check_min_version("0.10.0.dev0") logger = get_logger(__name__) -def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str): +def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): text_encoder_config = PretrainedConfig.from_pretrained( pretrained_model_name_or_path, subfolder="text_encoder", - revision=args.revision, + revision=revision, ) model_class = text_encoder_config.architectures[0] @@ -469,7 +469,7 @@ def main(args): ) # import correct text encoder class - text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path) + text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) # Load models and create wrapper for stable diffusion text_encoder = text_encoder_cls.from_pretrained(