diff --git a/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py b/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py index b26f2218f4..52cd99c046 100644 --- a/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py +++ b/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py @@ -386,13 +386,13 @@ def main(args): ] ) - def transforms(examples): + def transform_images(examples): images = [augmentations(image.convert("RGB")) for image in examples["image"]] return {"input": images} logger.info(f"Dataset size: {len(dataset)}") - dataset.set_transform(transforms) + dataset.set_transform(transform_images) train_dataloader = torch.utils.data.DataLoader( dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers ) diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index 9a72463bb3..d4df7adacb 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -386,13 +386,13 @@ def main(args): ] ) - def transforms(examples): + def transform_images(examples): images = [augmentations(image.convert("RGB")) for image in examples["image"]] return {"input": images} logger.info(f"Dataset size: {len(dataset)}") - dataset.set_transform(transforms) + dataset.set_transform(transform_images) train_dataloader = torch.utils.data.DataLoader( dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers )