1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
Dhruv Nair
2023-12-28 12:36:05 +00:00
parent 0012dd2309
commit 2616e03062
2 changed files with 34 additions and 0 deletions

View File

@@ -41,6 +41,7 @@ from .single_file_utils import (
create_unet_model,
create_vae_model,
fetch_original_config,
infer_model_type,
)
@@ -218,6 +219,14 @@ def build_additional_components(
)
return paint_by_example_components
if pipeline_class_name == "StableDiffusionXLImg2ImgPipeline":
model_type = infer_model_type(pipeline_class_name, original_config)
is_refiner = model_type == "SDXL-Refiner"
return {
"requires_aesthetics_score": is_refiner,
"force_zeros_for_empty_prompt": False if is_refiner else True,
}
class FromSingleFileMixin:
"""

View File

@@ -1287,6 +1287,21 @@ def create_unet_model(pipeline_class_name, original_config, checkpoint, checkpoi
return unet
def create_controlnet_model(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, image_size, **kwargs):
if "control_stage_config" not in original_config.model.params:
raise ValueError("Config does not have controlnet information")
path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else ""
extract_ema = kwargs.get("extract_ema", False)
upcast_attention = kwargs.get("upcast_attention", False)
controlnet = convert_controlnet_checkpoint(
checkpoint, original_config, path, image_size, upcast_attention, extract_ema
)
return {"controlnet": controlnet}
def create_vae_model(original_config, checkpoint, checkpoint_path_or_dict, **kwargs):
vae_config = create_vae_diffusers_config(original_config)
diffusers_format_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
@@ -1477,6 +1492,16 @@ def create_scheduler(pipeline_class_name, original_config, checkpoint, checkpoin
else:
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
"""
elif model_type == "UpScale":
elif pipeline_class == StableDiffusionUpscalePipeline:
scheduler = DDIMScheduler.from_pretrained(
"stabilityai/stable-diffusion-x4-upscaler", subfolder="scheduler"
)
low_res_scheduler = DDPMScheduler.from_pretrained(
"stabilityai/stable-diffusion-x4-upscaler", subfolder="low_res_scheduler"
)
"""
return {"scheduler": scheduler}