mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add LANCZOS as default interplotation mode. (#11463)
* Add LANCZOS as default interplotation mode. * LANCZOS as default interplotation * LANCZOS as default interplotation mode * Added LANCZOS as default interplotation mode
This commit is contained in:
@@ -134,7 +134,25 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
|
||||
|
||||
for validation_prompt, validation_image in zip(validation_prompts, validation_images):
|
||||
validation_image = Image.open(validation_image).convert("RGB")
|
||||
validation_image = validation_image.resize((args.resolution, args.resolution))
|
||||
|
||||
try:
|
||||
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper())
|
||||
except (AttributeError, KeyError):
|
||||
supported_interpolation_modes = [
|
||||
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
|
||||
]
|
||||
raise ValueError(
|
||||
f"Interpolation mode {args.image_interpolation_mode} is not supported. "
|
||||
f"Please select one of the following: {', '.join(supported_interpolation_modes)}"
|
||||
)
|
||||
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(args.resolution, interpolation=interpolation),
|
||||
transforms.CenterCrop(args.resolution),
|
||||
]
|
||||
)
|
||||
validation_image = transform(validation_image)
|
||||
|
||||
images = []
|
||||
|
||||
@@ -587,6 +605,15 @@ def parse_args(input_args=None):
|
||||
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_interpolation_mode",
|
||||
type=str,
|
||||
default="lanczos",
|
||||
choices=[
|
||||
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
|
||||
],
|
||||
help="The image interpolation method to use for resizing images.",
|
||||
)
|
||||
|
||||
if input_args is not None:
|
||||
args = parser.parse_args(input_args)
|
||||
@@ -732,9 +759,20 @@ def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prom
|
||||
|
||||
|
||||
def prepare_train_dataset(dataset, accelerator):
|
||||
try:
|
||||
interpolation_mode = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper())
|
||||
except (AttributeError, KeyError):
|
||||
supported_interpolation_modes = [
|
||||
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
|
||||
]
|
||||
raise ValueError(
|
||||
f"Interpolation mode {args.image_interpolation_mode} is not supported. "
|
||||
f"Please select one of the following: {', '.join(supported_interpolation_modes)}"
|
||||
)
|
||||
|
||||
image_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.Resize(args.resolution, interpolation=interpolation_mode),
|
||||
transforms.CenterCrop(args.resolution),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
@@ -743,7 +781,7 @@ def prepare_train_dataset(dataset, accelerator):
|
||||
|
||||
conditioning_image_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.Resize(args.resolution, interpolation=interpolation_mode),
|
||||
transforms.CenterCrop(args.resolution),
|
||||
transforms.ToTensor(),
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user