mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Support saving multiple t2i adapter models under one checkpoint (#4798)
* adding save and load for MultiAdapter, adding test * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Adding changes from review test_stable_diffusion_adapter * import sorting fix --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -11,17 +11,21 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional
|
||||
import os
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import logging
|
||||
from .modeling_utils import ModelMixin
|
||||
from .resnet import Downsample2D
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class MultiAdapter(ModelMixin):
|
||||
r"""
|
||||
MultiAdapter is a wrapper model that contains multiple adapter models and merges their outputs according to
|
||||
@@ -91,6 +95,120 @@ class MultiAdapter(ModelMixin):
|
||||
accume_state[i] += w * features[i]
|
||||
return accume_state
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
is_main_process: bool = True,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = True,
|
||||
variant: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Save a model and its configuration file to a directory, so that it can be re-loaded using the
|
||||
`[`~models.adapter.MultiAdapter.from_pretrained`]` class method.
|
||||
|
||||
Arguments:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory to which to save. Will be created if it doesn't exist.
|
||||
is_main_process (`bool`, *optional*, defaults to `True`):
|
||||
Whether the process calling this is the main process or not. Useful when in distributed training like
|
||||
TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
|
||||
the main process to avoid race conditions.
|
||||
save_function (`Callable`):
|
||||
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
|
||||
need to replace `torch.save` by another method. Can be configured with the environment variable
|
||||
`DIFFUSERS_SAVE_MODE`.
|
||||
safe_serialization (`bool`, *optional*, defaults to `True`):
|
||||
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
|
||||
variant (`str`, *optional*):
|
||||
If specified, weights are saved in the format pytorch_model.<variant>.bin.
|
||||
"""
|
||||
idx = 0
|
||||
model_path_to_save = save_directory
|
||||
for adapter in self.adapters:
|
||||
adapter.save_pretrained(
|
||||
model_path_to_save,
|
||||
is_main_process=is_main_process,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
variant=variant,
|
||||
)
|
||||
|
||||
idx += 1
|
||||
model_path_to_save = model_path_to_save + f"_{idx}"
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||
r"""
|
||||
Instantiate a pretrained MultiAdapter model from multiple pre-trained adapter models.
|
||||
|
||||
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
|
||||
the model, you should first set it back in training mode with `model.train()`.
|
||||
|
||||
The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
|
||||
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
|
||||
task.
|
||||
|
||||
The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
|
||||
weights are discarded.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_path (`os.PathLike`):
|
||||
A path to a *directory* containing model weights saved using
|
||||
[`~diffusers.models.adapter.MultiAdapter.save_pretrained`], e.g.,
|
||||
`./my_model_directory/adapter`.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
|
||||
will be automatically derived from the model's weights.
|
||||
output_loading_info(`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
||||
A map that specifies where each submodule should go. It doesn't need to be refined to each
|
||||
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
|
||||
same device.
|
||||
|
||||
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
|
||||
more information about each option see [designing a device
|
||||
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
||||
max_memory (`Dict`, *optional*):
|
||||
A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
|
||||
GPU and the available CPU RAM if unset.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
|
||||
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
|
||||
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
|
||||
setting this argument to `True` will raise an error.
|
||||
variant (`str`, *optional*):
|
||||
If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
|
||||
ignored when using `from_flax`.
|
||||
use_safetensors (`bool`, *optional*, defaults to `None`):
|
||||
If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the
|
||||
`safetensors` library is installed. If set to `True`, the model will be forcibly loaded from
|
||||
`safetensors` weights. If set to `False`, loading will *not* use `safetensors`.
|
||||
"""
|
||||
idx = 0
|
||||
adapters = []
|
||||
|
||||
# load adapter and append to list until no adapter directory exists anymore
|
||||
# first adapter has to be saved under `./mydirectory/adapter` to be compliant with `DiffusionPipeline.from_pretrained`
|
||||
# second, third, ... adapters have to be saved under `./mydirectory/adapter_1`, `./mydirectory/adapter_2`, ...
|
||||
model_path_to_load = pretrained_model_path
|
||||
while os.path.isdir(model_path_to_load):
|
||||
adapter = T2IAdapter.from_pretrained(model_path_to_load, **kwargs)
|
||||
adapters.append(adapter)
|
||||
|
||||
idx += 1
|
||||
model_path_to_load = pretrained_model_path + f"_{idx}"
|
||||
|
||||
logger.info(f"{len(adapters)} adapters loaded from {pretrained_model_path}.")
|
||||
|
||||
if len(adapters) == 0:
|
||||
raise ValueError(
|
||||
f"No T2IAdapters found under {os.path.dirname(pretrained_model_path)}. Expected at least {pretrained_model_path + '_0'}."
|
||||
)
|
||||
|
||||
return cls(adapters)
|
||||
|
||||
|
||||
class T2IAdapter(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
|
||||
@@ -404,15 +404,6 @@ class StableDiffusionMultiAdapterPipelineFastTests(AdapterTests, PipelineTesterM
|
||||
if test_mean_pixel_difference:
|
||||
assert_mean_pixel_difference(output_batch[0][0], output[0][0])
|
||||
|
||||
# We do not support saving pipelines with multiple adapters. The multiple adapters should be saved as their
|
||||
# own independent pipelines
|
||||
|
||||
def test_save_load_local(self):
|
||||
...
|
||||
|
||||
def test_save_load_optional_components(self):
|
||||
...
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
Reference in New Issue
Block a user