mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
add from_ckpt method as Mixin (#2318)
* add mixin class for pipeline from original sd ckpt * Improve * make style * merge main into * Improve more * fix more * up * Apply suggestions from code review * finish docs * rename * make style --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -36,3 +36,7 @@ API to load such adapter neural networks via the [`loaders.py` module](https://g
|
||||
### LoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.LoraLoaderMixin
|
||||
|
||||
### FromCkptMixin
|
||||
|
||||
[[autodoc]] loaders.FromCkptMixin
|
||||
|
||||
@@ -308,6 +308,7 @@ All checkpoints can be found under the authors' namespace [lllyasviel](https://h
|
||||
- disable_vae_slicing
|
||||
- enable_xformers_memory_efficient_attention
|
||||
- disable_xformers_memory_efficient_attention
|
||||
- load_textual_inversion
|
||||
|
||||
## FlaxStableDiffusionControlNetPipeline
|
||||
[[autodoc]] FlaxStableDiffusionControlNetPipeline
|
||||
|
||||
@@ -30,4 +30,7 @@ Available Checkpoints are:
|
||||
- enable_attention_slicing
|
||||
- disable_attention_slicing
|
||||
- enable_xformers_memory_efficient_attention
|
||||
- disable_xformers_memory_efficient_attention
|
||||
- disable_xformers_memory_efficient_attention
|
||||
- load_textual_inversion
|
||||
- load_lora_weights
|
||||
- save_lora_weights
|
||||
|
||||
@@ -30,7 +30,11 @@ proposed by Chenlin Meng, Yutong He, Yang Song, Jiaming Song, Jiajun Wu, Jun-Yan
|
||||
- disable_attention_slicing
|
||||
- enable_xformers_memory_efficient_attention
|
||||
- disable_xformers_memory_efficient_attention
|
||||
- load_textual_inversion
|
||||
- from_ckpt
|
||||
- load_lora_weights
|
||||
- save_lora_weights
|
||||
|
||||
[[autodoc]] FlaxStableDiffusionImg2ImgPipeline
|
||||
- all
|
||||
- __call__
|
||||
- __call__
|
||||
|
||||
@@ -31,7 +31,10 @@ Available checkpoints are:
|
||||
- disable_attention_slicing
|
||||
- enable_xformers_memory_efficient_attention
|
||||
- disable_xformers_memory_efficient_attention
|
||||
- load_textual_inversion
|
||||
- load_lora_weights
|
||||
- save_lora_weights
|
||||
|
||||
[[autodoc]] FlaxStableDiffusionInpaintPipeline
|
||||
- all
|
||||
- __call__
|
||||
- __call__
|
||||
|
||||
@@ -68,3 +68,6 @@ images[0].save("snowy_mountains.png")
|
||||
[[autodoc]] StableDiffusionInstructPix2PixPipeline
|
||||
- __call__
|
||||
- all
|
||||
- load_textual_inversion
|
||||
- load_lora_weights
|
||||
- save_lora_weights
|
||||
|
||||
@@ -39,6 +39,10 @@ Available Checkpoints are:
|
||||
- disable_xformers_memory_efficient_attention
|
||||
- enable_vae_tiling
|
||||
- disable_vae_tiling
|
||||
- load_textual_inversion
|
||||
- from_ckpt
|
||||
- load_lora_weights
|
||||
- save_lora_weights
|
||||
|
||||
[[autodoc]] FlaxStableDiffusionPipeline
|
||||
- all
|
||||
|
||||
@@ -109,7 +109,6 @@ try:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .loaders import TextualInversionLoaderMixin
|
||||
from .pipelines import (
|
||||
AltDiffusionImg2ImgPipeline,
|
||||
AltDiffusionPipeline,
|
||||
|
||||
@@ -13,9 +13,11 @@
|
||||
# limitations under the License.
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from .models.attention_processor import LoRAAttnProcessor
|
||||
from .utils import (
|
||||
@@ -431,6 +433,7 @@ class TextualInversionLoaderMixin:
|
||||
Example:
|
||||
|
||||
To load a textual inversion embedding vector in `diffusers` format:
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionPipeline
|
||||
import torch
|
||||
@@ -463,6 +466,7 @@ class TextualInversionLoaderMixin:
|
||||
image = pipe(prompt, num_inference_steps=50).images[0]
|
||||
image.save("character.png")
|
||||
```
|
||||
|
||||
"""
|
||||
if not hasattr(self, "tokenizer") or not isinstance(self.tokenizer, PreTrainedTokenizer):
|
||||
raise ValueError(
|
||||
@@ -1051,3 +1055,197 @@ class LoraLoaderMixin:
|
||||
|
||||
save_function(state_dict, os.path.join(save_directory, weight_name))
|
||||
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
|
||||
|
||||
|
||||
class FromCkptMixin:
|
||||
"""This helper class allows to directly load .ckpt stable diffusion file_extension
|
||||
into the respective classes."""
|
||||
|
||||
@classmethod
|
||||
def from_ckpt(cls, pretrained_model_link_or_path, **kwargs):
|
||||
r"""
|
||||
Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights saved in the original .ckpt format.
|
||||
|
||||
The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated).
|
||||
|
||||
Parameters:
|
||||
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
- A link to the .ckpt file on the Hub. Should be in the format
|
||||
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>"`
|
||||
- A path to a *file* containing all pipeline weights.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
|
||||
will be automatically derived from the model's weights.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
||||
standard cache should not be used.
|
||||
resume_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
|
||||
file exists.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to only look at local files (i.e., do not try to download the model).
|
||||
use_auth_token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
||||
when running `huggingface-cli login` (stored in `~/.huggingface`).
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
||||
identifier allowed by git.
|
||||
use_safetensors (`bool`, *optional* ):
|
||||
If set to `True`, the pipeline will be loaded from `safetensors` weights. If set to `None` (the
|
||||
default). The pipeline will load using `safetensors` if the safetensors weights are available *and* if
|
||||
`safetensors` is installed. If the to `False` the pipeline will *not* use `safetensors`.
|
||||
extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for
|
||||
checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults
|
||||
to `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for
|
||||
inference. Non-EMA weights are usually better to continue fine-tuning.
|
||||
upcast_attention (`bool`, *optional*, defaults to `None`):
|
||||
Whether the attention computation should always be upcasted. This is necessary when running stable
|
||||
image_size (`int`, *optional*, defaults to 512):
|
||||
The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2
|
||||
Base. Use 768 for Stable Diffusion v2.
|
||||
prediction_type (`str`, *optional*):
|
||||
The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion v1.X and Stable
|
||||
Diffusion v2 Base. Use `'v_prediction'` for Stable Diffusion v2.
|
||||
num_in_channels (`int`, *optional*, defaults to None):
|
||||
The number of input channels. If `None`, it will be automatically inferred.
|
||||
scheduler_type (`str`, *optional*, defaults to 'pndm'):
|
||||
Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm",
|
||||
"ddim"]`.
|
||||
load_safety_checker (`bool`, *optional*, defaults to `True`):
|
||||
Whether to load the safety checker or not. Defaults to `True`.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
|
||||
specific pipeline class. The overwritten components are then directly passed to the pipelines
|
||||
`__init__` method. See example below for more information.
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
>>> from diffusers import StableDiffusionPipeline
|
||||
|
||||
>>> # Download pipeline from huggingface.co and cache.
|
||||
>>> pipeline = StableDiffusionPipeline.from_ckpt(
|
||||
... "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors"
|
||||
... )
|
||||
|
||||
>>> # Download pipeline from local file
|
||||
>>> # file is downloaded under ./v1-5-pruned-emaonly.ckpt
|
||||
>>> pipeline = StableDiffusionPipeline.from_ckpt("./v1-5-pruned-emaonly")
|
||||
|
||||
>>> # Enable float16 and move to GPU
|
||||
>>> pipeline = StableDiffusionPipeline.from_ckpt(
|
||||
... "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt",
|
||||
... torch_dtype=torch.float16,
|
||||
... )
|
||||
>>> pipeline.to("cuda")
|
||||
```
|
||||
"""
|
||||
# import here to avoid circular dependency
|
||||
from .pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
|
||||
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
extract_ema = kwargs.pop("extract_ema", False)
|
||||
image_size = kwargs.pop("image_size", 512)
|
||||
scheduler_type = kwargs.pop("scheduler_type", "pndm")
|
||||
num_in_channels = kwargs.pop("num_in_channels", None)
|
||||
upcast_attention = kwargs.pop("upcast_attention", None)
|
||||
load_safety_checker = kwargs.pop("load_safety_checker", True)
|
||||
prediction_type = kwargs.pop("prediction_type", None)
|
||||
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
|
||||
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
|
||||
|
||||
pipeline_name = cls.__name__
|
||||
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
|
||||
from_safetensors = file_extension == "safetensors"
|
||||
|
||||
if from_safetensors and use_safetensors is True:
|
||||
raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
|
||||
|
||||
# TODO: For now we only support stable diffusion
|
||||
stable_unclip = None
|
||||
controlnet = False
|
||||
|
||||
if pipeline_name == "StableDiffusionControlNetPipeline":
|
||||
model_type = "FrozenCLIPEmbedder"
|
||||
controlnet = True
|
||||
elif "StableDiffusion" in pipeline_name:
|
||||
model_type = "FrozenCLIPEmbedder"
|
||||
elif pipeline_name == "StableUnCLIPPipeline":
|
||||
model_type == "FrozenOpenCLIPEmbedder"
|
||||
stable_unclip = "txt2img"
|
||||
elif pipeline_name == "StableUnCLIPImg2ImgPipeline":
|
||||
model_type == "FrozenOpenCLIPEmbedder"
|
||||
stable_unclip = "img2img"
|
||||
elif pipeline_name == "PaintByExamplePipeline":
|
||||
model_type == "PaintByExample"
|
||||
elif pipeline_name == "LDMTextToImagePipeline":
|
||||
model_type == "LDMTextToImage"
|
||||
else:
|
||||
raise ValueError(f"Unhandled pipeline class: {pipeline_name}")
|
||||
|
||||
# remove huggingface url
|
||||
for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
|
||||
if pretrained_model_link_or_path.startswith(prefix):
|
||||
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
|
||||
|
||||
# Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
|
||||
ckpt_path = Path(pretrained_model_link_or_path)
|
||||
if not ckpt_path.is_file():
|
||||
# get repo_id and (potentially nested) file path of ckpt in repo
|
||||
repo_id = str(Path().joinpath(*ckpt_path.parts[:2]))
|
||||
file_path = str(Path().joinpath(*ckpt_path.parts[2:]))
|
||||
|
||||
if file_path.startswith("blob/"):
|
||||
file_path = file_path[len("blob/") :]
|
||||
|
||||
if file_path.startswith("main/"):
|
||||
file_path = file_path[len("main/") :]
|
||||
|
||||
pretrained_model_link_or_path = hf_hub_download(
|
||||
repo_id,
|
||||
filename=file_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
force_download=force_download,
|
||||
)
|
||||
|
||||
pipe = download_from_original_stable_diffusion_ckpt(
|
||||
pretrained_model_link_or_path,
|
||||
pipeline_class=cls,
|
||||
model_type=model_type,
|
||||
stable_unclip=stable_unclip,
|
||||
controlnet=controlnet,
|
||||
from_safetensors=from_safetensors,
|
||||
extract_ema=extract_ema,
|
||||
image_size=image_size,
|
||||
scheduler_type=scheduler_type,
|
||||
num_in_channels=num_in_channels,
|
||||
upcast_attention=upcast_attention,
|
||||
load_safety_checker=load_safety_checker,
|
||||
prediction_type=prediction_type,
|
||||
)
|
||||
|
||||
if torch_dtype is not None:
|
||||
pipe.to(torch_dtype=torch_dtype)
|
||||
|
||||
return pipe
|
||||
|
||||
@@ -57,6 +57,14 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
||||
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
|
||||
@@ -96,6 +96,14 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
||||
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
|
||||
@@ -31,35 +31,30 @@ from transformers import (
|
||||
CLIPVisionModelWithProjection,
|
||||
)
|
||||
|
||||
from diffusers import (
|
||||
from ...models import (
|
||||
AutoencoderKL,
|
||||
ControlNetModel,
|
||||
PriorTransformer,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from ...schedulers import (
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
LDMTextToImagePipeline,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
PriorTransformer,
|
||||
StableDiffusionControlNetPipeline,
|
||||
StableDiffusionImg2ImgPipeline,
|
||||
StableDiffusionInpaintPipeline,
|
||||
StableDiffusionPipeline,
|
||||
StableUnCLIPImg2ImgPipeline,
|
||||
StableUnCLIPPipeline,
|
||||
UnCLIPScheduler,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
|
||||
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder, PaintByExamplePipeline
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
||||
from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
|
||||
|
||||
from ...utils import is_omegaconf_available, is_safetensors_available, logging
|
||||
from ...utils.import_utils import BACKENDS_MAPPING
|
||||
from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
|
||||
from ..paint_by_example import PaintByExampleImageEncoder
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .safety_checker import StableDiffusionSafetyChecker
|
||||
from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -981,7 +976,6 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
image_size: int = 512,
|
||||
prediction_type: str = None,
|
||||
model_type: str = None,
|
||||
is_img2img: bool = False,
|
||||
extract_ema: bool = False,
|
||||
scheduler_type: str = "pndm",
|
||||
num_in_channels: Optional[int] = None,
|
||||
@@ -993,7 +987,8 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
clip_stats_path: Optional[str] = None,
|
||||
controlnet: Optional[bool] = None,
|
||||
load_safety_checker: bool = True,
|
||||
) -> StableDiffusionPipeline:
|
||||
pipeline_class: DiffusionPipeline = None,
|
||||
) -> DiffusionPipeline:
|
||||
"""
|
||||
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
|
||||
config file.
|
||||
@@ -1031,12 +1026,29 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
Whether the attention computation should always be upcasted. This is necessary when running stable
|
||||
diffusion 2.1.
|
||||
device (`str`, *optional*, defaults to `None`):
|
||||
The device to use. Pass `None` to determine automatically. :param from_safetensors: If `checkpoint_path` is
|
||||
in `safetensors` format, load checkpoint with safetensors instead of PyTorch. :return: A
|
||||
StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
|
||||
The device to use. Pass `None` to determine automatically.
|
||||
from_safetensors (`str`, *optional*, defaults to `False`):
|
||||
If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.
|
||||
load_safety_checker (`bool`, *optional*, defaults to `True`):
|
||||
Whether to load the safety checker or not. Defaults to `True`.
|
||||
pipeline_class (`str`, *optional*, defaults to `None`):
|
||||
The pipeline class to use. Pass `None` to determine automatically.
|
||||
return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
|
||||
"""
|
||||
|
||||
# import pipelines here to avoid circular import error when using from_ckpt method
|
||||
from diffusers import (
|
||||
LDMTextToImagePipeline,
|
||||
PaintByExamplePipeline,
|
||||
StableDiffusionControlNetPipeline,
|
||||
StableDiffusionPipeline,
|
||||
StableUnCLIPImg2ImgPipeline,
|
||||
StableUnCLIPPipeline,
|
||||
)
|
||||
|
||||
if pipeline_class is None:
|
||||
pipeline_class = StableDiffusionPipeline
|
||||
|
||||
if prediction_type == "v-prediction":
|
||||
prediction_type = "v_prediction"
|
||||
|
||||
@@ -1198,44 +1210,16 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
else:
|
||||
if (
|
||||
hasattr(original_config, "model")
|
||||
and hasattr(original_config.model, "target")
|
||||
and "LatentInpaintDiffusion" in original_config.model.target
|
||||
):
|
||||
pipe = StableDiffusionInpaintPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_model,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
else:
|
||||
if is_img2img:
|
||||
pipe = StableDiffusionImg2ImgPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_model,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
else:
|
||||
pipe = StableDiffusionPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_model,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
pipe = pipeline_class(
|
||||
vae=vae,
|
||||
text_encoder=text_model,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
else:
|
||||
image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components(
|
||||
original_config, clip_stats_path=clip_stats_path, device=device
|
||||
@@ -1326,41 +1310,15 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
else:
|
||||
if (
|
||||
hasattr(original_config, "model")
|
||||
and hasattr(original_config.model, "target")
|
||||
and "LatentInpaintDiffusion" in original_config.model.target
|
||||
):
|
||||
pipe = StableDiffusionInpaintPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_model,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
else:
|
||||
if is_img2img:
|
||||
pipe = StableDiffusionImg2ImgPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_model,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
else:
|
||||
pipe = StableDiffusionPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_model,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
pipe = pipeline_class(
|
||||
vae=vae,
|
||||
text_encoder=text_model,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
else:
|
||||
text_config = create_ldm_bert_config(original_config)
|
||||
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
|
||||
@@ -1379,7 +1337,7 @@ def download_controlnet_from_original_ckpt(
|
||||
upcast_attention: Optional[bool] = None,
|
||||
device: str = None,
|
||||
from_safetensors: bool = False,
|
||||
) -> StableDiffusionPipeline:
|
||||
) -> DiffusionPipeline:
|
||||
if not is_omegaconf_available():
|
||||
raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ from packaging import version
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
@@ -53,13 +53,21 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
|
||||
class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
||||
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
|
||||
@@ -156,6 +156,9 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
|
||||
@@ -23,7 +23,7 @@ from packaging import version
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTForDepthEstimation
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...loaders import TextualInversionLoaderMixin
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging, randn_tensor
|
||||
@@ -55,13 +55,20 @@ def preprocess(image):
|
||||
return image
|
||||
|
||||
|
||||
class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-guided image to image generation using Stable Diffusion.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
|
||||
@@ -23,7 +23,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import TextualInversionLoaderMixin
|
||||
from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
@@ -92,13 +92,21 @@ def preprocess(image):
|
||||
return image
|
||||
|
||||
|
||||
class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin):
|
||||
r"""
|
||||
Pipeline for text-guided image to image generation using Stable Diffusion.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
||||
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
|
||||
@@ -22,7 +22,7 @@ from packaging import version
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...loaders import TextualInversionLoaderMixin
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor
|
||||
@@ -138,13 +138,20 @@ def prepare_mask_and_masked_image(image, mask):
|
||||
return mask, masked_image
|
||||
|
||||
|
||||
class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
|
||||
@@ -22,7 +22,7 @@ from packaging import version
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...loaders import TextualInversionLoaderMixin
|
||||
from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
@@ -82,13 +82,23 @@ def preprocess_mask(mask, scale_factor=8):
|
||||
return mask
|
||||
|
||||
|
||||
class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
class StableDiffusionInpaintPipelineLegacy(
|
||||
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
||||
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
|
||||
@@ -20,7 +20,7 @@ import PIL
|
||||
import torch
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...loaders import TextualInversionLoaderMixin
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
@@ -61,13 +61,20 @@ def preprocess(image):
|
||||
return image
|
||||
|
||||
|
||||
class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for pixel-level image editing by following text instructions. Based on Stable Diffusion.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
|
||||
@@ -2,21 +2,6 @@
|
||||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
class TextualInversionLoaderMixin(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class AltDiffusionImg2ImgPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -36,6 +36,7 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
logging,
|
||||
)
|
||||
from diffusers.models.attention_processor import AttnProcessor
|
||||
from diffusers.utils import load_numpy, nightly, slow, torch_device
|
||||
from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu
|
||||
|
||||
@@ -865,6 +866,62 @@ class StableDiffusionPipelineSlowTests(unittest.TestCase):
|
||||
assert max_diff < 5e-2
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class StableDiffusionPipelineCkptTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_download_from_hub(self):
|
||||
ckpt_paths = [
|
||||
"https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt",
|
||||
"https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix_base.ckpt",
|
||||
]
|
||||
|
||||
for ckpt_path in ckpt_paths:
|
||||
pipe = StableDiffusionPipeline.from_ckpt(ckpt_path, torch_dtype=torch.float16)
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.to("cuda")
|
||||
|
||||
image_out = pipe("test", num_inference_steps=1, output_type="np").images[0]
|
||||
|
||||
assert image_out.shape == (512, 512, 3)
|
||||
|
||||
def test_download_local(self):
|
||||
filename = hf_hub_download("runwayml/stable-diffusion-v1-5", filename="v1-5-pruned-emaonly.ckpt")
|
||||
|
||||
pipe = StableDiffusionPipeline.from_ckpt(filename, torch_dtype=torch.float16)
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.to("cuda")
|
||||
|
||||
image_out = pipe("test", num_inference_steps=1, output_type="np").images[0]
|
||||
|
||||
assert image_out.shape == (512, 512, 3)
|
||||
|
||||
def test_download_ckpt_diff_format_is_same(self):
|
||||
ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt"
|
||||
|
||||
pipe = StableDiffusionPipeline.from_ckpt(ckpt_path)
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.unet.set_attn_processor(AttnProcessor())
|
||||
pipe.to("cuda")
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
image_ckpt = pipe("a turtle", num_inference_steps=5, generator=generator, output_type="np").images[0]
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.unet.set_attn_processor(AttnProcessor())
|
||||
pipe.to("cuda")
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
image = pipe("a turtle", num_inference_steps=5, generator=generator, output_type="np").images[0]
|
||||
|
||||
assert np.max(np.abs(image - image_ckpt)) < 1e-4
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_gpu
|
||||
class StableDiffusionPipelineNightlyTests(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user