mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[MultiControlNet] Allow save and load (#3747)
* [MultiControlNet] Allow save and load * Correct more * [MultiControlNet] Allow save and load * make style * Apply suggestions from code review
This commit is contained in:
committed by
GitHub
parent
ef9590712a
commit
34d14d7848
@@ -1,10 +1,15 @@
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...models.controlnet import ControlNetModel, ControlNetOutput
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class MultiControlNetModel(ModelMixin):
|
||||
@@ -64,3 +69,117 @@ class MultiControlNetModel(ModelMixin):
|
||||
mid_block_res_sample += mid_sample
|
||||
|
||||
return down_block_res_samples, mid_block_res_sample
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
is_main_process: bool = True,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = False,
|
||||
variant: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Save a model and its configuration file to a directory, so that it can be re-loaded using the
|
||||
`[`~pipelines.controlnet.MultiControlNetModel.from_pretrained`]` class method.
|
||||
|
||||
Arguments:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory to which to save. Will be created if it doesn't exist.
|
||||
is_main_process (`bool`, *optional*, defaults to `True`):
|
||||
Whether the process calling this is the main process or not. Useful when in distributed training like
|
||||
TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
|
||||
the main process to avoid race conditions.
|
||||
save_function (`Callable`):
|
||||
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
|
||||
need to replace `torch.save` by another method. Can be configured with the environment variable
|
||||
`DIFFUSERS_SAVE_MODE`.
|
||||
safe_serialization (`bool`, *optional*, defaults to `False`):
|
||||
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
|
||||
variant (`str`, *optional*):
|
||||
If specified, weights are saved in the format pytorch_model.<variant>.bin.
|
||||
"""
|
||||
idx = 0
|
||||
model_path_to_save = save_directory
|
||||
for controlnet in self.nets:
|
||||
controlnet.save_pretrained(
|
||||
model_path_to_save,
|
||||
is_main_process=is_main_process,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
variant=variant,
|
||||
)
|
||||
|
||||
idx += 1
|
||||
model_path_to_save = model_path_to_save + f"_{idx}"
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||
r"""
|
||||
Instantiate a pretrained MultiControlNet model from multiple pre-trained controlnet models.
|
||||
|
||||
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
|
||||
the model, you should first set it back in training mode with `model.train()`.
|
||||
|
||||
The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
|
||||
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
|
||||
task.
|
||||
|
||||
The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
|
||||
weights are discarded.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_path (`os.PathLike`):
|
||||
A path to a *directory* containing model weights saved using
|
||||
[`~diffusers.pipelines.controlnet.MultiControlNetModel.save_pretrained`], e.g.,
|
||||
`./my_model_directory/controlnet`.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
|
||||
will be automatically derived from the model's weights.
|
||||
output_loading_info(`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
||||
A map that specifies where each submodule should go. It doesn't need to be refined to each
|
||||
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
|
||||
same device.
|
||||
|
||||
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
|
||||
more information about each option see [designing a device
|
||||
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
||||
max_memory (`Dict`, *optional*):
|
||||
A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
|
||||
GPU and the available CPU RAM if unset.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
|
||||
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
|
||||
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
|
||||
setting this argument to `True` will raise an error.
|
||||
variant (`str`, *optional*):
|
||||
If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
|
||||
ignored when using `from_flax`.
|
||||
use_safetensors (`bool`, *optional*, defaults to `None`):
|
||||
If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the
|
||||
`safetensors` library is installed. If set to `True`, the model will be forcibly loaded from
|
||||
`safetensors` weights. If set to `False`, loading will *not* use `safetensors`.
|
||||
"""
|
||||
idx = 0
|
||||
controlnets = []
|
||||
|
||||
# load controlnet and append to list until no controlnet directory exists anymore
|
||||
# first controlnet has to be saved under `./mydirectory/controlnet` to be compliant with `DiffusionPipeline.from_prertained`
|
||||
# second, third, ... controlnets have to be saved under `./mydirectory/controlnet_1`, `./mydirectory/controlnet_2`, ...
|
||||
model_path_to_load = pretrained_model_path
|
||||
while os.path.isdir(model_path_to_load):
|
||||
controlnet = ControlNetModel.from_pretrained(model_path_to_load, **kwargs)
|
||||
controlnets.append(controlnet)
|
||||
|
||||
idx += 1
|
||||
model_path_to_load = pretrained_model_path + f"_{idx}"
|
||||
|
||||
logger.info(f"{len(controlnets)} controlnets loaded from {pretrained_model_path}.")
|
||||
|
||||
if len(controlnets) == 0:
|
||||
raise ValueError(
|
||||
f"No ControlNets found under {os.path.dirname(pretrained_model_path)}. Expected at least {pretrained_model_path + '_0'}."
|
||||
)
|
||||
|
||||
return cls(controlnets)
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
@@ -560,7 +559,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
||||
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
|
||||
elif len(image) != len(self.controlnet.nets):
|
||||
raise ValueError(
|
||||
"For multiple controlnets: `image` must have the same length as the number of controlnets."
|
||||
f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
|
||||
)
|
||||
|
||||
for image_ in image:
|
||||
@@ -679,18 +678,6 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
# override DiffusionPipeline
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
safe_serialization: bool = False,
|
||||
variant: Optional[str] = None,
|
||||
):
|
||||
if isinstance(self.controlnet, ControlNetModel):
|
||||
super().save_pretrained(save_directory, safe_serialization, variant)
|
||||
else:
|
||||
raise NotImplementedError("Currently, the `save_pretrained()` is not implemented for Multi-ControlNet.")
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
@@ -586,7 +585,7 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
|
||||
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
|
||||
elif len(image) != len(self.controlnet.nets):
|
||||
raise ValueError(
|
||||
"For multiple controlnets: `image` must have the same length as the number of controlnets."
|
||||
f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
|
||||
)
|
||||
|
||||
for image_ in image:
|
||||
@@ -757,18 +756,6 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
|
||||
|
||||
return latents
|
||||
|
||||
# override DiffusionPipeline
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
safe_serialization: bool = False,
|
||||
variant: Optional[str] = None,
|
||||
):
|
||||
if isinstance(self.controlnet, ControlNetModel):
|
||||
super().save_pretrained(save_directory, safe_serialization, variant)
|
||||
else:
|
||||
raise NotImplementedError("Currently, the `save_pretrained()` is not implemented for Multi-ControlNet.")
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
# This model implementation is heavily inspired by https://github.com/haofanwang/ControlNet-for-Diffusers/
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
@@ -718,7 +717,7 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi
|
||||
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
|
||||
elif len(image) != len(self.controlnet.nets):
|
||||
raise ValueError(
|
||||
"For multiple controlnets: `image` must have the same length as the number of controlnets."
|
||||
f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
|
||||
)
|
||||
|
||||
for image_ in image:
|
||||
@@ -957,18 +956,6 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi
|
||||
|
||||
return image_latents
|
||||
|
||||
# override DiffusionPipeline
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
safe_serialization: bool = False,
|
||||
variant: Optional[str] = None,
|
||||
):
|
||||
if isinstance(self.controlnet, ControlNetModel):
|
||||
super().save_pretrained(save_directory, safe_serialization, variant)
|
||||
else:
|
||||
raise NotImplementedError("Currently, the `save_pretrained()` is not implemented for Multi-ControlNet.")
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
|
||||
@@ -346,21 +346,6 @@ class StableDiffusionMultiControlNetPipelineFastTests(PipelineTesterMixin, unitt
|
||||
except NotImplementedError:
|
||||
pass
|
||||
|
||||
# override PipelineTesterMixin
|
||||
@unittest.skip("save pretrained not implemented")
|
||||
def test_save_load_float16(self):
|
||||
...
|
||||
|
||||
# override PipelineTesterMixin
|
||||
@unittest.skip("save pretrained not implemented")
|
||||
def test_save_load_local(self):
|
||||
...
|
||||
|
||||
# override PipelineTesterMixin
|
||||
@unittest.skip("save pretrained not implemented")
|
||||
def test_save_load_optional_components(self):
|
||||
...
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -304,21 +304,6 @@ class StableDiffusionMultiControlNetPipelineFastTests(PipelineTesterMixin, unitt
|
||||
except NotImplementedError:
|
||||
pass
|
||||
|
||||
# override PipelineTesterMixin
|
||||
@unittest.skip("save pretrained not implemented")
|
||||
def test_save_load_float16(self):
|
||||
...
|
||||
|
||||
# override PipelineTesterMixin
|
||||
@unittest.skip("save pretrained not implemented")
|
||||
def test_save_load_local(self):
|
||||
...
|
||||
|
||||
# override PipelineTesterMixin
|
||||
@unittest.skip("save pretrained not implemented")
|
||||
def test_save_load_optional_components(self):
|
||||
...
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -382,21 +382,6 @@ class MultiControlNetInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test
|
||||
except NotImplementedError:
|
||||
pass
|
||||
|
||||
# override PipelineTesterMixin
|
||||
@unittest.skip("save pretrained not implemented")
|
||||
def test_save_load_float16(self):
|
||||
...
|
||||
|
||||
# override PipelineTesterMixin
|
||||
@unittest.skip("save pretrained not implemented")
|
||||
def test_save_load_local(self):
|
||||
...
|
||||
|
||||
# override PipelineTesterMixin
|
||||
@unittest.skip("save pretrained not implemented")
|
||||
def test_save_load_optional_components(self):
|
||||
...
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
Reference in New Issue
Block a user