1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Support saving multiple t2i adapter models under one checkpoint (#4798)

* adding save and load for MultiAdapter, adding test

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Adding changes from review test_stable_diffusion_adapter

* import sorting fix

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
VitjanZ
2023-08-29 09:24:40 +02:00
committed by GitHub
parent 3eeaf4e041
commit 7200daa412
2 changed files with 120 additions and 11 deletions

View File

@@ -11,17 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional
import os
from typing import Callable, List, Optional, Union
import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import logging
from .modeling_utils import ModelMixin
from .resnet import Downsample2D
logger = logging.get_logger(__name__)
class MultiAdapter(ModelMixin):
r"""
MultiAdapter is a wrapper model that contains multiple adapter models and merges their outputs according to
@@ -91,6 +95,120 @@ class MultiAdapter(ModelMixin):
accume_state[i] += w * features[i]
return accume_state
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
is_main_process: bool = True,
save_function: Callable = None,
safe_serialization: bool = True,
variant: Optional[str] = None,
):
"""
Save a model and its configuration file to a directory, so that it can be re-loaded using the
`[`~models.adapter.MultiAdapter.from_pretrained`]` class method.
Arguments:
save_directory (`str` or `os.PathLike`):
Directory to which to save. Will be created if it doesn't exist.
is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process or not. Useful when in distributed training like
TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
the main process to avoid race conditions.
save_function (`Callable`):
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
need to replace `torch.save` by another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
variant (`str`, *optional*):
If specified, weights are saved in the format pytorch_model.<variant>.bin.
"""
idx = 0
model_path_to_save = save_directory
for adapter in self.adapters:
adapter.save_pretrained(
model_path_to_save,
is_main_process=is_main_process,
save_function=save_function,
safe_serialization=safe_serialization,
variant=variant,
)
idx += 1
model_path_to_save = model_path_to_save + f"_{idx}"
@classmethod
def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs):
r"""
Instantiate a pretrained MultiAdapter model from multiple pre-trained adapter models.
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
the model, you should first set it back in training mode with `model.train()`.
The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
task.
The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
weights are discarded.
Parameters:
pretrained_model_path (`os.PathLike`):
A path to a *directory* containing model weights saved using
[`~diffusers.models.adapter.MultiAdapter.save_pretrained`], e.g.,
`./my_model_directory/adapter`.
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
will be automatically derived from the model's weights.
output_loading_info(`bool`, *optional*, defaults to `False`):
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
A map that specifies where each submodule should go. It doesn't need to be refined to each
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
same device.
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
more information about each option see [designing a device
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
max_memory (`Dict`, *optional*):
A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
GPU and the available CPU RAM if unset.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
setting this argument to `True` will raise an error.
variant (`str`, *optional*):
If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
ignored when using `from_flax`.
use_safetensors (`bool`, *optional*, defaults to `None`):
If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the
`safetensors` library is installed. If set to `True`, the model will be forcibly loaded from
`safetensors` weights. If set to `False`, loading will *not* use `safetensors`.
"""
idx = 0
adapters = []
# load adapter and append to list until no adapter directory exists anymore
# first adapter has to be saved under `./mydirectory/adapter` to be compliant with `DiffusionPipeline.from_pretrained`
# second, third, ... adapters have to be saved under `./mydirectory/adapter_1`, `./mydirectory/adapter_2`, ...
model_path_to_load = pretrained_model_path
while os.path.isdir(model_path_to_load):
adapter = T2IAdapter.from_pretrained(model_path_to_load, **kwargs)
adapters.append(adapter)
idx += 1
model_path_to_load = pretrained_model_path + f"_{idx}"
logger.info(f"{len(adapters)} adapters loaded from {pretrained_model_path}.")
if len(adapters) == 0:
raise ValueError(
f"No T2IAdapters found under {os.path.dirname(pretrained_model_path)}. Expected at least {pretrained_model_path + '_0'}."
)
return cls(adapters)
class T2IAdapter(ModelMixin, ConfigMixin):
r"""

View File

@@ -404,15 +404,6 @@ class StableDiffusionMultiAdapterPipelineFastTests(AdapterTests, PipelineTesterM
if test_mean_pixel_difference:
assert_mean_pixel_difference(output_batch[0][0], output[0][0])
# We do not support saving pipelines with multiple adapters. The multiple adapters should be saved as their
# own independent pipelines
def test_save_load_local(self):
...
def test_save_load_optional_components(self):
...
@slow
@require_torch_gpu