1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Move buffers to device (#10523)

* Move buffers to device

* add test

* named_buffers
This commit is contained in:
hlky
2025-01-16 17:42:56 +00:00
committed by GitHub
parent b785ddb654
commit 0b065c099a
4 changed files with 56 additions and 2 deletions

View File

@@ -362,6 +362,7 @@ class FromOriginalModelMixin:
if is_accelerate_available():
param_device = torch.device(device) if device else torch.device("cpu")
named_buffers = model.named_buffers()
unexpected_keys = load_model_dict_into_meta(
model,
diffusers_format_checkpoint,
@@ -369,6 +370,7 @@ class FromOriginalModelMixin:
device=param_device,
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
named_buffers=named_buffers,
)
else:

View File

@@ -20,7 +20,7 @@ import os
from array import array
from collections import OrderedDict
from pathlib import Path
from typing import Dict, List, Optional, Union
from typing import Dict, Iterator, List, Optional, Tuple, Union
import safetensors
import torch
@@ -193,6 +193,7 @@ def load_model_dict_into_meta(
model_name_or_path: Optional[str] = None,
hf_quantizer=None,
keep_in_fp32_modules=None,
named_buffers: Optional[Iterator[Tuple[str, torch.Tensor]]] = None,
) -> List[str]:
if device is not None and not isinstance(device, (str, torch.device)):
raise ValueError(f"Expected device to have type `str` or `torch.device`, but got {type(device)=}.")
@@ -254,6 +255,20 @@ def load_model_dict_into_meta(
else:
set_module_tensor_to_device(model, param_name, device, value=param)
if named_buffers is None:
return unexpected_keys
for param_name, param in named_buffers:
if is_quantized and (
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
):
hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys)
else:
if accepts_dtype:
set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs)
else:
set_module_tensor_to_device(model, param_name, device, value=param)
return unexpected_keys

View File

@@ -913,6 +913,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
" those weights or else make sure your checkpoint file is correct."
)
named_buffers = model.named_buffers()
unexpected_keys = load_model_dict_into_meta(
model,
state_dict,
@@ -921,6 +923,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
model_name_or_path=pretrained_model_name_or_path,
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
named_buffers=named_buffers,
)
if cls._keys_to_ignore_on_load_unexpected is not None:

View File

@@ -20,7 +20,14 @@ import numpy as np
import pytest
from huggingface_hub import hf_hub_download
from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel, logging
from diffusers import (
BitsAndBytesConfig,
DiffusionPipeline,
FluxTransformer2DModel,
SanaTransformer2DModel,
SD3Transformer2DModel,
logging,
)
from diffusers.utils import is_accelerate_version
from diffusers.utils.testing_utils import (
CaptureLogger,
@@ -302,6 +309,33 @@ class BnB8bitBasicTests(Base8bitTests):
_ = self.model_fp16.cuda()
class Bnb8bitDeviceTests(Base8bitTests):
def setUp(self) -> None:
gc.collect()
torch.cuda.empty_cache()
mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
self.model_8bit = SanaTransformer2DModel.from_pretrained(
"Efficient-Large-Model/Sana_1600M_4Kpx_BF16_diffusers",
subfolder="transformer",
quantization_config=mixed_int8_config,
)
def tearDown(self):
del self.model_8bit
gc.collect()
torch.cuda.empty_cache()
def test_buffers_device_assignment(self):
for buffer_name, buffer in self.model_8bit.named_buffers():
self.assertEqual(
buffer.device.type,
torch.device(torch_device).type,
f"Expected device {torch_device} for {buffer_name} got {buffer.device}.",
)
class BnB8bitTrainingTests(Base8bitTests):
def setUp(self):
gc.collect()