mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Würstchen model (#3849)
* initial * initial * added initial convert script for paella vqmodel * initial wuerstchen pipeline * add LayerNorm2d * added modules * fix typo * use model_v2 * embed clip caption amd negative_caption * fixed name of var * initial modules in one place * WuerstchenPriorPipeline * inital shape * initial denoising prior loop * fix output * add WuerstchenPriorPipeline to __init__.py * use the noise ratio in the Prior * try to save pipeline * save_pretrained working * Few additions * add _execution_device * shape is int * fix batch size * fix shape of ratio * fix shape of ratio * fix output dataclass * tests folder * fix formatting * fix float16 + started with generator * Update pipeline_wuerstchen.py * removed vqgan code * add WuerstchenGeneratorPipeline * fix WuerstchenGeneratorPipeline * fix docstrings * fix imports * convert generator pipeline * fix convert * Work on Generator Pipeline. WIP * Pipeline works with our diffuzz code * apply scale factor * removed vqgan.py * use cosine schedule * redo the denoising loop * Update src/diffusers/models/resnet.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * use torch.lerp * use warp-diffusion org * clip_sample=False, * some refactoring * use model_v3_stage_c * c_cond size * use clip-bigG * allow stage b clip to be None * add dummy * würstchen scheduler * minor changes * set clip=None in the pipeline * fix attention mask * add attention_masks to text_encoder * make fix-copies * add back clip * add text_encoder * gen_text_encoder and tokenizer * fix import * updated pipeline test * undo changes to pipeline test * nip * fix typo * fix output name * set guidance_scale=0 and remove diffuze * fix doc strings * make style * nip * removed unused * initial docs * rename * toc * cleanup * remvoe test script * fix-copies * fix multi images * remove dup * remove unused modules * undo changes for debugging * no new line * remove dup conversion script * fix doc string * cleanup * pass default args * dup permute * fix some tests * fix prepare_latents * move Prior class to modules * offload only the text encoder and vqgan * fix resolution calculation for prior * nip * removed testing script * fix shape * fix argument to set_timesteps * do not change .gitignore * fix resolution calculations + readme * resolution calculation fix + readme * small fixes * Add combined pipeline * rename generator -> decoder * Update .gitignore Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * removed efficient_net * create combined WuerstchenPipeline * make arguments consistent with VQ model * fix var names * no need to return text_encoder_hidden_states * add latent_dim_scale to config * split model into its own file * add WuerschenPipeline to docs * remove unused latent_size * register latent_dim_scale * update script * update docstring * use Attention preprocessor * concat with normed input * fix-copies * add docs * fix test * fix style * add to cpu_offloaded_model * updated type * remove 1-line func * updated type * initial decoder test * formatting * formatting * fix autodoc link * num_inference_steps is int * remove comments * fix example in docs * Update src/diffusers/pipelines/wuerstchen/diffnext.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * rename layernorm to WuerstchenLayerNorm * rename DiffNext to WuerstchenDiffNeXt * added comment about MixingResidualBlock * move paella vq-vae to pipelines' folder * initial decoder test * increased test_float16_inference expected diff * self_attn is always true * more passing decoder tests * batch image_embeds * fix failing tests * set the correct dtype * relax inference test * update prior * added combined pipeline test * faster test * faster test * Update src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * fix issues from review * update wuerstchen.md + change generator name * resolve issues * fix copied from usage and add back batch_size * fix API * fix arguments * fix combined test * Added timesteps argument + fixes * Update tests/pipelines/test_pipelines_common.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update tests/pipelines/wuerstchen/test_wuerstchen_prior.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py * up * Fix more * failing tests * up * up * correct naming * correct docs * correct docs * fix test params * correct docs * fix classifier free guidance * fix classifier free guidance * fix more * fix all * make tests faster --------- Co-authored-by: Dominic Rampas <d6582533@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Dominic Rampas <61938694+dome272@users.noreply.github.com>
This commit is contained in:
115
scripts/convert_wuerstchen.py
Normal file
115
scripts/convert_wuerstchen.py
Normal file
@@ -0,0 +1,115 @@
|
||||
# Run inside root directory of official source code: https://github.com/dome272/wuerstchen/
|
||||
import os
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer, CLIPTextModel
|
||||
from vqgan import VQModel
|
||||
|
||||
from diffusers import (
|
||||
DDPMWuerstchenScheduler,
|
||||
WuerstchenCombinedPipeline,
|
||||
WuerstchenDecoderPipeline,
|
||||
WuerstchenPriorPipeline,
|
||||
)
|
||||
from diffusers.pipelines.wuerstchen import PaellaVQModel, WuerstchenDiffNeXt, WuerstchenPrior
|
||||
|
||||
|
||||
model_path = "models/"
|
||||
device = "cpu"
|
||||
|
||||
paella_vqmodel = VQModel()
|
||||
state_dict = torch.load(os.path.join(model_path, "vqgan_f4_v1_500k.pt"), map_location=device)["state_dict"]
|
||||
paella_vqmodel.load_state_dict(state_dict)
|
||||
|
||||
state_dict["vquantizer.embedding.weight"] = state_dict["vquantizer.codebook.weight"]
|
||||
state_dict.pop("vquantizer.codebook.weight")
|
||||
vqmodel = PaellaVQModel(num_vq_embeddings=paella_vqmodel.codebook_size, latent_channels=paella_vqmodel.c_latent)
|
||||
vqmodel.load_state_dict(state_dict)
|
||||
|
||||
# Clip Text encoder and tokenizer
|
||||
text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
|
||||
tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
|
||||
|
||||
# Generator
|
||||
gen_text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K").to("cpu")
|
||||
gen_tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
|
||||
|
||||
orig_state_dict = torch.load(os.path.join(model_path, "model_v2_stage_b.pt"), map_location=device)["state_dict"]
|
||||
state_dict = {}
|
||||
for key in orig_state_dict.keys():
|
||||
if key.endswith("in_proj_weight"):
|
||||
weights = orig_state_dict[key].chunk(3, 0)
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
|
||||
elif key.endswith("in_proj_bias"):
|
||||
weights = orig_state_dict[key].chunk(3, 0)
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
|
||||
elif key.endswith("out_proj.weight"):
|
||||
weights = orig_state_dict[key]
|
||||
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
|
||||
elif key.endswith("out_proj.bias"):
|
||||
weights = orig_state_dict[key]
|
||||
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
|
||||
else:
|
||||
state_dict[key] = orig_state_dict[key]
|
||||
deocder = WuerstchenDiffNeXt()
|
||||
deocder.load_state_dict(state_dict)
|
||||
|
||||
# Prior
|
||||
orig_state_dict = torch.load(os.path.join(model_path, "model_v3_stage_c.pt"), map_location=device)["ema_state_dict"]
|
||||
state_dict = {}
|
||||
for key in orig_state_dict.keys():
|
||||
if key.endswith("in_proj_weight"):
|
||||
weights = orig_state_dict[key].chunk(3, 0)
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
|
||||
elif key.endswith("in_proj_bias"):
|
||||
weights = orig_state_dict[key].chunk(3, 0)
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
|
||||
elif key.endswith("out_proj.weight"):
|
||||
weights = orig_state_dict[key]
|
||||
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
|
||||
elif key.endswith("out_proj.bias"):
|
||||
weights = orig_state_dict[key]
|
||||
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
|
||||
else:
|
||||
state_dict[key] = orig_state_dict[key]
|
||||
prior_model = WuerstchenPrior(c_in=16, c=1536, c_cond=1280, c_r=64, depth=32, nhead=24).to(device)
|
||||
prior_model.load_state_dict(state_dict)
|
||||
|
||||
# scheduler
|
||||
scheduler = DDPMWuerstchenScheduler()
|
||||
|
||||
# Prior pipeline
|
||||
prior_pipeline = WuerstchenPriorPipeline(
|
||||
prior=prior_model, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler
|
||||
)
|
||||
|
||||
prior_pipeline.save_pretrained("warp-diffusion/wuerstchen-prior")
|
||||
|
||||
decoder_pipeline = WuerstchenDecoderPipeline(
|
||||
text_encoder=gen_text_encoder, tokenizer=gen_tokenizer, vqgan=vqmodel, decoder=deocder, scheduler=scheduler
|
||||
)
|
||||
decoder_pipeline.save_pretrained("warp-diffusion/wuerstchen")
|
||||
|
||||
# Wuerstchen pipeline
|
||||
wuerstchen_pipeline = WuerstchenCombinedPipeline(
|
||||
# Decoder
|
||||
text_encoder=gen_text_encoder,
|
||||
tokenizer=gen_tokenizer,
|
||||
decoder=deocder,
|
||||
scheduler=scheduler,
|
||||
vqgan=vqmodel,
|
||||
# Prior
|
||||
prior_tokenizer=tokenizer,
|
||||
prior_text_encoder=text_encoder,
|
||||
prior=prior_model,
|
||||
prior_scheduler=scheduler,
|
||||
)
|
||||
wuerstchen_pipeline.save_pretrained("warp-diffusion/WuerstchenCombinedPipeline")
|
||||
Reference in New Issue
Block a user