mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[advanced dreambooth lora] add clip_skip arg (#8715)
* add clip_skip * style * smol fix --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -573,6 +573,13 @@ def parse_args(input_args=None):
|
||||
default=1e-4,
|
||||
help="Initial learning rate (after the potential warmup period) to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clip_skip",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that "
|
||||
"the output of the pre-final layer will be used for computing the prompt embeddings.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--text_encoder_lr",
|
||||
@@ -1236,7 +1243,7 @@ def tokenize_prompt(tokenizer, prompt, add_special_tokens=False):
|
||||
|
||||
|
||||
# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
|
||||
def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
|
||||
def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None, clip_skip=None):
|
||||
prompt_embeds_list = []
|
||||
|
||||
for i, text_encoder in enumerate(text_encoders):
|
||||
@@ -1253,7 +1260,11 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
|
||||
|
||||
# We are only ALWAYS interested in the pooled output of the final text encoder
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
prompt_embeds = prompt_embeds[-1][-2]
|
||||
if clip_skip is None:
|
||||
prompt_embeds = prompt_embeds[-1][-2]
|
||||
else:
|
||||
# "2" because SDXL always indexes from the penultimate layer.
|
||||
prompt_embeds = prompt_embeds[-1][-(clip_skip + 2)]
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
|
||||
prompt_embeds_list.append(prompt_embeds)
|
||||
@@ -1830,9 +1841,9 @@ def main(args):
|
||||
tokenizers = [tokenizer_one, tokenizer_two]
|
||||
text_encoders = [text_encoder_one, text_encoder_two]
|
||||
|
||||
def compute_text_embeddings(prompt, text_encoders, tokenizers):
|
||||
def compute_text_embeddings(prompt, text_encoders, tokenizers, clip_skip):
|
||||
with torch.no_grad():
|
||||
prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt)
|
||||
prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt, clip_skip)
|
||||
prompt_embeds = prompt_embeds.to(accelerator.device)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
@@ -1842,7 +1853,7 @@ def main(args):
|
||||
# the redundant encoding.
|
||||
if freeze_text_encoder and not train_dataset.custom_instance_prompts:
|
||||
instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings(
|
||||
args.instance_prompt, text_encoders, tokenizers
|
||||
args.instance_prompt, text_encoders, tokenizers, args.clip_skip
|
||||
)
|
||||
|
||||
# Handle class prompt for prior-preservation.
|
||||
@@ -2052,7 +2063,7 @@ def main(args):
|
||||
if train_dataset.custom_instance_prompts:
|
||||
if freeze_text_encoder:
|
||||
prompt_embeds, unet_add_text_embeds = compute_text_embeddings(
|
||||
prompts, text_encoders, tokenizers
|
||||
prompts, text_encoders, tokenizers, args.clip_skip
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -2147,6 +2158,7 @@ def main(args):
|
||||
tokenizers=None,
|
||||
prompt=None,
|
||||
text_input_ids_list=[tokens_one, tokens_two],
|
||||
clip_skip=args.clip_skip,
|
||||
)
|
||||
unet_added_conditions.update(
|
||||
{"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat_text_embeds, 1)}
|
||||
|
||||
Reference in New Issue
Block a user