mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Update train_unconditional.py (#3899)
increase the time of timeout when using big dataset or high resolution
This commit is contained in:
@@ -4,6 +4,7 @@ import logging
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
@@ -11,7 +12,7 @@ import accelerate
|
||||
import datasets
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from accelerate import Accelerator
|
||||
from accelerate import Accelerator, InitProcessGroupKwargs
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import ProjectConfiguration
|
||||
from datasets import load_dataset
|
||||
@@ -286,11 +287,13 @@ def main(args):
|
||||
logging_dir = os.path.join(args.output_dir, args.logging_dir)
|
||||
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
||||
|
||||
kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=7200))#a big number for high resolution or big dataset
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
mixed_precision=args.mixed_precision,
|
||||
log_with=args.logger,
|
||||
project_config=accelerator_project_config,
|
||||
kwargs_handlers=[kwargs],
|
||||
)
|
||||
|
||||
if args.logger == "tensorboard":
|
||||
|
||||
Reference in New Issue
Block a user