1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Würstchen model (#3849)

* initial

* initial

* added initial convert script for paella vqmodel

* initial wuerstchen pipeline

* add LayerNorm2d

* added modules

* fix typo

* use model_v2

* embed clip caption amd negative_caption

* fixed name of var

* initial modules in one place

* WuerstchenPriorPipeline

* inital shape

* initial denoising prior loop

* fix output

* add WuerstchenPriorPipeline to __init__.py

* use the noise ratio in the Prior

* try to save pipeline

* save_pretrained working

* Few additions

* add _execution_device

* shape is int

* fix batch size

* fix shape of ratio

* fix shape of ratio

* fix output dataclass

* tests folder

* fix formatting

* fix float16 + started with generator

* Update pipeline_wuerstchen.py

* removed vqgan code

* add WuerstchenGeneratorPipeline

* fix WuerstchenGeneratorPipeline

* fix docstrings

* fix imports

* convert generator pipeline

* fix convert

* Work on Generator Pipeline. WIP

* Pipeline works with our diffuzz code

* apply scale factor

* removed vqgan.py

* use cosine schedule

* redo the denoising loop

* Update src/diffusers/models/resnet.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* use torch.lerp

* use warp-diffusion org

* clip_sample=False,

* some refactoring

* use model_v3_stage_c

* c_cond size

* use clip-bigG

* allow stage b clip to be None

* add dummy

* würstchen scheduler

* minor changes

* set clip=None in the pipeline

* fix attention mask

* add attention_masks to text_encoder

* make fix-copies

* add back clip

* add text_encoder

* gen_text_encoder and tokenizer

* fix import

* updated pipeline test

* undo changes to pipeline test

* nip

* fix typo

* fix output name

* set guidance_scale=0 and remove diffuze

* fix doc strings

* make style

* nip

* removed unused

* initial docs

* rename

* toc

* cleanup

* remvoe test script

* fix-copies

* fix multi images

* remove dup

* remove unused modules

* undo changes for debugging

* no  new line

* remove dup conversion script

* fix doc string

* cleanup

* pass default args

* dup permute

* fix some tests

* fix prepare_latents

* move Prior class to modules

* offload only the text encoder and vqgan

* fix resolution calculation for prior

* nip

* removed testing script

* fix shape

* fix argument to set_timesteps

* do not change .gitignore

* fix resolution calculations + readme

* resolution calculation fix + readme

* small fixes

* Add combined pipeline

* rename generator -> decoder

* Update .gitignore

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* removed efficient_net

* create combined WuerstchenPipeline

* make arguments consistent with VQ model

* fix var names

* no need to return text_encoder_hidden_states

* add latent_dim_scale to config

* split model into its own file

* add WuerschenPipeline to docs

* remove unused latent_size

* register latent_dim_scale

* update script

* update docstring

* use Attention preprocessor

* concat with normed input

* fix-copies

* add docs

* fix test

* fix style

* add to cpu_offloaded_model

* updated type

* remove 1-line func

* updated type

* initial decoder test

* formatting

* formatting

* fix autodoc link

* num_inference_steps is int

* remove comments

* fix example in docs

* Update src/diffusers/pipelines/wuerstchen/diffnext.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* rename layernorm to WuerstchenLayerNorm

* rename DiffNext to WuerstchenDiffNeXt

* added comment about MixingResidualBlock

* move paella vq-vae to pipelines' folder

* initial decoder test

* increased test_float16_inference expected diff

* self_attn is always true

* more passing decoder tests

* batch image_embeds

* fix failing tests

* set the correct dtype

* relax inference test

* update prior

* added combined pipeline test

* faster test

* faster test

* Update src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* fix issues from review

* update wuerstchen.md + change generator name

* resolve issues

* fix copied from usage and add back batch_size

* fix API

* fix arguments

* fix combined test

* Added timesteps argument + fixes

* Update tests/pipelines/test_pipelines_common.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update tests/pipelines/wuerstchen/test_wuerstchen_prior.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py

* up

* Fix more

* failing tests

* up

* up

* correct naming

* correct docs

* correct docs

* fix test params

* correct docs

* fix classifier free guidance

* fix classifier free guidance

* fix more

* fix all

* make tests faster

---------

Co-authored-by: Dominic Rampas <d6582533@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Dominic Rampas <61938694+dome272@users.noreply.github.com>
This commit is contained in:
Kashif Rasul
2023-09-06 16:15:51 +02:00
committed by GitHub
parent b76274cb53
commit 541bb6ee63
26 changed files with 2838 additions and 15 deletions

View File

@@ -0,0 +1,115 @@
# Run inside root directory of official source code: https://github.com/dome272/wuerstchen/
import os
import torch
from transformers import AutoTokenizer, CLIPTextModel
from vqgan import VQModel
from diffusers import (
DDPMWuerstchenScheduler,
WuerstchenCombinedPipeline,
WuerstchenDecoderPipeline,
WuerstchenPriorPipeline,
)
from diffusers.pipelines.wuerstchen import PaellaVQModel, WuerstchenDiffNeXt, WuerstchenPrior
model_path = "models/"
device = "cpu"
paella_vqmodel = VQModel()
state_dict = torch.load(os.path.join(model_path, "vqgan_f4_v1_500k.pt"), map_location=device)["state_dict"]
paella_vqmodel.load_state_dict(state_dict)
state_dict["vquantizer.embedding.weight"] = state_dict["vquantizer.codebook.weight"]
state_dict.pop("vquantizer.codebook.weight")
vqmodel = PaellaVQModel(num_vq_embeddings=paella_vqmodel.codebook_size, latent_channels=paella_vqmodel.c_latent)
vqmodel.load_state_dict(state_dict)
# Clip Text encoder and tokenizer
text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
# Generator
gen_text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K").to("cpu")
gen_tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
orig_state_dict = torch.load(os.path.join(model_path, "model_v2_stage_b.pt"), map_location=device)["state_dict"]
state_dict = {}
for key in orig_state_dict.keys():
if key.endswith("in_proj_weight"):
weights = orig_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
elif key.endswith("in_proj_bias"):
weights = orig_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
elif key.endswith("out_proj.weight"):
weights = orig_state_dict[key]
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
elif key.endswith("out_proj.bias"):
weights = orig_state_dict[key]
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
else:
state_dict[key] = orig_state_dict[key]
deocder = WuerstchenDiffNeXt()
deocder.load_state_dict(state_dict)
# Prior
orig_state_dict = torch.load(os.path.join(model_path, "model_v3_stage_c.pt"), map_location=device)["ema_state_dict"]
state_dict = {}
for key in orig_state_dict.keys():
if key.endswith("in_proj_weight"):
weights = orig_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
elif key.endswith("in_proj_bias"):
weights = orig_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
elif key.endswith("out_proj.weight"):
weights = orig_state_dict[key]
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
elif key.endswith("out_proj.bias"):
weights = orig_state_dict[key]
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
else:
state_dict[key] = orig_state_dict[key]
prior_model = WuerstchenPrior(c_in=16, c=1536, c_cond=1280, c_r=64, depth=32, nhead=24).to(device)
prior_model.load_state_dict(state_dict)
# scheduler
scheduler = DDPMWuerstchenScheduler()
# Prior pipeline
prior_pipeline = WuerstchenPriorPipeline(
prior=prior_model, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler
)
prior_pipeline.save_pretrained("warp-diffusion/wuerstchen-prior")
decoder_pipeline = WuerstchenDecoderPipeline(
text_encoder=gen_text_encoder, tokenizer=gen_tokenizer, vqgan=vqmodel, decoder=deocder, scheduler=scheduler
)
decoder_pipeline.save_pretrained("warp-diffusion/wuerstchen")
# Wuerstchen pipeline
wuerstchen_pipeline = WuerstchenCombinedPipeline(
# Decoder
text_encoder=gen_text_encoder,
tokenizer=gen_tokenizer,
decoder=deocder,
scheduler=scheduler,
vqgan=vqmodel,
# Prior
prior_tokenizer=tokenizer,
prior_text_encoder=text_encoder,
prior=prior_model,
prior_scheduler=scheduler,
)
wuerstchen_pipeline.save_pretrained("warp-diffusion/WuerstchenCombinedPipeline")