1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[FEAT] Model loading refactor (#10604)

* first draft model loading refactor

* revert name change

* fix bnb

* revert name

* fix dduf

* fix huanyan

* style

* Update src/diffusers/models/model_loading_utils.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* suggestions from reviews

* Update src/diffusers/models/modeling_utils.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* remove safetensors check

* fix default value

* more fix from suggestions

* revert logic for single file

* style

* typing + fix couple of issues

* improve speed

* Update src/diffusers/models/modeling_utils.py

Co-authored-by: Aryan <aryan@huggingface.co>

* fp8 dtype

* add tests

* rename resolved_archive_file to resolved_model_file

* format

* map_location default cpu

* add utility function

* switch to smaller model + test inference

* Apply suggestions from code review

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* rm comment

* add log

* Apply suggestions from code review

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* add decorator

* cosine sim instead

* fix use_keep_in_fp32_modules

* comm

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: Aryan <aryan@huggingface.co>
This commit is contained in:
Marc Sun
2025-02-19 13:04:53 +01:00
committed by GitHub
parent 6fe05b9b93
commit f5929e0306
12 changed files with 842 additions and 513 deletions

View File

@@ -52,7 +52,7 @@ logger = logging.get_logger(__name__)
if is_accelerate_available():
from accelerate import init_empty_weights
from accelerate import dispatch_model, init_empty_weights
from ..models.modeling_utils import load_model_dict_into_meta
@@ -366,19 +366,23 @@ class FromOriginalModelMixin:
keep_in_fp32_modules=keep_in_fp32_modules,
)
device_map = None
if is_accelerate_available():
param_device = torch.device(device) if device else torch.device("cpu")
named_buffers = model.named_buffers()
unexpected_keys = load_model_dict_into_meta(
empty_state_dict = model.state_dict()
unexpected_keys = [
param_name for param_name in diffusers_format_checkpoint if param_name not in empty_state_dict
]
device_map = {"": param_device}
load_model_dict_into_meta(
model,
diffusers_format_checkpoint,
dtype=torch_dtype,
device=param_device,
device_map=device_map,
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
named_buffers=named_buffers,
unexpected_keys=unexpected_keys,
)
else:
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
@@ -400,4 +404,8 @@ class FromOriginalModelMixin:
model.eval()
if device_map is not None:
device_map_kwargs = {"device_map": device_map}
dispatch_model(model, **device_map_kwargs)
return model

View File

@@ -1593,18 +1593,9 @@ def create_diffusers_clip_model_from_ldm(
raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.")
if is_accelerate_available():
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
else:
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
if model._keys_to_ignore_on_load_unexpected is not None:
for pat in model._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)
model.load_state_dict(diffusers_format_checkpoint, strict=False)
if torch_dtype is not None:
model.to(torch_dtype)
@@ -2061,16 +2052,7 @@ def create_diffusers_t5_model_from_checkpoint(
diffusers_format_checkpoint = convert_sd3_t5_checkpoint_to_diffusers(checkpoint)
if is_accelerate_available():
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
if model._keys_to_ignore_on_load_unexpected is not None:
for pat in model._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
else:
model.load_state_dict(diffusers_format_checkpoint)

View File

@@ -20,13 +20,15 @@ import os
from array import array
from collections import OrderedDict
from pathlib import Path
from typing import Dict, Iterator, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Union
from zipfile import is_zipfile
import safetensors
import torch
from huggingface_hub import DDUFEntry
from huggingface_hub.utils import EntryNotFoundError
from ..quantizers import DiffusersQuantizer
from ..utils import (
GGUF_FILE_EXTENSION,
SAFE_WEIGHTS_INDEX_NAME,
@@ -55,7 +57,7 @@ _CLASS_REMAPPING_DICT = {
if is_accelerate_available():
from accelerate import infer_auto_device_map
from accelerate.utils import get_balanced_memory, get_max_memory, set_module_tensor_to_device
from accelerate.utils import get_balanced_memory, get_max_memory, offload_weight, set_module_tensor_to_device
# Adapted from `transformers` (see modeling_utils.py)
@@ -132,17 +134,46 @@ def _fetch_remapped_cls_from_config(config, old_class):
return old_class
def _check_archive_and_maybe_raise_error(checkpoint_file, format_list):
"""
Check format of the archive
"""
with safetensors.safe_open(checkpoint_file, framework="pt") as f:
metadata = f.metadata()
if metadata is not None and metadata.get("format") not in format_list:
raise OSError(
f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
"you save your model with the `save_pretrained` method."
)
def _determine_param_device(param_name: str, device_map: Optional[Dict[str, Union[int, str, torch.device]]]):
"""
Find the device of param_name from the device_map.
"""
if device_map is None:
return "cpu"
else:
module_name = param_name
# find next higher level module that is defined in device_map:
# bert.lm_head.weight -> bert.lm_head -> bert -> ''
while len(module_name) > 0 and module_name not in device_map:
module_name = ".".join(module_name.split(".")[:-1])
if module_name == "" and "" not in device_map:
raise ValueError(f"{param_name} doesn't have any device set.")
return device_map[module_name]
def load_state_dict(
checkpoint_file: Union[str, os.PathLike],
variant: Optional[str] = None,
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
disable_mmap: bool = False,
map_location: Union[str, torch.device] = "cpu",
):
"""
Reads a checkpoint file, returning properly formatted errors if they arise.
"""
# TODO: We merge the sharded checkpoints in case we're doing quantization. We can revisit this change
# when refactoring the _merge_sharded_checkpoints() method later.
# TODO: maybe refactor a bit this part where we pass a dict here
if isinstance(checkpoint_file, dict):
return checkpoint_file
try:
@@ -152,19 +183,26 @@ def load_state_dict(
# tensors are loaded on cpu
with dduf_entries[checkpoint_file].as_mmap() as mm:
return safetensors.torch.load(mm)
_check_archive_and_maybe_raise_error(checkpoint_file, format_list=["pt", "flax"])
if disable_mmap:
return safetensors.torch.load(open(checkpoint_file, "rb").read())
else:
return safetensors.torch.load_file(checkpoint_file, device="cpu")
return safetensors.torch.load_file(checkpoint_file, device=map_location)
elif file_extension == GGUF_FILE_EXTENSION:
return load_gguf_checkpoint(checkpoint_file)
else:
extra_args = {}
weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
return torch.load(
checkpoint_file,
map_location="cpu",
**weights_only_kwarg,
)
# mmap can only be used with files serialized with zipfile-based format.
if (
isinstance(checkpoint_file, str)
and map_location != "meta"
and is_torch_version(">=", "2.1.0")
and is_zipfile(checkpoint_file)
and not disable_mmap
):
extra_args = {"mmap": True}
return torch.load(checkpoint_file, map_location=map_location, **weights_only_kwarg, **extra_args)
except Exception as e:
try:
with open(checkpoint_file) as f:
@@ -188,23 +226,24 @@ def load_state_dict(
def load_model_dict_into_meta(
model,
state_dict: OrderedDict,
device: Optional[Union[str, torch.device]] = None,
dtype: Optional[Union[str, torch.dtype]] = None,
model_name_or_path: Optional[str] = None,
hf_quantizer=None,
keep_in_fp32_modules=None,
named_buffers: Optional[Iterator[Tuple[str, torch.Tensor]]] = None,
hf_quantizer: Optional[DiffusersQuantizer] = None,
keep_in_fp32_modules: Optional[List] = None,
device_map: Optional[Dict[str, Union[int, str, torch.device]]] = None,
unexpected_keys: Optional[List[str]] = None,
offload_folder: Optional[Union[str, os.PathLike]] = None,
offload_index: Optional[Dict] = None,
state_dict_index: Optional[Dict] = None,
state_dict_folder: Optional[Union[str, os.PathLike]] = None,
) -> List[str]:
if device is not None and not isinstance(device, (str, torch.device)):
raise ValueError(f"Expected device to have type `str` or `torch.device`, but got {type(device)=}.")
if hf_quantizer is None:
device = device or torch.device("cpu")
dtype = dtype or torch.float32
is_quantized = hf_quantizer is not None
"""
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
params on a `meta` device. It replaces the model params with the data from the `state_dict`
"""
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
is_quantized = hf_quantizer is not None
empty_state_dict = model.state_dict()
unexpected_keys = [param_name for param_name in state_dict if param_name not in empty_state_dict]
for param_name, param in state_dict.items():
if param_name not in empty_state_dict:
@@ -214,21 +253,35 @@ def load_model_dict_into_meta(
# We convert floating dtypes to the `dtype` passed. We also want to keep the buffers/params
# in int/uint/bool and not cast them.
# TODO: revisit cases when param.dtype == torch.float8_e4m3fn
if torch.is_floating_point(param):
if (
keep_in_fp32_modules is not None
and any(
module_to_keep_in_fp32 in param_name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules
)
and dtype == torch.float16
if dtype is not None and torch.is_floating_point(param):
if keep_in_fp32_modules is not None and any(
module_to_keep_in_fp32 in param_name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules
):
param = param.to(torch.float32)
if accepts_dtype:
set_module_kwargs["dtype"] = torch.float32
set_module_kwargs["dtype"] = torch.float32
else:
param = param.to(dtype)
if accepts_dtype:
set_module_kwargs["dtype"] = dtype
set_module_kwargs["dtype"] = dtype
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
old_param = model
splits = param_name.split(".")
for split in splits:
old_param = getattr(old_param, split)
if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)):
old_param = None
if old_param is not None:
if dtype is None:
param = param.to(old_param.dtype)
if old_param.is_contiguous():
param = param.contiguous()
param_device = _determine_param_device(param_name, device_map)
# bnb params are flattened.
# gguf quants have a different shape based on the type of quantization applied
@@ -236,7 +289,9 @@ def load_model_dict_into_meta(
if (
is_quantized
and hf_quantizer.pre_quantized
and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
and hf_quantizer.check_if_quantized_param(
model, param, param_name, state_dict, param_device=param_device
)
):
hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name], param)
else:
@@ -244,35 +299,23 @@ def load_model_dict_into_meta(
raise ValueError(
f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name].shape}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
)
if is_quantized and (
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
if param_device == "disk":
offload_index = offload_weight(param, param_name, offload_folder, offload_index)
elif param_device == "cpu" and state_dict_index is not None:
state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index)
elif is_quantized and (
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=param_device)
):
hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys)
hf_quantizer.create_quantized_param(model, param, param_name, param_device, state_dict, unexpected_keys)
else:
if accepts_dtype:
set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs)
else:
set_module_tensor_to_device(model, param_name, device, value=param)
set_module_tensor_to_device(model, param_name, param_device, value=param, **set_module_kwargs)
if named_buffers is None:
return unexpected_keys
for param_name, param in named_buffers:
if is_quantized and (
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
):
hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys)
else:
if accepts_dtype:
set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs)
else:
set_module_tensor_to_device(model, param_name, device, value=param)
return unexpected_keys
return offload_index, state_dict_index
def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]:
def _load_state_dict_into_model(
model_to_load, state_dict: OrderedDict, assign_to_params_buffers: bool = False
) -> List[str]:
# Convert old format to new format if needed from a PyTorch state_dict
# copy state_dict so _load_from_state_dict can modify it
state_dict = state_dict.copy()
@@ -280,15 +323,19 @@ def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
# so we need to apply the function recursively.
def load(module: torch.nn.Module, prefix: str = ""):
args = (state_dict, prefix, {}, True, [], [], error_msgs)
def load(module: torch.nn.Module, prefix: str = "", assign_to_params_buffers: bool = False):
local_metadata = {}
local_metadata["assign_to_params_buffers"] = assign_to_params_buffers
if assign_to_params_buffers and not is_torch_version(">=", "2.1"):
logger.info("You need to have torch>=2.1 in order to load the model with assign_to_params_buffers=True")
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
module._load_from_state_dict(*args)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + ".")
load(child, prefix + name + ".", assign_to_params_buffers)
load(model_to_load)
load(model_to_load, assign_to_params_buffers=assign_to_params_buffers)
return error_msgs
@@ -343,46 +390,6 @@ def _fetch_index_file(
return index_file
# Adapted from
# https://github.com/bghira/SimpleTuner/blob/cea2457ab063f6dedb9e697830ae68a96be90641/helpers/training/save_hooks.py#L64
def _merge_sharded_checkpoints(
sharded_ckpt_cached_folder, sharded_metadata, dduf_entries: Optional[Dict[str, DDUFEntry]] = None
):
weight_map = sharded_metadata.get("weight_map", None)
if weight_map is None:
raise KeyError("'weight_map' key not found in the shard index file.")
# Collect all unique safetensors files from weight_map
files_to_load = set(weight_map.values())
is_safetensors = all(f.endswith(".safetensors") for f in files_to_load)
merged_state_dict = {}
# Load tensors from each unique file
for file_name in files_to_load:
part_file_path = os.path.join(sharded_ckpt_cached_folder, file_name)
if dduf_entries:
if part_file_path not in dduf_entries:
raise FileNotFoundError(f"Part file {file_name} not found.")
else:
if not os.path.exists(part_file_path):
raise FileNotFoundError(f"Part file {file_name} not found.")
if is_safetensors:
if dduf_entries:
with dduf_entries[part_file_path].as_mmap() as mm:
tensors = safetensors.torch.load(mm)
merged_state_dict.update(tensors)
else:
with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f:
for tensor_key in f.keys():
if tensor_key in weight_map:
merged_state_dict[tensor_key] = f.get_tensor(tensor_key)
else:
merged_state_dict.update(torch.load(part_file_path, weights_only=True, map_location="cpu"))
return merged_state_dict
def _fetch_index_file_legacy(
is_local,
pretrained_model_name_or_path,

View File

@@ -20,10 +20,13 @@ import itertools
import json
import os
import re
import shutil
import tempfile
from collections import OrderedDict
from contextlib import ExitStack, contextmanager
from functools import wraps
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Callable, ContextManager, Dict, List, Optional, Tuple, Type, Union
import safetensors
import torch
@@ -65,16 +68,49 @@ from .model_loading_utils import (
_fetch_index_file,
_fetch_index_file_legacy,
_load_state_dict_into_model,
_merge_sharded_checkpoints,
load_model_dict_into_meta,
load_state_dict,
)
class ContextManagers:
"""
Wrapper for `contextlib.ExitStack` which enters a collection of context managers. Adaptation of `ContextManagers`
in the `fastcore` library.
"""
def __init__(self, context_managers: List[ContextManager]):
self.context_managers = context_managers
self.stack = ExitStack()
def __enter__(self):
for context_manager in self.context_managers:
self.stack.enter_context(context_manager)
def __exit__(self, *args, **kwargs):
self.stack.__exit__(*args, **kwargs)
logger = logging.get_logger(__name__)
_REGEX_SHARD = re.compile(r"(.*?)-\d{5}-of-\d{5}")
TORCH_INIT_FUNCTIONS = {
"uniform_": nn.init.uniform_,
"normal_": nn.init.normal_,
"trunc_normal_": nn.init.trunc_normal_,
"constant_": nn.init.constant_,
"xavier_uniform_": nn.init.xavier_uniform_,
"xavier_normal_": nn.init.xavier_normal_,
"kaiming_uniform_": nn.init.kaiming_uniform_,
"kaiming_normal_": nn.init.kaiming_normal_,
"uniform": nn.init.uniform,
"normal": nn.init.normal,
"xavier_uniform": nn.init.xavier_uniform,
"xavier_normal": nn.init.xavier_normal,
"kaiming_uniform": nn.init.kaiming_uniform,
"kaiming_normal": nn.init.kaiming_normal,
}
if is_torch_version(">=", "1.9.0"):
_LOW_CPU_MEM_USAGE_DEFAULT = True
@@ -84,6 +120,8 @@ else:
if is_accelerate_available():
import accelerate
from accelerate import dispatch_model
from accelerate.utils import load_offloaded_weights, save_offload_index
def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
@@ -159,6 +197,54 @@ def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
return last_tuple[1].dtype
def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""):
"""
Checks if `model_to_load` supports param buffer assignment (such as when loading in empty weights) by first
checking if the model explicitly disables it, then by ensuring that the state dict keys are a subset of the model's
parameters.
"""
if model_to_load.device.type == "meta":
return False
if len([key for key in state_dict if key.startswith(start_prefix)]) == 0:
return False
# Some models explicitly do not support param buffer assignment
if not getattr(model_to_load, "_supports_param_buffer_assignment", True):
logger.debug(
f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower"
)
return False
# If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype
first_key = next(iter(model_to_load.state_dict().keys()))
if start_prefix + first_key in state_dict:
return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype
return False
@contextmanager
def no_init_weights():
"""
Context manager to globally disable weight initialization to speed up loading large models. To do that, all the
torch.nn.init function are all replaced with skip.
"""
def _skip_init(*args, **kwargs):
pass
for name, init_func in TORCH_INIT_FUNCTIONS.items():
setattr(torch.nn.init, name, _skip_init)
try:
yield
finally:
# Restore the original initialization functions
for name, init_func in TORCH_INIT_FUNCTIONS.items():
setattr(torch.nn.init, name, init_func)
class ModelMixin(torch.nn.Module, PushToHubMixin):
r"""
Base class for all models.
@@ -785,7 +871,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
device_map = kwargs.pop("device_map", None)
max_memory = kwargs.pop("max_memory", None)
offload_folder = kwargs.pop("offload_folder", None)
offload_state_dict = kwargs.pop("offload_state_dict", False)
offload_state_dict = kwargs.pop("offload_state_dict", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None)
@@ -862,14 +948,15 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
# The max memory utils require PyTorch >= 1.10 to have torch.cuda.mem_get_info.
raise ValueError("`low_cpu_mem_usage` and `device_map` require PyTorch >= 1.10.")
# Load config if we don't provide a configuration
config_path = pretrained_model_name_or_path
user_agent = {
"diffusers": __version__,
"file_type": "model",
"framework": "pytorch",
}
unused_kwargs = {}
# Load config if we don't provide a configuration
config_path = pretrained_model_name_or_path
# load config
config, unused_kwargs, commit_hash = cls.load_config(
@@ -907,13 +994,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
hf_quantizer = None
if hf_quantizer is not None:
if device_map is not None:
raise NotImplementedError(
"Currently, providing `device_map` is not supported for quantized models. Providing `device_map` as an input will be added in the future."
)
hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map)
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
device_map = hf_quantizer.update_device_map(device_map)
# In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value
@@ -926,9 +1009,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
raise ValueError("`low_cpu_mem_usage` cannot be False or None when using quantization.")
# Check if `_keep_in_fp32_modules` is not None
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
use_keep_in_fp32_modules = cls._keep_in_fp32_modules is not None and (
hf_quantizer is None or getattr(hf_quantizer, "use_keep_in_fp32_modules", False)
)
if use_keep_in_fp32_modules:
keep_in_fp32_modules = cls._keep_in_fp32_modules
if not isinstance(keep_in_fp32_modules, list):
@@ -941,10 +1025,12 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
raise ValueError("`low_cpu_mem_usage` cannot be False when `keep_in_fp32_modules` is True.")
else:
keep_in_fp32_modules = []
#######################################
is_sharded = False
resolved_model_file = None
# Determine if we're loading from a directory of sharded checkpoints.
is_sharded = False
sharded_metadata = None
index_file = None
is_local = os.path.isdir(pretrained_model_name_or_path)
index_file_kwargs = {
@@ -975,9 +1061,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
raise ValueError("Loading of sharded checkpoints is not supported when `from_flax=True`.")
# load model
model_file = None
if from_flax:
model_file = _get_model_file(
resolved_model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=FLAX_WEIGHTS_NAME,
cache_dir=cache_dir,
@@ -995,11 +1080,11 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
# Convert the weights
from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
model = load_flax_checkpoint_in_pytorch_model(model, model_file)
model = load_flax_checkpoint_in_pytorch_model(model, resolved_model_file)
else:
# in the case it is sharded, we have already the index
if is_sharded:
sharded_ckpt_cached_folder, sharded_metadata = _get_checkpoint_shard_files(
resolved_model_file, sharded_metadata = _get_checkpoint_shard_files(
pretrained_model_name_or_path,
index_file,
cache_dir=cache_dir,
@@ -1011,17 +1096,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
subfolder=subfolder or "",
dduf_entries=dduf_entries,
)
# TODO: https://github.com/huggingface/diffusers/issues/10013
if hf_quantizer is not None or dduf_entries:
model_file = _merge_sharded_checkpoints(
sharded_ckpt_cached_folder, sharded_metadata, dduf_entries=dduf_entries
)
logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.")
is_sharded = False
elif use_safetensors and not is_sharded:
elif use_safetensors:
try:
model_file = _get_model_file(
resolved_model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
cache_dir=cache_dir,
@@ -1044,8 +1121,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
"Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
)
if model_file is None and not is_sharded:
model_file = _get_model_file(
if resolved_model_file is None and not is_sharded:
resolved_model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=_add_variant(WEIGHTS_NAME, variant),
cache_dir=cache_dir,
@@ -1060,157 +1137,104 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
dduf_entries=dduf_entries,
)
if low_cpu_mem_usage:
# Instantiate model with empty weights
with accelerate.init_empty_weights():
model = cls.from_config(config, **unused_kwargs)
if not isinstance(resolved_model_file, list):
resolved_model_file = [resolved_model_file]
if hf_quantizer is not None:
hf_quantizer.preprocess_model(
model=model, device_map=device_map, keep_in_fp32_modules=keep_in_fp32_modules
)
# if device_map is None, load the state dict and move the params from meta device to the cpu
if device_map is None and not is_sharded:
# `torch.cuda.current_device()` is fine here when `hf_quantizer` is not None.
# It would error out during the `validate_environment()` call above in the absence of cuda.
if hf_quantizer is None:
param_device = "cpu"
# TODO (sayakpaul, SunMarc): remove this after model loading refactor
else:
param_device = torch.device(torch.cuda.current_device())
state_dict = load_state_dict(
model_file, variant=variant, dduf_entries=dduf_entries, disable_mmap=disable_mmap
)
model._convert_deprecated_attention_blocks(state_dict)
# move the params from meta device to cpu
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
if hf_quantizer is not None:
missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix="")
if len(missing_keys) > 0:
raise ValueError(
f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
" those weights or else make sure your checkpoint file is correct."
)
named_buffers = model.named_buffers()
unexpected_keys = load_model_dict_into_meta(
model,
state_dict,
device=param_device,
dtype=torch_dtype,
model_name_or_path=pretrained_model_name_or_path,
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
named_buffers=named_buffers,
)
if cls._keys_to_ignore_on_load_unexpected is not None:
for pat in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)
else: # else let accelerate handle loading and dispatching.
# Load weights and dispatch according to the device_map
# by default the device_map is None and the weights are loaded on the CPU
device_map = _determine_device_map(
model, device_map, max_memory, torch_dtype, keep_in_fp32_modules, hf_quantizer
)
if device_map is None and is_sharded:
# we load the parameters on the cpu
device_map = {"": "cpu"}
try:
accelerate.load_checkpoint_and_dispatch(
model,
model_file if not is_sharded else index_file,
device_map,
max_memory=max_memory,
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
strict=True,
)
except AttributeError as e:
# When using accelerate loading, we do not have the ability to load the state
# dict and rename the weight names manually. Additionally, accelerate skips
# torch loading conventions and directly writes into `module.{_buffers, _parameters}`
# (which look like they should be private variables?), so we can't use the standard hooks
# to rename parameters on load. We need to mimic the original weight names so the correct
# attributes are available. After we have loaded the weights, we convert the deprecated
# names to the new non-deprecated names. Then we _greatly encourage_ the user to convert
# the weights so we don't have to do this again.
if "'Attention' object has no attribute" in str(e):
logger.warning(
f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
" was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
" names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
" so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint,"
" please also re-upload it or open a PR on the original repository."
)
model._temp_convert_self_to_deprecated_attention_blocks()
accelerate.load_checkpoint_and_dispatch(
model,
model_file if not is_sharded else index_file,
device_map,
max_memory=max_memory,
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
strict=True,
)
model._undo_temp_convert_self_to_deprecated_attention_blocks()
else:
raise e
loading_info = {
"missing_keys": [],
"unexpected_keys": [],
"mismatched_keys": [],
"error_msgs": [],
}
else:
model = cls.from_config(config, **unused_kwargs)
state_dict = load_state_dict(
model_file, variant=variant, dduf_entries=dduf_entries, disable_mmap=disable_mmap
# set dtype to instantiate the model under:
# 1. If torch_dtype is not None, we use that dtype
# 2. If torch_dtype is float8, we don't use _set_default_torch_dtype and we downcast after loading the model
dtype_orig = None
if torch_dtype is not None and not torch_dtype == getattr(torch, "float8_e4m3fn", None):
if not isinstance(torch_dtype, torch.dtype):
raise ValueError(
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
)
model._convert_deprecated_attention_blocks(state_dict)
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
model,
state_dict,
model_file,
pretrained_model_name_or_path,
ignore_mismatched_sizes=ignore_mismatched_sizes,
)
init_contexts = [no_init_weights()]
loading_info = {
"missing_keys": missing_keys,
"unexpected_keys": unexpected_keys,
"mismatched_keys": mismatched_keys,
"error_msgs": error_msgs,
}
if low_cpu_mem_usage:
init_contexts.append(accelerate.init_empty_weights())
with ContextManagers(init_contexts):
model = cls.from_config(config, **unused_kwargs)
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)
state_dict = None
if not is_sharded:
# Time to load the checkpoint
state_dict = load_state_dict(resolved_model_file[0], disable_mmap=disable_mmap, dduf_entries=dduf_entries)
# We only fix it for non sharded checkpoints as we don't need it yet for sharded one.
model._fix_state_dict_keys_on_load(state_dict)
if is_sharded:
loaded_keys = sharded_metadata["all_checkpoint_keys"]
else:
loaded_keys = list(state_dict.keys())
if hf_quantizer is not None:
hf_quantizer.preprocess_model(
model=model, device_map=device_map, keep_in_fp32_modules=keep_in_fp32_modules
)
print(keep_in_fp32_modules)
# Now that the model is loaded, we can determine the device_map
device_map = _determine_device_map(
model, device_map, max_memory, torch_dtype, keep_in_fp32_modules, hf_quantizer
)
if hf_quantizer is not None:
hf_quantizer.validate_environment(device_map=device_map)
(
model,
missing_keys,
unexpected_keys,
mismatched_keys,
offload_index,
error_msgs,
) = cls._load_pretrained_model(
model,
state_dict,
resolved_model_file,
pretrained_model_name_or_path,
loaded_keys,
ignore_mismatched_sizes=ignore_mismatched_sizes,
low_cpu_mem_usage=low_cpu_mem_usage,
device_map=device_map,
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
dduf_entries=dduf_entries,
)
loading_info = {
"missing_keys": missing_keys,
"unexpected_keys": unexpected_keys,
"mismatched_keys": mismatched_keys,
"error_msgs": error_msgs,
}
# Dispatch model with hooks on all devices if necessary
if device_map is not None:
device_map_kwargs = {
"device_map": device_map,
"offload_dir": offload_folder,
"offload_index": offload_index,
}
dispatch_model(model, **device_map_kwargs)
if hf_quantizer is not None:
hf_quantizer.postprocess_model(model)
model.hf_quantizer = hf_quantizer
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
raise ValueError(
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
)
# When using `use_keep_in_fp32_modules` if we do a global `to()` here, then we will
# completely lose the effectivity of `use_keep_in_fp32_modules`.
elif torch_dtype is not None and hf_quantizer is None and not use_keep_in_fp32_modules:
if (
torch_dtype is not None
and torch_dtype == getattr(torch, "float8_e4m3fn", None)
and hf_quantizer is None
and not use_keep_in_fp32_modules
):
model = model.to(torch_dtype)
if hf_quantizer is not None:
@@ -1222,6 +1246,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
# Set model in evaluation mode to deactivate DropOut modules by default
model.eval()
if output_loading_info:
return model, loading_info
@@ -1332,54 +1357,127 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
cls,
model,
state_dict: OrderedDict,
resolved_archive_file,
resolved_model_file: List[str],
pretrained_model_name_or_path: Union[str, os.PathLike],
loaded_keys: List[str],
ignore_mismatched_sizes: bool = False,
assign_to_params_buffers: bool = False,
hf_quantizer: Optional[DiffusersQuantizer] = None,
low_cpu_mem_usage: bool = True,
dtype: Optional[Union[str, torch.dtype]] = None,
keep_in_fp32_modules: Optional[List[str]] = None,
device_map: Dict[str, Union[int, str, torch.device]] = None,
offload_state_dict: Optional[bool] = None,
offload_folder: Optional[Union[str, os.PathLike]] = None,
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
):
# Retrieve missing & unexpected_keys
model_state_dict = model.state_dict()
loaded_keys = list(state_dict.keys())
expected_keys = list(model_state_dict.keys())
original_loaded_keys = loaded_keys
missing_keys = list(set(expected_keys) - set(loaded_keys))
if hf_quantizer is not None:
missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix="")
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
# Some models may have keys that are not in the state by design, removing them before needlessly warning
# the user.
if cls._keys_to_ignore_on_load_unexpected is not None:
for pat in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
# Make sure we are able to load base models as well as derived models (with heads)
model_to_load = model
mismatched_keys = []
def _find_mismatched_keys(
state_dict,
model_state_dict,
loaded_keys,
ignore_mismatched_sizes,
):
mismatched_keys = []
if ignore_mismatched_sizes:
for checkpoint_key in loaded_keys:
model_key = checkpoint_key
assign_to_params_buffers = None
error_msgs = []
if (
model_key in model_state_dict
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
):
mismatched_keys.append(
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
)
del state_dict[checkpoint_key]
return mismatched_keys
# Deal with offload
if device_map is not None and "disk" in device_map.values():
if offload_folder is None:
raise ValueError(
"The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`"
" for them. Alternatively, make sure you have `safetensors` installed if the model you are using"
" offers the weights in this format."
)
if offload_folder is not None:
os.makedirs(offload_folder, exist_ok=True)
if offload_state_dict is None:
offload_state_dict = True
offload_index = {} if device_map is not None and "disk" in device_map.values() else None
if offload_state_dict:
state_dict_folder = tempfile.mkdtemp()
state_dict_index = {}
else:
state_dict_folder = None
state_dict_index = None
if state_dict is not None:
# Whole checkpoint
mismatched_keys = _find_mismatched_keys(
# load_state_dict will manage the case where we pass a dict instead of a file
# if state dict is not None, it means that we don't need to read the files from resolved_model_file also
resolved_model_file = [state_dict]
if len(resolved_model_file) > 1:
resolved_model_file = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards")
for shard_file in resolved_model_file:
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
def _find_mismatched_keys(
state_dict,
model_state_dict,
original_loaded_keys,
loaded_keys,
ignore_mismatched_sizes,
):
mismatched_keys = []
if ignore_mismatched_sizes:
for checkpoint_key in loaded_keys:
model_key = checkpoint_key
# If the checkpoint is sharded, we may not have the key here.
if checkpoint_key not in state_dict:
continue
if (
model_key in model_state_dict
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
):
mismatched_keys.append(
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
)
del state_dict[checkpoint_key]
return mismatched_keys
mismatched_keys += _find_mismatched_keys(
state_dict,
model_state_dict,
loaded_keys,
ignore_mismatched_sizes,
)
error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
if low_cpu_mem_usage:
offload_index, state_dict_index = load_model_dict_into_meta(
model,
state_dict,
device_map=device_map,
dtype=dtype,
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys,
offload_folder=offload_folder,
offload_index=offload_index,
state_dict_index=state_dict_index,
state_dict_folder=state_dict_folder,
)
else:
if assign_to_params_buffers is None:
assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict)
error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)
if offload_index is not None and len(offload_index) > 0:
save_offload_index(offload_index, offload_folder)
offload_index = None
if offload_state_dict:
load_offloaded_weights(model, state_dict_index, state_dict_folder)
shutil.rmtree(state_dict_folder)
if len(error_msgs) > 0:
error_msg = "\n\t".join(error_msgs)
@@ -1391,17 +1489,11 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
" or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
" BertForPreTraining model).\n- This IS NOT expected if you are initializing"
f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
" identical (initializing a BertForSequenceClassification model from a"
" BertForSequenceClassification model)."
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)
else:
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
if len(missing_keys) > 0:
logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
@@ -1429,7 +1521,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
" able to use it for predictions and inference."
)
return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
return model, missing_keys, unexpected_keys, mismatched_keys, offload_index, error_msgs
@classmethod
def _get_signature_keys(cls, obj):
@@ -1470,6 +1562,33 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
modules_to_check += list(module.children())
return list(_no_split_modules)
@classmethod
def _set_default_torch_dtype(cls, dtype: torch.dtype) -> torch.dtype:
"""
Change the default dtype and return the previous one. This is needed when wanting to instantiate the model
under specific dtype.
Args:
dtype (`torch.dtype`):
a floating dtype to set to.
Returns:
`torch.dtype`: the original `dtype` that can be used to restore `torch.set_default_dtype(dtype)` if it was
modified. If it wasn't, returns `None`.
Note `set_default_dtype` currently only works with floating-point types and asserts if for example,
`torch.int64` is passed. So if a non-float `dtype` is passed this functions will throw an exception.
"""
if not dtype.is_floating_point:
raise ValueError(
f"Can't instantiate {cls.__name__} model under dtype={dtype} since it is not a floating point dtype"
)
logger.info(f"Instantiating {cls.__name__} model under default dtype {dtype}.")
dtype_orig = torch.get_default_dtype()
torch.set_default_dtype(dtype)
return dtype_orig
@property
def device(self) -> torch.device:
"""
@@ -1585,7 +1704,13 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
f"use a module that supports gradient checkpointing by creating a boolean attribute `gradient_checkpointing`."
)
def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None:
def _fix_state_dict_keys_on_load(self, state_dict: OrderedDict) -> None:
"""
This function fix the state dict of the model to take into account some changes that were made in the model
architecture:
- deprecated attention blocks (happened before we introduced sharded checkpoint,
so this is why we apply this method only when loading non sharded checkpoints for now)
"""
deprecated_attention_block_paths = []
def recursive_find_attn_block(name, module):
@@ -1628,56 +1753,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight")
if f"{path}.proj_attn.bias" in state_dict:
state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
def _temp_convert_self_to_deprecated_attention_blocks(self) -> None:
deprecated_attention_block_modules = []
def recursive_find_attn_block(module):
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
deprecated_attention_block_modules.append(module)
for sub_module in module.children():
recursive_find_attn_block(sub_module)
recursive_find_attn_block(self)
for module in deprecated_attention_block_modules:
module.query = module.to_q
module.key = module.to_k
module.value = module.to_v
module.proj_attn = module.to_out[0]
# We don't _have_ to delete the old attributes, but it's helpful to ensure
# that _all_ the weights are loaded into the new attributes and we're not
# making an incorrect assumption that this model should be converted when
# it really shouldn't be.
del module.to_q
del module.to_k
del module.to_v
del module.to_out
def _undo_temp_convert_self_to_deprecated_attention_blocks(self) -> None:
deprecated_attention_block_modules = []
def recursive_find_attn_block(module) -> None:
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
deprecated_attention_block_modules.append(module)
for sub_module in module.children():
recursive_find_attn_block(sub_module)
recursive_find_attn_block(self)
for module in deprecated_attention_block_modules:
module.to_q = module.query
module.to_k = module.key
module.to_v = module.value
module.to_out = nn.ModuleList([module.proj_attn, nn.Dropout(module.dropout)])
del module.query
del module.key
del module.value
del module.proj_attn
return state_dict
class LegacyModelMixin(ModelMixin):

