mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Set LANCZOS as the default interpolation for image resizing in ControlNet training (#11449)
Set LANCZOS as the default interpolation for image resizing
This commit is contained in:
@@ -639,6 +639,15 @@ def parse_args(input_args=None):
|
||||
action="store_true",
|
||||
help="Enable model cpu offload and save memory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_interpolation_mode",
|
||||
type=str,
|
||||
default="lanczos",
|
||||
choices=[
|
||||
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
|
||||
],
|
||||
help="The image interpolation method to use for resizing images.",
|
||||
)
|
||||
|
||||
if input_args is not None:
|
||||
args = parser.parse_args(input_args)
|
||||
@@ -736,9 +745,13 @@ def get_train_dataset(args, accelerator):
|
||||
|
||||
|
||||
def prepare_train_dataset(dataset, accelerator):
|
||||
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
|
||||
if interpolation is None:
|
||||
raise ValueError(f"Unsupported interpolation mode {interpolation=}.")
|
||||
|
||||
image_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.Resize(args.resolution, interpolation=interpolation),
|
||||
transforms.CenterCrop(args.resolution),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
@@ -747,7 +760,7 @@ def prepare_train_dataset(dataset, accelerator):
|
||||
|
||||
conditioning_image_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.Resize(args.resolution, interpolation=interpolation),
|
||||
transforms.CenterCrop(args.resolution),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
|
||||
Reference in New Issue
Block a user