mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[HiDream LoRA] optimizations + small updates (#11381)
* 1. add pre-computation of prompt embeddings when custom prompts are used as well 2. save model card even if model is not pushed to hub 3. remove scheduler initialization from code example - not necessary anymore (it's now if the base model's config) 4. add skip_final_inference - to allow to run with validation, but skip the final loading of the pipeline with the lora weights to reduce memory reqs * pre encode validation prompt as well * Update examples/dreambooth/train_dreambooth_lora_hidream.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * Update examples/dreambooth/train_dreambooth_lora_hidream.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * Update examples/dreambooth/train_dreambooth_lora_hidream.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * pre encode validation prompt as well * Apply style fixes * empty commit * change default trained modules * empty commit * address comments + change encoding of validation prompt (before it was only pre-encoded if custom prompts are provided, but should be pre-encoded either way) * Apply style fixes * empty commit * fix validation_embeddings definition * fix final inference condition * fix pipeline deletion in last inference * Apply style fixes * empty commit * layers * remove readme remarks on only pre-computing when instance prompt is provided and change example to 3d icons * smol fix * empty commit --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
@@ -51,54 +51,41 @@ When running `accelerate config`, if we specify torch compile mode to True there
|
||||
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.14.0` installed in your environment.
|
||||
|
||||
|
||||
### Dog toy example
|
||||
### 3d icon example
|
||||
|
||||
Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.
|
||||
|
||||
Let's first download it locally:
|
||||
|
||||
```python
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
local_dir = "./dog"
|
||||
snapshot_download(
|
||||
"diffusers/dog-example",
|
||||
local_dir=local_dir, repo_type="dataset",
|
||||
ignore_patterns=".gitattributes",
|
||||
)
|
||||
```
|
||||
For this example we will use some 3d icon images: https://huggingface.co/datasets/linoyts/3d_icon.
|
||||
|
||||
This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.
|
||||
|
||||
Now, we can launch training using:
|
||||
> [!NOTE]
|
||||
> The following training configuration prioritizes lower memory consumption by using gradient checkpointing,
|
||||
> 8-bit Adam optimizer, latent caching, offloading, no validation.
|
||||
> Additionally, when provided with 'instance_prompt' only and no 'caption_column' (used for custom prompts for each image)
|
||||
> text embeddings are pre-computed to save memory.
|
||||
|
||||
> 8-bit Adam optimizer, latent caching, offloading, no validation.
|
||||
> all text embeddings are pre-computed to save memory.
|
||||
```bash
|
||||
export MODEL_NAME="HiDream-ai/HiDream-I1-Dev"
|
||||
export INSTANCE_DIR="dog"
|
||||
export INSTANCE_DIR="linoyts/3d_icon"
|
||||
export OUTPUT_DIR="trained-hidream-lora"
|
||||
|
||||
accelerate launch train_dreambooth_lora_hidream.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--dataset_name=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--mixed_precision="bf16" \
|
||||
--instance_prompt="a photo of sks dog" \
|
||||
--instance_prompt="3d icon" \
|
||||
--caption_column="prompt"\
|
||||
--validation_prompt="a 3dicon, a llama eating ramen" \
|
||||
--resolution=1024 \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--use_8bit_adam \
|
||||
--rank=16 \
|
||||
--rank=8 \
|
||||
--learning_rate=2e-4 \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--lr_scheduler="constant_with_warmup" \
|
||||
--lr_warmup_steps=100 \
|
||||
--max_train_steps=1000 \
|
||||
--cache_latents \
|
||||
--cache_latents\
|
||||
--gradient_checkpointing \
|
||||
--validation_epochs=25 \
|
||||
--seed="0" \
|
||||
@@ -128,6 +115,5 @@ We provide several options for optimizing memory optimization:
|
||||
* `--offload`: When enabled, we will offload the text encoder and VAE to CPU, when they are not used.
|
||||
* `cache_latents`: When enabled, we will pre-compute the latents from the input images with the VAE and remove the VAE from memory once done.
|
||||
* `--use_8bit_adam`: When enabled, we will use the 8bit version of AdamW provided by the `bitsandbytes` library.
|
||||
* `--instance_prompt` and no `--caption_column`: when only an instance prompt is provided, we will pre-compute the text embeddings and remove the text encoders from memory once done.
|
||||
|
||||
Refer to the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/) of the `HiDreamImagePipeline` to know more about the model.
|
||||
|
||||
@@ -120,11 +120,7 @@ You should use `{instance_prompt}` to trigger the image generation.
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from transformers import PreTrainedTokenizerFast, LlamaForCausalLM
|
||||
>>> from diffusers import UniPCMultistepScheduler, HiDreamImagePipeline
|
||||
|
||||
>>> scheduler = UniPCMultistepScheduler(
|
||||
... flow_shift=3.0, prediction_type="flow_prediction", use_flow_sigmas=True
|
||||
... )
|
||||
>>> from diffusers import HiDreamImagePipeline
|
||||
|
||||
>>> tokenizer_4 = PreTrainedTokenizerFast.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
|
||||
>>> text_encoder_4 = LlamaForCausalLM.from_pretrained(
|
||||
@@ -136,7 +132,6 @@ You should use `{instance_prompt}` to trigger the image generation.
|
||||
|
||||
>>> pipe = HiDreamImagePipeline.from_pretrained(
|
||||
... "HiDream-ai/HiDream-I1-Full",
|
||||
... scheduler=scheduler,
|
||||
... tokenizer_4=tokenizer_4,
|
||||
... text_encoder_4=text_encoder_4,
|
||||
... torch_dtype=torch.bfloat16,
|
||||
@@ -201,6 +196,7 @@ def log_validation(
|
||||
torch_dtype,
|
||||
is_final_validation=False,
|
||||
):
|
||||
args.num_validation_images = args.num_validation_images if args.num_validation_images else 1
|
||||
logger.info(
|
||||
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
|
||||
f" {args.validation_prompt}."
|
||||
@@ -212,28 +208,16 @@ def log_validation(
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
|
||||
autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
|
||||
|
||||
# pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast
|
||||
with torch.no_grad():
|
||||
(
|
||||
prompt_embeds_t5,
|
||||
negative_prompt_embeds_t5,
|
||||
prompt_embeds_llama3,
|
||||
negative_prompt_embeds_llama3,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = pipeline.encode_prompt(
|
||||
pipeline_args["prompt"],
|
||||
)
|
||||
images = []
|
||||
for _ in range(args.num_validation_images):
|
||||
with autocast_ctx:
|
||||
image = pipeline(
|
||||
prompt_embeds_t5=prompt_embeds_t5,
|
||||
prompt_embeds_llama3=prompt_embeds_llama3,
|
||||
negative_prompt_embeds_t5=negative_prompt_embeds_t5,
|
||||
negative_prompt_embeds_llama3=negative_prompt_embeds_llama3,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
prompt_embeds_t5=pipeline_args["prompt_embeds_t5"],
|
||||
prompt_embeds_llama3=pipeline_args["prompt_embeds_llama3"],
|
||||
negative_prompt_embeds_t5=pipeline_args["negative_prompt_embeds_t5"],
|
||||
negative_prompt_embeds_llama3=pipeline_args["negative_prompt_embeds_llama3"],
|
||||
pooled_prompt_embeds=pipeline_args["pooled_prompt_embeds"],
|
||||
negative_pooled_prompt_embeds=pipeline_args["negative_pooled_prompt_embeds"],
|
||||
generator=generator,
|
||||
).images[0]
|
||||
images.append(image)
|
||||
@@ -252,9 +236,9 @@ def log_validation(
|
||||
}
|
||||
)
|
||||
|
||||
pipeline.to("cpu")
|
||||
del pipeline
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
free_memory()
|
||||
|
||||
return images
|
||||
|
||||
@@ -392,6 +376,14 @@ def parse_args(input_args=None):
|
||||
default=None,
|
||||
help="A prompt that is used during validation to verify that the model is learning.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--skip_final_inference",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Whether to skip the final inference step with loaded lora weights upon training completion. This will run intermediate validation inference if `validation_prompt` is provided. Specify to reduce memory.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--final_validation_prompt",
|
||||
type=str,
|
||||
@@ -1016,6 +1008,7 @@ def main(args):
|
||||
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
|
||||
image.save(image_filename)
|
||||
|
||||
pipeline.to("cpu")
|
||||
del pipeline
|
||||
free_memory()
|
||||
|
||||
@@ -1140,7 +1133,7 @@ def main(args):
|
||||
if args.lora_layers is not None:
|
||||
target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
|
||||
else:
|
||||
target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
|
||||
target_modules = ["to_k", "to_q", "to_v", "to_out"]
|
||||
|
||||
# now we will add new LoRA weights the transformer layers
|
||||
transformer_lora_config = LoraConfig(
|
||||
@@ -1314,42 +1307,65 @@ def main(args):
|
||||
)
|
||||
|
||||
def compute_text_embeddings(prompt, text_encoding_pipeline):
|
||||
if args.offload:
|
||||
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
|
||||
with torch.no_grad():
|
||||
t5_prompt_embeds, _, llama3_prompt_embeds, _, pooled_prompt_embeds, _ = (
|
||||
text_encoding_pipeline.encode_prompt(prompt=prompt, max_sequence_length=args.max_sequence_length)
|
||||
)
|
||||
if args.offload: # back to cpu
|
||||
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
|
||||
return t5_prompt_embeds, llama3_prompt_embeds, pooled_prompt_embeds
|
||||
(
|
||||
t5_prompt_embeds,
|
||||
negative_prompt_embeds_t5,
|
||||
llama3_prompt_embeds,
|
||||
negative_prompt_embeds_llama3,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = text_encoding_pipeline.encode_prompt(prompt=prompt, max_sequence_length=args.max_sequence_length)
|
||||
return (
|
||||
t5_prompt_embeds,
|
||||
llama3_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_prompt_embeds_t5,
|
||||
negative_prompt_embeds_llama3,
|
||||
negative_pooled_prompt_embeds,
|
||||
)
|
||||
|
||||
# If no type of tuning is done on the text_encoder and custom instance prompts are NOT
|
||||
# provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
|
||||
# the redundant encoding.
|
||||
if not train_dataset.custom_instance_prompts:
|
||||
if args.offload:
|
||||
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
|
||||
(
|
||||
instance_prompt_hidden_states_t5,
|
||||
instance_prompt_hidden_states_llama3,
|
||||
instance_pooled_prompt_embeds,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
) = compute_text_embeddings(args.instance_prompt, text_encoding_pipeline)
|
||||
if args.offload:
|
||||
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
|
||||
|
||||
# Handle class prompt for prior-preservation.
|
||||
if args.with_prior_preservation:
|
||||
(
|
||||
class_prompt_hidden_states_t5,
|
||||
class_prompt_hidden_states_llama3,
|
||||
class_pooled_prompt_embeds,
|
||||
) = compute_text_embeddings(args.class_prompt, text_encoding_pipeline)
|
||||
if args.offload:
|
||||
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
|
||||
(class_prompt_hidden_states_t5, class_prompt_hidden_states_llama3, class_pooled_prompt_embeds, _, _, _) = (
|
||||
compute_text_embeddings(args.class_prompt, text_encoding_pipeline)
|
||||
)
|
||||
if args.offload:
|
||||
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
|
||||
|
||||
# Clear the memory here
|
||||
if not train_dataset.custom_instance_prompts:
|
||||
# delete tokenizers and text ecnoders except for llama (tokenizer & te four)
|
||||
# as it's needed for inference with pipeline
|
||||
del text_encoder_one, text_encoder_two, text_encoder_three, tokenizer_one, tokenizer_two, tokenizer_three
|
||||
if not args.validation_prompt:
|
||||
del tokenizer_four, text_encoder_four
|
||||
free_memory()
|
||||
validation_embeddings = {}
|
||||
if args.validation_prompt is not None:
|
||||
if args.offload:
|
||||
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
|
||||
(
|
||||
validation_embeddings["prompt_embeds_t5"],
|
||||
validation_embeddings["prompt_embeds_llama3"],
|
||||
validation_embeddings["pooled_prompt_embeds"],
|
||||
validation_embeddings["negative_prompt_embeds_t5"],
|
||||
validation_embeddings["negative_prompt_embeds_llama3"],
|
||||
validation_embeddings["negative_pooled_prompt_embeds"],
|
||||
) = compute_text_embeddings(args.validation_prompt, text_encoding_pipeline)
|
||||
if args.offload:
|
||||
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
|
||||
|
||||
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
|
||||
# pack the statically computed variables appropriately here. This is so that we don't
|
||||
@@ -1367,20 +1383,52 @@ def main(args):
|
||||
|
||||
vae_config_scaling_factor = vae.config.scaling_factor
|
||||
vae_config_shift_factor = vae.config.shift_factor
|
||||
if args.cache_latents:
|
||||
|
||||
# if cache_latents is set to True, we encode images to latents and store them.
|
||||
# Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided
|
||||
# we encode them in advance as well.
|
||||
precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
|
||||
if precompute_latents:
|
||||
t5_prompt_cache = []
|
||||
llama3_prompt_cache = []
|
||||
pooled_prompt_cache = []
|
||||
latents_cache = []
|
||||
if args.offload:
|
||||
vae = vae.to(accelerator.device)
|
||||
for batch in tqdm(train_dataloader, desc="Caching latents"):
|
||||
with torch.no_grad():
|
||||
batch["pixel_values"] = batch["pixel_values"].to(
|
||||
accelerator.device, non_blocking=True, dtype=vae.dtype
|
||||
)
|
||||
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
|
||||
if args.cache_latents:
|
||||
batch["pixel_values"] = batch["pixel_values"].to(
|
||||
accelerator.device, non_blocking=True, dtype=vae.dtype
|
||||
)
|
||||
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
|
||||
if train_dataset.custom_instance_prompts:
|
||||
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
|
||||
t5_prompt_embeds, llama3_prompt_embeds, pooled_prompt_embeds, _, _, _ = compute_text_embeddings(
|
||||
batch["prompts"], text_encoding_pipeline
|
||||
)
|
||||
t5_prompt_cache.append(t5_prompt_embeds)
|
||||
llama3_prompt_cache.append(llama3_prompt_embeds)
|
||||
pooled_prompt_cache.append(pooled_prompt_embeds)
|
||||
|
||||
if args.validation_prompt is None:
|
||||
# move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624
|
||||
if args.offload or args.cache_latents:
|
||||
vae = vae.to("cpu")
|
||||
if args.cache_latents:
|
||||
del vae
|
||||
free_memory()
|
||||
# move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624
|
||||
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
|
||||
del (
|
||||
text_encoder_one,
|
||||
text_encoder_two,
|
||||
text_encoder_three,
|
||||
text_encoder_four,
|
||||
tokenizer_two,
|
||||
tokenizer_three,
|
||||
tokenizer_four,
|
||||
text_encoding_pipeline,
|
||||
)
|
||||
free_memory()
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
@@ -1487,9 +1535,9 @@ def main(args):
|
||||
with accelerator.accumulate(models_to_accumulate):
|
||||
# encode batch prompts when custom prompts are provided for each image -
|
||||
if train_dataset.custom_instance_prompts:
|
||||
t5_prompt_embeds, llama3_prompt_embeds, pooled_prompt_embeds = compute_text_embeddings(
|
||||
prompts, text_encoding_pipeline
|
||||
)
|
||||
t5_prompt_embeds = t5_prompt_cache[step]
|
||||
llama3_prompt_embeds = llama3_prompt_cache[step]
|
||||
pooled_prompt_embeds = pooled_prompt_cache[step]
|
||||
else:
|
||||
t5_prompt_embeds = t5_prompt_embeds.repeat(len(prompts), 1, 1)
|
||||
llama3_prompt_embeds = llama3_prompt_embeds.repeat(1, len(prompts), 1, 1)
|
||||
@@ -1619,26 +1667,30 @@ def main(args):
|
||||
# create pipeline
|
||||
pipeline = HiDreamImagePipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
tokenizer_4=tokenizer_four,
|
||||
text_encoder_4=text_encoder_four,
|
||||
tokenizer=None,
|
||||
text_encoder=None,
|
||||
tokenizer_2=None,
|
||||
text_encoder_2=None,
|
||||
tokenizer_3=None,
|
||||
text_encoder_3=None,
|
||||
tokenizer_4=None,
|
||||
text_encoder_4=None,
|
||||
transformer=accelerator.unwrap_model(transformer),
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
pipeline_args = {"prompt": args.validation_prompt}
|
||||
images = log_validation(
|
||||
pipeline=pipeline,
|
||||
args=args,
|
||||
accelerator=accelerator,
|
||||
pipeline_args=pipeline_args,
|
||||
pipeline_args=validation_embeddings,
|
||||
torch_dtype=weight_dtype,
|
||||
epoch=epoch,
|
||||
)
|
||||
free_memory()
|
||||
|
||||
images = None
|
||||
del pipeline
|
||||
images = None
|
||||
free_memory()
|
||||
|
||||
# Save the lora layers
|
||||
accelerator.wait_for_everyone()
|
||||
@@ -1655,50 +1707,49 @@ def main(args):
|
||||
transformer_lora_layers=transformer_lora_layers,
|
||||
)
|
||||
|
||||
# Final inference
|
||||
# Load previous pipeline
|
||||
tokenizer_4 = AutoTokenizer.from_pretrained(args.pretrained_tokenizer_4_name_or_path)
|
||||
tokenizer_4.pad_token = tokenizer_4.eos_token
|
||||
text_encoder_4 = LlamaForCausalLM.from_pretrained(
|
||||
args.pretrained_text_encoder_4_name_or_path,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
pipeline = HiDreamImagePipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
tokenizer_4=tokenizer_4,
|
||||
text_encoder_4=text_encoder_4,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
# load attention processors
|
||||
pipeline.load_lora_weights(args.output_dir)
|
||||
|
||||
# run inference
|
||||
images = []
|
||||
if (args.validation_prompt and args.num_validation_images > 0) or (args.final_validation_prompt):
|
||||
prompt_to_use = args.validation_prompt if args.validation_prompt else args.final_validation_prompt
|
||||
args.num_validation_images = args.num_validation_images if args.num_validation_images else 1
|
||||
pipeline_args = {"prompt": prompt_to_use, "num_images_per_prompt": args.num_validation_images}
|
||||
run_validation = (args.validation_prompt and args.num_validation_images > 0) or (args.final_validation_prompt)
|
||||
should_run_final_inference = not args.skip_final_inference and run_validation
|
||||
if should_run_final_inference:
|
||||
# Final inference
|
||||
# Load previous pipeline
|
||||
pipeline = HiDreamImagePipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
tokenizer=None,
|
||||
text_encoder=None,
|
||||
tokenizer_2=None,
|
||||
text_encoder_2=None,
|
||||
tokenizer_3=None,
|
||||
text_encoder_3=None,
|
||||
tokenizer_4=None,
|
||||
text_encoder_4=None,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
# load attention processors
|
||||
pipeline.load_lora_weights(args.output_dir)
|
||||
|
||||
# run inference
|
||||
images = log_validation(
|
||||
pipeline=pipeline,
|
||||
args=args,
|
||||
accelerator=accelerator,
|
||||
pipeline_args=pipeline_args,
|
||||
pipeline_args=validation_embeddings,
|
||||
epoch=epoch,
|
||||
is_final_validation=True,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
del pipeline
|
||||
free_memory()
|
||||
|
||||
validation_prpmpt = args.validation_prompt if args.validation_prompt else args.final_validation_prompt
|
||||
validation_prompt = args.validation_prompt if args.validation_prompt else args.final_validation_prompt
|
||||
save_model_card(
|
||||
(args.hub_model_id or Path(args.output_dir).name) if not args.push_to_hub else repo_id,
|
||||
images=images,
|
||||
base_model=args.pretrained_model_name_or_path,
|
||||
instance_prompt=args.instance_prompt,
|
||||
validation_prompt=validation_prpmpt,
|
||||
validation_prompt=validation_prompt,
|
||||
repo_folder=args.output_dir,
|
||||
)
|
||||
|
||||
@@ -1711,7 +1762,6 @@ def main(args):
|
||||
)
|
||||
|
||||
images = None
|
||||
del pipeline
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user