mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Make dreambooth lora more robust to orig unet (#3462)
* Make dreambooth lora more robust to orig unet * up
This commit is contained in:
committed by
GitHub
parent
15f1bab13b
commit
3ebd2d1f9e
@@ -31,7 +31,7 @@ import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import ProjectConfiguration, set_seed
|
||||
from huggingface_hub import create_repo, model_info, upload_folder
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from packaging import version
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
@@ -589,16 +589,6 @@ class PromptDataset(Dataset):
|
||||
return example
|
||||
|
||||
|
||||
def model_has_vae(args):
|
||||
config_file_name = os.path.join("vae", AutoencoderKL.config_name)
|
||||
if os.path.isdir(args.pretrained_model_name_or_path):
|
||||
config_file_name = os.path.join(args.pretrained_model_name_or_path, config_file_name)
|
||||
return os.path.isfile(config_file_name)
|
||||
else:
|
||||
files_in_repo = model_info(args.pretrained_model_name_or_path, revision=args.revision).siblings
|
||||
return any(file.rfilename == config_file_name for file in files_in_repo)
|
||||
|
||||
|
||||
def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None):
|
||||
if tokenizer_max_length is not None:
|
||||
max_length = tokenizer_max_length
|
||||
@@ -753,11 +743,13 @@ def main(args):
|
||||
text_encoder = text_encoder_cls.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
|
||||
)
|
||||
if model_has_vae(args):
|
||||
try:
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
|
||||
)
|
||||
else:
|
||||
except OSError:
|
||||
# IF does not have a VAE so let's just set it to None
|
||||
# We don't have to error out here
|
||||
vae = None
|
||||
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
|
||||
Reference in New Issue
Block a user