mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[SDXL ControlNet Training] Follow-up fixes (#4188)
* hash computation. thanks to @lhoestq * disable dtype casting. * remove comments.
This commit is contained in:
@@ -1001,7 +1001,12 @@ def main(args):
|
||||
proportion_empty_prompts=args.proportion_empty_prompts,
|
||||
)
|
||||
with accelerator.main_process_first():
|
||||
train_dataset = train_dataset.map(compute_embeddings_fn, batched=True)
|
||||
from datasets.fingerprint import Hasher
|
||||
|
||||
# fingerprint used by the cache for the other processes to load the result
|
||||
# details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401
|
||||
new_fingerprint = Hasher.hash(args)
|
||||
train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint)
|
||||
|
||||
del text_encoders, tokenizers
|
||||
gc.collect()
|
||||
@@ -1113,8 +1118,6 @@ def main(args):
|
||||
# Convert images to latent space
|
||||
if args.pretrained_vae_model_name_or_path is not None:
|
||||
pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
|
||||
if vae.dtype != weight_dtype:
|
||||
vae.to(dtype=weight_dtype)
|
||||
else:
|
||||
pixel_values = batch["pixel_values"]
|
||||
latents = vae.encode(pixel_values).latent_dist.sample()
|
||||
|
||||
Reference in New Issue
Block a user