1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
Dhruv Nair
2023-12-29 14:24:27 +00:00
parent 7a8c72200a
commit ccf8d62c22
2 changed files with 34 additions and 20 deletions

View File

@@ -339,7 +339,13 @@ class FromSingleFileMixin:
pipeline_components = {}
for component in component_names:
components = build_component(
pipeline_components, pipeline_name, component, checkpoint, original_config, **kwargs
pipeline_components,
pipeline_name,
component,
original_config,
checkpoint,
pretrained_model_link_or_path,
**kwargs,
)
pipeline_components.update(components)
@@ -354,4 +360,3 @@ class FromSingleFileMixin:
pipe.to(dtype=torch_dtype)
return pipe

View File

@@ -222,6 +222,23 @@ def get_default_scheduler_config():
return SCHEDULER_DEFAULT_CONFIG
def determine_image_size(pipeline_class_name, original_config, checkpoint, **kwargs):
image_size = kwargs.get("image_size", 512)
global_step = checkpoint["global_step"] if "global_step" in checkpoint else None
if pipeline_class_name == "StableDiffusionUpscalePipeline":
image_size = image_size or original_config.model.params.unet_config.params.image_size
elif (
"parameterization" in original_config["model"]["params"]
and original_config["model"]["params"]["parameterization"] == "v"
):
# NOTE: For stable diffusion 2 base one has to pass `image_size==512`
# as it relies on a brittle global step parameter here
image_size = 512 if global_step == 875000 else 768
return image_size
def shave_segments(path, n_shave_prefix_segments=1):
"""
Removes segments. Positive values shave the first segments, negative shave the last segments.
@@ -1268,19 +1285,7 @@ def create_unet_model(pipeline_class_name, original_config, checkpoint, checkpoi
else:
num_in_channels = 4
image_size = kwargs.get("image_size", 512)
global_step = checkpoint["global_step"] if "global_step" in checkpoint else None
if pipeline_class_name == "StableDiffusionUpscalePipeline":
image_size = image_size or original_config.model.params.unet_config.params.image_size
elif (
"parameterization" in original_config["model"]["params"]
and original_config["model"]["params"]["parameterization"] == "v"
):
# NOTE: For stable diffusion 2 base one has to pass `image_size==512`
# as it relies on a brittle global step parameter here
image_size = 512 if global_step == 875000 else 768
image_size = determine_image_size(pipeline_class_name, original_config, checkpoint, **kwargs)
upcast_attention = kwargs.get("upcast_attention", False)
extract_ema = kwargs.get("extract_ema", False)
@@ -1303,7 +1308,7 @@ def create_unet_model(pipeline_class_name, original_config, checkpoint, checkpoi
else:
unet.load_state_dict(diffusers_format_unet_checkpoint)
return unet
return {"unet": unet}
def create_controlnet_model(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, image_size, **kwargs):
@@ -1321,10 +1326,13 @@ def create_controlnet_model(pipeline_class_name, original_config, checkpoint, ch
return {"controlnet": controlnet}
def create_vae_model(original_config, checkpoint, checkpoint_path_or_dict, **kwargs):
vae_config = create_vae_diffusers_config(original_config)
def create_vae_model(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs):
image_size = determine_image_size(pipeline_class_name, original_config, checkpoint, **kwargs)
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
diffusers_format_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
vae = AutoencoderKL(**vae_config)
@@ -1334,7 +1342,7 @@ def create_vae_model(original_config, checkpoint, checkpoint_path_or_dict, **kwa
else:
vae.load_state_dict(diffusers_format_vae_checkpoint)
return vae
return {"vae": vae}
def create_text_encoders_and_tokenizers(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs):
@@ -1425,7 +1433,8 @@ def create_text_encoders_and_tokenizers(pipeline_class_name, original_config, ch
local_files_only=local_files_only,
**config_kwargs,
)
except Exception:
except Exception as e:
print(e)
raise ValueError(
f"With local_files_only set to {local_files_only}, you must first locally save the text_encoder_2 and tokenizer_2 in the following path: {config_name} with `pad_token` set to '!'."
)