mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
AutoModel (#11115)
* AutoModel * ... * lol * ... * add test * update * make fix-copies --------- Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
@@ -155,6 +155,7 @@ else:
|
||||
"AutoencoderKLWan",
|
||||
"AutoencoderOobleck",
|
||||
"AutoencoderTiny",
|
||||
"AutoModel",
|
||||
"CacheMixin",
|
||||
"CogVideoXTransformer3DModel",
|
||||
"CogView3PlusTransformer2DModel",
|
||||
@@ -197,6 +198,7 @@ else:
|
||||
"T2IAdapter",
|
||||
"T5FilmDecoder",
|
||||
"Transformer2DModel",
|
||||
"TransformerTemporalModel",
|
||||
"UNet1DModel",
|
||||
"UNet2DConditionModel",
|
||||
"UNet2DModel",
|
||||
@@ -731,6 +733,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLWan,
|
||||
AutoencoderOobleck,
|
||||
AutoencoderTiny,
|
||||
AutoModel,
|
||||
CacheMixin,
|
||||
CogVideoXTransformer3DModel,
|
||||
CogView3PlusTransformer2DModel,
|
||||
@@ -772,6 +775,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
T2IAdapter,
|
||||
T5FilmDecoder,
|
||||
Transformer2DModel,
|
||||
TransformerTemporalModel,
|
||||
UNet1DModel,
|
||||
UNet2DConditionModel,
|
||||
UNet2DModel,
|
||||
|
||||
@@ -41,6 +41,7 @@ if is_torch_available():
|
||||
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
|
||||
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
|
||||
_import_structure["autoencoders.vq_model"] = ["VQModel"]
|
||||
_import_structure["auto_model"] = ["AutoModel"]
|
||||
_import_structure["cache_utils"] = ["CacheMixin"]
|
||||
_import_structure["controlnets.controlnet"] = ["ControlNetModel"]
|
||||
_import_structure["controlnets.controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"]
|
||||
@@ -103,6 +104,7 @@ if is_flax_available():
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
if is_torch_available():
|
||||
from .adapter import MultiAdapter, T2IAdapter
|
||||
from .auto_model import AutoModel
|
||||
from .autoencoders import (
|
||||
AsymmetricAutoencoderKL,
|
||||
AutoencoderDC,
|
||||
|
||||
169
src/diffusers/models/auto_model.py
Normal file
169
src/diffusers/models/auto_model.py
Normal file
@@ -0,0 +1,169 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
|
||||
|
||||
class AutoModel(ConfigMixin):
|
||||
config_name = "config.json"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise EnvironmentError(
|
||||
f"{self.__class__.__name__} is designed to be instantiated "
|
||||
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
|
||||
f"`{self.__class__.__name__}.from_pipe(pipeline)` methods."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_pretrained(cls, pretrained_model_or_path: Optional[Union[str, os.PathLike]] = None, **kwargs):
|
||||
r"""
|
||||
Instantiate a pretrained PyTorch model from a pretrained model configuration.
|
||||
|
||||
The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To
|
||||
train the model, set it back in training mode with `model.train()`.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
||||
the Hub.
|
||||
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
||||
with [`~ModelMixin.save_pretrained`].
|
||||
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
output_loading_info (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||
local_files_only(`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
from_flax (`bool`, *optional*, defaults to `False`):
|
||||
Load the model weights from a Flax checkpoint save file.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
||||
mirror (`str`, *optional*):
|
||||
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
|
||||
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
|
||||
information.
|
||||
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
||||
A map that specifies where each submodule should go. It doesn't need to be defined for each
|
||||
parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
|
||||
same device. Defaults to `None`, meaning that the model will be loaded on CPU.
|
||||
|
||||
Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
|
||||
more information about each option see [designing a device
|
||||
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
||||
max_memory (`Dict`, *optional*):
|
||||
A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
|
||||
each GPU and the available CPU RAM if unset.
|
||||
offload_folder (`str` or `os.PathLike`, *optional*):
|
||||
The path to offload weights if `device_map` contains the value `"disk"`.
|
||||
offload_state_dict (`bool`, *optional*):
|
||||
If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
|
||||
the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
|
||||
when there is some disk offload.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
variant (`str`, *optional*):
|
||||
Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
|
||||
loading `from_flax`.
|
||||
use_safetensors (`bool`, *optional*, defaults to `None`):
|
||||
If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
|
||||
`safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
|
||||
weights. If set to `False`, `safetensors` weights are not loaded.
|
||||
disable_mmap ('bool', *optional*, defaults to 'False'):
|
||||
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
|
||||
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
|
||||
|
||||
<Tip>
|
||||
|
||||
To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
|
||||
`huggingface-cli login`. You can also activate the special
|
||||
["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
|
||||
firewalled environment.
|
||||
|
||||
</Tip>
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
from diffusers import AutoModel
|
||||
|
||||
unet = AutoModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
|
||||
```
|
||||
|
||||
If you get the error message below, you need to finetune the weights for your downstream task:
|
||||
|
||||
```bash
|
||||
Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
|
||||
- conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
|
||||
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
|
||||
```
|
||||
"""
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
token = kwargs.pop("token", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
|
||||
load_config_kwargs = {
|
||||
"cache_dir": cache_dir,
|
||||
"force_download": force_download,
|
||||
"proxies": proxies,
|
||||
"token": token,
|
||||
"local_files_only": local_files_only,
|
||||
"revision": revision,
|
||||
"subfolder": subfolder,
|
||||
}
|
||||
|
||||
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
|
||||
orig_class_name = config["_class_name"]
|
||||
|
||||
library = importlib.import_module("diffusers")
|
||||
|
||||
model_cls = getattr(library, orig_class_name, None)
|
||||
if model_cls is None:
|
||||
raise ValueError(f"AutoModel can't find a model linked to {orig_class_name}.")
|
||||
|
||||
kwargs = {**load_config_kwargs, **kwargs}
|
||||
return model_cls.from_pretrained(pretrained_model_or_path, **kwargs)
|
||||
@@ -280,6 +280,21 @@ class AutoencoderTiny(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class CacheMixin(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -895,6 +910,21 @@ class Transformer2DModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class TransformerTemporalModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class UNet1DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -45,6 +45,7 @@ from diffusers.models.attention_processor import (
|
||||
AttnProcessorNPU,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from diffusers.models.auto_model import AutoModel
|
||||
from diffusers.training_utils import EMAModel
|
||||
from diffusers.utils import (
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
@@ -1577,6 +1578,49 @@ class ModelTesterMixin:
|
||||
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5))
|
||||
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5))
|
||||
|
||||
def test_auto_model(self, expected_max_diff=5e-5):
|
||||
if self.forward_requires_fresh_args:
|
||||
model = self.model_class(**self.init_dict)
|
||||
else:
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model = model.eval()
|
||||
model = model.to(torch_device)
|
||||
|
||||
if hasattr(model, "set_default_attn_processor"):
|
||||
model.set_default_attn_processor()
|
||||
|
||||
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname:
|
||||
model.save_pretrained(tmpdirname, safe_serialization=False)
|
||||
|
||||
auto_model = AutoModel.from_pretrained(tmpdirname)
|
||||
if hasattr(auto_model, "set_default_attn_processor"):
|
||||
auto_model.set_default_attn_processor()
|
||||
|
||||
auto_model = auto_model.eval()
|
||||
auto_model = auto_model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.forward_requires_fresh_args:
|
||||
output_original = model(**self.inputs_dict(0))
|
||||
output_auto = auto_model(**self.inputs_dict(0))
|
||||
else:
|
||||
output_original = model(**inputs_dict)
|
||||
output_auto = auto_model(**inputs_dict)
|
||||
|
||||
if isinstance(output_original, dict):
|
||||
output_original = output_original.to_tuple()[0]
|
||||
if isinstance(output_auto, dict):
|
||||
output_auto = output_auto.to_tuple()[0]
|
||||
|
||||
max_diff = (output_original - output_auto).abs().max().item()
|
||||
self.assertLessEqual(
|
||||
max_diff,
|
||||
expected_max_diff,
|
||||
f"AutoModel forward pass diff: {max_diff} exceeds threshold {expected_max_diff}",
|
||||
)
|
||||
|
||||
|
||||
@is_staging_test
|
||||
class ModelPushToHubTester(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user