mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[FEAT] Model loading refactor (#10604)
* first draft model loading refactor * revert name change * fix bnb * revert name * fix dduf * fix huanyan * style * Update src/diffusers/models/model_loading_utils.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * suggestions from reviews * Update src/diffusers/models/modeling_utils.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * remove safetensors check * fix default value * more fix from suggestions * revert logic for single file * style * typing + fix couple of issues * improve speed * Update src/diffusers/models/modeling_utils.py Co-authored-by: Aryan <aryan@huggingface.co> * fp8 dtype * add tests * rename resolved_archive_file to resolved_model_file * format * map_location default cpu * add utility function * switch to smaller model + test inference * Apply suggestions from code review Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * rm comment * add log * Apply suggestions from code review Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * add decorator * cosine sim instead * fix use_keep_in_fp32_modules * comm --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: YiYi Xu <yixu310@gmail.com> Co-authored-by: Aryan <aryan@huggingface.co>
This commit is contained in:
@@ -52,7 +52,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate import dispatch_model, init_empty_weights
|
||||
|
||||
from ..models.modeling_utils import load_model_dict_into_meta
|
||||
|
||||
@@ -366,19 +366,23 @@ class FromOriginalModelMixin:
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
)
|
||||
|
||||
device_map = None
|
||||
if is_accelerate_available():
|
||||
param_device = torch.device(device) if device else torch.device("cpu")
|
||||
named_buffers = model.named_buffers()
|
||||
unexpected_keys = load_model_dict_into_meta(
|
||||
empty_state_dict = model.state_dict()
|
||||
unexpected_keys = [
|
||||
param_name for param_name in diffusers_format_checkpoint if param_name not in empty_state_dict
|
||||
]
|
||||
device_map = {"": param_device}
|
||||
load_model_dict_into_meta(
|
||||
model,
|
||||
diffusers_format_checkpoint,
|
||||
dtype=torch_dtype,
|
||||
device=param_device,
|
||||
device_map=device_map,
|
||||
hf_quantizer=hf_quantizer,
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
named_buffers=named_buffers,
|
||||
unexpected_keys=unexpected_keys,
|
||||
)
|
||||
|
||||
else:
|
||||
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
|
||||
|
||||
@@ -400,4 +404,8 @@ class FromOriginalModelMixin:
|
||||
|
||||
model.eval()
|
||||
|
||||
if device_map is not None:
|
||||
device_map_kwargs = {"device_map": device_map}
|
||||
dispatch_model(model, **device_map_kwargs)
|
||||
|
||||
return model
|
||||
|
||||
@@ -1593,18 +1593,9 @@ def create_diffusers_clip_model_from_ldm(
|
||||
raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.")
|
||||
|
||||
if is_accelerate_available():
|
||||
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
|
||||
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
|
||||
else:
|
||||
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
|
||||
|
||||
if model._keys_to_ignore_on_load_unexpected is not None:
|
||||
for pat in model._keys_to_ignore_on_load_unexpected:
|
||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warning(
|
||||
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
||||
)
|
||||
model.load_state_dict(diffusers_format_checkpoint, strict=False)
|
||||
|
||||
if torch_dtype is not None:
|
||||
model.to(torch_dtype)
|
||||
@@ -2061,16 +2052,7 @@ def create_diffusers_t5_model_from_checkpoint(
|
||||
diffusers_format_checkpoint = convert_sd3_t5_checkpoint_to_diffusers(checkpoint)
|
||||
|
||||
if is_accelerate_available():
|
||||
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
|
||||
if model._keys_to_ignore_on_load_unexpected is not None:
|
||||
for pat in model._keys_to_ignore_on_load_unexpected:
|
||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warning(
|
||||
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
||||
)
|
||||
|
||||
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
|
||||
else:
|
||||
model.load_state_dict(diffusers_format_checkpoint)
|
||||
|
||||
|
||||
@@ -20,13 +20,15 @@ import os
|
||||
from array import array
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterator, List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional, Union
|
||||
from zipfile import is_zipfile
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
from huggingface_hub import DDUFEntry
|
||||
from huggingface_hub.utils import EntryNotFoundError
|
||||
|
||||
from ..quantizers import DiffusersQuantizer
|
||||
from ..utils import (
|
||||
GGUF_FILE_EXTENSION,
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
@@ -55,7 +57,7 @@ _CLASS_REMAPPING_DICT = {
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import infer_auto_device_map
|
||||
from accelerate.utils import get_balanced_memory, get_max_memory, set_module_tensor_to_device
|
||||
from accelerate.utils import get_balanced_memory, get_max_memory, offload_weight, set_module_tensor_to_device
|
||||
|
||||
|
||||
# Adapted from `transformers` (see modeling_utils.py)
|
||||
@@ -132,17 +134,46 @@ def _fetch_remapped_cls_from_config(config, old_class):
|
||||
return old_class
|
||||
|
||||
|
||||
def _check_archive_and_maybe_raise_error(checkpoint_file, format_list):
|
||||
"""
|
||||
Check format of the archive
|
||||
"""
|
||||
with safetensors.safe_open(checkpoint_file, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
if metadata is not None and metadata.get("format") not in format_list:
|
||||
raise OSError(
|
||||
f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
|
||||
"you save your model with the `save_pretrained` method."
|
||||
)
|
||||
|
||||
|
||||
def _determine_param_device(param_name: str, device_map: Optional[Dict[str, Union[int, str, torch.device]]]):
|
||||
"""
|
||||
Find the device of param_name from the device_map.
|
||||
"""
|
||||
if device_map is None:
|
||||
return "cpu"
|
||||
else:
|
||||
module_name = param_name
|
||||
# find next higher level module that is defined in device_map:
|
||||
# bert.lm_head.weight -> bert.lm_head -> bert -> ''
|
||||
while len(module_name) > 0 and module_name not in device_map:
|
||||
module_name = ".".join(module_name.split(".")[:-1])
|
||||
if module_name == "" and "" not in device_map:
|
||||
raise ValueError(f"{param_name} doesn't have any device set.")
|
||||
return device_map[module_name]
|
||||
|
||||
|
||||
def load_state_dict(
|
||||
checkpoint_file: Union[str, os.PathLike],
|
||||
variant: Optional[str] = None,
|
||||
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
|
||||
disable_mmap: bool = False,
|
||||
map_location: Union[str, torch.device] = "cpu",
|
||||
):
|
||||
"""
|
||||
Reads a checkpoint file, returning properly formatted errors if they arise.
|
||||
"""
|
||||
# TODO: We merge the sharded checkpoints in case we're doing quantization. We can revisit this change
|
||||
# when refactoring the _merge_sharded_checkpoints() method later.
|
||||
# TODO: maybe refactor a bit this part where we pass a dict here
|
||||
if isinstance(checkpoint_file, dict):
|
||||
return checkpoint_file
|
||||
try:
|
||||
@@ -152,19 +183,26 @@ def load_state_dict(
|
||||
# tensors are loaded on cpu
|
||||
with dduf_entries[checkpoint_file].as_mmap() as mm:
|
||||
return safetensors.torch.load(mm)
|
||||
_check_archive_and_maybe_raise_error(checkpoint_file, format_list=["pt", "flax"])
|
||||
if disable_mmap:
|
||||
return safetensors.torch.load(open(checkpoint_file, "rb").read())
|
||||
else:
|
||||
return safetensors.torch.load_file(checkpoint_file, device="cpu")
|
||||
return safetensors.torch.load_file(checkpoint_file, device=map_location)
|
||||
elif file_extension == GGUF_FILE_EXTENSION:
|
||||
return load_gguf_checkpoint(checkpoint_file)
|
||||
else:
|
||||
extra_args = {}
|
||||
weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
|
||||
return torch.load(
|
||||
checkpoint_file,
|
||||
map_location="cpu",
|
||||
**weights_only_kwarg,
|
||||
)
|
||||
# mmap can only be used with files serialized with zipfile-based format.
|
||||
if (
|
||||
isinstance(checkpoint_file, str)
|
||||
and map_location != "meta"
|
||||
and is_torch_version(">=", "2.1.0")
|
||||
and is_zipfile(checkpoint_file)
|
||||
and not disable_mmap
|
||||
):
|
||||
extra_args = {"mmap": True}
|
||||
return torch.load(checkpoint_file, map_location=map_location, **weights_only_kwarg, **extra_args)
|
||||
except Exception as e:
|
||||
try:
|
||||
with open(checkpoint_file) as f:
|
||||
@@ -188,23 +226,24 @@ def load_state_dict(
|
||||
def load_model_dict_into_meta(
|
||||
model,
|
||||
state_dict: OrderedDict,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
dtype: Optional[Union[str, torch.dtype]] = None,
|
||||
model_name_or_path: Optional[str] = None,
|
||||
hf_quantizer=None,
|
||||
keep_in_fp32_modules=None,
|
||||
named_buffers: Optional[Iterator[Tuple[str, torch.Tensor]]] = None,
|
||||
hf_quantizer: Optional[DiffusersQuantizer] = None,
|
||||
keep_in_fp32_modules: Optional[List] = None,
|
||||
device_map: Optional[Dict[str, Union[int, str, torch.device]]] = None,
|
||||
unexpected_keys: Optional[List[str]] = None,
|
||||
offload_folder: Optional[Union[str, os.PathLike]] = None,
|
||||
offload_index: Optional[Dict] = None,
|
||||
state_dict_index: Optional[Dict] = None,
|
||||
state_dict_folder: Optional[Union[str, os.PathLike]] = None,
|
||||
) -> List[str]:
|
||||
if device is not None and not isinstance(device, (str, torch.device)):
|
||||
raise ValueError(f"Expected device to have type `str` or `torch.device`, but got {type(device)=}.")
|
||||
if hf_quantizer is None:
|
||||
device = device or torch.device("cpu")
|
||||
dtype = dtype or torch.float32
|
||||
is_quantized = hf_quantizer is not None
|
||||
"""
|
||||
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
|
||||
params on a `meta` device. It replaces the model params with the data from the `state_dict`
|
||||
"""
|
||||
|
||||
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
|
||||
is_quantized = hf_quantizer is not None
|
||||
empty_state_dict = model.state_dict()
|
||||
unexpected_keys = [param_name for param_name in state_dict if param_name not in empty_state_dict]
|
||||
|
||||
for param_name, param in state_dict.items():
|
||||
if param_name not in empty_state_dict:
|
||||
@@ -214,21 +253,35 @@ def load_model_dict_into_meta(
|
||||
# We convert floating dtypes to the `dtype` passed. We also want to keep the buffers/params
|
||||
# in int/uint/bool and not cast them.
|
||||
# TODO: revisit cases when param.dtype == torch.float8_e4m3fn
|
||||
if torch.is_floating_point(param):
|
||||
if (
|
||||
keep_in_fp32_modules is not None
|
||||
and any(
|
||||
module_to_keep_in_fp32 in param_name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules
|
||||
)
|
||||
and dtype == torch.float16
|
||||
if dtype is not None and torch.is_floating_point(param):
|
||||
if keep_in_fp32_modules is not None and any(
|
||||
module_to_keep_in_fp32 in param_name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules
|
||||
):
|
||||
param = param.to(torch.float32)
|
||||
if accepts_dtype:
|
||||
set_module_kwargs["dtype"] = torch.float32
|
||||
set_module_kwargs["dtype"] = torch.float32
|
||||
else:
|
||||
param = param.to(dtype)
|
||||
if accepts_dtype:
|
||||
set_module_kwargs["dtype"] = dtype
|
||||
set_module_kwargs["dtype"] = dtype
|
||||
|
||||
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
|
||||
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
|
||||
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
|
||||
old_param = model
|
||||
splits = param_name.split(".")
|
||||
for split in splits:
|
||||
old_param = getattr(old_param, split)
|
||||
|
||||
if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)):
|
||||
old_param = None
|
||||
|
||||
if old_param is not None:
|
||||
if dtype is None:
|
||||
param = param.to(old_param.dtype)
|
||||
|
||||
if old_param.is_contiguous():
|
||||
param = param.contiguous()
|
||||
|
||||
param_device = _determine_param_device(param_name, device_map)
|
||||
|
||||
# bnb params are flattened.
|
||||
# gguf quants have a different shape based on the type of quantization applied
|
||||
@@ -236,7 +289,9 @@ def load_model_dict_into_meta(
|
||||
if (
|
||||
is_quantized
|
||||
and hf_quantizer.pre_quantized
|
||||
and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
|
||||
and hf_quantizer.check_if_quantized_param(
|
||||
model, param, param_name, state_dict, param_device=param_device
|
||||
)
|
||||
):
|
||||
hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name], param)
|
||||
else:
|
||||
@@ -244,35 +299,23 @@ def load_model_dict_into_meta(
|
||||
raise ValueError(
|
||||
f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name].shape}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
|
||||
)
|
||||
|
||||
if is_quantized and (
|
||||
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
|
||||
if param_device == "disk":
|
||||
offload_index = offload_weight(param, param_name, offload_folder, offload_index)
|
||||
elif param_device == "cpu" and state_dict_index is not None:
|
||||
state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index)
|
||||
elif is_quantized and (
|
||||
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=param_device)
|
||||
):
|
||||
hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys)
|
||||
hf_quantizer.create_quantized_param(model, param, param_name, param_device, state_dict, unexpected_keys)
|
||||
else:
|
||||
if accepts_dtype:
|
||||
set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs)
|
||||
else:
|
||||
set_module_tensor_to_device(model, param_name, device, value=param)
|
||||
set_module_tensor_to_device(model, param_name, param_device, value=param, **set_module_kwargs)
|
||||
|
||||
if named_buffers is None:
|
||||
return unexpected_keys
|
||||
|
||||
for param_name, param in named_buffers:
|
||||
if is_quantized and (
|
||||
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
|
||||
):
|
||||
hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys)
|
||||
else:
|
||||
if accepts_dtype:
|
||||
set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs)
|
||||
else:
|
||||
set_module_tensor_to_device(model, param_name, device, value=param)
|
||||
|
||||
return unexpected_keys
|
||||
return offload_index, state_dict_index
|
||||
|
||||
|
||||
def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]:
|
||||
def _load_state_dict_into_model(
|
||||
model_to_load, state_dict: OrderedDict, assign_to_params_buffers: bool = False
|
||||
) -> List[str]:
|
||||
# Convert old format to new format if needed from a PyTorch state_dict
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
state_dict = state_dict.copy()
|
||||
@@ -280,15 +323,19 @@ def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[
|
||||
|
||||
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
||||
# so we need to apply the function recursively.
|
||||
def load(module: torch.nn.Module, prefix: str = ""):
|
||||
args = (state_dict, prefix, {}, True, [], [], error_msgs)
|
||||
def load(module: torch.nn.Module, prefix: str = "", assign_to_params_buffers: bool = False):
|
||||
local_metadata = {}
|
||||
local_metadata["assign_to_params_buffers"] = assign_to_params_buffers
|
||||
if assign_to_params_buffers and not is_torch_version(">=", "2.1"):
|
||||
logger.info("You need to have torch>=2.1 in order to load the model with assign_to_params_buffers=True")
|
||||
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
|
||||
module._load_from_state_dict(*args)
|
||||
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, prefix + name + ".")
|
||||
load(child, prefix + name + ".", assign_to_params_buffers)
|
||||
|
||||
load(model_to_load)
|
||||
load(model_to_load, assign_to_params_buffers=assign_to_params_buffers)
|
||||
|
||||
return error_msgs
|
||||
|
||||
@@ -343,46 +390,6 @@ def _fetch_index_file(
|
||||
return index_file
|
||||
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/bghira/SimpleTuner/blob/cea2457ab063f6dedb9e697830ae68a96be90641/helpers/training/save_hooks.py#L64
|
||||
def _merge_sharded_checkpoints(
|
||||
sharded_ckpt_cached_folder, sharded_metadata, dduf_entries: Optional[Dict[str, DDUFEntry]] = None
|
||||
):
|
||||
weight_map = sharded_metadata.get("weight_map", None)
|
||||
if weight_map is None:
|
||||
raise KeyError("'weight_map' key not found in the shard index file.")
|
||||
|
||||
# Collect all unique safetensors files from weight_map
|
||||
files_to_load = set(weight_map.values())
|
||||
is_safetensors = all(f.endswith(".safetensors") for f in files_to_load)
|
||||
merged_state_dict = {}
|
||||
|
||||
# Load tensors from each unique file
|
||||
for file_name in files_to_load:
|
||||
part_file_path = os.path.join(sharded_ckpt_cached_folder, file_name)
|
||||
if dduf_entries:
|
||||
if part_file_path not in dduf_entries:
|
||||
raise FileNotFoundError(f"Part file {file_name} not found.")
|
||||
else:
|
||||
if not os.path.exists(part_file_path):
|
||||
raise FileNotFoundError(f"Part file {file_name} not found.")
|
||||
|
||||
if is_safetensors:
|
||||
if dduf_entries:
|
||||
with dduf_entries[part_file_path].as_mmap() as mm:
|
||||
tensors = safetensors.torch.load(mm)
|
||||
merged_state_dict.update(tensors)
|
||||
else:
|
||||
with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f:
|
||||
for tensor_key in f.keys():
|
||||
if tensor_key in weight_map:
|
||||
merged_state_dict[tensor_key] = f.get_tensor(tensor_key)
|
||||
else:
|
||||
merged_state_dict.update(torch.load(part_file_path, weights_only=True, map_location="cpu"))
|
||||
|
||||
return merged_state_dict
|
||||
|
||||
|
||||
def _fetch_index_file_legacy(
|
||||
is_local,
|
||||
pretrained_model_name_or_path,
|
||||
|
||||
@@ -20,10 +20,13 @@ import itertools
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import tempfile
|
||||
from collections import OrderedDict
|
||||
from contextlib import ExitStack, contextmanager
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
|
||||
from typing import Any, Callable, ContextManager, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
@@ -65,16 +68,49 @@ from .model_loading_utils import (
|
||||
_fetch_index_file,
|
||||
_fetch_index_file_legacy,
|
||||
_load_state_dict_into_model,
|
||||
_merge_sharded_checkpoints,
|
||||
load_model_dict_into_meta,
|
||||
load_state_dict,
|
||||
)
|
||||
|
||||
|
||||
class ContextManagers:
|
||||
"""
|
||||
Wrapper for `contextlib.ExitStack` which enters a collection of context managers. Adaptation of `ContextManagers`
|
||||
in the `fastcore` library.
|
||||
"""
|
||||
|
||||
def __init__(self, context_managers: List[ContextManager]):
|
||||
self.context_managers = context_managers
|
||||
self.stack = ExitStack()
|
||||
|
||||
def __enter__(self):
|
||||
for context_manager in self.context_managers:
|
||||
self.stack.enter_context(context_manager)
|
||||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
self.stack.__exit__(*args, **kwargs)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_REGEX_SHARD = re.compile(r"(.*?)-\d{5}-of-\d{5}")
|
||||
|
||||
TORCH_INIT_FUNCTIONS = {
|
||||
"uniform_": nn.init.uniform_,
|
||||
"normal_": nn.init.normal_,
|
||||
"trunc_normal_": nn.init.trunc_normal_,
|
||||
"constant_": nn.init.constant_,
|
||||
"xavier_uniform_": nn.init.xavier_uniform_,
|
||||
"xavier_normal_": nn.init.xavier_normal_,
|
||||
"kaiming_uniform_": nn.init.kaiming_uniform_,
|
||||
"kaiming_normal_": nn.init.kaiming_normal_,
|
||||
"uniform": nn.init.uniform,
|
||||
"normal": nn.init.normal,
|
||||
"xavier_uniform": nn.init.xavier_uniform,
|
||||
"xavier_normal": nn.init.xavier_normal,
|
||||
"kaiming_uniform": nn.init.kaiming_uniform,
|
||||
"kaiming_normal": nn.init.kaiming_normal,
|
||||
}
|
||||
|
||||
if is_torch_version(">=", "1.9.0"):
|
||||
_LOW_CPU_MEM_USAGE_DEFAULT = True
|
||||
@@ -84,6 +120,8 @@ else:
|
||||
|
||||
if is_accelerate_available():
|
||||
import accelerate
|
||||
from accelerate import dispatch_model
|
||||
from accelerate.utils import load_offloaded_weights, save_offload_index
|
||||
|
||||
|
||||
def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
|
||||
@@ -159,6 +197,54 @@ def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
|
||||
return last_tuple[1].dtype
|
||||
|
||||
|
||||
def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""):
|
||||
"""
|
||||
Checks if `model_to_load` supports param buffer assignment (such as when loading in empty weights) by first
|
||||
checking if the model explicitly disables it, then by ensuring that the state dict keys are a subset of the model's
|
||||
parameters.
|
||||
|
||||
"""
|
||||
if model_to_load.device.type == "meta":
|
||||
return False
|
||||
|
||||
if len([key for key in state_dict if key.startswith(start_prefix)]) == 0:
|
||||
return False
|
||||
|
||||
# Some models explicitly do not support param buffer assignment
|
||||
if not getattr(model_to_load, "_supports_param_buffer_assignment", True):
|
||||
logger.debug(
|
||||
f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower"
|
||||
)
|
||||
return False
|
||||
|
||||
# If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype
|
||||
first_key = next(iter(model_to_load.state_dict().keys()))
|
||||
if start_prefix + first_key in state_dict:
|
||||
return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@contextmanager
|
||||
def no_init_weights():
|
||||
"""
|
||||
Context manager to globally disable weight initialization to speed up loading large models. To do that, all the
|
||||
torch.nn.init function are all replaced with skip.
|
||||
"""
|
||||
|
||||
def _skip_init(*args, **kwargs):
|
||||
pass
|
||||
|
||||
for name, init_func in TORCH_INIT_FUNCTIONS.items():
|
||||
setattr(torch.nn.init, name, _skip_init)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# Restore the original initialization functions
|
||||
for name, init_func in TORCH_INIT_FUNCTIONS.items():
|
||||
setattr(torch.nn.init, name, init_func)
|
||||
|
||||
|
||||
class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
r"""
|
||||
Base class for all models.
|
||||
@@ -785,7 +871,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
device_map = kwargs.pop("device_map", None)
|
||||
max_memory = kwargs.pop("max_memory", None)
|
||||
offload_folder = kwargs.pop("offload_folder", None)
|
||||
offload_state_dict = kwargs.pop("offload_state_dict", False)
|
||||
offload_state_dict = kwargs.pop("offload_state_dict", None)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
variant = kwargs.pop("variant", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
@@ -862,14 +948,15 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
# The max memory utils require PyTorch >= 1.10 to have torch.cuda.mem_get_info.
|
||||
raise ValueError("`low_cpu_mem_usage` and `device_map` require PyTorch >= 1.10.")
|
||||
|
||||
# Load config if we don't provide a configuration
|
||||
config_path = pretrained_model_name_or_path
|
||||
|
||||
user_agent = {
|
||||
"diffusers": __version__,
|
||||
"file_type": "model",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
unused_kwargs = {}
|
||||
|
||||
# Load config if we don't provide a configuration
|
||||
config_path = pretrained_model_name_or_path
|
||||
|
||||
# load config
|
||||
config, unused_kwargs, commit_hash = cls.load_config(
|
||||
@@ -907,13 +994,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
hf_quantizer = None
|
||||
|
||||
if hf_quantizer is not None:
|
||||
if device_map is not None:
|
||||
raise NotImplementedError(
|
||||
"Currently, providing `device_map` is not supported for quantized models. Providing `device_map` as an input will be added in the future."
|
||||
)
|
||||
|
||||
hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map)
|
||||
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
|
||||
device_map = hf_quantizer.update_device_map(device_map)
|
||||
|
||||
# In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
|
||||
user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value
|
||||
@@ -926,9 +1009,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
raise ValueError("`low_cpu_mem_usage` cannot be False or None when using quantization.")
|
||||
|
||||
# Check if `_keep_in_fp32_modules` is not None
|
||||
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
|
||||
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
|
||||
use_keep_in_fp32_modules = cls._keep_in_fp32_modules is not None and (
|
||||
hf_quantizer is None or getattr(hf_quantizer, "use_keep_in_fp32_modules", False)
|
||||
)
|
||||
|
||||
if use_keep_in_fp32_modules:
|
||||
keep_in_fp32_modules = cls._keep_in_fp32_modules
|
||||
if not isinstance(keep_in_fp32_modules, list):
|
||||
@@ -941,10 +1025,12 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
raise ValueError("`low_cpu_mem_usage` cannot be False when `keep_in_fp32_modules` is True.")
|
||||
else:
|
||||
keep_in_fp32_modules = []
|
||||
#######################################
|
||||
|
||||
is_sharded = False
|
||||
resolved_model_file = None
|
||||
|
||||
# Determine if we're loading from a directory of sharded checkpoints.
|
||||
is_sharded = False
|
||||
sharded_metadata = None
|
||||
index_file = None
|
||||
is_local = os.path.isdir(pretrained_model_name_or_path)
|
||||
index_file_kwargs = {
|
||||
@@ -975,9 +1061,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
raise ValueError("Loading of sharded checkpoints is not supported when `from_flax=True`.")
|
||||
|
||||
# load model
|
||||
model_file = None
|
||||
if from_flax:
|
||||
model_file = _get_model_file(
|
||||
resolved_model_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=FLAX_WEIGHTS_NAME,
|
||||
cache_dir=cache_dir,
|
||||
@@ -995,11 +1080,11 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
# Convert the weights
|
||||
from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
|
||||
|
||||
model = load_flax_checkpoint_in_pytorch_model(model, model_file)
|
||||
model = load_flax_checkpoint_in_pytorch_model(model, resolved_model_file)
|
||||
else:
|
||||
# in the case it is sharded, we have already the index
|
||||
if is_sharded:
|
||||
sharded_ckpt_cached_folder, sharded_metadata = _get_checkpoint_shard_files(
|
||||
resolved_model_file, sharded_metadata = _get_checkpoint_shard_files(
|
||||
pretrained_model_name_or_path,
|
||||
index_file,
|
||||
cache_dir=cache_dir,
|
||||
@@ -1011,17 +1096,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
subfolder=subfolder or "",
|
||||
dduf_entries=dduf_entries,
|
||||
)
|
||||
# TODO: https://github.com/huggingface/diffusers/issues/10013
|
||||
if hf_quantizer is not None or dduf_entries:
|
||||
model_file = _merge_sharded_checkpoints(
|
||||
sharded_ckpt_cached_folder, sharded_metadata, dduf_entries=dduf_entries
|
||||
)
|
||||
logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.")
|
||||
is_sharded = False
|
||||
|
||||
elif use_safetensors and not is_sharded:
|
||||
elif use_safetensors:
|
||||
try:
|
||||
model_file = _get_model_file(
|
||||
resolved_model_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
|
||||
cache_dir=cache_dir,
|
||||
@@ -1044,8 +1121,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
"Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
|
||||
)
|
||||
|
||||
if model_file is None and not is_sharded:
|
||||
model_file = _get_model_file(
|
||||
if resolved_model_file is None and not is_sharded:
|
||||
resolved_model_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=_add_variant(WEIGHTS_NAME, variant),
|
||||
cache_dir=cache_dir,
|
||||
@@ -1060,157 +1137,104 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
dduf_entries=dduf_entries,
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
# Instantiate model with empty weights
|
||||
with accelerate.init_empty_weights():
|
||||
model = cls.from_config(config, **unused_kwargs)
|
||||
if not isinstance(resolved_model_file, list):
|
||||
resolved_model_file = [resolved_model_file]
|
||||
|
||||
if hf_quantizer is not None:
|
||||
hf_quantizer.preprocess_model(
|
||||
model=model, device_map=device_map, keep_in_fp32_modules=keep_in_fp32_modules
|
||||
)
|
||||
|
||||
# if device_map is None, load the state dict and move the params from meta device to the cpu
|
||||
if device_map is None and not is_sharded:
|
||||
# `torch.cuda.current_device()` is fine here when `hf_quantizer` is not None.
|
||||
# It would error out during the `validate_environment()` call above in the absence of cuda.
|
||||
if hf_quantizer is None:
|
||||
param_device = "cpu"
|
||||
# TODO (sayakpaul, SunMarc): remove this after model loading refactor
|
||||
else:
|
||||
param_device = torch.device(torch.cuda.current_device())
|
||||
state_dict = load_state_dict(
|
||||
model_file, variant=variant, dduf_entries=dduf_entries, disable_mmap=disable_mmap
|
||||
)
|
||||
model._convert_deprecated_attention_blocks(state_dict)
|
||||
|
||||
# move the params from meta device to cpu
|
||||
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
|
||||
if hf_quantizer is not None:
|
||||
missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix="")
|
||||
if len(missing_keys) > 0:
|
||||
raise ValueError(
|
||||
f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
|
||||
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
|
||||
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
|
||||
" those weights or else make sure your checkpoint file is correct."
|
||||
)
|
||||
|
||||
named_buffers = model.named_buffers()
|
||||
|
||||
unexpected_keys = load_model_dict_into_meta(
|
||||
model,
|
||||
state_dict,
|
||||
device=param_device,
|
||||
dtype=torch_dtype,
|
||||
model_name_or_path=pretrained_model_name_or_path,
|
||||
hf_quantizer=hf_quantizer,
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
named_buffers=named_buffers,
|
||||
)
|
||||
|
||||
if cls._keys_to_ignore_on_load_unexpected is not None:
|
||||
for pat in cls._keys_to_ignore_on_load_unexpected:
|
||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warning(
|
||||
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
||||
)
|
||||
|
||||
else: # else let accelerate handle loading and dispatching.
|
||||
# Load weights and dispatch according to the device_map
|
||||
# by default the device_map is None and the weights are loaded on the CPU
|
||||
device_map = _determine_device_map(
|
||||
model, device_map, max_memory, torch_dtype, keep_in_fp32_modules, hf_quantizer
|
||||
)
|
||||
if device_map is None and is_sharded:
|
||||
# we load the parameters on the cpu
|
||||
device_map = {"": "cpu"}
|
||||
try:
|
||||
accelerate.load_checkpoint_and_dispatch(
|
||||
model,
|
||||
model_file if not is_sharded else index_file,
|
||||
device_map,
|
||||
max_memory=max_memory,
|
||||
offload_folder=offload_folder,
|
||||
offload_state_dict=offload_state_dict,
|
||||
dtype=torch_dtype,
|
||||
strict=True,
|
||||
)
|
||||
except AttributeError as e:
|
||||
# When using accelerate loading, we do not have the ability to load the state
|
||||
# dict and rename the weight names manually. Additionally, accelerate skips
|
||||
# torch loading conventions and directly writes into `module.{_buffers, _parameters}`
|
||||
# (which look like they should be private variables?), so we can't use the standard hooks
|
||||
# to rename parameters on load. We need to mimic the original weight names so the correct
|
||||
# attributes are available. After we have loaded the weights, we convert the deprecated
|
||||
# names to the new non-deprecated names. Then we _greatly encourage_ the user to convert
|
||||
# the weights so we don't have to do this again.
|
||||
|
||||
if "'Attention' object has no attribute" in str(e):
|
||||
logger.warning(
|
||||
f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
|
||||
" was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
|
||||
" names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
|
||||
" so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint,"
|
||||
" please also re-upload it or open a PR on the original repository."
|
||||
)
|
||||
model._temp_convert_self_to_deprecated_attention_blocks()
|
||||
accelerate.load_checkpoint_and_dispatch(
|
||||
model,
|
||||
model_file if not is_sharded else index_file,
|
||||
device_map,
|
||||
max_memory=max_memory,
|
||||
offload_folder=offload_folder,
|
||||
offload_state_dict=offload_state_dict,
|
||||
dtype=torch_dtype,
|
||||
strict=True,
|
||||
)
|
||||
model._undo_temp_convert_self_to_deprecated_attention_blocks()
|
||||
else:
|
||||
raise e
|
||||
|
||||
loading_info = {
|
||||
"missing_keys": [],
|
||||
"unexpected_keys": [],
|
||||
"mismatched_keys": [],
|
||||
"error_msgs": [],
|
||||
}
|
||||
else:
|
||||
model = cls.from_config(config, **unused_kwargs)
|
||||
|
||||
state_dict = load_state_dict(
|
||||
model_file, variant=variant, dduf_entries=dduf_entries, disable_mmap=disable_mmap
|
||||
# set dtype to instantiate the model under:
|
||||
# 1. If torch_dtype is not None, we use that dtype
|
||||
# 2. If torch_dtype is float8, we don't use _set_default_torch_dtype and we downcast after loading the model
|
||||
dtype_orig = None
|
||||
if torch_dtype is not None and not torch_dtype == getattr(torch, "float8_e4m3fn", None):
|
||||
if not isinstance(torch_dtype, torch.dtype):
|
||||
raise ValueError(
|
||||
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
|
||||
)
|
||||
model._convert_deprecated_attention_blocks(state_dict)
|
||||
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
|
||||
|
||||
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
||||
model,
|
||||
state_dict,
|
||||
model_file,
|
||||
pretrained_model_name_or_path,
|
||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||
)
|
||||
init_contexts = [no_init_weights()]
|
||||
|
||||
loading_info = {
|
||||
"missing_keys": missing_keys,
|
||||
"unexpected_keys": unexpected_keys,
|
||||
"mismatched_keys": mismatched_keys,
|
||||
"error_msgs": error_msgs,
|
||||
}
|
||||
if low_cpu_mem_usage:
|
||||
init_contexts.append(accelerate.init_empty_weights())
|
||||
|
||||
with ContextManagers(init_contexts):
|
||||
model = cls.from_config(config, **unused_kwargs)
|
||||
|
||||
if dtype_orig is not None:
|
||||
torch.set_default_dtype(dtype_orig)
|
||||
|
||||
state_dict = None
|
||||
if not is_sharded:
|
||||
# Time to load the checkpoint
|
||||
state_dict = load_state_dict(resolved_model_file[0], disable_mmap=disable_mmap, dduf_entries=dduf_entries)
|
||||
# We only fix it for non sharded checkpoints as we don't need it yet for sharded one.
|
||||
model._fix_state_dict_keys_on_load(state_dict)
|
||||
|
||||
if is_sharded:
|
||||
loaded_keys = sharded_metadata["all_checkpoint_keys"]
|
||||
else:
|
||||
loaded_keys = list(state_dict.keys())
|
||||
|
||||
if hf_quantizer is not None:
|
||||
hf_quantizer.preprocess_model(
|
||||
model=model, device_map=device_map, keep_in_fp32_modules=keep_in_fp32_modules
|
||||
)
|
||||
print(keep_in_fp32_modules)
|
||||
# Now that the model is loaded, we can determine the device_map
|
||||
device_map = _determine_device_map(
|
||||
model, device_map, max_memory, torch_dtype, keep_in_fp32_modules, hf_quantizer
|
||||
)
|
||||
if hf_quantizer is not None:
|
||||
hf_quantizer.validate_environment(device_map=device_map)
|
||||
|
||||
(
|
||||
model,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
mismatched_keys,
|
||||
offload_index,
|
||||
error_msgs,
|
||||
) = cls._load_pretrained_model(
|
||||
model,
|
||||
state_dict,
|
||||
resolved_model_file,
|
||||
pretrained_model_name_or_path,
|
||||
loaded_keys,
|
||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
device_map=device_map,
|
||||
offload_folder=offload_folder,
|
||||
offload_state_dict=offload_state_dict,
|
||||
dtype=torch_dtype,
|
||||
hf_quantizer=hf_quantizer,
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
dduf_entries=dduf_entries,
|
||||
)
|
||||
loading_info = {
|
||||
"missing_keys": missing_keys,
|
||||
"unexpected_keys": unexpected_keys,
|
||||
"mismatched_keys": mismatched_keys,
|
||||
"error_msgs": error_msgs,
|
||||
}
|
||||
|
||||
# Dispatch model with hooks on all devices if necessary
|
||||
if device_map is not None:
|
||||
device_map_kwargs = {
|
||||
"device_map": device_map,
|
||||
"offload_dir": offload_folder,
|
||||
"offload_index": offload_index,
|
||||
}
|
||||
dispatch_model(model, **device_map_kwargs)
|
||||
|
||||
if hf_quantizer is not None:
|
||||
hf_quantizer.postprocess_model(model)
|
||||
model.hf_quantizer = hf_quantizer
|
||||
|
||||
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
||||
raise ValueError(
|
||||
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
|
||||
)
|
||||
# When using `use_keep_in_fp32_modules` if we do a global `to()` here, then we will
|
||||
# completely lose the effectivity of `use_keep_in_fp32_modules`.
|
||||
elif torch_dtype is not None and hf_quantizer is None and not use_keep_in_fp32_modules:
|
||||
if (
|
||||
torch_dtype is not None
|
||||
and torch_dtype == getattr(torch, "float8_e4m3fn", None)
|
||||
and hf_quantizer is None
|
||||
and not use_keep_in_fp32_modules
|
||||
):
|
||||
model = model.to(torch_dtype)
|
||||
|
||||
if hf_quantizer is not None:
|
||||
@@ -1222,6 +1246,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
|
||||
# Set model in evaluation mode to deactivate DropOut modules by default
|
||||
model.eval()
|
||||
|
||||
if output_loading_info:
|
||||
return model, loading_info
|
||||
|
||||
@@ -1332,54 +1357,127 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
cls,
|
||||
model,
|
||||
state_dict: OrderedDict,
|
||||
resolved_archive_file,
|
||||
resolved_model_file: List[str],
|
||||
pretrained_model_name_or_path: Union[str, os.PathLike],
|
||||
loaded_keys: List[str],
|
||||
ignore_mismatched_sizes: bool = False,
|
||||
assign_to_params_buffers: bool = False,
|
||||
hf_quantizer: Optional[DiffusersQuantizer] = None,
|
||||
low_cpu_mem_usage: bool = True,
|
||||
dtype: Optional[Union[str, torch.dtype]] = None,
|
||||
keep_in_fp32_modules: Optional[List[str]] = None,
|
||||
device_map: Dict[str, Union[int, str, torch.device]] = None,
|
||||
offload_state_dict: Optional[bool] = None,
|
||||
offload_folder: Optional[Union[str, os.PathLike]] = None,
|
||||
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
|
||||
):
|
||||
# Retrieve missing & unexpected_keys
|
||||
model_state_dict = model.state_dict()
|
||||
loaded_keys = list(state_dict.keys())
|
||||
|
||||
expected_keys = list(model_state_dict.keys())
|
||||
|
||||
original_loaded_keys = loaded_keys
|
||||
|
||||
missing_keys = list(set(expected_keys) - set(loaded_keys))
|
||||
if hf_quantizer is not None:
|
||||
missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix="")
|
||||
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
|
||||
# Some models may have keys that are not in the state by design, removing them before needlessly warning
|
||||
# the user.
|
||||
if cls._keys_to_ignore_on_load_unexpected is not None:
|
||||
for pat in cls._keys_to_ignore_on_load_unexpected:
|
||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||
|
||||
# Make sure we are able to load base models as well as derived models (with heads)
|
||||
model_to_load = model
|
||||
mismatched_keys = []
|
||||
|
||||
def _find_mismatched_keys(
|
||||
state_dict,
|
||||
model_state_dict,
|
||||
loaded_keys,
|
||||
ignore_mismatched_sizes,
|
||||
):
|
||||
mismatched_keys = []
|
||||
if ignore_mismatched_sizes:
|
||||
for checkpoint_key in loaded_keys:
|
||||
model_key = checkpoint_key
|
||||
assign_to_params_buffers = None
|
||||
error_msgs = []
|
||||
|
||||
if (
|
||||
model_key in model_state_dict
|
||||
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
|
||||
):
|
||||
mismatched_keys.append(
|
||||
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
||||
)
|
||||
del state_dict[checkpoint_key]
|
||||
return mismatched_keys
|
||||
# Deal with offload
|
||||
if device_map is not None and "disk" in device_map.values():
|
||||
if offload_folder is None:
|
||||
raise ValueError(
|
||||
"The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`"
|
||||
" for them. Alternatively, make sure you have `safetensors` installed if the model you are using"
|
||||
" offers the weights in this format."
|
||||
)
|
||||
if offload_folder is not None:
|
||||
os.makedirs(offload_folder, exist_ok=True)
|
||||
if offload_state_dict is None:
|
||||
offload_state_dict = True
|
||||
|
||||
offload_index = {} if device_map is not None and "disk" in device_map.values() else None
|
||||
if offload_state_dict:
|
||||
state_dict_folder = tempfile.mkdtemp()
|
||||
state_dict_index = {}
|
||||
else:
|
||||
state_dict_folder = None
|
||||
state_dict_index = None
|
||||
|
||||
if state_dict is not None:
|
||||
# Whole checkpoint
|
||||
mismatched_keys = _find_mismatched_keys(
|
||||
# load_state_dict will manage the case where we pass a dict instead of a file
|
||||
# if state dict is not None, it means that we don't need to read the files from resolved_model_file also
|
||||
resolved_model_file = [state_dict]
|
||||
|
||||
if len(resolved_model_file) > 1:
|
||||
resolved_model_file = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards")
|
||||
|
||||
for shard_file in resolved_model_file:
|
||||
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
|
||||
|
||||
def _find_mismatched_keys(
|
||||
state_dict,
|
||||
model_state_dict,
|
||||
original_loaded_keys,
|
||||
loaded_keys,
|
||||
ignore_mismatched_sizes,
|
||||
):
|
||||
mismatched_keys = []
|
||||
if ignore_mismatched_sizes:
|
||||
for checkpoint_key in loaded_keys:
|
||||
model_key = checkpoint_key
|
||||
# If the checkpoint is sharded, we may not have the key here.
|
||||
if checkpoint_key not in state_dict:
|
||||
continue
|
||||
|
||||
if (
|
||||
model_key in model_state_dict
|
||||
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
|
||||
):
|
||||
mismatched_keys.append(
|
||||
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
||||
)
|
||||
del state_dict[checkpoint_key]
|
||||
return mismatched_keys
|
||||
|
||||
mismatched_keys += _find_mismatched_keys(
|
||||
state_dict,
|
||||
model_state_dict,
|
||||
loaded_keys,
|
||||
ignore_mismatched_sizes,
|
||||
)
|
||||
error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
offload_index, state_dict_index = load_model_dict_into_meta(
|
||||
model,
|
||||
state_dict,
|
||||
device_map=device_map,
|
||||
dtype=dtype,
|
||||
hf_quantizer=hf_quantizer,
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
unexpected_keys=unexpected_keys,
|
||||
offload_folder=offload_folder,
|
||||
offload_index=offload_index,
|
||||
state_dict_index=state_dict_index,
|
||||
state_dict_folder=state_dict_folder,
|
||||
)
|
||||
else:
|
||||
if assign_to_params_buffers is None:
|
||||
assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict)
|
||||
|
||||
error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)
|
||||
|
||||
if offload_index is not None and len(offload_index) > 0:
|
||||
save_offload_index(offload_index, offload_folder)
|
||||
offload_index = None
|
||||
|
||||
if offload_state_dict:
|
||||
load_offloaded_weights(model, state_dict_index, state_dict_folder)
|
||||
shutil.rmtree(state_dict_folder)
|
||||
|
||||
if len(error_msgs) > 0:
|
||||
error_msg = "\n\t".join(error_msgs)
|
||||
@@ -1391,17 +1489,11 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warning(
|
||||
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
|
||||
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
|
||||
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
|
||||
" or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
|
||||
" BertForPreTraining model).\n- This IS NOT expected if you are initializing"
|
||||
f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
|
||||
" identical (initializing a BertForSequenceClassification model from a"
|
||||
" BertForSequenceClassification model)."
|
||||
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
||||
)
|
||||
else:
|
||||
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
|
||||
|
||||
if len(missing_keys) > 0:
|
||||
logger.warning(
|
||||
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
||||
@@ -1429,7 +1521,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
" able to use it for predictions and inference."
|
||||
)
|
||||
|
||||
return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
|
||||
return model, missing_keys, unexpected_keys, mismatched_keys, offload_index, error_msgs
|
||||
|
||||
@classmethod
|
||||
def _get_signature_keys(cls, obj):
|
||||
@@ -1470,6 +1562,33 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
modules_to_check += list(module.children())
|
||||
return list(_no_split_modules)
|
||||
|
||||
@classmethod
|
||||
def _set_default_torch_dtype(cls, dtype: torch.dtype) -> torch.dtype:
|
||||
"""
|
||||
Change the default dtype and return the previous one. This is needed when wanting to instantiate the model
|
||||
under specific dtype.
|
||||
|
||||
Args:
|
||||
dtype (`torch.dtype`):
|
||||
a floating dtype to set to.
|
||||
|
||||
Returns:
|
||||
`torch.dtype`: the original `dtype` that can be used to restore `torch.set_default_dtype(dtype)` if it was
|
||||
modified. If it wasn't, returns `None`.
|
||||
|
||||
Note `set_default_dtype` currently only works with floating-point types and asserts if for example,
|
||||
`torch.int64` is passed. So if a non-float `dtype` is passed this functions will throw an exception.
|
||||
"""
|
||||
if not dtype.is_floating_point:
|
||||
raise ValueError(
|
||||
f"Can't instantiate {cls.__name__} model under dtype={dtype} since it is not a floating point dtype"
|
||||
)
|
||||
|
||||
logger.info(f"Instantiating {cls.__name__} model under default dtype {dtype}.")
|
||||
dtype_orig = torch.get_default_dtype()
|
||||
torch.set_default_dtype(dtype)
|
||||
return dtype_orig
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
"""
|
||||
@@ -1585,7 +1704,13 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
f"use a module that supports gradient checkpointing by creating a boolean attribute `gradient_checkpointing`."
|
||||
)
|
||||
|
||||
def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None:
|
||||
def _fix_state_dict_keys_on_load(self, state_dict: OrderedDict) -> None:
|
||||
"""
|
||||
This function fix the state dict of the model to take into account some changes that were made in the model
|
||||
architecture:
|
||||
- deprecated attention blocks (happened before we introduced sharded checkpoint,
|
||||
so this is why we apply this method only when loading non sharded checkpoints for now)
|
||||
"""
|
||||
deprecated_attention_block_paths = []
|
||||
|
||||
def recursive_find_attn_block(name, module):
|
||||
@@ -1628,56 +1753,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight")
|
||||
if f"{path}.proj_attn.bias" in state_dict:
|
||||
state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
|
||||
|
||||
def _temp_convert_self_to_deprecated_attention_blocks(self) -> None:
|
||||
deprecated_attention_block_modules = []
|
||||
|
||||
def recursive_find_attn_block(module):
|
||||
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
|
||||
deprecated_attention_block_modules.append(module)
|
||||
|
||||
for sub_module in module.children():
|
||||
recursive_find_attn_block(sub_module)
|
||||
|
||||
recursive_find_attn_block(self)
|
||||
|
||||
for module in deprecated_attention_block_modules:
|
||||
module.query = module.to_q
|
||||
module.key = module.to_k
|
||||
module.value = module.to_v
|
||||
module.proj_attn = module.to_out[0]
|
||||
|
||||
# We don't _have_ to delete the old attributes, but it's helpful to ensure
|
||||
# that _all_ the weights are loaded into the new attributes and we're not
|
||||
# making an incorrect assumption that this model should be converted when
|
||||
# it really shouldn't be.
|
||||
del module.to_q
|
||||
del module.to_k
|
||||
del module.to_v
|
||||
del module.to_out
|
||||
|
||||
def _undo_temp_convert_self_to_deprecated_attention_blocks(self) -> None:
|
||||
deprecated_attention_block_modules = []
|
||||
|
||||
def recursive_find_attn_block(module) -> None:
|
||||
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
|
||||
deprecated_attention_block_modules.append(module)
|
||||
|
||||
for sub_module in module.children():
|
||||
recursive_find_attn_block(sub_module)
|
||||
|
||||
recursive_find_attn_block(self)
|
||||
|
||||
for module in deprecated_attention_block_modules:
|
||||
module.to_q = module.query
|
||||
module.to_k = module.key
|
||||
module.to_v = module.value
|
||||
module.to_out = nn.ModuleList([module.proj_attn, nn.Dropout(module.dropout)])
|
||||
|
||||
del module.query
|
||||
del module.key
|
||||
del module.value
|
||||
del module.proj_attn
|
||||
return state_dict
|
||||
|
||||
|
||||
class LegacyModelMixin(ModelMixin):
|
||||
|
||||
@@ -280,9 +280,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
|
||||
act_fn="silu_fp32",
|
||||
)
|
||||
|
||||
self.text_embedding_padding = nn.Parameter(
|
||||
torch.randn(text_len + text_len_t5, cross_attention_dim, dtype=torch.float32)
|
||||
)
|
||||
self.text_embedding_padding = nn.Parameter(torch.randn(text_len + text_len_t5, cross_attention_dim))
|
||||
|
||||
self.pos_embed = PatchEmbed(
|
||||
height=sample_size,
|
||||
|
||||
@@ -693,7 +693,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
device_map = kwargs.pop("device_map", None)
|
||||
max_memory = kwargs.pop("max_memory", None)
|
||||
offload_folder = kwargs.pop("offload_folder", None)
|
||||
offload_state_dict = kwargs.pop("offload_state_dict", False)
|
||||
offload_state_dict = kwargs.pop("offload_state_dict", None)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
variant = kwargs.pop("variant", None)
|
||||
dduf_file = kwargs.pop("dduf_file", None)
|
||||
|
||||
@@ -235,18 +235,16 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer):
|
||||
torch_dtype = torch.float16
|
||||
return torch_dtype
|
||||
|
||||
# (sayakpaul): I think it could be better to disable custom `device_map`s
|
||||
# for the first phase of the integration in the interest of simplicity.
|
||||
# Commenting this for discussions on the PR.
|
||||
# def update_device_map(self, device_map):
|
||||
# if device_map is None:
|
||||
# device_map = {"": torch.cuda.current_device()}
|
||||
# logger.info(
|
||||
# "The device_map was not initialized. "
|
||||
# "Setting device_map to {'':torch.cuda.current_device()}. "
|
||||
# "If you want to use the model for inference, please set device_map ='auto' "
|
||||
# )
|
||||
# return device_map
|
||||
def update_device_map(self, device_map):
|
||||
if device_map is None:
|
||||
device_map = {"": f"cuda:{torch.cuda.current_device()}"}
|
||||
logger.info(
|
||||
"The device_map was not initialized. "
|
||||
"Setting device_map to {"
|
||||
": f`cuda:{torch.cuda.current_device()}`}. "
|
||||
"If you want to use the model for inference, please set device_map ='auto' "
|
||||
)
|
||||
return device_map
|
||||
|
||||
def _process_model_before_weight_loading(
|
||||
self,
|
||||
@@ -289,9 +287,9 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer):
|
||||
model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config
|
||||
)
|
||||
model.config.quantization_config = self.quantization_config
|
||||
model.is_loaded_in_4bit = True
|
||||
|
||||
def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs):
|
||||
model.is_loaded_in_4bit = True
|
||||
model.is_4bit_serializable = self.is_serializable
|
||||
return model
|
||||
|
||||
@@ -400,16 +398,17 @@ class BnB8BitDiffusersQuantizer(DiffusersQuantizer):
|
||||
torch_dtype = torch.float16
|
||||
return torch_dtype
|
||||
|
||||
# # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.update_device_map
|
||||
# def update_device_map(self, device_map):
|
||||
# if device_map is None:
|
||||
# device_map = {"": torch.cuda.current_device()}
|
||||
# logger.info(
|
||||
# "The device_map was not initialized. "
|
||||
# "Setting device_map to {'':torch.cuda.current_device()}. "
|
||||
# "If you want to use the model for inference, please set device_map ='auto' "
|
||||
# )
|
||||
# return device_map
|
||||
# Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.update_device_map
|
||||
def update_device_map(self, device_map):
|
||||
if device_map is None:
|
||||
device_map = {"": f"cuda:{torch.cuda.current_device()}"}
|
||||
logger.info(
|
||||
"The device_map was not initialized. "
|
||||
"Setting device_map to {"
|
||||
": f`cuda:{torch.cuda.current_device()}`}. "
|
||||
"If you want to use the model for inference, please set device_map ='auto' "
|
||||
)
|
||||
return device_map
|
||||
|
||||
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
|
||||
if target_dtype != torch.int8:
|
||||
@@ -493,11 +492,10 @@ class BnB8BitDiffusersQuantizer(DiffusersQuantizer):
|
||||
|
||||
# Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer._process_model_after_weight_loading with 4bit->8bit
|
||||
def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs):
|
||||
model.is_loaded_in_8bit = True
|
||||
model.is_8bit_serializable = self.is_serializable
|
||||
return model
|
||||
|
||||
# Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer._process_model_before_weight_loading
|
||||
# Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer._process_model_before_weight_loading with 4bit->8bit
|
||||
def _process_model_before_weight_loading(
|
||||
self,
|
||||
model: "ModelMixin",
|
||||
@@ -539,6 +537,7 @@ class BnB8BitDiffusersQuantizer(DiffusersQuantizer):
|
||||
model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config
|
||||
)
|
||||
model.config.quantization_config = self.quantization_config
|
||||
model.is_loaded_in_8bit = True
|
||||
|
||||
@property
|
||||
# Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.is_serializable
|
||||
|
||||
@@ -338,22 +338,6 @@ def _get_model_file(
|
||||
) from e
|
||||
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/1360801a69c0b169e3efdbb0cd05d9a0e72bfb70/src/transformers/utils/hub.py#L976
|
||||
# Differences are in parallelization of shard downloads and checking if shards are present.
|
||||
|
||||
|
||||
def _check_if_shards_exist_locally(local_dir, subfolder, original_shard_filenames):
|
||||
shards_path = os.path.join(local_dir, subfolder)
|
||||
shard_filenames = [os.path.join(shards_path, f) for f in original_shard_filenames]
|
||||
for shard_file in shard_filenames:
|
||||
if not os.path.exists(shard_file):
|
||||
raise ValueError(
|
||||
f"{shards_path} does not appear to have a file named {shard_file} which is "
|
||||
"required according to the checkpoint index."
|
||||
)
|
||||
|
||||
|
||||
def _get_checkpoint_shard_files(
|
||||
pretrained_model_name_or_path,
|
||||
index_filename,
|
||||
@@ -396,13 +380,22 @@ def _get_checkpoint_shard_files(
|
||||
shards_path = os.path.join(pretrained_model_name_or_path, subfolder)
|
||||
|
||||
# First, let's deal with local folder.
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
_check_if_shards_exist_locally(
|
||||
pretrained_model_name_or_path, subfolder=subfolder, original_shard_filenames=original_shard_filenames
|
||||
)
|
||||
return shards_path, sharded_metadata
|
||||
elif dduf_entries:
|
||||
return shards_path, sharded_metadata
|
||||
if os.path.isdir(pretrained_model_name_or_path) or dduf_entries:
|
||||
shard_filenames = [os.path.join(shards_path, f) for f in original_shard_filenames]
|
||||
for shard_file in shard_filenames:
|
||||
if dduf_entries:
|
||||
if shard_file not in dduf_entries:
|
||||
raise FileNotFoundError(
|
||||
f"{shards_path} does not appear to have a file named {shard_file} which is "
|
||||
"required according to the checkpoint index."
|
||||
)
|
||||
else:
|
||||
if not os.path.exists(shard_file):
|
||||
raise FileNotFoundError(
|
||||
f"{shards_path} does not appear to have a file named {shard_file} which is "
|
||||
"required according to the checkpoint index."
|
||||
)
|
||||
return shard_filenames, sharded_metadata
|
||||
|
||||
# At this stage pretrained_model_name_or_path is a model identifier on the Hub
|
||||
allow_patterns = original_shard_filenames
|
||||
@@ -444,7 +437,9 @@ def _get_checkpoint_shard_files(
|
||||
" again after checking your internet connection."
|
||||
) from e
|
||||
|
||||
return cached_folder, sharded_metadata
|
||||
cached_filenames = [os.path.join(cached_folder, f) for f in original_shard_filenames]
|
||||
|
||||
return cached_filenames, sharded_metadata
|
||||
|
||||
|
||||
def _check_legacy_sharding_variant_format(folder: str = None, filenames: List[str] = None, variant: str = None):
|
||||
|
||||
@@ -37,7 +37,7 @@ from huggingface_hub.utils import is_jinja_available
|
||||
from parameterized import parameterized
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
from diffusers.models import SD3Transformer2DModel, UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnProcessor,
|
||||
AttnProcessor2_0,
|
||||
@@ -200,12 +200,12 @@ class ModelUtilsTest(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
|
||||
def test_accelerate_loading_error_message(self):
|
||||
with self.assertRaises(ValueError) as error_context:
|
||||
def test_missing_key_loading_warning_message(self):
|
||||
with self.assertLogs("diffusers.models.modeling_utils", level="WARNING") as logs:
|
||||
UNet2DConditionModel.from_pretrained("hf-internal-testing/stable-diffusion-broken", subfolder="unet")
|
||||
|
||||
# make sure that error message states what keys are missing
|
||||
assert "conv_out.bias" in str(error_context.exception)
|
||||
assert "conv_out.bias" in " ".join(logs.output)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
@@ -334,6 +334,58 @@ class ModelUtilsTest(unittest.TestCase):
|
||||
|
||||
assert model.config.in_channels == 9
|
||||
|
||||
@require_torch_gpu
|
||||
def test_keep_modules_in_fp32(self):
|
||||
r"""
|
||||
A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32 when we load the model in fp16/bf16
|
||||
Also ensures if inference works.
|
||||
"""
|
||||
fp32_modules = SD3Transformer2DModel._keep_in_fp32_modules
|
||||
|
||||
for torch_dtype in [torch.bfloat16, torch.float16]:
|
||||
SD3Transformer2DModel._keep_in_fp32_modules = ["proj_out"]
|
||||
|
||||
model = SD3Transformer2DModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-sd3-pipe", subfolder="transformer", torch_dtype=torch_dtype
|
||||
).to(torch_device)
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
if name in model._keep_in_fp32_modules:
|
||||
self.assertTrue(module.weight.dtype == torch.float32)
|
||||
else:
|
||||
self.assertTrue(module.weight.dtype == torch_dtype)
|
||||
|
||||
def get_dummy_inputs():
|
||||
batch_size = 2
|
||||
num_channels = 4
|
||||
height = width = embedding_dim = 32
|
||||
pooled_embedding_dim = embedding_dim * 2
|
||||
sequence_length = 154
|
||||
|
||||
hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
||||
pooled_prompt_embeds = torch.randn((batch_size, pooled_embedding_dim)).to(torch_device)
|
||||
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"pooled_projections": pooled_prompt_embeds,
|
||||
"timestep": timestep,
|
||||
}
|
||||
|
||||
# test if inference works.
|
||||
with torch.no_grad() and torch.amp.autocast(torch_device, dtype=torch_dtype):
|
||||
input_dict_for_transformer = get_dummy_inputs()
|
||||
model_inputs = {
|
||||
k: v.to(device=torch_device) for k, v in input_dict_for_transformer.items() if not isinstance(v, bool)
|
||||
}
|
||||
model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs})
|
||||
_ = model(**model_inputs)
|
||||
|
||||
SD3Transformer2DModel._keep_in_fp32_modules = fp32_modules
|
||||
|
||||
|
||||
class UNetTesterMixin:
|
||||
def test_forward_with_norm_groups(self):
|
||||
|
||||
@@ -136,7 +136,7 @@ class BnB4BitBasicTests(Base4bitTests):
|
||||
bnb_4bit_compute_dtype=torch.float16,
|
||||
)
|
||||
self.model_4bit = SD3Transformer2DModel.from_pretrained(
|
||||
self.model_name, subfolder="transformer", quantization_config=nf4_config
|
||||
self.model_name, subfolder="transformer", quantization_config=nf4_config, device_map=torch_device
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
@@ -202,7 +202,7 @@ class BnB4BitBasicTests(Base4bitTests):
|
||||
bnb_4bit_compute_dtype=torch.float16,
|
||||
)
|
||||
model = SD3Transformer2DModel.from_pretrained(
|
||||
self.model_name, subfolder="transformer", quantization_config=nf4_config
|
||||
self.model_name, subfolder="transformer", quantization_config=nf4_config, device_map=torch_device
|
||||
)
|
||||
|
||||
for name, module in model.named_modules():
|
||||
@@ -327,7 +327,7 @@ class BnB4BitBasicTests(Base4bitTests):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
nf4_config = BitsAndBytesConfig(load_in_4bit=True)
|
||||
model_4bit = SD3Transformer2DModel.from_pretrained(
|
||||
self.model_name, subfolder="transformer", quantization_config=nf4_config
|
||||
self.model_name, subfolder="transformer", quantization_config=nf4_config, device_map=torch_device
|
||||
)
|
||||
model_4bit.save_pretrained(tmpdirname)
|
||||
del model_4bit
|
||||
@@ -362,7 +362,7 @@ class BnB4BitTrainingTests(Base4bitTests):
|
||||
bnb_4bit_compute_dtype=torch.float16,
|
||||
)
|
||||
self.model_4bit = SD3Transformer2DModel.from_pretrained(
|
||||
self.model_name, subfolder="transformer", quantization_config=nf4_config
|
||||
self.model_name, subfolder="transformer", quantization_config=nf4_config, device_map=torch_device
|
||||
)
|
||||
|
||||
def test_training(self):
|
||||
@@ -410,7 +410,7 @@ class SlowBnb4BitTests(Base4bitTests):
|
||||
bnb_4bit_compute_dtype=torch.float16,
|
||||
)
|
||||
model_4bit = SD3Transformer2DModel.from_pretrained(
|
||||
self.model_name, subfolder="transformer", quantization_config=nf4_config
|
||||
self.model_name, subfolder="transformer", quantization_config=nf4_config, device_map=torch_device
|
||||
)
|
||||
self.pipeline_4bit = DiffusionPipeline.from_pretrained(
|
||||
self.model_name, transformer=model_4bit, torch_dtype=torch.float16
|
||||
@@ -472,7 +472,7 @@ class SlowBnb4BitTests(Base4bitTests):
|
||||
bnb_4bit_compute_dtype=torch.float16,
|
||||
)
|
||||
model_4bit = SD3Transformer2DModel.from_pretrained(
|
||||
self.model_name, subfolder="transformer", quantization_config=nf4_config
|
||||
self.model_name, subfolder="transformer", quantization_config=nf4_config, device_map=torch_device
|
||||
)
|
||||
|
||||
logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
|
||||
@@ -502,6 +502,7 @@ class SlowBnb4BitTests(Base4bitTests):
|
||||
subfolder="transformer",
|
||||
quantization_config=transformer_nf4_config,
|
||||
torch_dtype=torch.float16,
|
||||
device_map=torch_device,
|
||||
)
|
||||
text_encoder_3_nf4_config = BnbConfig(
|
||||
load_in_4bit=True,
|
||||
@@ -513,6 +514,7 @@ class SlowBnb4BitTests(Base4bitTests):
|
||||
subfolder="text_encoder_3",
|
||||
quantization_config=text_encoder_3_nf4_config,
|
||||
torch_dtype=torch.float16,
|
||||
device_map=torch_device,
|
||||
)
|
||||
# CUDA device placement works.
|
||||
pipeline_4bit = DiffusionPipeline.from_pretrained(
|
||||
@@ -527,6 +529,94 @@ class SlowBnb4BitTests(Base4bitTests):
|
||||
|
||||
del pipeline_4bit
|
||||
|
||||
def test_device_map(self):
|
||||
"""
|
||||
Test if the quantized model is working properly with "auto".
|
||||
cpu/disk offloading as well doesn't work with bnb.
|
||||
"""
|
||||
|
||||
def get_dummy_tensor_inputs(device=None, seed: int = 0):
|
||||
batch_size = 1
|
||||
num_latent_channels = 4
|
||||
num_image_channels = 3
|
||||
height = width = 4
|
||||
sequence_length = 48
|
||||
embedding_dim = 32
|
||||
|
||||
torch.manual_seed(seed)
|
||||
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(
|
||||
device, dtype=torch.bfloat16
|
||||
)
|
||||
torch.manual_seed(seed)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(
|
||||
device, dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
torch.manual_seed(seed)
|
||||
pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16)
|
||||
|
||||
torch.manual_seed(seed)
|
||||
text_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16)
|
||||
|
||||
torch.manual_seed(seed)
|
||||
image_ids = torch.randn((height * width, num_image_channels)).to(device, dtype=torch.bfloat16)
|
||||
|
||||
timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"pooled_projections": pooled_prompt_embeds,
|
||||
"txt_ids": text_ids,
|
||||
"img_ids": image_ids,
|
||||
"timestep": timestep,
|
||||
}
|
||||
|
||||
inputs = get_dummy_tensor_inputs(torch_device)
|
||||
expected_slice = np.array(
|
||||
[0.47070312, 0.00390625, -0.03662109, -0.19628906, -0.53125, 0.5234375, -0.17089844, -0.59375, 0.578125]
|
||||
)
|
||||
|
||||
# non sharded
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16
|
||||
)
|
||||
quantized_model = FluxTransformer2DModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-flux-pipe",
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
weight = quantized_model.transformer_blocks[0].ff.net[2].weight
|
||||
self.assertTrue(isinstance(weight, bnb.nn.modules.Params4bit))
|
||||
|
||||
output = quantized_model(**inputs)[0]
|
||||
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
|
||||
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
|
||||
|
||||
# sharded
|
||||
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16
|
||||
)
|
||||
quantized_model = FluxTransformer2DModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-flux-sharded",
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
weight = quantized_model.transformer_blocks[0].ff.net[2].weight
|
||||
self.assertTrue(isinstance(weight, bnb.nn.modules.Params4bit))
|
||||
|
||||
output = quantized_model(**inputs)[0]
|
||||
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
|
||||
|
||||
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
|
||||
|
||||
|
||||
@require_transformers_version_greater("4.44.0")
|
||||
class SlowBnb4BitFluxTests(Base4bitTests):
|
||||
@@ -610,7 +700,10 @@ class BaseBnb4BitSerializationTests(Base4bitTests):
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
)
|
||||
model_0 = SD3Transformer2DModel.from_pretrained(
|
||||
self.model_name, subfolder="transformer", quantization_config=self.quantization_config
|
||||
self.model_name,
|
||||
subfolder="transformer",
|
||||
quantization_config=self.quantization_config,
|
||||
device_map=torch_device,
|
||||
)
|
||||
self.assertTrue("_pre_quantization_dtype" in model_0.config)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
|
||||
@@ -138,7 +138,7 @@ class BnB8bitBasicTests(Base8bitTests):
|
||||
)
|
||||
mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
self.model_8bit = SD3Transformer2DModel.from_pretrained(
|
||||
self.model_name, subfolder="transformer", quantization_config=mixed_int8_config
|
||||
self.model_name, subfolder="transformer", quantization_config=mixed_int8_config, device_map=torch_device
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
@@ -200,7 +200,7 @@ class BnB8bitBasicTests(Base8bitTests):
|
||||
|
||||
mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
model = SD3Transformer2DModel.from_pretrained(
|
||||
self.model_name, subfolder="transformer", quantization_config=mixed_int8_config
|
||||
self.model_name, subfolder="transformer", quantization_config=mixed_int8_config, device_map=torch_device
|
||||
)
|
||||
|
||||
for name, module in model.named_modules():
|
||||
@@ -242,7 +242,7 @@ class BnB8bitBasicTests(Base8bitTests):
|
||||
"""
|
||||
config = BitsAndBytesConfig(load_in_8bit=True, llm_int8_skip_modules=["proj_out"])
|
||||
model_8bit = SD3Transformer2DModel.from_pretrained(
|
||||
self.model_name, subfolder="transformer", quantization_config=config
|
||||
self.model_name, subfolder="transformer", quantization_config=config, device_map=torch_device
|
||||
)
|
||||
linear = get_some_linear_layer(model_8bit)
|
||||
self.assertTrue(linear.weight.dtype == torch.int8)
|
||||
@@ -319,6 +319,7 @@ class Bnb8bitDeviceTests(Base8bitTests):
|
||||
"Efficient-Large-Model/Sana_1600M_4Kpx_BF16_diffusers",
|
||||
subfolder="transformer",
|
||||
quantization_config=mixed_int8_config,
|
||||
device_map=torch_device,
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
@@ -343,7 +344,7 @@ class BnB8bitTrainingTests(Base8bitTests):
|
||||
|
||||
mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
self.model_8bit = SD3Transformer2DModel.from_pretrained(
|
||||
self.model_name, subfolder="transformer", quantization_config=mixed_int8_config
|
||||
self.model_name, subfolder="transformer", quantization_config=mixed_int8_config, device_map=torch_device
|
||||
)
|
||||
|
||||
def test_training(self):
|
||||
@@ -387,7 +388,7 @@ class SlowBnb8bitTests(Base8bitTests):
|
||||
|
||||
mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
model_8bit = SD3Transformer2DModel.from_pretrained(
|
||||
self.model_name, subfolder="transformer", quantization_config=mixed_int8_config
|
||||
self.model_name, subfolder="transformer", quantization_config=mixed_int8_config, device_map=torch_device
|
||||
)
|
||||
self.pipeline_8bit = DiffusionPipeline.from_pretrained(
|
||||
self.model_name, transformer=model_8bit, torch_dtype=torch.float16
|
||||
@@ -415,7 +416,10 @@ class SlowBnb8bitTests(Base8bitTests):
|
||||
|
||||
def test_model_cpu_offload_raises_warning(self):
|
||||
model_8bit = SD3Transformer2DModel.from_pretrained(
|
||||
self.model_name, subfolder="transformer", quantization_config=BitsAndBytesConfig(load_in_8bit=True)
|
||||
self.model_name,
|
||||
subfolder="transformer",
|
||||
quantization_config=BitsAndBytesConfig(load_in_8bit=True),
|
||||
device_map=torch_device,
|
||||
)
|
||||
pipeline_8bit = DiffusionPipeline.from_pretrained(
|
||||
self.model_name, transformer=model_8bit, torch_dtype=torch.float16
|
||||
@@ -430,7 +434,10 @@ class SlowBnb8bitTests(Base8bitTests):
|
||||
|
||||
def test_moving_to_cpu_throws_warning(self):
|
||||
model_8bit = SD3Transformer2DModel.from_pretrained(
|
||||
self.model_name, subfolder="transformer", quantization_config=BitsAndBytesConfig(load_in_8bit=True)
|
||||
self.model_name,
|
||||
subfolder="transformer",
|
||||
quantization_config=BitsAndBytesConfig(load_in_8bit=True),
|
||||
device_map=torch_device,
|
||||
)
|
||||
logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
|
||||
logger.setLevel(30)
|
||||
@@ -483,6 +490,7 @@ class SlowBnb8bitTests(Base8bitTests):
|
||||
subfolder="transformer",
|
||||
quantization_config=transformer_8bit_config,
|
||||
torch_dtype=torch.float16,
|
||||
device_map=torch_device,
|
||||
)
|
||||
text_encoder_3_8bit_config = BnbConfig(load_in_8bit=True)
|
||||
text_encoder_3_8bit = T5EncoderModel.from_pretrained(
|
||||
@@ -490,6 +498,7 @@ class SlowBnb8bitTests(Base8bitTests):
|
||||
subfolder="text_encoder_3",
|
||||
quantization_config=text_encoder_3_8bit_config,
|
||||
torch_dtype=torch.float16,
|
||||
device_map=torch_device,
|
||||
)
|
||||
# CUDA device placement works.
|
||||
pipeline_8bit = DiffusionPipeline.from_pretrained(
|
||||
@@ -504,6 +513,99 @@ class SlowBnb8bitTests(Base8bitTests):
|
||||
|
||||
del pipeline_8bit
|
||||
|
||||
def test_device_map(self):
|
||||
"""
|
||||
Test if the quantized model is working properly with "auto"
|
||||
pu/disk offloading doesn't work with bnb.
|
||||
"""
|
||||
|
||||
def get_dummy_tensor_inputs(device=None, seed: int = 0):
|
||||
batch_size = 1
|
||||
num_latent_channels = 4
|
||||
num_image_channels = 3
|
||||
height = width = 4
|
||||
sequence_length = 48
|
||||
embedding_dim = 32
|
||||
|
||||
torch.manual_seed(seed)
|
||||
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(
|
||||
device, dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
torch.manual_seed(seed)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(
|
||||
device, dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
torch.manual_seed(seed)
|
||||
pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16)
|
||||
|
||||
torch.manual_seed(seed)
|
||||
text_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16)
|
||||
|
||||
torch.manual_seed(seed)
|
||||
image_ids = torch.randn((height * width, num_image_channels)).to(device, dtype=torch.bfloat16)
|
||||
|
||||
timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"pooled_projections": pooled_prompt_embeds,
|
||||
"txt_ids": text_ids,
|
||||
"img_ids": image_ids,
|
||||
"timestep": timestep,
|
||||
}
|
||||
|
||||
inputs = get_dummy_tensor_inputs(torch_device)
|
||||
expected_slice = np.array(
|
||||
[
|
||||
0.33789062,
|
||||
-0.04736328,
|
||||
-0.00256348,
|
||||
-0.23144531,
|
||||
-0.49804688,
|
||||
0.4375,
|
||||
-0.15429688,
|
||||
-0.65234375,
|
||||
0.44335938,
|
||||
]
|
||||
)
|
||||
|
||||
# non sharded
|
||||
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
quantized_model = FluxTransformer2DModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-flux-pipe",
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
weight = quantized_model.transformer_blocks[0].ff.net[2].weight
|
||||
self.assertTrue(isinstance(weight, bnb.nn.modules.Int8Params))
|
||||
|
||||
output = quantized_model(**inputs)[0]
|
||||
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
|
||||
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
|
||||
|
||||
# sharded
|
||||
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
quantized_model = FluxTransformer2DModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-flux-sharded",
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
weight = quantized_model.transformer_blocks[0].ff.net[2].weight
|
||||
self.assertTrue(isinstance(weight, bnb.nn.modules.Int8Params))
|
||||
output = quantized_model(**inputs)[0]
|
||||
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
|
||||
|
||||
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
|
||||
|
||||
|
||||
@require_transformers_version_greater("4.44.0")
|
||||
class SlowBnb8bitFluxTests(Base8bitTests):
|
||||
@@ -579,7 +681,7 @@ class BaseBnb8bitSerializationTests(Base8bitTests):
|
||||
load_in_8bit=True,
|
||||
)
|
||||
self.model_0 = SD3Transformer2DModel.from_pretrained(
|
||||
self.model_name, subfolder="transformer", quantization_config=quantization_config
|
||||
self.model_name, subfolder="transformer", quantization_config=quantization_config, device_map=torch_device
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
|
||||
@@ -34,6 +34,7 @@ from diffusers.utils.testing_utils import (
|
||||
is_torch_available,
|
||||
is_torchao_available,
|
||||
nightly,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_torchao_version_greater_or_equal,
|
||||
@@ -282,9 +283,6 @@ class TorchAoTest(unittest.TestCase):
|
||||
self.assertEqual(weight.quant_max, 15)
|
||||
|
||||
def test_device_map(self):
|
||||
# Note: We were not checking if the weight tensor's were AffineQuantizedTensor's before. If we did
|
||||
# it would have errored out. Now, we do. So, device_map basically never worked with or without
|
||||
# sharded checkpoints. This will need to be supported in the future (TODO(aryan))
|
||||
"""
|
||||
Test if the quantized model int4 weight-only is working properly with "auto" and custom device maps.
|
||||
The custom device map performs cpu/disk offloading as well. Also verifies that the device map is
|
||||
@@ -301,54 +299,73 @@ class TorchAoTest(unittest.TestCase):
|
||||
}
|
||||
device_maps = ["auto", custom_device_map_dict]
|
||||
|
||||
# inputs = self.get_dummy_tensor_inputs(torch_device)
|
||||
# expected_slice = np.array([0.3457, -0.0366, 0.0105, -0.2275, -0.4941, 0.4395, -0.166, -0.6641, 0.4375])
|
||||
|
||||
inputs = self.get_dummy_tensor_inputs(torch_device)
|
||||
# requires with different expected slices since models are different due to offload (we don't quantize modules offloaded to cpu/disk)
|
||||
expected_slice_auto = np.array(
|
||||
[
|
||||
0.34179688,
|
||||
-0.03613281,
|
||||
0.01428223,
|
||||
-0.22949219,
|
||||
-0.49609375,
|
||||
0.4375,
|
||||
-0.1640625,
|
||||
-0.66015625,
|
||||
0.43164062,
|
||||
]
|
||||
)
|
||||
expected_slice_offload = np.array(
|
||||
[0.34375, -0.03515625, 0.0123291, -0.22753906, -0.49414062, 0.4375, -0.16308594, -0.66015625, 0.43554688]
|
||||
)
|
||||
for device_map in device_maps:
|
||||
# device_map_to_compare = {"": 0} if device_map == "auto" else device_map
|
||||
if device_map == "auto":
|
||||
expected_slice = expected_slice_auto
|
||||
else:
|
||||
expected_slice = expected_slice_offload
|
||||
with tempfile.TemporaryDirectory() as offload_folder:
|
||||
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
|
||||
quantized_model = FluxTransformer2DModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-flux-pipe",
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
device_map=device_map,
|
||||
torch_dtype=torch.bfloat16,
|
||||
offload_folder=offload_folder,
|
||||
)
|
||||
|
||||
# Test non-sharded model - should work
|
||||
with self.assertRaises(NotImplementedError):
|
||||
with tempfile.TemporaryDirectory() as offload_folder:
|
||||
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
|
||||
_ = FluxTransformer2DModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-flux-pipe",
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
device_map=device_map,
|
||||
torch_dtype=torch.bfloat16,
|
||||
offload_folder=offload_folder,
|
||||
)
|
||||
weight = quantized_model.transformer_blocks[0].ff.net[2].weight
|
||||
|
||||
# weight = quantized_model.transformer_blocks[0].ff.net[2].weight
|
||||
# self.assertTrue(quantized_model.hf_device_map == device_map_to_compare)
|
||||
# self.assertTrue(isinstance(weight, AffineQuantizedTensor))
|
||||
# Note that when performing cpu/disk offload, the offloaded weights are not quantized, only the weights on the gpu.
|
||||
# This is not the case when the model are already quantized
|
||||
if "transformer_blocks.0" in device_map:
|
||||
self.assertTrue(isinstance(weight, nn.Parameter))
|
||||
else:
|
||||
self.assertTrue(isinstance(weight, AffineQuantizedTensor))
|
||||
|
||||
# output = quantized_model(**inputs)[0]
|
||||
# output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
|
||||
# self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
|
||||
output = quantized_model(**inputs)[0]
|
||||
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
|
||||
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
|
||||
|
||||
# Test sharded model - should not work
|
||||
with self.assertRaises(NotImplementedError):
|
||||
with tempfile.TemporaryDirectory() as offload_folder:
|
||||
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
|
||||
_ = FluxTransformer2DModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-flux-sharded",
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
device_map=device_map,
|
||||
torch_dtype=torch.bfloat16,
|
||||
offload_folder=offload_folder,
|
||||
)
|
||||
with tempfile.TemporaryDirectory() as offload_folder:
|
||||
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
|
||||
quantized_model = FluxTransformer2DModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-flux-sharded",
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
device_map=device_map,
|
||||
torch_dtype=torch.bfloat16,
|
||||
offload_folder=offload_folder,
|
||||
)
|
||||
|
||||
# weight = quantized_model.transformer_blocks[0].ff.net[2].weight
|
||||
# self.assertTrue(quantized_model.hf_device_map == device_map_to_compare)
|
||||
# self.assertTrue(isinstance(weight, AffineQuantizedTensor))
|
||||
weight = quantized_model.transformer_blocks[0].ff.net[2].weight
|
||||
if "transformer_blocks.0" in device_map:
|
||||
self.assertTrue(isinstance(weight, nn.Parameter))
|
||||
else:
|
||||
self.assertTrue(isinstance(weight, AffineQuantizedTensor))
|
||||
|
||||
# output = quantized_model(**inputs)[0]
|
||||
# output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
|
||||
|
||||
# self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
|
||||
output = quantized_model(**inputs)[0]
|
||||
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
|
||||
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
|
||||
|
||||
def test_modules_to_not_convert(self):
|
||||
quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"])
|
||||
@@ -544,7 +561,7 @@ class TorchAoSerializationTest(unittest.TestCase):
|
||||
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
|
||||
weight = quantized_model.transformer_blocks[0].ff.net[2].weight
|
||||
self.assertTrue(isinstance(weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor)))
|
||||
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
|
||||
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
|
||||
|
||||
def _check_serialization_expected_slice(self, quant_method, quant_method_kwargs, expected_slice, device):
|
||||
quantized_model = self.get_dummy_model(quant_method, quant_method_kwargs, device)
|
||||
@@ -564,7 +581,7 @@ class TorchAoSerializationTest(unittest.TestCase):
|
||||
loaded_quantized_model.proj_out.weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor)
|
||||
)
|
||||
)
|
||||
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
|
||||
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
|
||||
|
||||
def test_int_a8w8_cuda(self):
|
||||
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
|
||||
|
||||
Reference in New Issue
Block a user