View File

@@ -280,9 +280,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
act_fn="silu_fp32",
)
self.text_embedding_padding = nn.Parameter(
torch.randn(text_len + text_len_t5, cross_attention_dim, dtype=torch.float32)
)
self.text_embedding_padding = nn.Parameter(torch.randn(text_len + text_len_t5, cross_attention_dim))
self.pos_embed = PatchEmbed(
height=sample_size,

View File

@@ -693,7 +693,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
device_map = kwargs.pop("device_map", None)
max_memory = kwargs.pop("max_memory", None)
offload_folder = kwargs.pop("offload_folder", None)
offload_state_dict = kwargs.pop("offload_state_dict", False)
offload_state_dict = kwargs.pop("offload_state_dict", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
variant = kwargs.pop("variant", None)
dduf_file = kwargs.pop("dduf_file", None)

View File

@@ -235,18 +235,16 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer):
torch_dtype = torch.float16
return torch_dtype
# (sayakpaul): I think it could be better to disable custom `device_map`s
# for the first phase of the integration in the interest of simplicity.
# Commenting this for discussions on the PR.
# def update_device_map(self, device_map):
# if device_map is None:
# device_map = {"": torch.cuda.current_device()}
# logger.info(
# "The device_map was not initialized. "
# "Setting device_map to {'':torch.cuda.current_device()}. "
# "If you want to use the model for inference, please set device_map ='auto' "
# )
# return device_map
def update_device_map(self, device_map):
if device_map is None:
device_map = {"": f"cuda:{torch.cuda.current_device()}"}
logger.info(
"The device_map was not initialized. "
"Setting device_map to {"
": f`cuda:{torch.cuda.current_device()}`}. "
"If you want to use the model for inference, please set device_map ='auto' "
)
return device_map
def _process_model_before_weight_loading(
self,
@@ -289,9 +287,9 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer):
model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config
)
model.config.quantization_config = self.quantization_config
model.is_loaded_in_4bit = True
def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs):
model.is_loaded_in_4bit = True
model.is_4bit_serializable = self.is_serializable
return model
@@ -400,16 +398,17 @@ class BnB8BitDiffusersQuantizer(DiffusersQuantizer):
torch_dtype = torch.float16
return torch_dtype
# # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.update_device_map
# def update_device_map(self, device_map):
# if device_map is None:
# device_map = {"": torch.cuda.current_device()}
# logger.info(
# "The device_map was not initialized. "
# "Setting device_map to {'':torch.cuda.current_device()}. "
# "If you want to use the model for inference, please set device_map ='auto' "
# )
# return device_map
# Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.update_device_map
def update_device_map(self, device_map):
if device_map is None:
device_map = {"": f"cuda:{torch.cuda.current_device()}"}
logger.info(
"The device_map was not initialized. "
"Setting device_map to {"
": f`cuda:{torch.cuda.current_device()}`}. "
"If you want to use the model for inference, please set device_map ='auto' "
)
return device_map
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
if target_dtype != torch.int8:
@@ -493,11 +492,10 @@ class BnB8BitDiffusersQuantizer(DiffusersQuantizer):
# Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer._process_model_after_weight_loading with 4bit->8bit
def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs):
model.is_loaded_in_8bit = True
model.is_8bit_serializable = self.is_serializable
return model
# Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer._process_model_before_weight_loading
# Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer._process_model_before_weight_loading with 4bit->8bit
def _process_model_before_weight_loading(
self,
model: "ModelMixin",
@@ -539,6 +537,7 @@ class BnB8BitDiffusersQuantizer(DiffusersQuantizer):
model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config
)
model.config.quantization_config = self.quantization_config
model.is_loaded_in_8bit = True
@property
# Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.is_serializable

