1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Store vae.config.scaling_factor to prevent missing attr reference (sdxl advanced dreambooth training script) (#12346)

Store vae.config.scaling_factor to prevent missing attr reference

In sdxl advanced dreambooth training script

vae.config.scaling_factor becomes inaccessible after: del vae

when: --cache_latents, and no --validation_prompt

Co-authored-by: Teriks <Teriks@users.noreply.github.com>
Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
Teriks
2026-01-09 13:42:30 -06:00
committed by GitHub
parent 644169433f
commit 57e57cfae0

View File

@@ -1929,6 +1929,8 @@ def main(args):
if args.cache_latents:
latents_cache = []
# Store vae config before potential deletion
vae_scaling_factor = vae.config.scaling_factor
for batch in tqdm(train_dataloader, desc="Caching latents"):
with torch.no_grad():
batch["pixel_values"] = batch["pixel_values"].to(
@@ -1940,6 +1942,8 @@ def main(args):
del vae
if torch.cuda.is_available():
torch.cuda.empty_cache()
else:
vae_scaling_factor = vae.config.scaling_factor
# Scheduler and math around the number of training steps.
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
@@ -2109,13 +2113,13 @@ def main(args):
model_input = vae.encode(pixel_values).latent_dist.sample()
if latents_mean is None and latents_std is None:
model_input = model_input * vae.config.scaling_factor
model_input = model_input * vae_scaling_factor
if args.pretrained_vae_model_name_or_path is None:
model_input = model_input.to(weight_dtype)
else:
latents_mean = latents_mean.to(device=model_input.device, dtype=model_input.dtype)
latents_std = latents_std.to(device=model_input.device, dtype=model_input.dtype)
model_input = (model_input - latents_mean) * vae.config.scaling_factor / latents_std
model_input = (model_input - latents_mean) * vae_scaling_factor / latents_std
model_input = model_input.to(dtype=weight_dtype)
# Sample noise that we'll add to the latents