1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Make dreambooth lora more robust to orig unet (#3462)

* Make dreambooth lora more robust to orig unet

* up
This commit is contained in:
Patrick von Platen
2023-05-17 12:20:13 +02:00
committed by GitHub
parent 15f1bab13b
commit 3ebd2d1f9e

View File

@@ -31,7 +31,7 @@ import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from huggingface_hub import create_repo, model_info, upload_folder
from huggingface_hub import create_repo, upload_folder
from packaging import version
from PIL import Image
from torch.utils.data import Dataset
@@ -589,16 +589,6 @@ class PromptDataset(Dataset):
return example
def model_has_vae(args):
config_file_name = os.path.join("vae", AutoencoderKL.config_name)
if os.path.isdir(args.pretrained_model_name_or_path):
config_file_name = os.path.join(args.pretrained_model_name_or_path, config_file_name)
return os.path.isfile(config_file_name)
else:
files_in_repo = model_info(args.pretrained_model_name_or_path, revision=args.revision).siblings
return any(file.rfilename == config_file_name for file in files_in_repo)
def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None):
if tokenizer_max_length is not None:
max_length = tokenizer_max_length
@@ -753,11 +743,13 @@ def main(args):
text_encoder = text_encoder_cls.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
if model_has_vae(args):
try:
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
)
else:
except OSError:
# IF does not have a VAE so let's just set it to None
# We don't have to error out here
vae = None
unet = UNet2DConditionModel.from_pretrained(