mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
* Quick implementation of t2i-adapter Load adapter module with from_pretrained Prototyping generalized adapter framework Writeup doc string for sideload framework(WIP) + some minor update on implementation Update adapter models Remove old adapter optional args in UNet Add StableDiffusionAdapterPipeline unit test Handle cpu offload in StableDiffusionAdapterPipeline Auto correct coding style Update model repo name to "RzZ/sd-v1-4-adapter-pipeline" Refactor MultiAdapter to better compatible with config system Export MultiAdapter Create pipeline document template from controlnet Create dummy objects Supproting new AdapterLight model Fix StableDiffusionAdapterPipeline common pipeline test [WIP] Update adapter pipeline document Handle num_inference_steps in StableDiffusionAdapterPipeline Update definition of Adapter "channels_in" Update documents Apply code style Fix doc typo and merge error Update doc string and example Quality of life improvement Remove redundant code and file from prototyping Remove unused pageage Remove comments Fix title Fix typo Add conditioning scale arg Bring back old implmentation Offload sideload Add supply info on document Update src/diffusers/models/adapter.py Co-authored-by: Will Berman <wlbberman@gmail.com> Update MultiAdapter constructor Swap out custom checkpoint and update pipeline constructor Update docment Apply suggestions from code review Co-authored-by: Will Berman <wlbberman@gmail.com> Correcting style Following single-file policy Update auto size in image preprocess func Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_adapter.py Co-authored-by: Will Berman <wlbberman@gmail.com> fix copies Update adapter pipeline behavior Add adapter_conditioning_scale doc string Add the missing doc string Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Fix few bugs from suggestion Handle L-mode PIL image as control image Rename to differentiate adapter resblock Update src/diffusers/models/adapter.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Fix typo Update adapter parameter name Update test case and code style Fix copies Fix typo Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_adapter.py Co-authored-by: Will Berman <wlbberman@gmail.com> Update Adapter class name Add checkpoint converting script Fix style Fix-copies Remove dev script Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Updates for parameter rename Fix convert_adapter remove main fix diff more refactoring more more small fixes refactor tests more slow tests more tests Update docs/source/en/api/pipelines/overview.mdx Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> add community contributor to docs Update docs/source/en/api/pipelines/stable_diffusion/adapter.mdx Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Update docs/source/en/api/pipelines/stable_diffusion/adapter.mdx Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Update docs/source/en/api/pipelines/stable_diffusion/adapter.mdx Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Update docs/source/en/api/pipelines/stable_diffusion/adapter.mdx Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Update docs/source/en/api/pipelines/stable_diffusion/adapter.mdx Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> fix remove from_adapters license paper link docs more url fixes more docs fix fixes fix fix * fix sample inplace add * additional_kwargs -> additional_residuals * move t2i adapter pipeline to own module * preprocess -> _preprocess_adapter_image * add TencentArc to license * fix example code links * add image converter and fix example doc string * fix links * clearer additional residual application --------- Co-authored-by: HimariO <dsfhe49854@gmail.com>
251 lines
14 KiB
Python
251 lines
14 KiB
Python
# coding=utf-8
|
|
# Copyright 2023 The HuggingFace Inc. team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
Conversion script for the T2I-Adapter checkpoints.
|
|
"""
|
|
|
|
import argparse
|
|
|
|
import torch
|
|
|
|
from diffusers import T2IAdapter
|
|
|
|
|
|
def convert_adapter(src_state, in_channels):
|
|
original_body_length = max([int(x.split(".")[1]) for x in src_state.keys() if "body." in x]) + 1
|
|
|
|
assert original_body_length == 8
|
|
|
|
# (0, 1) -> channels 1
|
|
assert src_state["body.0.block1.weight"].shape == (320, 320, 3, 3)
|
|
|
|
# (2, 3) -> channels 2
|
|
assert src_state["body.2.in_conv.weight"].shape == (640, 320, 1, 1)
|
|
|
|
# (4, 5) -> channels 3
|
|
assert src_state["body.4.in_conv.weight"].shape == (1280, 640, 1, 1)
|
|
|
|
# (6, 7) -> channels 4
|
|
assert src_state["body.6.block1.weight"].shape == (1280, 1280, 3, 3)
|
|
|
|
res_state = {
|
|
"adapter.conv_in.weight": src_state.pop("conv_in.weight"),
|
|
"adapter.conv_in.bias": src_state.pop("conv_in.bias"),
|
|
# 0.resnets.0
|
|
"adapter.body.0.resnets.0.block1.weight": src_state.pop("body.0.block1.weight"),
|
|
"adapter.body.0.resnets.0.block1.bias": src_state.pop("body.0.block1.bias"),
|
|
"adapter.body.0.resnets.0.block2.weight": src_state.pop("body.0.block2.weight"),
|
|
"adapter.body.0.resnets.0.block2.bias": src_state.pop("body.0.block2.bias"),
|
|
# 0.resnets.1
|
|
"adapter.body.0.resnets.1.block1.weight": src_state.pop("body.1.block1.weight"),
|
|
"adapter.body.0.resnets.1.block1.bias": src_state.pop("body.1.block1.bias"),
|
|
"adapter.body.0.resnets.1.block2.weight": src_state.pop("body.1.block2.weight"),
|
|
"adapter.body.0.resnets.1.block2.bias": src_state.pop("body.1.block2.bias"),
|
|
# 1
|
|
"adapter.body.1.in_conv.weight": src_state.pop("body.2.in_conv.weight"),
|
|
"adapter.body.1.in_conv.bias": src_state.pop("body.2.in_conv.bias"),
|
|
# 1.resnets.0
|
|
"adapter.body.1.resnets.0.block1.weight": src_state.pop("body.2.block1.weight"),
|
|
"adapter.body.1.resnets.0.block1.bias": src_state.pop("body.2.block1.bias"),
|
|
"adapter.body.1.resnets.0.block2.weight": src_state.pop("body.2.block2.weight"),
|
|
"adapter.body.1.resnets.0.block2.bias": src_state.pop("body.2.block2.bias"),
|
|
# 1.resnets.1
|
|
"adapter.body.1.resnets.1.block1.weight": src_state.pop("body.3.block1.weight"),
|
|
"adapter.body.1.resnets.1.block1.bias": src_state.pop("body.3.block1.bias"),
|
|
"adapter.body.1.resnets.1.block2.weight": src_state.pop("body.3.block2.weight"),
|
|
"adapter.body.1.resnets.1.block2.bias": src_state.pop("body.3.block2.bias"),
|
|
# 2
|
|
"adapter.body.2.in_conv.weight": src_state.pop("body.4.in_conv.weight"),
|
|
"adapter.body.2.in_conv.bias": src_state.pop("body.4.in_conv.bias"),
|
|
# 2.resnets.0
|
|
"adapter.body.2.resnets.0.block1.weight": src_state.pop("body.4.block1.weight"),
|
|
"adapter.body.2.resnets.0.block1.bias": src_state.pop("body.4.block1.bias"),
|
|
"adapter.body.2.resnets.0.block2.weight": src_state.pop("body.4.block2.weight"),
|
|
"adapter.body.2.resnets.0.block2.bias": src_state.pop("body.4.block2.bias"),
|
|
# 2.resnets.1
|
|
"adapter.body.2.resnets.1.block1.weight": src_state.pop("body.5.block1.weight"),
|
|
"adapter.body.2.resnets.1.block1.bias": src_state.pop("body.5.block1.bias"),
|
|
"adapter.body.2.resnets.1.block2.weight": src_state.pop("body.5.block2.weight"),
|
|
"adapter.body.2.resnets.1.block2.bias": src_state.pop("body.5.block2.bias"),
|
|
# 3.resnets.0
|
|
"adapter.body.3.resnets.0.block1.weight": src_state.pop("body.6.block1.weight"),
|
|
"adapter.body.3.resnets.0.block1.bias": src_state.pop("body.6.block1.bias"),
|
|
"adapter.body.3.resnets.0.block2.weight": src_state.pop("body.6.block2.weight"),
|
|
"adapter.body.3.resnets.0.block2.bias": src_state.pop("body.6.block2.bias"),
|
|
# 3.resnets.1
|
|
"adapter.body.3.resnets.1.block1.weight": src_state.pop("body.7.block1.weight"),
|
|
"adapter.body.3.resnets.1.block1.bias": src_state.pop("body.7.block1.bias"),
|
|
"adapter.body.3.resnets.1.block2.weight": src_state.pop("body.7.block2.weight"),
|
|
"adapter.body.3.resnets.1.block2.bias": src_state.pop("body.7.block2.bias"),
|
|
}
|
|
|
|
assert len(src_state) == 0
|
|
|
|
adapter = T2IAdapter(in_channels=in_channels, adapter_type="full_adapter")
|
|
|
|
adapter.load_state_dict(res_state)
|
|
|
|
return adapter
|
|
|
|
|
|
def convert_light_adapter(src_state):
|
|
original_body_length = max([int(x.split(".")[1]) for x in src_state.keys() if "body." in x]) + 1
|
|
|
|
assert original_body_length == 4
|
|
|
|
res_state = {
|
|
# body.0.in_conv
|
|
"adapter.body.0.in_conv.weight": src_state.pop("body.0.in_conv.weight"),
|
|
"adapter.body.0.in_conv.bias": src_state.pop("body.0.in_conv.bias"),
|
|
# body.0.resnets.0
|
|
"adapter.body.0.resnets.0.block1.weight": src_state.pop("body.0.body.0.block1.weight"),
|
|
"adapter.body.0.resnets.0.block1.bias": src_state.pop("body.0.body.0.block1.bias"),
|
|
"adapter.body.0.resnets.0.block2.weight": src_state.pop("body.0.body.0.block2.weight"),
|
|
"adapter.body.0.resnets.0.block2.bias": src_state.pop("body.0.body.0.block2.bias"),
|
|
# body.0.resnets.1
|
|
"adapter.body.0.resnets.1.block1.weight": src_state.pop("body.0.body.1.block1.weight"),
|
|
"adapter.body.0.resnets.1.block1.bias": src_state.pop("body.0.body.1.block1.bias"),
|
|
"adapter.body.0.resnets.1.block2.weight": src_state.pop("body.0.body.1.block2.weight"),
|
|
"adapter.body.0.resnets.1.block2.bias": src_state.pop("body.0.body.1.block2.bias"),
|
|
# body.0.resnets.2
|
|
"adapter.body.0.resnets.2.block1.weight": src_state.pop("body.0.body.2.block1.weight"),
|
|
"adapter.body.0.resnets.2.block1.bias": src_state.pop("body.0.body.2.block1.bias"),
|
|
"adapter.body.0.resnets.2.block2.weight": src_state.pop("body.0.body.2.block2.weight"),
|
|
"adapter.body.0.resnets.2.block2.bias": src_state.pop("body.0.body.2.block2.bias"),
|
|
# body.0.resnets.3
|
|
"adapter.body.0.resnets.3.block1.weight": src_state.pop("body.0.body.3.block1.weight"),
|
|
"adapter.body.0.resnets.3.block1.bias": src_state.pop("body.0.body.3.block1.bias"),
|
|
"adapter.body.0.resnets.3.block2.weight": src_state.pop("body.0.body.3.block2.weight"),
|
|
"adapter.body.0.resnets.3.block2.bias": src_state.pop("body.0.body.3.block2.bias"),
|
|
# body.0.out_conv
|
|
"adapter.body.0.out_conv.weight": src_state.pop("body.0.out_conv.weight"),
|
|
"adapter.body.0.out_conv.bias": src_state.pop("body.0.out_conv.bias"),
|
|
# body.1.in_conv
|
|
"adapter.body.1.in_conv.weight": src_state.pop("body.1.in_conv.weight"),
|
|
"adapter.body.1.in_conv.bias": src_state.pop("body.1.in_conv.bias"),
|
|
# body.1.resnets.0
|
|
"adapter.body.1.resnets.0.block1.weight": src_state.pop("body.1.body.0.block1.weight"),
|
|
"adapter.body.1.resnets.0.block1.bias": src_state.pop("body.1.body.0.block1.bias"),
|
|
"adapter.body.1.resnets.0.block2.weight": src_state.pop("body.1.body.0.block2.weight"),
|
|
"adapter.body.1.resnets.0.block2.bias": src_state.pop("body.1.body.0.block2.bias"),
|
|
# body.1.resnets.1
|
|
"adapter.body.1.resnets.1.block1.weight": src_state.pop("body.1.body.1.block1.weight"),
|
|
"adapter.body.1.resnets.1.block1.bias": src_state.pop("body.1.body.1.block1.bias"),
|
|
"adapter.body.1.resnets.1.block2.weight": src_state.pop("body.1.body.1.block2.weight"),
|
|
"adapter.body.1.resnets.1.block2.bias": src_state.pop("body.1.body.1.block2.bias"),
|
|
# body.1.body.2
|
|
"adapter.body.1.resnets.2.block1.weight": src_state.pop("body.1.body.2.block1.weight"),
|
|
"adapter.body.1.resnets.2.block1.bias": src_state.pop("body.1.body.2.block1.bias"),
|
|
"adapter.body.1.resnets.2.block2.weight": src_state.pop("body.1.body.2.block2.weight"),
|
|
"adapter.body.1.resnets.2.block2.bias": src_state.pop("body.1.body.2.block2.bias"),
|
|
# body.1.body.3
|
|
"adapter.body.1.resnets.3.block1.weight": src_state.pop("body.1.body.3.block1.weight"),
|
|
"adapter.body.1.resnets.3.block1.bias": src_state.pop("body.1.body.3.block1.bias"),
|
|
"adapter.body.1.resnets.3.block2.weight": src_state.pop("body.1.body.3.block2.weight"),
|
|
"adapter.body.1.resnets.3.block2.bias": src_state.pop("body.1.body.3.block2.bias"),
|
|
# body.1.out_conv
|
|
"adapter.body.1.out_conv.weight": src_state.pop("body.1.out_conv.weight"),
|
|
"adapter.body.1.out_conv.bias": src_state.pop("body.1.out_conv.bias"),
|
|
# body.2.in_conv
|
|
"adapter.body.2.in_conv.weight": src_state.pop("body.2.in_conv.weight"),
|
|
"adapter.body.2.in_conv.bias": src_state.pop("body.2.in_conv.bias"),
|
|
# body.2.body.0
|
|
"adapter.body.2.resnets.0.block1.weight": src_state.pop("body.2.body.0.block1.weight"),
|
|
"adapter.body.2.resnets.0.block1.bias": src_state.pop("body.2.body.0.block1.bias"),
|
|
"adapter.body.2.resnets.0.block2.weight": src_state.pop("body.2.body.0.block2.weight"),
|
|
"adapter.body.2.resnets.0.block2.bias": src_state.pop("body.2.body.0.block2.bias"),
|
|
# body.2.body.1
|
|
"adapter.body.2.resnets.1.block1.weight": src_state.pop("body.2.body.1.block1.weight"),
|
|
"adapter.body.2.resnets.1.block1.bias": src_state.pop("body.2.body.1.block1.bias"),
|
|
"adapter.body.2.resnets.1.block2.weight": src_state.pop("body.2.body.1.block2.weight"),
|
|
"adapter.body.2.resnets.1.block2.bias": src_state.pop("body.2.body.1.block2.bias"),
|
|
# body.2.body.2
|
|
"adapter.body.2.resnets.2.block1.weight": src_state.pop("body.2.body.2.block1.weight"),
|
|
"adapter.body.2.resnets.2.block1.bias": src_state.pop("body.2.body.2.block1.bias"),
|
|
"adapter.body.2.resnets.2.block2.weight": src_state.pop("body.2.body.2.block2.weight"),
|
|
"adapter.body.2.resnets.2.block2.bias": src_state.pop("body.2.body.2.block2.bias"),
|
|
# body.2.body.3
|
|
"adapter.body.2.resnets.3.block1.weight": src_state.pop("body.2.body.3.block1.weight"),
|
|
"adapter.body.2.resnets.3.block1.bias": src_state.pop("body.2.body.3.block1.bias"),
|
|
"adapter.body.2.resnets.3.block2.weight": src_state.pop("body.2.body.3.block2.weight"),
|
|
"adapter.body.2.resnets.3.block2.bias": src_state.pop("body.2.body.3.block2.bias"),
|
|
# body.2.out_conv
|
|
"adapter.body.2.out_conv.weight": src_state.pop("body.2.out_conv.weight"),
|
|
"adapter.body.2.out_conv.bias": src_state.pop("body.2.out_conv.bias"),
|
|
# body.3.in_conv
|
|
"adapter.body.3.in_conv.weight": src_state.pop("body.3.in_conv.weight"),
|
|
"adapter.body.3.in_conv.bias": src_state.pop("body.3.in_conv.bias"),
|
|
# body.3.body.0
|
|
"adapter.body.3.resnets.0.block1.weight": src_state.pop("body.3.body.0.block1.weight"),
|
|
"adapter.body.3.resnets.0.block1.bias": src_state.pop("body.3.body.0.block1.bias"),
|
|
"adapter.body.3.resnets.0.block2.weight": src_state.pop("body.3.body.0.block2.weight"),
|
|
"adapter.body.3.resnets.0.block2.bias": src_state.pop("body.3.body.0.block2.bias"),
|
|
# body.3.body.1
|
|
"adapter.body.3.resnets.1.block1.weight": src_state.pop("body.3.body.1.block1.weight"),
|
|
"adapter.body.3.resnets.1.block1.bias": src_state.pop("body.3.body.1.block1.bias"),
|
|
"adapter.body.3.resnets.1.block2.weight": src_state.pop("body.3.body.1.block2.weight"),
|
|
"adapter.body.3.resnets.1.block2.bias": src_state.pop("body.3.body.1.block2.bias"),
|
|
# body.3.body.2
|
|
"adapter.body.3.resnets.2.block1.weight": src_state.pop("body.3.body.2.block1.weight"),
|
|
"adapter.body.3.resnets.2.block1.bias": src_state.pop("body.3.body.2.block1.bias"),
|
|
"adapter.body.3.resnets.2.block2.weight": src_state.pop("body.3.body.2.block2.weight"),
|
|
"adapter.body.3.resnets.2.block2.bias": src_state.pop("body.3.body.2.block2.bias"),
|
|
# body.3.body.3
|
|
"adapter.body.3.resnets.3.block1.weight": src_state.pop("body.3.body.3.block1.weight"),
|
|
"adapter.body.3.resnets.3.block1.bias": src_state.pop("body.3.body.3.block1.bias"),
|
|
"adapter.body.3.resnets.3.block2.weight": src_state.pop("body.3.body.3.block2.weight"),
|
|
"adapter.body.3.resnets.3.block2.bias": src_state.pop("body.3.body.3.block2.bias"),
|
|
# body.3.out_conv
|
|
"adapter.body.3.out_conv.weight": src_state.pop("body.3.out_conv.weight"),
|
|
"adapter.body.3.out_conv.bias": src_state.pop("body.3.out_conv.bias"),
|
|
}
|
|
|
|
assert len(src_state) == 0
|
|
|
|
adapter = T2IAdapter(in_channels=3, channels=[320, 640, 1280], num_res_blocks=4, adapter_type="light_adapter")
|
|
|
|
adapter.load_state_dict(res_state)
|
|
|
|
return adapter
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument(
|
|
"--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
|
|
)
|
|
parser.add_argument(
|
|
"--output_path", default=None, type=str, required=True, help="Path to the store the result checkpoint."
|
|
)
|
|
parser.add_argument(
|
|
"--is_adapter_light",
|
|
action="store_true",
|
|
help="Is checkpoint come from Adapter-Light architecture. ex: color-adapter",
|
|
)
|
|
parser.add_argument("--in_channels", required=False, type=int, help="Input channels for non-light adapter")
|
|
|
|
args = parser.parse_args()
|
|
src_state = torch.load(args.checkpoint_path)
|
|
|
|
if args.is_adapter_light:
|
|
adapter = convert_light_adapter(src_state)
|
|
else:
|
|
if args.in_channels is None:
|
|
raise ValueError("set `--in_channels=<n>`")
|
|
adapter = convert_adapter(src_state, args.in_channels)
|
|
|
|
adapter.save_pretrained(args.output_path)
|