From 58431f102cf39c3c8a569f32d71b2ea8caa461e1 Mon Sep 17 00:00:00 2001 From: Youlun Peng <116731168+YoulunPeng@users.noreply.github.com> Date: Tue, 29 Apr 2025 14:47:02 +0200 Subject: [PATCH] Set LANCZOS as the default interpolation for image resizing in ControlNet training (#11449) Set LANCZOS as the default interpolation for image resizing --- examples/controlnet/train_controlnet_flux.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/examples/controlnet/train_controlnet_flux.py b/examples/controlnet/train_controlnet_flux.py index edee3fbe55..232d3da8e8 100644 --- a/examples/controlnet/train_controlnet_flux.py +++ b/examples/controlnet/train_controlnet_flux.py @@ -639,6 +639,15 @@ def parse_args(input_args=None): action="store_true", help="Enable model cpu offload and save memory.", ) + parser.add_argument( + "--image_interpolation_mode", + type=str, + default="lanczos", + choices=[ + f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__") + ], + help="The image interpolation method to use for resizing images.", + ) if input_args is not None: args = parser.parse_args(input_args) @@ -736,9 +745,13 @@ def get_train_dataset(args, accelerator): def prepare_train_dataset(dataset, accelerator): + interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None) + if interpolation is None: + raise ValueError(f"Unsupported interpolation mode {interpolation=}.") + image_transforms = transforms.Compose( [ - transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.Resize(args.resolution, interpolation=interpolation), transforms.CenterCrop(args.resolution), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), @@ -747,7 +760,7 @@ def prepare_train_dataset(dataset, accelerator): conditioning_image_transforms = transforms.Compose( [ - transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.Resize(args.resolution, interpolation=interpolation), transforms.CenterCrop(args.resolution), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]),