View File

@@ -338,22 +338,6 @@ def _get_model_file(
) from e
# Adapted from
# https://github.com/huggingface/transformers/blob/1360801a69c0b169e3efdbb0cd05d9a0e72bfb70/src/transformers/utils/hub.py#L976
# Differences are in parallelization of shard downloads and checking if shards are present.
def _check_if_shards_exist_locally(local_dir, subfolder, original_shard_filenames):
shards_path = os.path.join(local_dir, subfolder)
shard_filenames = [os.path.join(shards_path, f) for f in original_shard_filenames]
for shard_file in shard_filenames:
if not os.path.exists(shard_file):
raise ValueError(
f"{shards_path} does not appear to have a file named {shard_file} which is "
"required according to the checkpoint index."
)
def _get_checkpoint_shard_files(
pretrained_model_name_or_path,
index_filename,
@@ -396,13 +380,22 @@ def _get_checkpoint_shard_files(
shards_path = os.path.join(pretrained_model_name_or_path, subfolder)
# First, let's deal with local folder.
if os.path.isdir(pretrained_model_name_or_path):
_check_if_shards_exist_locally(
pretrained_model_name_or_path, subfolder=subfolder, original_shard_filenames=original_shard_filenames
)
return shards_path, sharded_metadata
elif dduf_entries:
return shards_path, sharded_metadata
if os.path.isdir(pretrained_model_name_or_path) or dduf_entries:
shard_filenames = [os.path.join(shards_path, f) for f in original_shard_filenames]
for shard_file in shard_filenames:
if dduf_entries:
if shard_file not in dduf_entries:
raise FileNotFoundError(
f"{shards_path} does not appear to have a file named {shard_file} which is "
"required according to the checkpoint index."
)
else:
if not os.path.exists(shard_file):
raise FileNotFoundError(
f"{shards_path} does not appear to have a file named {shard_file} which is "
"required according to the checkpoint index."
)
return shard_filenames, sharded_metadata
# At this stage pretrained_model_name_or_path is a model identifier on the Hub
allow_patterns = original_shard_filenames
@@ -444,7 +437,9 @@ def _get_checkpoint_shard_files(
" again after checking your internet connection."
) from e
return cached_folder, sharded_metadata
cached_filenames = [os.path.join(cached_folder, f) for f in original_shard_filenames]
return cached_filenames, sharded_metadata
def _check_legacy_sharding_variant_format(folder: str = None, filenames: List[str] = None, variant: str = None):

View File

@@ -37,7 +37,7 @@ from huggingface_hub.utils import is_jinja_available
from parameterized import parameterized
from requests.exceptions import HTTPError
from diffusers.models import UNet2DConditionModel
from diffusers.models import SD3Transformer2DModel, UNet2DConditionModel
from diffusers.models.attention_processor import (
AttnProcessor,
AttnProcessor2_0,
@@ -200,12 +200,12 @@ class ModelUtilsTest(unittest.TestCase):
def tearDown(self):
super().tearDown()
def test_accelerate_loading_error_message(self):
with self.assertRaises(ValueError) as error_context:
def test_missing_key_loading_warning_message(self):
with self.assertLogs("diffusers.models.modeling_utils", level="WARNING") as logs:
UNet2DConditionModel.from_pretrained("hf-internal-testing/stable-diffusion-broken", subfolder="unet")
# make sure that error message states what keys are missing
assert "conv_out.bias" in str(error_context.exception)
assert "conv_out.bias" in " ".join(logs.output)
@parameterized.expand(
[
@@ -334,6 +334,58 @@ class ModelUtilsTest(unittest.TestCase):
assert model.config.in_channels == 9
@require_torch_gpu
def test_keep_modules_in_fp32(self):
r"""
A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32 when we load the model in fp16/bf16
Also ensures if inference works.
"""
fp32_modules = SD3Transformer2DModel._keep_in_fp32_modules
for torch_dtype in [torch.bfloat16, torch.float16]:
SD3Transformer2DModel._keep_in_fp32_modules = ["proj_out"]
model = SD3Transformer2DModel.from_pretrained(
"hf-internal-testing/tiny-sd3-pipe", subfolder="transformer", torch_dtype=torch_dtype
).to(torch_device)
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
if name in model._keep_in_fp32_modules:
self.assertTrue(module.weight.dtype == torch.float32)
else:
self.assertTrue(module.weight.dtype == torch_dtype)
def get_dummy_inputs():
batch_size = 2
num_channels = 4
height = width = embedding_dim = 32
pooled_embedding_dim = embedding_dim * 2
sequence_length = 154
hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
pooled_prompt_embeds = torch.randn((batch_size, pooled_embedding_dim)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"pooled_projections": pooled_prompt_embeds,
"timestep": timestep,
}
# test if inference works.
with torch.no_grad() and torch.amp.autocast(torch_device, dtype=torch_dtype):
input_dict_for_transformer = get_dummy_inputs()
model_inputs = {
k: v.to(device=torch_device) for k, v in input_dict_for_transformer.items() if not isinstance(v, bool)
}
model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs})
_ = model(**model_inputs)
SD3Transformer2DModel._keep_in_fp32_modules = fp32_modules
class UNetTesterMixin:
def test_forward_with_norm_groups(self):

View File

@@ -136,7 +136,7 @@ class BnB4BitBasicTests(Base4bitTests):
bnb_4bit_compute_dtype=torch.float16,
)
self.model_4bit = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=nf4_config
self.model_name, subfolder="transformer", quantization_config=nf4_config, device_map=torch_device
)
def tearDown(self):
@@ -202,7 +202,7 @@ class BnB4BitBasicTests(Base4bitTests):
bnb_4bit_compute_dtype=torch.float16,
)
model = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=nf4_config
self.model_name, subfolder="transformer", quantization_config=nf4_config, device_map=torch_device
)
for name, module in model.named_modules():
@@ -327,7 +327,7 @@ class BnB4BitBasicTests(Base4bitTests):
with tempfile.TemporaryDirectory() as tmpdirname:
nf4_config = BitsAndBytesConfig(load_in_4bit=True)
model_4bit = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=nf4_config
self.model_name, subfolder="transformer", quantization_config=nf4_config, device_map=torch_device
)
model_4bit.save_pretrained(tmpdirname)
del model_4bit
@@ -362,7 +362,7 @@ class BnB4BitTrainingTests(Base4bitTests):
bnb_4bit_compute_dtype=torch.float16,
)
self.model_4bit = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=nf4_config
self.model_name, subfolder="transformer", quantization_config=nf4_config, device_map=torch_device
)
def test_training(self):
@@ -410,7 +410,7 @@ class SlowBnb4BitTests(Base4bitTests):
bnb_4bit_compute_dtype=torch.float16,
)
model_4bit = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=nf4_config
self.model_name, subfolder="transformer", quantization_config=nf4_config, device_map=torch_device
)
self.pipeline_4bit = DiffusionPipeline.from_pretrained(
self.model_name, transformer=model_4bit, torch_dtype=torch.float16
@@ -472,7 +472,7 @@ class SlowBnb4BitTests(Base4bitTests):
bnb_4bit_compute_dtype=torch.float16,
)
model_4bit = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=nf4_config
self.model_name, subfolder="transformer", quantization_config=nf4_config, device_map=torch_device
)
logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
@@ -502,6 +502,7 @@ class SlowBnb4BitTests(Base4bitTests):
subfolder="transformer",
quantization_config=transformer_nf4_config,
torch_dtype=torch.float16,
device_map=torch_device,
)
text_encoder_3_nf4_config = BnbConfig(
load_in_4bit=True,
@@ -513,6 +514,7 @@ class SlowBnb4BitTests(Base4bitTests):
subfolder="text_encoder_3",
quantization_config=text_encoder_3_nf4_config,
torch_dtype=torch.float16,
device_map=torch_device,
)
# CUDA device placement works.
pipeline_4bit = DiffusionPipeline.from_pretrained(
@@ -527,6 +529,94 @@ class SlowBnb4BitTests(Base4bitTests):
del pipeline_4bit
def test_device_map(self):
"""
Test if the quantized model is working properly with "auto".
cpu/disk offloading as well doesn't work with bnb.
"""
def get_dummy_tensor_inputs(device=None, seed: int = 0):
batch_size = 1
num_latent_channels = 4
num_image_channels = 3
height = width = 4
sequence_length = 48
embedding_dim = 32
torch.manual_seed(seed)
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(
device, dtype=torch.bfloat16
)
torch.manual_seed(seed)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(
device, dtype=torch.bfloat16
)
torch.manual_seed(seed)
pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16)
torch.manual_seed(seed)
text_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16)
torch.manual_seed(seed)
image_ids = torch.randn((height * width, num_image_channels)).to(device, dtype=torch.bfloat16)
timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"pooled_projections": pooled_prompt_embeds,
"txt_ids": text_ids,
"img_ids": image_ids,
"timestep": timestep,
}
inputs = get_dummy_tensor_inputs(torch_device)
expected_slice = np.array(
[0.47070312, 0.00390625, -0.03662109, -0.19628906, -0.53125, 0.5234375, -0.17089844, -0.59375, 0.578125]
)
# non sharded
quantization_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16
)
quantized_model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-pipe",
subfolder="transformer",
quantization_config=quantization_config,
device_map="auto",
torch_dtype=torch.bfloat16,
)
weight = quantized_model.transformer_blocks[0].ff.net[2].weight
self.assertTrue(isinstance(weight, bnb.nn.modules.Params4bit))
output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
# sharded
quantization_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16
)
quantized_model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-sharded",
subfolder="transformer",
quantization_config=quantization_config,
device_map="auto",
torch_dtype=torch.bfloat16,
)
weight = quantized_model.transformer_blocks[0].ff.net[2].weight
self.assertTrue(isinstance(weight, bnb.nn.modules.Params4bit))
output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
@require_transformers_version_greater("4.44.0")
class SlowBnb4BitFluxTests(Base4bitTests):
@@ -610,7 +700,10 @@ class BaseBnb4BitSerializationTests(Base4bitTests):
bnb_4bit_compute_dtype=torch.bfloat16,
)
model_0 = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=self.quantization_config
self.model_name,
subfolder="transformer",
quantization_config=self.quantization_config,
device_map=torch_device,
)
self.assertTrue("_pre_quantization_dtype" in model_0.config)
with tempfile.TemporaryDirectory() as tmpdirname:

