mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Move buffers to device (#10523)
* Move buffers to device * add test * named_buffers
This commit is contained in:
@@ -362,6 +362,7 @@ class FromOriginalModelMixin:
|
||||
|
||||
if is_accelerate_available():
|
||||
param_device = torch.device(device) if device else torch.device("cpu")
|
||||
named_buffers = model.named_buffers()
|
||||
unexpected_keys = load_model_dict_into_meta(
|
||||
model,
|
||||
diffusers_format_checkpoint,
|
||||
@@ -369,6 +370,7 @@ class FromOriginalModelMixin:
|
||||
device=param_device,
|
||||
hf_quantizer=hf_quantizer,
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
named_buffers=named_buffers,
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
@@ -20,7 +20,7 @@ import os
|
||||
from array import array
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
@@ -193,6 +193,7 @@ def load_model_dict_into_meta(
|
||||
model_name_or_path: Optional[str] = None,
|
||||
hf_quantizer=None,
|
||||
keep_in_fp32_modules=None,
|
||||
named_buffers: Optional[Iterator[Tuple[str, torch.Tensor]]] = None,
|
||||
) -> List[str]:
|
||||
if device is not None and not isinstance(device, (str, torch.device)):
|
||||
raise ValueError(f"Expected device to have type `str` or `torch.device`, but got {type(device)=}.")
|
||||
@@ -254,6 +255,20 @@ def load_model_dict_into_meta(
|
||||
else:
|
||||
set_module_tensor_to_device(model, param_name, device, value=param)
|
||||
|
||||
if named_buffers is None:
|
||||
return unexpected_keys
|
||||
|
||||
for param_name, param in named_buffers:
|
||||
if is_quantized and (
|
||||
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
|
||||
):
|
||||
hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys)
|
||||
else:
|
||||
if accepts_dtype:
|
||||
set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs)
|
||||
else:
|
||||
set_module_tensor_to_device(model, param_name, device, value=param)
|
||||
|
||||
return unexpected_keys
|
||||
|
||||
|
||||
|
||||
@@ -913,6 +913,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
" those weights or else make sure your checkpoint file is correct."
|
||||
)
|
||||
|
||||
named_buffers = model.named_buffers()
|
||||
|
||||
unexpected_keys = load_model_dict_into_meta(
|
||||
model,
|
||||
state_dict,
|
||||
@@ -921,6 +923,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
model_name_or_path=pretrained_model_name_or_path,
|
||||
hf_quantizer=hf_quantizer,
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
named_buffers=named_buffers,
|
||||
)
|
||||
|
||||
if cls._keys_to_ignore_on_load_unexpected is not None:
|
||||
|
||||
@@ -20,7 +20,14 @@ import numpy as np
|
||||
import pytest
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel, logging
|
||||
from diffusers import (
|
||||
BitsAndBytesConfig,
|
||||
DiffusionPipeline,
|
||||
FluxTransformer2DModel,
|
||||
SanaTransformer2DModel,
|
||||
SD3Transformer2DModel,
|
||||
logging,
|
||||
)
|
||||
from diffusers.utils import is_accelerate_version
|
||||
from diffusers.utils.testing_utils import (
|
||||
CaptureLogger,
|
||||
@@ -302,6 +309,33 @@ class BnB8bitBasicTests(Base8bitTests):
|
||||
_ = self.model_fp16.cuda()
|
||||
|
||||
|
||||
class Bnb8bitDeviceTests(Base8bitTests):
|
||||
def setUp(self) -> None:
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
self.model_8bit = SanaTransformer2DModel.from_pretrained(
|
||||
"Efficient-Large-Model/Sana_1600M_4Kpx_BF16_diffusers",
|
||||
subfolder="transformer",
|
||||
quantization_config=mixed_int8_config,
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
del self.model_8bit
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_buffers_device_assignment(self):
|
||||
for buffer_name, buffer in self.model_8bit.named_buffers():
|
||||
self.assertEqual(
|
||||
buffer.device.type,
|
||||
torch.device(torch_device).type,
|
||||
f"Expected device {torch_device} for {buffer_name} got {buffer.device}.",
|
||||
)
|
||||
|
||||
|
||||
class BnB8bitTrainingTests(Base8bitTests):
|
||||
def setUp(self):
|
||||
gc.collect()
|
||||
|
||||
Reference in New Issue
Block a user