mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Flux LoRA] fix issues in flux lora scripts (#11111)
* remove custom scheduler * update requirements.txt * log_validation with mixed precision * add intermediate embeddings saving when checkpointing is enabled * remove comment * fix validation * add unwrap_model for accelerator, torch.no_grad context for validation, fix accelerator.accumulate call in advanced script * revert unwrap_model change temp * add .module to address distributed training bug + replace accelerator.unwrap_model with unwrap model * changes to align advanced script with canonical script * make changes for distributed training + unify unwrap_model calls in advanced script * add module.dtype fix to dreambooth script * unify unwrap_model calls in dreambooth script * fix condition in validation run * mixed precision * Update examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * smol style change * change autocast * Apply style fixes --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
accelerate>=0.16.0
|
||||
accelerate>=0.31.0
|
||||
torchvision
|
||||
transformers>=4.25.1
|
||||
transformers>=4.41.2
|
||||
ftfy
|
||||
tensorboard
|
||||
Jinja2
|
||||
peft==0.7.0
|
||||
peft>=0.11.1
|
||||
sentencepiece
|
||||
@@ -24,7 +24,7 @@ import re
|
||||
import shutil
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -228,10 +228,20 @@ def log_validation(
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
|
||||
autocast_ctx = nullcontext()
|
||||
autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
|
||||
|
||||
with autocast_ctx:
|
||||
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
|
||||
# pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast
|
||||
with torch.no_grad():
|
||||
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
|
||||
pipeline_args["prompt"], prompt_2=pipeline_args["prompt"]
|
||||
)
|
||||
images = []
|
||||
for _ in range(args.num_validation_images):
|
||||
with autocast_ctx:
|
||||
image = pipeline(
|
||||
prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, generator=generator
|
||||
).images[0]
|
||||
images.append(image)
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
phase_name = "test" if is_final_validation else "validation"
|
||||
@@ -657,6 +667,7 @@ def parse_args(input_args=None):
|
||||
parser.add_argument(
|
||||
"--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lora_layers",
|
||||
type=str,
|
||||
@@ -666,6 +677,7 @@ def parse_args(input_args=None):
|
||||
'E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only. For more examples refer to https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/README_flux.md'
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--adam_epsilon",
|
||||
type=float,
|
||||
@@ -738,6 +750,15 @@ def parse_args(input_args=None):
|
||||
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--upcast_before_saving",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help=(
|
||||
"Whether to upcast the trained transformer layers to float32 before saving (at the end of training). "
|
||||
"Defaults to precision dtype used for training to save memory"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prior_generation_precision",
|
||||
type=str,
|
||||
@@ -1147,7 +1168,7 @@ def tokenize_prompt(tokenizer, prompt, max_sequence_length, add_special_tokens=F
|
||||
return text_input_ids
|
||||
|
||||
|
||||
def _get_t5_prompt_embeds(
|
||||
def _encode_prompt_with_t5(
|
||||
text_encoder,
|
||||
tokenizer,
|
||||
max_sequence_length=512,
|
||||
@@ -1176,7 +1197,10 @@ def _get_t5_prompt_embeds(
|
||||
|
||||
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
|
||||
|
||||
dtype = text_encoder.dtype
|
||||
if hasattr(text_encoder, "module"):
|
||||
dtype = text_encoder.module.dtype
|
||||
else:
|
||||
dtype = text_encoder.dtype
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
@@ -1188,7 +1212,7 @@ def _get_t5_prompt_embeds(
|
||||
return prompt_embeds
|
||||
|
||||
|
||||
def _get_clip_prompt_embeds(
|
||||
def _encode_prompt_with_clip(
|
||||
text_encoder,
|
||||
tokenizer,
|
||||
prompt: str,
|
||||
@@ -1217,9 +1241,13 @@ def _get_clip_prompt_embeds(
|
||||
|
||||
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
|
||||
|
||||
if hasattr(text_encoder, "module"):
|
||||
dtype = text_encoder.module.dtype
|
||||
else:
|
||||
dtype = text_encoder.dtype
|
||||
# Use pooled output of CLIPTextModel
|
||||
prompt_embeds = prompt_embeds.pooler_output
|
||||
prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
@@ -1238,136 +1266,35 @@ def encode_prompt(
|
||||
text_input_ids_list=None,
|
||||
):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
dtype = text_encoders[0].dtype
|
||||
if hasattr(text_encoders[0], "module"):
|
||||
dtype = text_encoders[0].module.dtype
|
||||
else:
|
||||
dtype = text_encoders[0].dtype
|
||||
|
||||
pooled_prompt_embeds = _get_clip_prompt_embeds(
|
||||
pooled_prompt_embeds = _encode_prompt_with_clip(
|
||||
text_encoder=text_encoders[0],
|
||||
tokenizer=tokenizers[0],
|
||||
prompt=prompt,
|
||||
device=device if device is not None else text_encoders[0].device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
text_input_ids=text_input_ids_list[0] if text_input_ids_list is not None else None,
|
||||
text_input_ids=text_input_ids_list[0] if text_input_ids_list else None,
|
||||
)
|
||||
|
||||
prompt_embeds = _get_t5_prompt_embeds(
|
||||
prompt_embeds = _encode_prompt_with_t5(
|
||||
text_encoder=text_encoders[1],
|
||||
tokenizer=tokenizers[1],
|
||||
max_sequence_length=max_sequence_length,
|
||||
prompt=prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device if device is not None else text_encoders[1].device,
|
||||
text_input_ids=text_input_ids_list[1] if text_input_ids_list is not None else None,
|
||||
text_input_ids=text_input_ids_list[1] if text_input_ids_list else None,
|
||||
)
|
||||
|
||||
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
||||
text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
|
||||
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds, text_ids
|
||||
|
||||
|
||||
# CustomFlowMatchEulerDiscreteScheduler was taken from ostris ai-toolkit trainer:
|
||||
# https://github.com/ostris/ai-toolkit/blob/9ee1ef2a0a2a9a02b92d114a95f21312e5906e54/toolkit/samplers/custom_flowmatch_sampler.py#L95
|
||||
class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
with torch.no_grad():
|
||||
# create weights for timesteps
|
||||
num_timesteps = 1000
|
||||
|
||||
# generate the multiplier based on cosmap loss weighing
|
||||
# this is only used on linear timesteps for now
|
||||
|
||||
# cosine map weighing is higher in the middle and lower at the ends
|
||||
# bot = 1 - 2 * self.sigmas + 2 * self.sigmas ** 2
|
||||
# cosmap_weighing = 2 / (math.pi * bot)
|
||||
|
||||
# sigma sqrt weighing is significantly higher at the end and lower at the beginning
|
||||
sigma_sqrt_weighing = (self.sigmas**-2.0).float()
|
||||
# clip at 1e4 (1e6 is too high)
|
||||
sigma_sqrt_weighing = torch.clamp(sigma_sqrt_weighing, max=1e4)
|
||||
# bring to a mean of 1
|
||||
sigma_sqrt_weighing = sigma_sqrt_weighing / sigma_sqrt_weighing.mean()
|
||||
|
||||
# Create linear timesteps from 1000 to 0
|
||||
timesteps = torch.linspace(1000, 0, num_timesteps, device="cpu")
|
||||
|
||||
self.linear_timesteps = timesteps
|
||||
# self.linear_timesteps_weights = cosmap_weighing
|
||||
self.linear_timesteps_weights = sigma_sqrt_weighing
|
||||
|
||||
# self.sigmas = self.get_sigmas(timesteps, n_dim=1, dtype=torch.float32, device='cpu')
|
||||
pass
|
||||
|
||||
def get_weights_for_timesteps(self, timesteps: torch.Tensor) -> torch.Tensor:
|
||||
# Get the indices of the timesteps
|
||||
step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps]
|
||||
|
||||
# Get the weights for the timesteps
|
||||
weights = self.linear_timesteps_weights[step_indices].flatten()
|
||||
|
||||
return weights
|
||||
|
||||
def get_sigmas(self, timesteps: torch.Tensor, n_dim, dtype, device) -> torch.Tensor:
|
||||
sigmas = self.sigmas.to(device=device, dtype=dtype)
|
||||
schedule_timesteps = self.timesteps.to(device)
|
||||
timesteps = timesteps.to(device)
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < n_dim:
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
|
||||
return sigma
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
## ref https://github.com/huggingface/diffusers/blob/fbe29c62984c33c6cf9cf7ad120a992fe6d20854/examples/dreambooth/train_dreambooth_sd3.py#L1578
|
||||
## Add noise according to flow matching.
|
||||
## zt = (1 - texp) * x + texp * z1
|
||||
|
||||
# sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
|
||||
# noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
|
||||
|
||||
# timestep needs to be in [0, 1], we store them in [0, 1000]
|
||||
# noisy_sample = (1 - timestep) * latent + timestep * noise
|
||||
t_01 = (timesteps / 1000).to(original_samples.device)
|
||||
noisy_model_input = (1 - t_01) * original_samples + t_01 * noise
|
||||
|
||||
# n_dim = original_samples.ndim
|
||||
# sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device)
|
||||
# noisy_model_input = (1.0 - sigmas) * original_samples + sigmas * noise
|
||||
return noisy_model_input
|
||||
|
||||
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
return sample
|
||||
|
||||
def set_train_timesteps(self, num_timesteps, device, linear=False):
|
||||
if linear:
|
||||
timesteps = torch.linspace(1000, 0, num_timesteps, device=device)
|
||||
self.timesteps = timesteps
|
||||
return timesteps
|
||||
else:
|
||||
# distribute them closer to center. Inference distributes them as a bias toward first
|
||||
# Generate values from 0 to 1
|
||||
t = torch.sigmoid(torch.randn((num_timesteps,), device=device))
|
||||
|
||||
# Scale and reverse the values to go from 1000 to 0
|
||||
timesteps = (1 - t) * 1000
|
||||
|
||||
# Sort the timesteps in descending order
|
||||
timesteps, _ = torch.sort(timesteps, descending=True)
|
||||
|
||||
self.timesteps = timesteps.to(device=device)
|
||||
|
||||
return timesteps
|
||||
|
||||
|
||||
def main(args):
|
||||
if args.report_to == "wandb" and args.hub_token is not None:
|
||||
raise ValueError(
|
||||
@@ -1499,7 +1426,7 @@ def main(args):
|
||||
)
|
||||
|
||||
# Load scheduler and models
|
||||
noise_scheduler = CustomFlowMatchEulerDiscreteScheduler.from_pretrained(
|
||||
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="scheduler"
|
||||
)
|
||||
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
||||
@@ -1619,7 +1546,6 @@ def main(args):
|
||||
target_modules=target_modules,
|
||||
)
|
||||
transformer.add_adapter(transformer_lora_config)
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
@@ -1727,7 +1653,6 @@ def main(args):
|
||||
cast_training_params(models, dtype=torch.float32)
|
||||
|
||||
transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
|
||||
# if we use textual inversion, we freeze all parameters except for the token embeddings
|
||||
@@ -1737,7 +1662,8 @@ def main(args):
|
||||
for name, param in text_encoder_one.named_parameters():
|
||||
if "token_embedding" in name:
|
||||
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
|
||||
param.data = param.to(dtype=torch.float32)
|
||||
if args.mixed_precision == "fp16":
|
||||
param.data = param.to(dtype=torch.float32)
|
||||
param.requires_grad = True
|
||||
text_lora_parameters_one.append(param)
|
||||
else:
|
||||
@@ -1747,7 +1673,8 @@ def main(args):
|
||||
for name, param in text_encoder_two.named_parameters():
|
||||
if "shared" in name:
|
||||
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
|
||||
param.data = param.to(dtype=torch.float32)
|
||||
if args.mixed_precision == "fp16":
|
||||
param.data = param.to(dtype=torch.float32)
|
||||
param.requires_grad = True
|
||||
text_lora_parameters_two.append(param)
|
||||
else:
|
||||
@@ -1828,6 +1755,7 @@ def main(args):
|
||||
optimizer_class = bnb.optim.AdamW8bit
|
||||
else:
|
||||
optimizer_class = torch.optim.AdamW
|
||||
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
@@ -2021,6 +1949,7 @@ def main(args):
|
||||
lr_scheduler,
|
||||
)
|
||||
else:
|
||||
print("I SHOULD BE HERE")
|
||||
transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
@@ -2125,7 +2054,7 @@ def main(args):
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one.train()
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
|
||||
unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
|
||||
elif args.train_text_encoder_ti: # textual inversion / pivotal tuning
|
||||
text_encoder_one.train()
|
||||
if args.enable_t5_ti:
|
||||
@@ -2137,6 +2066,11 @@ def main(args):
|
||||
pivoted_tr = True
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
models_to_accumulate = [transformer]
|
||||
if not freeze_text_encoder:
|
||||
models_to_accumulate.extend([text_encoder_one])
|
||||
if args.enable_t5_ti:
|
||||
models_to_accumulate.extend([text_encoder_two])
|
||||
if pivoted_te:
|
||||
# stopping optimization of text_encoder params
|
||||
optimizer.param_groups[te_idx]["lr"] = 0.0
|
||||
@@ -2145,7 +2079,7 @@ def main(args):
|
||||
logger.info(f"PIVOT TRANSFORMER {epoch}")
|
||||
optimizer.param_groups[0]["lr"] = 0.0
|
||||
|
||||
with accelerator.accumulate(transformer):
|
||||
with accelerator.accumulate(models_to_accumulate):
|
||||
prompts = batch["prompts"]
|
||||
|
||||
# encode batch prompts when custom prompts are provided for each image -
|
||||
@@ -2189,7 +2123,7 @@ def main(args):
|
||||
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
|
||||
model_input = model_input.to(dtype=weight_dtype)
|
||||
|
||||
vae_scale_factor = 2 ** (len(vae_config_block_out_channels))
|
||||
vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1)
|
||||
|
||||
latent_image_ids = FluxPipeline._prepare_latent_image_ids(
|
||||
model_input.shape[0],
|
||||
@@ -2228,7 +2162,7 @@ def main(args):
|
||||
)
|
||||
|
||||
# handle guidance
|
||||
if transformer.config.guidance_embeds:
|
||||
if unwrap_model(transformer).config.guidance_embeds:
|
||||
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
|
||||
guidance = guidance.expand(model_input.shape[0])
|
||||
else:
|
||||
@@ -2288,16 +2222,26 @@ def main(args):
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
if not freeze_text_encoder:
|
||||
if args.train_text_encoder:
|
||||
if args.train_text_encoder: # text encoder tuning
|
||||
params_to_clip = itertools.chain(transformer.parameters(), text_encoder_one.parameters())
|
||||
elif pure_textual_inversion:
|
||||
params_to_clip = itertools.chain(
|
||||
text_encoder_one.parameters(), text_encoder_two.parameters()
|
||||
)
|
||||
if args.enable_t5_ti:
|
||||
params_to_clip = itertools.chain(
|
||||
text_encoder_one.parameters(), text_encoder_two.parameters()
|
||||
)
|
||||
else:
|
||||
params_to_clip = itertools.chain(text_encoder_one.parameters())
|
||||
else:
|
||||
params_to_clip = itertools.chain(
|
||||
transformer.parameters(), text_encoder_one.parameters(), text_encoder_two.parameters()
|
||||
)
|
||||
if args.enable_t5_ti:
|
||||
params_to_clip = itertools.chain(
|
||||
transformer.parameters(),
|
||||
text_encoder_one.parameters(),
|
||||
text_encoder_two.parameters(),
|
||||
)
|
||||
else:
|
||||
params_to_clip = itertools.chain(
|
||||
transformer.parameters(), text_encoder_one.parameters()
|
||||
)
|
||||
else:
|
||||
params_to_clip = itertools.chain(transformer.parameters())
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
@@ -2339,6 +2283,10 @@ def main(args):
|
||||
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
accelerator.save_state(save_path)
|
||||
if args.train_text_encoder_ti:
|
||||
embedding_handler.save_embeddings(
|
||||
f"{args.output_dir}/{Path(args.output_dir).name}_emb_checkpoint_{global_step}.safetensors"
|
||||
)
|
||||
logger.info(f"Saved state to {save_path}")
|
||||
|
||||
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
||||
@@ -2351,14 +2299,16 @@ def main(args):
|
||||
if accelerator.is_main_process:
|
||||
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
|
||||
# create pipeline
|
||||
if freeze_text_encoder:
|
||||
if freeze_text_encoder: # no text encoder one, two optimizations
|
||||
text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)
|
||||
text_encoder_one.to(weight_dtype)
|
||||
text_encoder_two.to(weight_dtype)
|
||||
pipeline = FluxPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
vae=vae,
|
||||
text_encoder=accelerator.unwrap_model(text_encoder_one),
|
||||
text_encoder_2=accelerator.unwrap_model(text_encoder_two),
|
||||
transformer=accelerator.unwrap_model(transformer),
|
||||
text_encoder=unwrap_model(text_encoder_one),
|
||||
text_encoder_2=unwrap_model(text_encoder_two),
|
||||
transformer=unwrap_model(transformer),
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
@@ -2372,21 +2322,21 @@ def main(args):
|
||||
epoch=epoch,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
images = None
|
||||
del pipeline
|
||||
|
||||
if freeze_text_encoder:
|
||||
del text_encoder_one, text_encoder_two
|
||||
free_memory()
|
||||
elif args.train_text_encoder:
|
||||
del text_encoder_two
|
||||
free_memory()
|
||||
|
||||
images = None
|
||||
del pipeline
|
||||
|
||||
# Save the lora layers
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
transformer = unwrap_model(transformer)
|
||||
transformer = transformer.to(weight_dtype)
|
||||
if args.upcast_before_saving:
|
||||
transformer.to(torch.float32)
|
||||
else:
|
||||
transformer = transformer.to(weight_dtype)
|
||||
transformer_lora_layers = get_peft_model_state_dict(transformer)
|
||||
|
||||
if args.train_text_encoder:
|
||||
@@ -2428,8 +2378,8 @@ def main(args):
|
||||
accelerator=accelerator,
|
||||
pipeline_args=pipeline_args,
|
||||
epoch=epoch,
|
||||
torch_dtype=weight_dtype,
|
||||
is_final_validation=True,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
|
||||
save_model_card(
|
||||
@@ -2452,6 +2402,7 @@ def main(args):
|
||||
commit_message="End of training",
|
||||
ignore_patterns=["step_*", "epoch_*"],
|
||||
)
|
||||
|
||||
images = None
|
||||
del pipeline
|
||||
|
||||
|
||||
@@ -895,7 +895,10 @@ def _encode_prompt_with_t5(
|
||||
|
||||
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
|
||||
|
||||
dtype = text_encoder.dtype
|
||||
if hasattr(text_encoder, "module"):
|
||||
dtype = text_encoder.module.dtype
|
||||
else:
|
||||
dtype = text_encoder.dtype
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
@@ -936,9 +939,13 @@ def _encode_prompt_with_clip(
|
||||
|
||||
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
|
||||
|
||||
if hasattr(text_encoder, "module"):
|
||||
dtype = text_encoder.module.dtype
|
||||
else:
|
||||
dtype = text_encoder.dtype
|
||||
# Use pooled output of CLIPTextModel
|
||||
prompt_embeds = prompt_embeds.pooler_output
|
||||
prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
@@ -958,7 +965,12 @@ def encode_prompt(
|
||||
):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
dtype = text_encoders[0].dtype
|
||||
|
||||
if hasattr(text_encoders[0], "module"):
|
||||
dtype = text_encoders[0].module.dtype
|
||||
else:
|
||||
dtype = text_encoders[0].dtype
|
||||
|
||||
device = device if device is not None else text_encoders[1].device
|
||||
pooled_prompt_embeds = _encode_prompt_with_clip(
|
||||
text_encoder=text_encoders[0],
|
||||
@@ -1590,7 +1602,7 @@ def main(args):
|
||||
)
|
||||
|
||||
# handle guidance
|
||||
if accelerator.unwrap_model(transformer).config.guidance_embeds:
|
||||
if unwrap_model(transformer).config.guidance_embeds:
|
||||
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
|
||||
guidance = guidance.expand(model_input.shape[0])
|
||||
else:
|
||||
@@ -1716,9 +1728,9 @@ def main(args):
|
||||
pipeline = FluxPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
vae=vae,
|
||||
text_encoder=accelerator.unwrap_model(text_encoder_one, keep_fp32_wrapper=False),
|
||||
text_encoder_2=accelerator.unwrap_model(text_encoder_two, keep_fp32_wrapper=False),
|
||||
transformer=accelerator.unwrap_model(transformer, keep_fp32_wrapper=False),
|
||||
text_encoder=unwrap_model(text_encoder_one, keep_fp32_wrapper=False),
|
||||
text_encoder_2=unwrap_model(text_encoder_two, keep_fp32_wrapper=False),
|
||||
transformer=unwrap_model(transformer, keep_fp32_wrapper=False),
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
|
||||
@@ -177,16 +177,25 @@ def log_validation(
|
||||
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
|
||||
f" {args.validation_prompt}."
|
||||
)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
|
||||
# autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
|
||||
autocast_ctx = nullcontext()
|
||||
autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
|
||||
|
||||
with autocast_ctx:
|
||||
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
|
||||
# pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast
|
||||
with torch.no_grad():
|
||||
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
|
||||
pipeline_args["prompt"], prompt_2=pipeline_args["prompt"]
|
||||
)
|
||||
images = []
|
||||
for _ in range(args.num_validation_images):
|
||||
with autocast_ctx:
|
||||
image = pipeline(
|
||||
prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, generator=generator
|
||||
).images[0]
|
||||
images.append(image)
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
phase_name = "test" if is_final_validation else "validation"
|
||||
@@ -203,8 +212,7 @@ def log_validation(
|
||||
)
|
||||
|
||||
del pipeline
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
free_memory()
|
||||
|
||||
return images
|
||||
|
||||
@@ -932,7 +940,10 @@ def _encode_prompt_with_t5(
|
||||
|
||||
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
|
||||
|
||||
dtype = text_encoder.dtype
|
||||
if hasattr(text_encoder, "module"):
|
||||
dtype = text_encoder.module.dtype
|
||||
else:
|
||||
dtype = text_encoder.dtype
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
@@ -973,9 +984,13 @@ def _encode_prompt_with_clip(
|
||||
|
||||
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
|
||||
|
||||
if hasattr(text_encoder, "module"):
|
||||
dtype = text_encoder.module.dtype
|
||||
else:
|
||||
dtype = text_encoder.dtype
|
||||
# Use pooled output of CLIPTextModel
|
||||
prompt_embeds = prompt_embeds.pooler_output
|
||||
prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
@@ -994,7 +1009,11 @@ def encode_prompt(
|
||||
text_input_ids_list=None,
|
||||
):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
dtype = text_encoders[0].dtype
|
||||
|
||||
if hasattr(text_encoders[0], "module"):
|
||||
dtype = text_encoders[0].module.dtype
|
||||
else:
|
||||
dtype = text_encoders[0].dtype
|
||||
|
||||
pooled_prompt_embeds = _encode_prompt_with_clip(
|
||||
text_encoder=text_encoders[0],
|
||||
@@ -1619,7 +1638,7 @@ def main(args):
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one.train()
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
|
||||
unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
models_to_accumulate = [transformer]
|
||||
@@ -1710,7 +1729,7 @@ def main(args):
|
||||
)
|
||||
|
||||
# handle guidance
|
||||
if accelerator.unwrap_model(transformer).config.guidance_embeds:
|
||||
if unwrap_model(transformer).config.guidance_embeds:
|
||||
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
|
||||
guidance = guidance.expand(model_input.shape[0])
|
||||
else:
|
||||
@@ -1828,9 +1847,9 @@ def main(args):
|
||||
pipeline = FluxPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
vae=vae,
|
||||
text_encoder=accelerator.unwrap_model(text_encoder_one),
|
||||
text_encoder_2=accelerator.unwrap_model(text_encoder_two),
|
||||
transformer=accelerator.unwrap_model(transformer),
|
||||
text_encoder=unwrap_model(text_encoder_one),
|
||||
text_encoder_2=unwrap_model(text_encoder_two),
|
||||
transformer=unwrap_model(transformer),
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
|
||||
Reference in New Issue
Block a user