1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Fix torchvision.transforms and transforms function naming clash (#2274)

* Fix torchvision.transforms and transforms function naming clash

* Update unconditional script for onnx

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
wfng92
2023-02-08 00:36:32 +08:00
committed by GitHub
parent bbb46ad3d5
commit 111228cb39
2 changed files with 4 additions and 4 deletions

View File

@@ -386,13 +386,13 @@ def main(args):
]
)
def transforms(examples):
def transform_images(examples):
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
return {"input": images}
logger.info(f"Dataset size: {len(dataset)}")
dataset.set_transform(transforms)
dataset.set_transform(transform_images)
train_dataloader = torch.utils.data.DataLoader(
dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
)

View File

@@ -386,13 +386,13 @@ def main(args):
]
)
def transforms(examples):
def transform_images(examples):
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
return {"input": images}
logger.info(f"Dataset size: {len(dataset)}")
dataset.set_transform(transforms)
dataset.set_transform(transform_images)
train_dataloader = torch.utils.data.DataLoader(
dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
)