From 723dbdd36300cd5a14000b828aaef87ba7e1fa68 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81lvaro=20Somoza?= Date: Tue, 8 Apr 2025 02:56:07 -0400 Subject: [PATCH] [Training] Better image interpolation in training scripts (#11206) * initial * Update examples/dreambooth/train_dreambooth_lora_sdxl.py Co-authored-by: hlky * update --------- Co-authored-by: Sayak Paul Co-authored-by: hlky --- .../dreambooth/train_dreambooth_lora_sdxl.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index f0d993ad9b..735d48b834 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -669,6 +669,16 @@ def parse_args(input_args=None): ), ) + parser.add_argument( + "--image_interpolation_mode", + type=str, + default="lanczos", + choices=[ + f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__") + ], + help="The image interpolation method to use for resizing images.", + ) + if input_args is not None: args = parser.parse_args(input_args) else: @@ -790,7 +800,12 @@ class DreamBoothDataset(Dataset): self.original_sizes = [] self.crop_top_lefts = [] self.pixel_values = [] - train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) + + interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None) + if interpolation is None: + raise ValueError(f"Unsupported interpolation mode {interpolation=}.") + train_resize = transforms.Resize(size, interpolation=interpolation) + train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) train_flip = transforms.RandomHorizontalFlip(p=1.0) train_transforms = transforms.Compose(