mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[ckpt loader] Allow loading the Inpaint and Img2Img pipelines, while loading a ckpt model (#2705)
* [ckpt loader] Allow loading the Inpaint and Img2Img pipelines, while loading a ckpt model * Address review comment from PR * PyLint formatting * Some more pylint fixes, unrelated to our change * Another pylint fix * Styling fix
This commit is contained in:
@@ -45,6 +45,8 @@ from diffusers import (
|
||||
PNDMScheduler,
|
||||
PriorTransformer,
|
||||
StableDiffusionControlNetPipeline,
|
||||
StableDiffusionImg2ImgPipeline,
|
||||
StableDiffusionInpaintPipeline,
|
||||
StableDiffusionPipeline,
|
||||
StableUnCLIPImg2ImgPipeline,
|
||||
StableUnCLIPPipeline,
|
||||
@@ -979,6 +981,7 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
image_size: int = 512,
|
||||
prediction_type: str = None,
|
||||
model_type: str = None,
|
||||
is_img2img: bool = False,
|
||||
extract_ema: bool = False,
|
||||
scheduler_type: str = "pndm",
|
||||
num_in_channels: Optional[int] = None,
|
||||
@@ -1018,6 +1021,8 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
model_type (`str`, *optional*, defaults to `None`):
|
||||
The pipeline type. `None` to automatically infer, or one of `["FrozenOpenCLIPEmbedder",
|
||||
"FrozenCLIPEmbedder", "PaintByExample"]`.
|
||||
is_img2img (`bool`, *optional*, defaults to `False`):
|
||||
Whether the model should be loaded as an img2img pipeline.
|
||||
extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for
|
||||
checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to
|
||||
`False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for
|
||||
@@ -1193,16 +1198,44 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
else:
|
||||
pipe = StableDiffusionPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_model,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
if (
|
||||
hasattr(original_config, "model")
|
||||
and hasattr(original_config.model, "target")
|
||||
and "LatentInpaintDiffusion" in original_config.model.target
|
||||
):
|
||||
pipe = StableDiffusionInpaintPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_model,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
else:
|
||||
if is_img2img:
|
||||
pipe = StableDiffusionImg2ImgPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_model,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
else:
|
||||
pipe = StableDiffusionPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_model,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
else:
|
||||
image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components(
|
||||
original_config, clip_stats_path=clip_stats_path, device=device
|
||||
@@ -1293,15 +1326,41 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
else:
|
||||
pipe = StableDiffusionPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_model,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
if (
|
||||
hasattr(original_config, "model")
|
||||
and hasattr(original_config.model, "target")
|
||||
and "LatentInpaintDiffusion" in original_config.model.target
|
||||
):
|
||||
pipe = StableDiffusionInpaintPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_model,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
else:
|
||||
if is_img2img:
|
||||
pipe = StableDiffusionImg2ImgPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_model,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
else:
|
||||
pipe = StableDiffusionPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_model,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
else:
|
||||
text_config = create_ldm_bert_config(original_config)
|
||||
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
|
||||
|
||||
Reference in New Issue
Block a user