1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

test the full pipeline

This commit is contained in:
anton-l
2022-11-16 00:06:51 +01:00
parent 833cd1de1c
commit e455921ff0
7 changed files with 300 additions and 39 deletions

View File

@@ -28,11 +28,11 @@ from diffusers import (
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
StableDiffusionPipeline,
UNet2DConditionModel,
VersatileDiffusionPipeline,
)
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer
from transformers import CLIPProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection
SCHEDULER_CONFIG = Namespace(
@@ -44,7 +44,7 @@ SCHEDULER_CONFIG = Namespace(
}
)
UNET_IMAGE_CONFIG = Namespace(
IMAGE_UNET_CONFIG = Namespace(
**{
"input_channels": 4,
"model_channels": 320,
@@ -58,7 +58,7 @@ UNET_IMAGE_CONFIG = Namespace(
}
)
UNET_TEXT_CONFIG = Namespace(
TEXT_UNET_CONFIG = Namespace(
**{
"input_channels": 768,
"model_channels": 320,
@@ -750,21 +750,20 @@ if __name__ == "__main__":
# Convert the UNet2DConditionModel model.
if args.unet_checkpoint_path is not None:
unet_image_config = create_unet_diffusers_config(UNET_IMAGE_CONFIG)
image_unet_config = create_unet_diffusers_config(IMAGE_UNET_CONFIG)
checkpoint = torch.load(args.unet_checkpoint_path)
converted_unet_image_checkpoint = convert_vd_unet_checkpoint(
checkpoint, unet_image_config, unet_key="model.diffusion_model.unet_image.", extract_ema=args.extract_ema
converted_image_unet_checkpoint = convert_vd_unet_checkpoint(
checkpoint, image_unet_config, unet_key="model.diffusion_model.unet_image.", extract_ema=args.extract_ema
)
unet_image = UNet2DConditionModel(**unet_image_config)
unet_image.load_state_dict(converted_unet_image_checkpoint)
unet_image.save_pretrained(os.path.join(args.dump_path, "unet_image"))
image_unet = UNet2DConditionModel(**image_unet_config)
image_unet.load_state_dict(converted_image_unet_checkpoint)
# unet_text_config = create_unet_diffusers_config(UNET_TEXT_CONFIG)
# converted_unet_text_checkpoint = convert_vd_unet_checkpoint(
# checkpoint, unet_text_config, unet_key="model.diffusion_model.unet_text.", extract_ema=args.extract_ema
# text_unet_config = create_unet_diffusers_config(TEXT_UNET_CONFIG)
# converted_text_unet_checkpoint = convert_vd_unet_checkpoint(
# checkpoint, text_unet_config, unet_key="model.diffusion_model.unet_text.", extract_ema=args.extract_ema
# )
# unet_text = UNet2DConditionModel(**unet_text_config)
# unet_text.load_state_dict(converted_unet_text_checkpoint)
# text_unet = UNet2DConditionModel(**text_unet_config)
# text_unet.load_state_dict(converted_text_unet_checkpoint)
# Convert the VAE model.
if args.vae_checkpoint_path is not None:
@@ -774,28 +773,20 @@ if __name__ == "__main__":
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)
vae.save_pretrained(os.path.join(args.dump_path, "vae"))
# Convert the text model.
# text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
# if text_model_type == "FrozenCLIPEmbedder":
# text_model = convert_ldm_clip_checkpoint(checkpoint)
# tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
# safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
# feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
# pipe = StableDiffusionPipeline(
# vae=vae,
# text_encoder=text_model,
# tokenizer=tokenizer,
# unet=unet,
# scheduler=scheduler,
# safety_checker=safety_checker,
# feature_extractor=feature_extractor,
# )
# else:
# text_config = create_ldm_bert_config(original_config)
# text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
# tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
# pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
#
# pipe.save_pretrained(args.dump_path)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
image_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
pipe = VersatileDiffusionPipeline(
scheduler=scheduler,
tokenizer=tokenizer,
image_processor=image_processor,
text_encoder=text_encoder,
image_encoder=image_encoder,
image_unet=image_unet,
# text_unet=text_unet,
vae=vae,
)
pipe.save_pretrained(args.dump_path)