mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
test the full pipeline
This commit is contained in:
@@ -28,11 +28,11 @@ from diffusers import (
|
||||
EulerDiscreteScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
VersatileDiffusionPipeline,
|
||||
)
|
||||
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
|
||||
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer
|
||||
from transformers import CLIPProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
|
||||
SCHEDULER_CONFIG = Namespace(
|
||||
@@ -44,7 +44,7 @@ SCHEDULER_CONFIG = Namespace(
|
||||
}
|
||||
)
|
||||
|
||||
UNET_IMAGE_CONFIG = Namespace(
|
||||
IMAGE_UNET_CONFIG = Namespace(
|
||||
**{
|
||||
"input_channels": 4,
|
||||
"model_channels": 320,
|
||||
@@ -58,7 +58,7 @@ UNET_IMAGE_CONFIG = Namespace(
|
||||
}
|
||||
)
|
||||
|
||||
UNET_TEXT_CONFIG = Namespace(
|
||||
TEXT_UNET_CONFIG = Namespace(
|
||||
**{
|
||||
"input_channels": 768,
|
||||
"model_channels": 320,
|
||||
@@ -750,21 +750,20 @@ if __name__ == "__main__":
|
||||
|
||||
# Convert the UNet2DConditionModel model.
|
||||
if args.unet_checkpoint_path is not None:
|
||||
unet_image_config = create_unet_diffusers_config(UNET_IMAGE_CONFIG)
|
||||
image_unet_config = create_unet_diffusers_config(IMAGE_UNET_CONFIG)
|
||||
checkpoint = torch.load(args.unet_checkpoint_path)
|
||||
converted_unet_image_checkpoint = convert_vd_unet_checkpoint(
|
||||
checkpoint, unet_image_config, unet_key="model.diffusion_model.unet_image.", extract_ema=args.extract_ema
|
||||
converted_image_unet_checkpoint = convert_vd_unet_checkpoint(
|
||||
checkpoint, image_unet_config, unet_key="model.diffusion_model.unet_image.", extract_ema=args.extract_ema
|
||||
)
|
||||
unet_image = UNet2DConditionModel(**unet_image_config)
|
||||
unet_image.load_state_dict(converted_unet_image_checkpoint)
|
||||
unet_image.save_pretrained(os.path.join(args.dump_path, "unet_image"))
|
||||
image_unet = UNet2DConditionModel(**image_unet_config)
|
||||
image_unet.load_state_dict(converted_image_unet_checkpoint)
|
||||
|
||||
# unet_text_config = create_unet_diffusers_config(UNET_TEXT_CONFIG)
|
||||
# converted_unet_text_checkpoint = convert_vd_unet_checkpoint(
|
||||
# checkpoint, unet_text_config, unet_key="model.diffusion_model.unet_text.", extract_ema=args.extract_ema
|
||||
# text_unet_config = create_unet_diffusers_config(TEXT_UNET_CONFIG)
|
||||
# converted_text_unet_checkpoint = convert_vd_unet_checkpoint(
|
||||
# checkpoint, text_unet_config, unet_key="model.diffusion_model.unet_text.", extract_ema=args.extract_ema
|
||||
# )
|
||||
# unet_text = UNet2DConditionModel(**unet_text_config)
|
||||
# unet_text.load_state_dict(converted_unet_text_checkpoint)
|
||||
# text_unet = UNet2DConditionModel(**text_unet_config)
|
||||
# text_unet.load_state_dict(converted_text_unet_checkpoint)
|
||||
|
||||
# Convert the VAE model.
|
||||
if args.vae_checkpoint_path is not None:
|
||||
@@ -774,28 +773,20 @@ if __name__ == "__main__":
|
||||
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
vae.save_pretrained(os.path.join(args.dump_path, "vae"))
|
||||
|
||||
# Convert the text model.
|
||||
# text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
|
||||
# if text_model_type == "FrozenCLIPEmbedder":
|
||||
# text_model = convert_ldm_clip_checkpoint(checkpoint)
|
||||
# tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
# safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
||||
# feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
||||
# pipe = StableDiffusionPipeline(
|
||||
# vae=vae,
|
||||
# text_encoder=text_model,
|
||||
# tokenizer=tokenizer,
|
||||
# unet=unet,
|
||||
# scheduler=scheduler,
|
||||
# safety_checker=safety_checker,
|
||||
# feature_extractor=feature_extractor,
|
||||
# )
|
||||
# else:
|
||||
# text_config = create_ldm_bert_config(original_config)
|
||||
# text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
|
||||
# tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
|
||||
# pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
|
||||
#
|
||||
# pipe.save_pretrained(args.dump_path)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
image_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
||||
text_encoder = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
|
||||
image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
|
||||
|
||||
pipe = VersatileDiffusionPipeline(
|
||||
scheduler=scheduler,
|
||||
tokenizer=tokenizer,
|
||||
image_processor=image_processor,
|
||||
text_encoder=text_encoder,
|
||||
image_encoder=image_encoder,
|
||||
image_unet=image_unet,
|
||||
# text_unet=text_unet,
|
||||
vae=vae,
|
||||
)
|
||||
pipe.save_pretrained(args.dump_path)
|
||||
|
||||
Reference in New Issue
Block a user