From ed6cf52572a4e1efd950ea186ad14bda16405a05 Mon Sep 17 00:00:00 2001 From: Yuanzhou <80858000+ca1yz@users.noreply.github.com> Date: Sat, 3 May 2025 04:46:01 +0800 Subject: [PATCH] [train_dreambooth_lora_sdxl_advanced] Add LANCZOS as the default interpolation mode for image resizing (#11471) --- .../train_dreambooth_lora_sdxl_advanced.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 25bbe155ee..dae618f43a 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -799,6 +799,15 @@ def parse_args(input_args=None): default=False, help="Cache the VAE latents", ) + parser.add_argument( + "--image_interpolation_mode", + type=str, + default="lanczos", + choices=[ + f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__") + ], + help="The image interpolation method to use for resizing images.", + ) if input_args is not None: args = parser.parse_args(input_args) @@ -1069,7 +1078,10 @@ class DreamBoothDataset(Dataset): self.original_sizes = [] self.crop_top_lefts = [] self.pixel_values = [] - train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) + interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None) + if interpolation is None: + raise ValueError(f"Unsupported interpolation mode {interpolation=}.") + train_resize = transforms.Resize(size, interpolation=interpolation) train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) train_flip = transforms.RandomHorizontalFlip(p=1.0) train_transforms = transforms.Compose( @@ -1146,7 +1158,7 @@ class DreamBoothDataset(Dataset): self.image_transforms = transforms.Compose( [ - transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.Resize(size, interpolation=interpolation), transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]),