From da95a28ff673c4f0c772fc9746870c4f92f0dfb1 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 22 Jan 2024 20:14:54 +0530 Subject: [PATCH] [Diffusion DPO] apply fixes from #6547 (#6668) apply fixes from #6547 --- .../diffusion_dpo/train_diffusion_dpo_sdxl.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py b/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py index 23e94bc679..3fd3fc11ad 100644 --- a/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py +++ b/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py @@ -740,6 +740,10 @@ def main(args): # Resize. combined_im = train_resize(combined_im) + # Flipping. + if not args.no_flip and random.random() < 0.5: + combined_im = train_flip(combined_im) + # Cropping. if not args.random_crop: y1 = max(0, int(round((combined_im.shape[1] - args.resolution) / 2.0))) @@ -749,11 +753,6 @@ def main(args): y1, x1, h, w = train_crop.get_params(combined_im, (args.resolution, args.resolution)) combined_im = crop(combined_im, y1, x1, h, w) - # Flipping. - if random.random() < 0.5: - x1 = combined_im.shape[2] - x1 - combined_im = train_flip(combined_im) - crop_top_left = (y1, x1) crop_top_lefts.append(crop_top_left) combined_im = normalize(combined_im)