1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
Files
diffusers/scripts/convert_original_t2i_adapter.py
Will Berman a0597f33ac t2i pipeline (#3932)
* Quick implementation of t2i-adapter

Load adapter module with from_pretrained

Prototyping generalized adapter framework

Writeup doc string for sideload framework(WIP) + some minor update on implementation

Update adapter models

Remove old adapter optional args in UNet

Add StableDiffusionAdapterPipeline unit test

Handle cpu offload in StableDiffusionAdapterPipeline

Auto correct coding style

Update model repo name to "RzZ/sd-v1-4-adapter-pipeline"

Refactor MultiAdapter to better compatible with config system

Export MultiAdapter

Create pipeline document template from controlnet

Create dummy objects

Supproting new AdapterLight model

Fix StableDiffusionAdapterPipeline common pipeline test

[WIP] Update adapter pipeline document

Handle num_inference_steps in StableDiffusionAdapterPipeline

Update definition of Adapter "channels_in"

Update documents

Apply code style

Fix doc typo and merge error

Update doc string and example

Quality of life improvement

Remove redundant code and file from prototyping

Remove unused pageage

Remove comments

Fix title

Fix typo

Add conditioning scale arg

Bring back old implmentation

Offload sideload

Add supply info on document

Update src/diffusers/models/adapter.py

Co-authored-by: Will Berman <wlbberman@gmail.com>

Update MultiAdapter constructor

Swap out custom checkpoint and update pipeline constructor

Update docment

Apply suggestions from code review

Co-authored-by: Will Berman <wlbberman@gmail.com>

Correcting style

Following single-file policy

Update auto size in image preprocess func

Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_adapter.py

Co-authored-by: Will Berman <wlbberman@gmail.com>

fix copies

Update adapter pipeline behavior

Add adapter_conditioning_scale doc string

Add the missing doc string

Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

Fix few bugs from suggestion

Handle L-mode PIL image as control image

Rename to differentiate adapter resblock

Update src/diffusers/models/adapter.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

Fix typo

Update adapter parameter name

Update test case and code style

Fix copies

Fix typo

Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_adapter.py

Co-authored-by: Will Berman <wlbberman@gmail.com>

Update Adapter class name

Add checkpoint converting script

Fix style

Fix-copies

Remove dev script

Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

Updates for parameter rename

Fix convert_adapter

remove main

fix diff

more

refactoring

more

more

small fixes

refactor

tests

more slow tests

more tests

Update docs/source/en/api/pipelines/overview.mdx

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

add community contributor to docs

Update docs/source/en/api/pipelines/stable_diffusion/adapter.mdx

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

Update docs/source/en/api/pipelines/stable_diffusion/adapter.mdx

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

Update docs/source/en/api/pipelines/stable_diffusion/adapter.mdx

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

Update docs/source/en/api/pipelines/stable_diffusion/adapter.mdx

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

Update docs/source/en/api/pipelines/stable_diffusion/adapter.mdx

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

fix

remove from_adapters

license

paper link

docs

more url fixes

more docs

fix

fixes

fix

fix

* fix sample inplace add

* additional_kwargs -> additional_residuals

* move t2i adapter pipeline to own module

* preprocess -> _preprocess_adapter_image

* add TencentArc to license

* fix example code links

* add image converter and fix example doc string

* fix links

* clearer additional residual application

---------

Co-authored-by: HimariO <dsfhe49854@gmail.com>
2023-07-17 12:55:44 -07:00

251 lines
14 KiB
Python

# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Conversion script for the T2I-Adapter checkpoints.
"""
import argparse
import torch
from diffusers import T2IAdapter
def convert_adapter(src_state, in_channels):
original_body_length = max([int(x.split(".")[1]) for x in src_state.keys() if "body." in x]) + 1
assert original_body_length == 8
# (0, 1) -> channels 1
assert src_state["body.0.block1.weight"].shape == (320, 320, 3, 3)
# (2, 3) -> channels 2
assert src_state["body.2.in_conv.weight"].shape == (640, 320, 1, 1)
# (4, 5) -> channels 3
assert src_state["body.4.in_conv.weight"].shape == (1280, 640, 1, 1)
# (6, 7) -> channels 4
assert src_state["body.6.block1.weight"].shape == (1280, 1280, 3, 3)
res_state = {
"adapter.conv_in.weight": src_state.pop("conv_in.weight"),
"adapter.conv_in.bias": src_state.pop("conv_in.bias"),
# 0.resnets.0
"adapter.body.0.resnets.0.block1.weight": src_state.pop("body.0.block1.weight"),
"adapter.body.0.resnets.0.block1.bias": src_state.pop("body.0.block1.bias"),
"adapter.body.0.resnets.0.block2.weight": src_state.pop("body.0.block2.weight"),
"adapter.body.0.resnets.0.block2.bias": src_state.pop("body.0.block2.bias"),
# 0.resnets.1
"adapter.body.0.resnets.1.block1.weight": src_state.pop("body.1.block1.weight"),
"adapter.body.0.resnets.1.block1.bias": src_state.pop("body.1.block1.bias"),
"adapter.body.0.resnets.1.block2.weight": src_state.pop("body.1.block2.weight"),
"adapter.body.0.resnets.1.block2.bias": src_state.pop("body.1.block2.bias"),
# 1
"adapter.body.1.in_conv.weight": src_state.pop("body.2.in_conv.weight"),
"adapter.body.1.in_conv.bias": src_state.pop("body.2.in_conv.bias"),
# 1.resnets.0
"adapter.body.1.resnets.0.block1.weight": src_state.pop("body.2.block1.weight"),
"adapter.body.1.resnets.0.block1.bias": src_state.pop("body.2.block1.bias"),
"adapter.body.1.resnets.0.block2.weight": src_state.pop("body.2.block2.weight"),
"adapter.body.1.resnets.0.block2.bias": src_state.pop("body.2.block2.bias"),
# 1.resnets.1
"adapter.body.1.resnets.1.block1.weight": src_state.pop("body.3.block1.weight"),
"adapter.body.1.resnets.1.block1.bias": src_state.pop("body.3.block1.bias"),
"adapter.body.1.resnets.1.block2.weight": src_state.pop("body.3.block2.weight"),
"adapter.body.1.resnets.1.block2.bias": src_state.pop("body.3.block2.bias"),
# 2
"adapter.body.2.in_conv.weight": src_state.pop("body.4.in_conv.weight"),
"adapter.body.2.in_conv.bias": src_state.pop("body.4.in_conv.bias"),
# 2.resnets.0
"adapter.body.2.resnets.0.block1.weight": src_state.pop("body.4.block1.weight"),
"adapter.body.2.resnets.0.block1.bias": src_state.pop("body.4.block1.bias"),
"adapter.body.2.resnets.0.block2.weight": src_state.pop("body.4.block2.weight"),
"adapter.body.2.resnets.0.block2.bias": src_state.pop("body.4.block2.bias"),
# 2.resnets.1
"adapter.body.2.resnets.1.block1.weight": src_state.pop("body.5.block1.weight"),
"adapter.body.2.resnets.1.block1.bias": src_state.pop("body.5.block1.bias"),
"adapter.body.2.resnets.1.block2.weight": src_state.pop("body.5.block2.weight"),
"adapter.body.2.resnets.1.block2.bias": src_state.pop("body.5.block2.bias"),
# 3.resnets.0
"adapter.body.3.resnets.0.block1.weight": src_state.pop("body.6.block1.weight"),
"adapter.body.3.resnets.0.block1.bias": src_state.pop("body.6.block1.bias"),
"adapter.body.3.resnets.0.block2.weight": src_state.pop("body.6.block2.weight"),
"adapter.body.3.resnets.0.block2.bias": src_state.pop("body.6.block2.bias"),
# 3.resnets.1
"adapter.body.3.resnets.1.block1.weight": src_state.pop("body.7.block1.weight"),
"adapter.body.3.resnets.1.block1.bias": src_state.pop("body.7.block1.bias"),
"adapter.body.3.resnets.1.block2.weight": src_state.pop("body.7.block2.weight"),
"adapter.body.3.resnets.1.block2.bias": src_state.pop("body.7.block2.bias"),
}
assert len(src_state) == 0
adapter = T2IAdapter(in_channels=in_channels, adapter_type="full_adapter")
adapter.load_state_dict(res_state)
return adapter
def convert_light_adapter(src_state):
original_body_length = max([int(x.split(".")[1]) for x in src_state.keys() if "body." in x]) + 1
assert original_body_length == 4
res_state = {
# body.0.in_conv
"adapter.body.0.in_conv.weight": src_state.pop("body.0.in_conv.weight"),
"adapter.body.0.in_conv.bias": src_state.pop("body.0.in_conv.bias"),
# body.0.resnets.0
"adapter.body.0.resnets.0.block1.weight": src_state.pop("body.0.body.0.block1.weight"),
"adapter.body.0.resnets.0.block1.bias": src_state.pop("body.0.body.0.block1.bias"),
"adapter.body.0.resnets.0.block2.weight": src_state.pop("body.0.body.0.block2.weight"),
"adapter.body.0.resnets.0.block2.bias": src_state.pop("body.0.body.0.block2.bias"),
# body.0.resnets.1
"adapter.body.0.resnets.1.block1.weight": src_state.pop("body.0.body.1.block1.weight"),
"adapter.body.0.resnets.1.block1.bias": src_state.pop("body.0.body.1.block1.bias"),
"adapter.body.0.resnets.1.block2.weight": src_state.pop("body.0.body.1.block2.weight"),
"adapter.body.0.resnets.1.block2.bias": src_state.pop("body.0.body.1.block2.bias"),
# body.0.resnets.2
"adapter.body.0.resnets.2.block1.weight": src_state.pop("body.0.body.2.block1.weight"),
"adapter.body.0.resnets.2.block1.bias": src_state.pop("body.0.body.2.block1.bias"),
"adapter.body.0.resnets.2.block2.weight": src_state.pop("body.0.body.2.block2.weight"),
"adapter.body.0.resnets.2.block2.bias": src_state.pop("body.0.body.2.block2.bias"),
# body.0.resnets.3
"adapter.body.0.resnets.3.block1.weight": src_state.pop("body.0.body.3.block1.weight"),
"adapter.body.0.resnets.3.block1.bias": src_state.pop("body.0.body.3.block1.bias"),
"adapter.body.0.resnets.3.block2.weight": src_state.pop("body.0.body.3.block2.weight"),
"adapter.body.0.resnets.3.block2.bias": src_state.pop("body.0.body.3.block2.bias"),
# body.0.out_conv
"adapter.body.0.out_conv.weight": src_state.pop("body.0.out_conv.weight"),
"adapter.body.0.out_conv.bias": src_state.pop("body.0.out_conv.bias"),
# body.1.in_conv
"adapter.body.1.in_conv.weight": src_state.pop("body.1.in_conv.weight"),
"adapter.body.1.in_conv.bias": src_state.pop("body.1.in_conv.bias"),
# body.1.resnets.0
"adapter.body.1.resnets.0.block1.weight": src_state.pop("body.1.body.0.block1.weight"),
"adapter.body.1.resnets.0.block1.bias": src_state.pop("body.1.body.0.block1.bias"),
"adapter.body.1.resnets.0.block2.weight": src_state.pop("body.1.body.0.block2.weight"),
"adapter.body.1.resnets.0.block2.bias": src_state.pop("body.1.body.0.block2.bias"),
# body.1.resnets.1
"adapter.body.1.resnets.1.block1.weight": src_state.pop("body.1.body.1.block1.weight"),
"adapter.body.1.resnets.1.block1.bias": src_state.pop("body.1.body.1.block1.bias"),
"adapter.body.1.resnets.1.block2.weight": src_state.pop("body.1.body.1.block2.weight"),
"adapter.body.1.resnets.1.block2.bias": src_state.pop("body.1.body.1.block2.bias"),
# body.1.body.2
"adapter.body.1.resnets.2.block1.weight": src_state.pop("body.1.body.2.block1.weight"),
"adapter.body.1.resnets.2.block1.bias": src_state.pop("body.1.body.2.block1.bias"),
"adapter.body.1.resnets.2.block2.weight": src_state.pop("body.1.body.2.block2.weight"),
"adapter.body.1.resnets.2.block2.bias": src_state.pop("body.1.body.2.block2.bias"),
# body.1.body.3
"adapter.body.1.resnets.3.block1.weight": src_state.pop("body.1.body.3.block1.weight"),
"adapter.body.1.resnets.3.block1.bias": src_state.pop("body.1.body.3.block1.bias"),
"adapter.body.1.resnets.3.block2.weight": src_state.pop("body.1.body.3.block2.weight"),
"adapter.body.1.resnets.3.block2.bias": src_state.pop("body.1.body.3.block2.bias"),
# body.1.out_conv
"adapter.body.1.out_conv.weight": src_state.pop("body.1.out_conv.weight"),
"adapter.body.1.out_conv.bias": src_state.pop("body.1.out_conv.bias"),
# body.2.in_conv
"adapter.body.2.in_conv.weight": src_state.pop("body.2.in_conv.weight"),
"adapter.body.2.in_conv.bias": src_state.pop("body.2.in_conv.bias"),
# body.2.body.0
"adapter.body.2.resnets.0.block1.weight": src_state.pop("body.2.body.0.block1.weight"),
"adapter.body.2.resnets.0.block1.bias": src_state.pop("body.2.body.0.block1.bias"),
"adapter.body.2.resnets.0.block2.weight": src_state.pop("body.2.body.0.block2.weight"),
"adapter.body.2.resnets.0.block2.bias": src_state.pop("body.2.body.0.block2.bias"),
# body.2.body.1
"adapter.body.2.resnets.1.block1.weight": src_state.pop("body.2.body.1.block1.weight"),
"adapter.body.2.resnets.1.block1.bias": src_state.pop("body.2.body.1.block1.bias"),
"adapter.body.2.resnets.1.block2.weight": src_state.pop("body.2.body.1.block2.weight"),
"adapter.body.2.resnets.1.block2.bias": src_state.pop("body.2.body.1.block2.bias"),
# body.2.body.2
"adapter.body.2.resnets.2.block1.weight": src_state.pop("body.2.body.2.block1.weight"),
"adapter.body.2.resnets.2.block1.bias": src_state.pop("body.2.body.2.block1.bias"),
"adapter.body.2.resnets.2.block2.weight": src_state.pop("body.2.body.2.block2.weight"),
"adapter.body.2.resnets.2.block2.bias": src_state.pop("body.2.body.2.block2.bias"),
# body.2.body.3
"adapter.body.2.resnets.3.block1.weight": src_state.pop("body.2.body.3.block1.weight"),
"adapter.body.2.resnets.3.block1.bias": src_state.pop("body.2.body.3.block1.bias"),
"adapter.body.2.resnets.3.block2.weight": src_state.pop("body.2.body.3.block2.weight"),
"adapter.body.2.resnets.3.block2.bias": src_state.pop("body.2.body.3.block2.bias"),
# body.2.out_conv
"adapter.body.2.out_conv.weight": src_state.pop("body.2.out_conv.weight"),
"adapter.body.2.out_conv.bias": src_state.pop("body.2.out_conv.bias"),
# body.3.in_conv
"adapter.body.3.in_conv.weight": src_state.pop("body.3.in_conv.weight"),
"adapter.body.3.in_conv.bias": src_state.pop("body.3.in_conv.bias"),
# body.3.body.0
"adapter.body.3.resnets.0.block1.weight": src_state.pop("body.3.body.0.block1.weight"),
"adapter.body.3.resnets.0.block1.bias": src_state.pop("body.3.body.0.block1.bias"),
"adapter.body.3.resnets.0.block2.weight": src_state.pop("body.3.body.0.block2.weight"),
"adapter.body.3.resnets.0.block2.bias": src_state.pop("body.3.body.0.block2.bias"),
# body.3.body.1
"adapter.body.3.resnets.1.block1.weight": src_state.pop("body.3.body.1.block1.weight"),
"adapter.body.3.resnets.1.block1.bias": src_state.pop("body.3.body.1.block1.bias"),
"adapter.body.3.resnets.1.block2.weight": src_state.pop("body.3.body.1.block2.weight"),
"adapter.body.3.resnets.1.block2.bias": src_state.pop("body.3.body.1.block2.bias"),
# body.3.body.2
"adapter.body.3.resnets.2.block1.weight": src_state.pop("body.3.body.2.block1.weight"),
"adapter.body.3.resnets.2.block1.bias": src_state.pop("body.3.body.2.block1.bias"),
"adapter.body.3.resnets.2.block2.weight": src_state.pop("body.3.body.2.block2.weight"),
"adapter.body.3.resnets.2.block2.bias": src_state.pop("body.3.body.2.block2.bias"),
# body.3.body.3
"adapter.body.3.resnets.3.block1.weight": src_state.pop("body.3.body.3.block1.weight"),
"adapter.body.3.resnets.3.block1.bias": src_state.pop("body.3.body.3.block1.bias"),
"adapter.body.3.resnets.3.block2.weight": src_state.pop("body.3.body.3.block2.weight"),
"adapter.body.3.resnets.3.block2.bias": src_state.pop("body.3.body.3.block2.bias"),
# body.3.out_conv
"adapter.body.3.out_conv.weight": src_state.pop("body.3.out_conv.weight"),
"adapter.body.3.out_conv.bias": src_state.pop("body.3.out_conv.bias"),
}
assert len(src_state) == 0
adapter = T2IAdapter(in_channels=3, channels=[320, 640, 1280], num_res_blocks=4, adapter_type="light_adapter")
adapter.load_state_dict(res_state)
return adapter
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
)
parser.add_argument(
"--output_path", default=None, type=str, required=True, help="Path to the store the result checkpoint."
)
parser.add_argument(
"--is_adapter_light",
action="store_true",
help="Is checkpoint come from Adapter-Light architecture. ex: color-adapter",
)
parser.add_argument("--in_channels", required=False, type=int, help="Input channels for non-light adapter")
args = parser.parse_args()
src_state = torch.load(args.checkpoint_path)
if args.is_adapter_light:
adapter = convert_light_adapter(src_state)
else:
if args.in_channels is None:
raise ValueError("set `--in_channels=<n>`")
adapter = convert_adapter(src_state, args.in_channels)
adapter.save_pretrained(args.output_path)