mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Store vae.config.scaling_factor to prevent missing attr reference (sdxl advanced dreambooth training script) (#12346)
Store vae.config.scaling_factor to prevent missing attr reference In sdxl advanced dreambooth training script vae.config.scaling_factor becomes inaccessible after: del vae when: --cache_latents, and no --validation_prompt Co-authored-by: Teriks <Teriks@users.noreply.github.com> Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
@@ -1929,6 +1929,8 @@ def main(args):
|
||||
|
||||
if args.cache_latents:
|
||||
latents_cache = []
|
||||
# Store vae config before potential deletion
|
||||
vae_scaling_factor = vae.config.scaling_factor
|
||||
for batch in tqdm(train_dataloader, desc="Caching latents"):
|
||||
with torch.no_grad():
|
||||
batch["pixel_values"] = batch["pixel_values"].to(
|
||||
@@ -1940,6 +1942,8 @@ def main(args):
|
||||
del vae
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
vae_scaling_factor = vae.config.scaling_factor
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
|
||||
@@ -2109,13 +2113,13 @@ def main(args):
|
||||
model_input = vae.encode(pixel_values).latent_dist.sample()
|
||||
|
||||
if latents_mean is None and latents_std is None:
|
||||
model_input = model_input * vae.config.scaling_factor
|
||||
model_input = model_input * vae_scaling_factor
|
||||
if args.pretrained_vae_model_name_or_path is None:
|
||||
model_input = model_input.to(weight_dtype)
|
||||
else:
|
||||
latents_mean = latents_mean.to(device=model_input.device, dtype=model_input.dtype)
|
||||
latents_std = latents_std.to(device=model_input.device, dtype=model_input.dtype)
|
||||
model_input = (model_input - latents_mean) * vae.config.scaling_factor / latents_std
|
||||
model_input = (model_input - latents_mean) * vae_scaling_factor / latents_std
|
||||
model_input = model_input.to(dtype=weight_dtype)
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
|
||||
Reference in New Issue
Block a user