mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Trivial fix for undefined symbol in train_dreambooth.py (#1598)
easy fix for undefined name in train_dreambooth.py import_model_class_from_model_name_or_path loads a pretrained model and refers to args.revision in a context where args is undefined. I modified the function to take revision as an argument and modified the invocation of the function to pass in the revision from args. Seems like this was caused by a cut and paste.
This commit is contained in:
@@ -30,11 +30,11 @@ check_min_version("0.10.0.dev0")
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str):
|
||||
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
|
||||
text_encoder_config = PretrainedConfig.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
subfolder="text_encoder",
|
||||
revision=args.revision,
|
||||
revision=revision,
|
||||
)
|
||||
model_class = text_encoder_config.architectures[0]
|
||||
|
||||
@@ -469,7 +469,7 @@ def main(args):
|
||||
)
|
||||
|
||||
# import correct text encoder class
|
||||
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path)
|
||||
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
|
||||
|
||||
# Load models and create wrapper for stable diffusion
|
||||
text_encoder = text_encoder_cls.from_pretrained(
|
||||
|
||||
Reference in New Issue
Block a user