From 1bd4c9e93dcbb31135aa8594aaf28f7b6efd39ab Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Fri, 14 Apr 2023 06:39:25 -1000 Subject: [PATCH] remvoe one line as requested by gc team (#3077) remvoe one line --- examples/text_to_image/train_text_to_image_flax.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/examples/text_to_image/train_text_to_image_flax.py b/examples/text_to_image/train_text_to_image_flax.py index 41a02d68f2..d44731896c 100644 --- a/examples/text_to_image/train_text_to_image_flax.py +++ b/examples/text_to_image/train_text_to_image_flax.py @@ -340,11 +340,10 @@ def main(): return examples - if jax.process_index() == 0: - if args.max_train_samples is not None: - dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) + if args.max_train_samples is not None: + dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) # Set the training transforms - train_dataset = dataset["train"].with_transform(preprocess_train) + train_dataset = dataset["train"].with_transform(preprocess_train) def collate_fn(examples): pixel_values = torch.stack([example["pixel_values"] for example in examples])