1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Adds local_files_only bool to prevent forced online connection (#3486)

This commit is contained in:
w4ffl35
2023-05-22 08:44:36 -06:00
committed by GitHub
parent 194b0a425d
commit 0160e5146f

View File

@@ -727,8 +727,8 @@ def convert_ldm_bert_checkpoint(checkpoint, config):
return hf_model
def convert_ldm_clip_checkpoint(checkpoint):
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False):
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only)
keys = list(checkpoint.keys())
@@ -992,6 +992,7 @@ def download_from_original_stable_diffusion_ckpt(
controlnet: Optional[bool] = None,
load_safety_checker: bool = True,
pipeline_class: DiffusionPipeline = None,
local_files_only=False
) -> DiffusionPipeline:
"""
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
@@ -1037,6 +1038,8 @@ def download_from_original_stable_diffusion_ckpt(
Whether to load the safety checker or not. Defaults to `True`.
pipeline_class (`str`, *optional*, defaults to `None`):
The pipeline class to use. Pass `None` to determine automatically.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether or not to only look at local files (i.e., do not try to download the model).
return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
"""
@@ -1292,7 +1295,7 @@ def download_from_original_stable_diffusion_ckpt(
feature_extractor=feature_extractor,
)
elif model_type == "FrozenCLIPEmbedder":
text_model = convert_ldm_clip_checkpoint(checkpoint)
text_model = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
if load_safety_checker: