1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
Dhruv Nair
2024-01-19 10:07:50 +00:00
parent b65861800e
commit 36203576d5
5 changed files with 27 additions and 26 deletions

View File

@@ -54,11 +54,11 @@ if is_transformers_available():
_import_structure = {}
if is_torch_available():
_import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
_import_structure["utils"] = ["AttnProcsLayers"]
_import_structure["controlnet"] = ["FromOriginalControlnetMixin"]
_import_structure["autoencoder"] = ["FromOriginalVAEMixin"]
_import_structure["controlnet"] = ["FromOriginalControlNetMixin"]
_import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
_import_structure["utils"] = ["AttnProcsLayers"]
if is_transformers_available():
_import_structure["single_file"] = ["FromSingleFileMixin"]
_import_structure["lora"] = ["LoraLoaderMixin", "StableDiffusionXLLoraLoaderMixin"]
@@ -71,7 +71,7 @@ _import_structure["peft"] = ["PeftAdapterMixin"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if is_torch_available():
from .autoencoder import FromOriginalVAEMixin
from .controlnet import FromOriginalControlnetMixin
from .controlnet import FromOriginalControlNetMixin
from .unet import UNet2DConditionLoadersMixin
from .utils import AttnProcsLayers

View File

@@ -22,14 +22,14 @@ from .single_file_utils import (
class FromOriginalVAEMixin:
"""
Load pretrained ControlNet weights saved in the `.ckpt` or `.safetensors` format into a [`ControlNetModel`].
Load pretrained AutoencoderKL weights saved in the `.ckpt` or `.safetensors` format into a [`ControlNetModel`].
"""
@classmethod
@validate_hf_hub_args
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
r"""
Instantiate a [`ControlNetModel`] from pretrained ControlNet weights saved in the original `.ckpt` or
Instantiate a [`AutoencoderKL`] from pretrained ControlNet weights saved in the original `.ckpt` or
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
Parameters:
@@ -62,32 +62,35 @@ class FromOriginalVAEMixin:
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
image_size (`int`, *optional*, defaults to 512):
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
use_safetensors (`bool`, *optional*, defaults to `None`):
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
weights. If set to `False`, safetensors weights are not loaded.
image_size (`int`, *optional*, defaults to 512):
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
upcast_attention (`bool`, *optional*, defaults to `None`):
Whether the attention computation should always be upcasted.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables (for example the pipeline components of the
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
method. See example below for more information.
<Tip warning={true}>
Make sure to pass both `image_size` and `scaling_factor` to `from_single_file()` if you're loading
a VAE from SDXL or a Stable Diffusion v2 model or higher.
</Tip>
Examples:
```py
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
from diffusers import AutoencoderKL
url = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth" # can also be a local path
model = ControlNetModel.from_single_file(url)
url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors" # can also be a local path
pipe = StableDiffusionControlNetPipeline.from_single_file(url, controlnet=controlnet)
url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors" # can also be local file
model = AutoencoderKL.from_single_file(url)
```
"""
original_config_file = kwargs.pop("original_config_file", None)
resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False)

View File

@@ -20,7 +20,7 @@ from .single_file_utils import (
)
class FromOriginalControlnetMixin:
class FromOriginalControlNetMixin:
"""
Load pretrained ControlNet weights saved in the `.ckpt` or `.safetensors` format into a [`ControlNetModel`].
"""

View File

@@ -507,7 +507,7 @@ def create_controlnet_diffusers_config(original_config, image_size: int):
return controlnet_config
def create_vae_diffusers_config(original_config, image_size: int):
def create_vae_diffusers_config(original_config, image_size, scaling_factor=0.18125):
"""
Creates a config for the diffusers based on the config of the LDM model.
"""
@@ -526,6 +526,7 @@ def create_vae_diffusers_config(original_config, image_size: int):
"block_out_channels": tuple(block_out_channels),
"latent_channels": vae_params["z_channels"],
"layers_per_block": vae_params["num_res_blocks"],
"scaling_factor": scaling_factor,
}
return config
@@ -1134,17 +1135,14 @@ def create_diffusers_unet_model_from_ldm(
def create_diffusers_vae_model_from_ldm(
pipeline_class_name,
original_config,
checkpoint,
image_size=None,
pipeline_class_name, original_config, checkpoint, image_size=None, scaling_factor=0.18125
):
# import here to avoid circular imports
from ..models import AutoencoderKL
image_size = set_image_size(pipeline_class_name, original_config, checkpoint, image_size=image_size)
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
vae_config = create_vae_diffusers_config(original_config, image_size=image_size, scaling_factor=scaling_factor)
diffusers_format_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
ctx = init_empty_weights if is_accelerate_available() else nullcontext

View File

@@ -19,7 +19,7 @@ from torch import nn
from torch.nn import functional as F
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import FromOriginalControlnetMixin
from ..loaders import FromOriginalControlNetMixin
from ..utils import BaseOutput, logging
from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
@@ -102,7 +102,7 @@ class ControlNetConditioningEmbedding(nn.Module):
return embedding
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
"""
A ControlNet model.