mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[WIP] Refactor UniDiffuser Pipeline and Tests (#4948)
* Add VAE slicing and tiling methods. * Switch to using VaeImageProcessing for preprocessing and postprocessing of images. * Rename the VaeImageProcessor to vae_image_processor to avoid a name clash with the CLIPImageProcessor (image_processor). * Remove the postprocess() function because we're using a VaeImageProcessor instead. * Remove UniDiffuserPipeline.decode_image_latents because we're using VaeImageProcessor instead. * Refactor generating text from text latents into a decode_text_latents method. * Add enable_full_determinism() to UniDiffuser tests. * make style * Add PipelineLatentTesterMixin to UniDiffuserPipelineFastTests. * Remove enable_model_cpu_offload since it is now part of DiffusionPipeline. * Rename the VaeImageProcessor instance to self.image_processor for consistency with other pipelines and rename the CLIPImageProcessor instance to clip_image_processor to avoid a name clash. * Update UniDiffuser conversion script. * Make safe_serialization configurable in UniDiffuser conversion script. * Rename image_processor to clip_image_processor in UniDiffuser tests. * Add PipelineKarrasSchedulerTesterMixin to UniDiffuserPipelineFastTests. * Add initial test for compiling the UniDiffuser model (not tested yet). * Update encode_prompt and _encode_prompt to match that of StableDiffusionPipeline. * Turn off standard classifier-free guidance for now. * make style * make fix-copies * apply suggestions from review --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -73,17 +73,17 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
||||
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
||||
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
||||
|
||||
new_item = new_item.replace("q.weight", "query.weight")
|
||||
new_item = new_item.replace("q.bias", "query.bias")
|
||||
new_item = new_item.replace("q.weight", "to_q.weight")
|
||||
new_item = new_item.replace("q.bias", "to_q.bias")
|
||||
|
||||
new_item = new_item.replace("k.weight", "key.weight")
|
||||
new_item = new_item.replace("k.bias", "key.bias")
|
||||
new_item = new_item.replace("k.weight", "to_k.weight")
|
||||
new_item = new_item.replace("k.bias", "to_k.bias")
|
||||
|
||||
new_item = new_item.replace("v.weight", "value.weight")
|
||||
new_item = new_item.replace("v.bias", "value.bias")
|
||||
new_item = new_item.replace("v.weight", "to_v.weight")
|
||||
new_item = new_item.replace("v.bias", "to_v.bias")
|
||||
|
||||
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
|
||||
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
|
||||
new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
|
||||
new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
|
||||
|
||||
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
||||
|
||||
@@ -92,6 +92,19 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
||||
return mapping
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear
|
||||
def conv_attn_to_linear(checkpoint):
|
||||
keys = list(checkpoint.keys())
|
||||
attn_keys = ["query.weight", "key.weight", "value.weight"]
|
||||
for key in keys:
|
||||
if ".".join(key.split(".")[-2:]) in attn_keys:
|
||||
if checkpoint[key].ndim > 2:
|
||||
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
||||
elif "proj_attn.weight" in key:
|
||||
if checkpoint[key].ndim > 2:
|
||||
checkpoint[key] = checkpoint[key][:, :, 0]
|
||||
|
||||
|
||||
# Modified from diffusers.pipelines.stable_diffusion.convert_from_ckpt.assign_to_checkpoint
|
||||
# config.num_head_channels => num_head_channels
|
||||
def assign_to_checkpoint(
|
||||
@@ -104,8 +117,9 @@ def assign_to_checkpoint(
|
||||
):
|
||||
"""
|
||||
This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
|
||||
attention layers, and takes into account additional replacements that may arise. Assigns the weights to the new
|
||||
checkpoint.
|
||||
attention layers, and takes into account additional replacements that may arise.
|
||||
|
||||
Assigns the weights to the new checkpoint.
|
||||
"""
|
||||
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
||||
|
||||
@@ -143,25 +157,16 @@ def assign_to_checkpoint(
|
||||
new_path = new_path.replace(replacement["old"], replacement["new"])
|
||||
|
||||
# proj_attn.weight has to be converted from conv 1D to linear
|
||||
if "proj_attn.weight" in new_path:
|
||||
is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path)
|
||||
shape = old_checkpoint[path["old"]].shape
|
||||
if is_attn_weight and len(shape) == 3:
|
||||
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
|
||||
elif is_attn_weight and len(shape) == 4:
|
||||
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
|
||||
else:
|
||||
checkpoint[new_path] = old_checkpoint[path["old"]]
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear
|
||||
def conv_attn_to_linear(checkpoint):
|
||||
keys = list(checkpoint.keys())
|
||||
attn_keys = ["query.weight", "key.weight", "value.weight"]
|
||||
for key in keys:
|
||||
if ".".join(key.split(".")[-2:]) in attn_keys:
|
||||
if checkpoint[key].ndim > 2:
|
||||
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
||||
elif "proj_attn.weight" in key:
|
||||
if checkpoint[key].ndim > 2:
|
||||
checkpoint[key] = checkpoint[key][:, :, 0]
|
||||
|
||||
|
||||
def create_vae_diffusers_config(config_type):
|
||||
# Hardcoded for now
|
||||
if args.config_type == "test":
|
||||
@@ -339,7 +344,7 @@ def create_text_decoder_config_big():
|
||||
return text_decoder_config
|
||||
|
||||
|
||||
# Based on diffusers.pipelines.stable_diffusion.convert_from_ckpt.shave_segments.convert_ldm_vae_checkpoint
|
||||
# Based on diffusers.pipelines.stable_diffusion.convert_from_ckpt.convert_ldm_vae_checkpoint
|
||||
def convert_vae_to_diffusers(ckpt, diffusers_model, num_head_channels=1):
|
||||
"""
|
||||
Converts a UniDiffuser autoencoder_kl.pth checkpoint to a diffusers AutoencoderKL.
|
||||
@@ -674,6 +679,11 @@ if __name__ == "__main__":
|
||||
type=int,
|
||||
help="The UniDiffuser model type to convert to. Should be 0 for UniDiffuser-v0 and 1 for UniDiffuser-v1.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--safe_serialization",
|
||||
action="store_true",
|
||||
help="Whether to use safetensors/safe seialization when saving the pipeline.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -766,11 +776,11 @@ if __name__ == "__main__":
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
image_encoder=image_encoder,
|
||||
image_processor=image_processor,
|
||||
clip_image_processor=image_processor,
|
||||
clip_tokenizer=clip_tokenizer,
|
||||
text_decoder=text_decoder,
|
||||
text_tokenizer=text_tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
pipeline.save_pretrained(args.pipeline_output_path)
|
||||
pipeline.save_pretrained(args.pipeline_output_path, safe_serialization=args.safe_serialization)
|
||||
|
||||
Reference in New Issue
Block a user