From 111228cb396f0ed33cdeb7dc718e20d7d629d2f1 Mon Sep 17 00:00:00 2001 From: wfng92 <43742196+wfng92@users.noreply.github.com> Date: Wed, 8 Feb 2023 00:36:32 +0800 Subject: [PATCH] Fix torchvision.transforms and transforms function naming clash (#2274) * Fix torchvision.transforms and transforms function naming clash * Update unconditional script for onnx * Apply suggestions from code review Co-authored-by: Pedro Cuenca --------- Co-authored-by: Patrick von Platen Co-authored-by: Pedro Cuenca --- .../unconditional_image_generation/train_unconditional.py | 4 ++-- .../unconditional_image_generation/train_unconditional.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py b/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py index b26f2218f4..52cd99c046 100644 --- a/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py +++ b/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py @@ -386,13 +386,13 @@ def main(args): ] ) - def transforms(examples): + def transform_images(examples): images = [augmentations(image.convert("RGB")) for image in examples["image"]] return {"input": images} logger.info(f"Dataset size: {len(dataset)}") - dataset.set_transform(transforms) + dataset.set_transform(transform_images) train_dataloader = torch.utils.data.DataLoader( dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers ) diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index 9a72463bb3..d4df7adacb 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -386,13 +386,13 @@ def main(args): ] ) - def transforms(examples): + def transform_images(examples): images = [augmentations(image.convert("RGB")) for image in examples["image"]] return {"input": images} logger.info(f"Dataset size: {len(dataset)}") - dataset.set_transform(transforms) + dataset.set_transform(transform_images) train_dataloader = torch.utils.data.DataLoader( dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers )