diff --git a/examples/text_to_image/train_text_to_image_flax.py b/examples/text_to_image/train_text_to_image_flax.py index 41a02d68f2..d44731896c 100644 --- a/examples/text_to_image/train_text_to_image_flax.py +++ b/examples/text_to_image/train_text_to_image_flax.py @@ -340,11 +340,10 @@ def main(): return examples - if jax.process_index() == 0: - if args.max_train_samples is not None: - dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) + if args.max_train_samples is not None: + dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) # Set the training transforms - train_dataset = dataset["train"].with_transform(preprocess_train) + train_dataset = dataset["train"].with_transform(preprocess_train) def collate_fn(examples): pixel_values = torch.stack([example["pixel_values"] for example in examples])