1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[ckpt loader] Allow loading the Inpaint and Img2Img pipelines, while loading a ckpt model (#2705)

* [ckpt loader] Allow loading the Inpaint and Img2Img pipelines, while loading a ckpt model

* Address review comment from PR

* PyLint formatting

* Some more pylint fixes, unrelated to our change

* Another pylint fix

* Styling fix
This commit is contained in:
cmdr2
2023-04-19 18:07:07 +05:30
committed by GitHub
parent fc1883918f
commit bdeff4d64a

View File

@@ -45,6 +45,8 @@ from diffusers import (
PNDMScheduler,
PriorTransformer,
StableDiffusionControlNetPipeline,
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionPipeline,
StableUnCLIPImg2ImgPipeline,
StableUnCLIPPipeline,
@@ -979,6 +981,7 @@ def download_from_original_stable_diffusion_ckpt(
image_size: int = 512,
prediction_type: str = None,
model_type: str = None,
is_img2img: bool = False,
extract_ema: bool = False,
scheduler_type: str = "pndm",
num_in_channels: Optional[int] = None,
@@ -1018,6 +1021,8 @@ def download_from_original_stable_diffusion_ckpt(
model_type (`str`, *optional*, defaults to `None`):
The pipeline type. `None` to automatically infer, or one of `["FrozenOpenCLIPEmbedder",
"FrozenCLIPEmbedder", "PaintByExample"]`.
is_img2img (`bool`, *optional*, defaults to `False`):
Whether the model should be loaded as an img2img pipeline.
extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for
checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to
`False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for
@@ -1193,16 +1198,44 @@ def download_from_original_stable_diffusion_ckpt(
requires_safety_checker=False,
)
else:
pipe = StableDiffusionPipeline(
vae=vae,
text_encoder=text_model,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
if (
hasattr(original_config, "model")
and hasattr(original_config.model, "target")
and "LatentInpaintDiffusion" in original_config.model.target
):
pipe = StableDiffusionInpaintPipeline(
vae=vae,
text_encoder=text_model,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
else:
if is_img2img:
pipe = StableDiffusionImg2ImgPipeline(
vae=vae,
text_encoder=text_model,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
else:
pipe = StableDiffusionPipeline(
vae=vae,
text_encoder=text_model,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
else:
image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components(
original_config, clip_stats_path=clip_stats_path, device=device
@@ -1293,15 +1326,41 @@ def download_from_original_stable_diffusion_ckpt(
feature_extractor=feature_extractor,
)
else:
pipe = StableDiffusionPipeline(
vae=vae,
text_encoder=text_model,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
if (
hasattr(original_config, "model")
and hasattr(original_config.model, "target")
and "LatentInpaintDiffusion" in original_config.model.target
):
pipe = StableDiffusionInpaintPipeline(
vae=vae,
text_encoder=text_model,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
else:
if is_img2img:
pipe = StableDiffusionImg2ImgPipeline(
vae=vae,
text_encoder=text_model,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
else:
pipe = StableDiffusionPipeline(
vae=vae,
text_encoder=text_model,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
else:
text_config = create_ldm_bert_config(original_config)
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)