1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[WIP] Refactor UniDiffuser Pipeline and Tests (#4948)

* Add VAE slicing and tiling methods.

* Switch to using VaeImageProcessing for preprocessing and postprocessing of images.

* Rename the VaeImageProcessor to vae_image_processor to avoid a name clash with the CLIPImageProcessor (image_processor).

* Remove the postprocess() function because we're using a VaeImageProcessor instead.

* Remove UniDiffuserPipeline.decode_image_latents because we're using VaeImageProcessor instead.

* Refactor generating text from text latents into a decode_text_latents method.

* Add enable_full_determinism() to UniDiffuser tests.

* make style

* Add PipelineLatentTesterMixin to UniDiffuserPipelineFastTests.

* Remove enable_model_cpu_offload since it is now part of DiffusionPipeline.

* Rename the VaeImageProcessor instance to self.image_processor for consistency with other pipelines and rename the CLIPImageProcessor instance to clip_image_processor to avoid a name clash.

* Update UniDiffuser conversion script.

* Make safe_serialization configurable in UniDiffuser conversion script.

* Rename image_processor to clip_image_processor in UniDiffuser tests.

* Add PipelineKarrasSchedulerTesterMixin to UniDiffuserPipelineFastTests.

* Add initial test for compiling the UniDiffuser model (not tested yet).

* Update encode_prompt and _encode_prompt to match that of StableDiffusionPipeline.

* Turn off standard classifier-free guidance for now.

* make style

* make fix-copies

* apply suggestions from review

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
dg845
2023-10-02 09:24:55 -07:00
committed by GitHub
parent db91e710da
commit cd1b8d7ca8
3 changed files with 270 additions and 153 deletions

View File

@@ -73,17 +73,17 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
new_item = new_item.replace("norm.weight", "group_norm.weight")
new_item = new_item.replace("norm.bias", "group_norm.bias")
new_item = new_item.replace("q.weight", "query.weight")
new_item = new_item.replace("q.bias", "query.bias")
new_item = new_item.replace("q.weight", "to_q.weight")
new_item = new_item.replace("q.bias", "to_q.bias")
new_item = new_item.replace("k.weight", "key.weight")
new_item = new_item.replace("k.bias", "key.bias")
new_item = new_item.replace("k.weight", "to_k.weight")
new_item = new_item.replace("k.bias", "to_k.bias")
new_item = new_item.replace("v.weight", "value.weight")
new_item = new_item.replace("v.bias", "value.bias")
new_item = new_item.replace("v.weight", "to_v.weight")
new_item = new_item.replace("v.bias", "to_v.bias")
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
@@ -92,6 +92,19 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
return mapping
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear
def conv_attn_to_linear(checkpoint):
keys = list(checkpoint.keys())
attn_keys = ["query.weight", "key.weight", "value.weight"]
for key in keys:
if ".".join(key.split(".")[-2:]) in attn_keys:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0, 0]
elif "proj_attn.weight" in key:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0]
# Modified from diffusers.pipelines.stable_diffusion.convert_from_ckpt.assign_to_checkpoint
# config.num_head_channels => num_head_channels
def assign_to_checkpoint(
@@ -104,8 +117,9 @@ def assign_to_checkpoint(
):
"""
This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
attention layers, and takes into account additional replacements that may arise. Assigns the weights to the new
checkpoint.
attention layers, and takes into account additional replacements that may arise.
Assigns the weights to the new checkpoint.
"""
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
@@ -143,25 +157,16 @@ def assign_to_checkpoint(
new_path = new_path.replace(replacement["old"], replacement["new"])
# proj_attn.weight has to be converted from conv 1D to linear
if "proj_attn.weight" in new_path:
is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path)
shape = old_checkpoint[path["old"]].shape
if is_attn_weight and len(shape) == 3:
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
elif is_attn_weight and len(shape) == 4:
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
else:
checkpoint[new_path] = old_checkpoint[path["old"]]
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear
def conv_attn_to_linear(checkpoint):
keys = list(checkpoint.keys())
attn_keys = ["query.weight", "key.weight", "value.weight"]
for key in keys:
if ".".join(key.split(".")[-2:]) in attn_keys:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0, 0]
elif "proj_attn.weight" in key:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0]
def create_vae_diffusers_config(config_type):
# Hardcoded for now
if args.config_type == "test":
@@ -339,7 +344,7 @@ def create_text_decoder_config_big():
return text_decoder_config
# Based on diffusers.pipelines.stable_diffusion.convert_from_ckpt.shave_segments.convert_ldm_vae_checkpoint
# Based on diffusers.pipelines.stable_diffusion.convert_from_ckpt.convert_ldm_vae_checkpoint
def convert_vae_to_diffusers(ckpt, diffusers_model, num_head_channels=1):
"""
Converts a UniDiffuser autoencoder_kl.pth checkpoint to a diffusers AutoencoderKL.
@@ -674,6 +679,11 @@ if __name__ == "__main__":
type=int,
help="The UniDiffuser model type to convert to. Should be 0 for UniDiffuser-v0 and 1 for UniDiffuser-v1.",
)
parser.add_argument(
"--safe_serialization",
action="store_true",
help="Whether to use safetensors/safe seialization when saving the pipeline.",
)
args = parser.parse_args()
@@ -766,11 +776,11 @@ if __name__ == "__main__":
vae=vae,
text_encoder=text_encoder,
image_encoder=image_encoder,
image_processor=image_processor,
clip_image_processor=image_processor,
clip_tokenizer=clip_tokenizer,
text_decoder=text_decoder,
text_tokenizer=text_tokenizer,
unet=unet,
scheduler=scheduler,
)
pipeline.save_pretrained(args.pipeline_output_path)
pipeline.save_pretrained(args.pipeline_output_path, safe_serialization=args.safe_serialization)