From af2a237676ada656889de5e5b96ce609e37ed8c4 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Thu, 11 May 2023 08:59:20 -0700 Subject: [PATCH] [deepspeed] partial ZeRO-3 support (#3076) * [deepspeed] partial ZeRO-3 support * cleanup * improve deepspeed fixes * Improve * make style --------- Co-authored-by: Patrick von Platen --- examples/text_to_image/train_text_to_image.py | 34 ++++++++++++++++--- src/diffusers/training_utils.py | 23 ++++++++++--- 2 files changed, 48 insertions(+), 9 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index f9592e5adc..1a6f4cde27 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -29,6 +29,7 @@ import torch.utils.checkpoint import transformers from accelerate import Accelerator from accelerate.logging import get_logger +from accelerate.state import AcceleratorState from accelerate.utils import ProjectConfiguration, set_seed from datasets import load_dataset from huggingface_hub import create_repo, upload_folder @@ -36,6 +37,7 @@ from packaging import version from torchvision import transforms from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer +from transformers.utils import ContextManagers import diffusers from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel @@ -464,10 +466,34 @@ def main(): tokenizer = CLIPTokenizer.from_pretrained( args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision ) - text_encoder = CLIPTextModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision - ) - vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) + + def deepspeed_zero_init_disabled_context_manager(): + """ + returns either a context list that includes one that will disable zero.Init or an empty context list + """ + deepspeed_plugin = AcceleratorState() if accelerate.state.is_initialized() else None + if deepspeed_plugin is None: + return [] + + return [deepspeed_plugin.zero3_init_context_manager(enable=False)] + + # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3. + # For this to work properly all models must be run through `accelerate.prepare`. But accelerate + # will try to assign the same optimizer with the same weights to all models during + # `deepspeed.initialize`, which of course doesn't work. + # + # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2 + # frozen models from being partitioned during `zero.Init` which gets called during + # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding + # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded. + with ContextManagers(deepspeed_zero_init_disabled_context_manager()): + text_encoder = CLIPTextModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + ) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision + ) + unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision ) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 340b96e29a..1a3abb49a0 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -1,3 +1,4 @@ +import contextlib import copy import os import random @@ -6,7 +7,11 @@ from typing import Any, Dict, Iterable, Optional, Union import numpy as np import torch -from .utils import deprecate +from .utils import deprecate, is_transformers_available + + +if is_transformers_available(): + import transformers def enable_full_determinism(seed: int): @@ -197,11 +202,19 @@ class EMAModel: self.cur_decay_value = decay one_minus_decay = 1 - decay + context_manager = contextlib.nullcontext + if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled(): + import deepspeed + for s_param, param in zip(self.shadow_params, parameters): - if param.requires_grad: - s_param.sub_(one_minus_decay * (s_param - param)) - else: - s_param.copy_(param) + if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled(): + context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None) + + with context_manager(): + if param.requires_grad: + s_param.sub_(one_minus_decay * (s_param - param)) + else: + s_param.copy_(param) def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: """