mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[deepspeed] partial ZeRO-3 support (#3076)
* [deepspeed] partial ZeRO-3 support * cleanup * improve deepspeed fixes * Improve * make style --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -29,6 +29,7 @@ import torch.utils.checkpoint
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.state import AcceleratorState
|
||||
from accelerate.utils import ProjectConfiguration, set_seed
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
@@ -36,6 +37,7 @@ from packaging import version
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
from transformers.utils import ContextManagers
|
||||
|
||||
import diffusers
|
||||
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
||||
@@ -464,10 +466,34 @@ def main():
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
|
||||
)
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
|
||||
|
||||
def deepspeed_zero_init_disabled_context_manager():
|
||||
"""
|
||||
returns either a context list that includes one that will disable zero.Init or an empty context list
|
||||
"""
|
||||
deepspeed_plugin = AcceleratorState() if accelerate.state.is_initialized() else None
|
||||
if deepspeed_plugin is None:
|
||||
return []
|
||||
|
||||
return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
|
||||
|
||||
# Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3.
|
||||
# For this to work properly all models must be run through `accelerate.prepare`. But accelerate
|
||||
# will try to assign the same optimizer with the same weights to all models during
|
||||
# `deepspeed.initialize`, which of course doesn't work.
|
||||
#
|
||||
# For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2
|
||||
# frozen models from being partitioned during `zero.Init` which gets called during
|
||||
# `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding
|
||||
# across multiple gpus and only UNet2DConditionModel will get ZeRO sharded.
|
||||
with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
|
||||
)
|
||||
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
|
||||
)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import contextlib
|
||||
import copy
|
||||
import os
|
||||
import random
|
||||
@@ -6,7 +7,11 @@ from typing import Any, Dict, Iterable, Optional, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .utils import deprecate
|
||||
from .utils import deprecate, is_transformers_available
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
import transformers
|
||||
|
||||
|
||||
def enable_full_determinism(seed: int):
|
||||
@@ -197,11 +202,19 @@ class EMAModel:
|
||||
self.cur_decay_value = decay
|
||||
one_minus_decay = 1 - decay
|
||||
|
||||
context_manager = contextlib.nullcontext
|
||||
if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled():
|
||||
import deepspeed
|
||||
|
||||
for s_param, param in zip(self.shadow_params, parameters):
|
||||
if param.requires_grad:
|
||||
s_param.sub_(one_minus_decay * (s_param - param))
|
||||
else:
|
||||
s_param.copy_(param)
|
||||
if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled():
|
||||
context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None)
|
||||
|
||||
with context_manager():
|
||||
if param.requires_grad:
|
||||
s_param.sub_(one_minus_decay * (s_param - param))
|
||||
else:
|
||||
s_param.copy_(param)
|
||||
|
||||
def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user