From 0d4dfbbd0a26d463d45a79b4667e288e00c3e0a0 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 7 Mar 2024 15:19:58 +0530 Subject: [PATCH] [Examples] fix: prior preservation setting in DreamBooth LoRA SDXL script. (#7242) fix: prior preservation setting in DreamBooth LoRA SDXL script. Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 0d7876db95..6e920d1a22 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -877,6 +877,8 @@ def collate_fn(examples, with_prior_preservation=False): if with_prior_preservation: pixel_values += [example["class_images"] for example in examples] prompts += [example["class_prompt"] for example in examples] + original_sizes += [example["original_size"] for example in examples] + crop_top_lefts += [example["crop_top_left"] for example in examples] pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()