View File

@@ -138,7 +138,7 @@ class BnB8bitBasicTests(Base8bitTests):
)
mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
self.model_8bit = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=mixed_int8_config
self.model_name, subfolder="transformer", quantization_config=mixed_int8_config, device_map=torch_device
)
def tearDown(self):
@@ -200,7 +200,7 @@ class BnB8bitBasicTests(Base8bitTests):
mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
model = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=mixed_int8_config
self.model_name, subfolder="transformer", quantization_config=mixed_int8_config, device_map=torch_device
)
for name, module in model.named_modules():
@@ -242,7 +242,7 @@ class BnB8bitBasicTests(Base8bitTests):
"""
config = BitsAndBytesConfig(load_in_8bit=True, llm_int8_skip_modules=["proj_out"])
model_8bit = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=config
self.model_name, subfolder="transformer", quantization_config=config, device_map=torch_device
)
linear = get_some_linear_layer(model_8bit)
self.assertTrue(linear.weight.dtype == torch.int8)
@@ -319,6 +319,7 @@ class Bnb8bitDeviceTests(Base8bitTests):
"Efficient-Large-Model/Sana_1600M_4Kpx_BF16_diffusers",
subfolder="transformer",
quantization_config=mixed_int8_config,
device_map=torch_device,
)
def tearDown(self):
@@ -343,7 +344,7 @@ class BnB8bitTrainingTests(Base8bitTests):
mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
self.model_8bit = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=mixed_int8_config
self.model_name, subfolder="transformer", quantization_config=mixed_int8_config, device_map=torch_device
)
def test_training(self):
@@ -387,7 +388,7 @@ class SlowBnb8bitTests(Base8bitTests):
mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
model_8bit = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=mixed_int8_config
self.model_name, subfolder="transformer", quantization_config=mixed_int8_config, device_map=torch_device
)
self.pipeline_8bit = DiffusionPipeline.from_pretrained(
self.model_name, transformer=model_8bit, torch_dtype=torch.float16
@@ -415,7 +416,10 @@ class SlowBnb8bitTests(Base8bitTests):
def test_model_cpu_offload_raises_warning(self):
model_8bit = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=BitsAndBytesConfig(load_in_8bit=True)
self.model_name,
subfolder="transformer",
quantization_config=BitsAndBytesConfig(load_in_8bit=True),
device_map=torch_device,
)
pipeline_8bit = DiffusionPipeline.from_pretrained(
self.model_name, transformer=model_8bit, torch_dtype=torch.float16
@@ -430,7 +434,10 @@ class SlowBnb8bitTests(Base8bitTests):
def test_moving_to_cpu_throws_warning(self):
model_8bit = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=BitsAndBytesConfig(load_in_8bit=True)
self.model_name,
subfolder="transformer",
quantization_config=BitsAndBytesConfig(load_in_8bit=True),
device_map=torch_device,
)
logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
logger.setLevel(30)
@@ -483,6 +490,7 @@ class SlowBnb8bitTests(Base8bitTests):
subfolder="transformer",
quantization_config=transformer_8bit_config,
torch_dtype=torch.float16,
device_map=torch_device,
)
text_encoder_3_8bit_config = BnbConfig(load_in_8bit=True)
text_encoder_3_8bit = T5EncoderModel.from_pretrained(
@@ -490,6 +498,7 @@ class SlowBnb8bitTests(Base8bitTests):
subfolder="text_encoder_3",
quantization_config=text_encoder_3_8bit_config,
torch_dtype=torch.float16,
device_map=torch_device,
)
# CUDA device placement works.
pipeline_8bit = DiffusionPipeline.from_pretrained(
@@ -504,6 +513,99 @@ class SlowBnb8bitTests(Base8bitTests):
del pipeline_8bit
def test_device_map(self):
"""
Test if the quantized model is working properly with "auto"
pu/disk offloading doesn't work with bnb.
"""
def get_dummy_tensor_inputs(device=None, seed: int = 0):
batch_size = 1
num_latent_channels = 4
num_image_channels = 3
height = width = 4
sequence_length = 48
embedding_dim = 32
torch.manual_seed(seed)
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(
device, dtype=torch.bfloat16
)
torch.manual_seed(seed)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(
device, dtype=torch.bfloat16
)
torch.manual_seed(seed)
pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16)
torch.manual_seed(seed)
text_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16)
torch.manual_seed(seed)
image_ids = torch.randn((height * width, num_image_channels)).to(device, dtype=torch.bfloat16)
timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"pooled_projections": pooled_prompt_embeds,
"txt_ids": text_ids,
"img_ids": image_ids,
"timestep": timestep,
}
inputs = get_dummy_tensor_inputs(torch_device)
expected_slice = np.array(
[
0.33789062,
-0.04736328,
-0.00256348,
-0.23144531,
-0.49804688,
0.4375,
-0.15429688,
-0.65234375,
0.44335938,
]
)
# non sharded
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
quantized_model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-pipe",
subfolder="transformer",
quantization_config=quantization_config,
device_map="auto",
torch_dtype=torch.bfloat16,
)
weight = quantized_model.transformer_blocks[0].ff.net[2].weight
self.assertTrue(isinstance(weight, bnb.nn.modules.Int8Params))
output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
# sharded
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
quantized_model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-sharded",
subfolder="transformer",
quantization_config=quantization_config,
device_map="auto",
torch_dtype=torch.bfloat16,
)
weight = quantized_model.transformer_blocks[0].ff.net[2].weight
self.assertTrue(isinstance(weight, bnb.nn.modules.Int8Params))
output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
@require_transformers_version_greater("4.44.0")
class SlowBnb8bitFluxTests(Base8bitTests):
@@ -579,7 +681,7 @@ class BaseBnb8bitSerializationTests(Base8bitTests):
load_in_8bit=True,
)
self.model_0 = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=quantization_config
self.model_name, subfolder="transformer", quantization_config=quantization_config, device_map=torch_device
)
def tearDown(self):

View File

@@ -34,6 +34,7 @@ from diffusers.utils.testing_utils import (
is_torch_available,
is_torchao_available,
nightly,
numpy_cosine_similarity_distance,
require_torch,
require_torch_gpu,
require_torchao_version_greater_or_equal,
@@ -282,9 +283,6 @@ class TorchAoTest(unittest.TestCase):
self.assertEqual(weight.quant_max, 15)
def test_device_map(self):
# Note: We were not checking if the weight tensor's were AffineQuantizedTensor's before. If we did
# it would have errored out. Now, we do. So, device_map basically never worked with or without
# sharded checkpoints. This will need to be supported in the future (TODO(aryan))
"""
Test if the quantized model int4 weight-only is working properly with "auto" and custom device maps.
The custom device map performs cpu/disk offloading as well. Also verifies that the device map is
@@ -301,54 +299,73 @@ class TorchAoTest(unittest.TestCase):
}
device_maps = ["auto", custom_device_map_dict]
# inputs = self.get_dummy_tensor_inputs(torch_device)
# expected_slice = np.array([0.3457, -0.0366, 0.0105, -0.2275, -0.4941, 0.4395, -0.166, -0.6641, 0.4375])
inputs = self.get_dummy_tensor_inputs(torch_device)
# requires with different expected slices since models are different due to offload (we don't quantize modules offloaded to cpu/disk)
expected_slice_auto = np.array(
[
0.34179688,
-0.03613281,
0.01428223,
-0.22949219,
-0.49609375,
0.4375,
-0.1640625,
-0.66015625,
0.43164062,
]
)
expected_slice_offload = np.array(
[0.34375, -0.03515625, 0.0123291, -0.22753906, -0.49414062, 0.4375, -0.16308594, -0.66015625, 0.43554688]
)
for device_map in device_maps:
# device_map_to_compare = {"": 0} if device_map == "auto" else device_map
if device_map == "auto":
expected_slice = expected_slice_auto
else:
expected_slice = expected_slice_offload
with tempfile.TemporaryDirectory() as offload_folder:
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
quantized_model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-pipe",
subfolder="transformer",
quantization_config=quantization_config,
device_map=device_map,
torch_dtype=torch.bfloat16,
offload_folder=offload_folder,
)
# Test non-sharded model - should work
with self.assertRaises(NotImplementedError):
with tempfile.TemporaryDirectory() as offload_folder:
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
_ = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-pipe",
subfolder="transformer",
quantization_config=quantization_config,
device_map=device_map,
torch_dtype=torch.bfloat16,
offload_folder=offload_folder,
)
weight = quantized_model.transformer_blocks[0].ff.net[2].weight
# weight = quantized_model.transformer_blocks[0].ff.net[2].weight
# self.assertTrue(quantized_model.hf_device_map == device_map_to_compare)
# self.assertTrue(isinstance(weight, AffineQuantizedTensor))
# Note that when performing cpu/disk offload, the offloaded weights are not quantized, only the weights on the gpu.
# This is not the case when the model are already quantized
if "transformer_blocks.0" in device_map:
self.assertTrue(isinstance(weight, nn.Parameter))
else:
self.assertTrue(isinstance(weight, AffineQuantizedTensor))
# output = quantized_model(**inputs)[0]
# output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
# self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
# Test sharded model - should not work
with self.assertRaises(NotImplementedError):
with tempfile.TemporaryDirectory() as offload_folder:
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
_ = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-sharded",
subfolder="transformer",
quantization_config=quantization_config,
device_map=device_map,
torch_dtype=torch.bfloat16,
offload_folder=offload_folder,
)
with tempfile.TemporaryDirectory() as offload_folder:
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
quantized_model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-sharded",
subfolder="transformer",
quantization_config=quantization_config,
device_map=device_map,
torch_dtype=torch.bfloat16,
offload_folder=offload_folder,
)
# weight = quantized_model.transformer_blocks[0].ff.net[2].weight
# self.assertTrue(quantized_model.hf_device_map == device_map_to_compare)
# self.assertTrue(isinstance(weight, AffineQuantizedTensor))
weight = quantized_model.transformer_blocks[0].ff.net[2].weight
if "transformer_blocks.0" in device_map:
self.assertTrue(isinstance(weight, nn.Parameter))
else:
self.assertTrue(isinstance(weight, AffineQuantizedTensor))
# output = quantized_model(**inputs)[0]
# output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
# self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
def test_modules_to_not_convert(self):
quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"])
@@ -544,7 +561,7 @@ class TorchAoSerializationTest(unittest.TestCase):
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
weight = quantized_model.transformer_blocks[0].ff.net[2].weight
self.assertTrue(isinstance(weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor)))
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
def _check_serialization_expected_slice(self, quant_method, quant_method_kwargs, expected_slice, device):
quantized_model = self.get_dummy_model(quant_method, quant_method_kwargs, device)
@@ -564,7 +581,7 @@ class TorchAoSerializationTest(unittest.TestCase):
loaded_quantized_model.proj_out.weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor)
)
)
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
def test_int_a8w8_cuda(self):
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}