mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix torchvision.transforms and transforms function naming clash (#2274)
* Fix torchvision.transforms and transforms function naming clash * Update unconditional script for onnx * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
@@ -386,13 +386,13 @@ def main(args):
|
||||
]
|
||||
)
|
||||
|
||||
def transforms(examples):
|
||||
def transform_images(examples):
|
||||
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
|
||||
return {"input": images}
|
||||
|
||||
logger.info(f"Dataset size: {len(dataset)}")
|
||||
|
||||
dataset.set_transform(transforms)
|
||||
dataset.set_transform(transform_images)
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
|
||||
)
|
||||
|
||||
@@ -386,13 +386,13 @@ def main(args):
|
||||
]
|
||||
)
|
||||
|
||||
def transforms(examples):
|
||||
def transform_images(examples):
|
||||
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
|
||||
return {"input": images}
|
||||
|
||||
logger.info(f"Dataset size: {len(dataset)}")
|
||||
|
||||
dataset.set_transform(transforms)
|
||||
dataset.set_transform(transform_images)
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user