1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

add from_ckpt method as Mixin (#2318)

* add mixin class for pipeline from original sd ckpt

* Improve

* make style

* merge main into

* Improve more

* fix more

* up

* Apply suggestions from code review

* finish docs

* rename

* make style

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
1lint
2023-04-19 11:07:36 -05:00
committed by Daniel Gu
parent 1fac211336
commit f3300a869a
21 changed files with 410 additions and 125 deletions

View File

@@ -36,3 +36,7 @@ API to load such adapter neural networks via the [`loaders.py` module](https://g
### LoraLoaderMixin
[[autodoc]] loaders.LoraLoaderMixin
### FromCkptMixin
[[autodoc]] loaders.FromCkptMixin

View File

@@ -308,6 +308,7 @@ All checkpoints can be found under the authors' namespace [lllyasviel](https://h
- disable_vae_slicing
- enable_xformers_memory_efficient_attention
- disable_xformers_memory_efficient_attention
- load_textual_inversion
## FlaxStableDiffusionControlNetPipeline
[[autodoc]] FlaxStableDiffusionControlNetPipeline

View File

@@ -30,4 +30,7 @@ Available Checkpoints are:
- enable_attention_slicing
- disable_attention_slicing
- enable_xformers_memory_efficient_attention
- disable_xformers_memory_efficient_attention
- disable_xformers_memory_efficient_attention
- load_textual_inversion
- load_lora_weights
- save_lora_weights

View File

@@ -30,7 +30,11 @@ proposed by Chenlin Meng, Yutong He, Yang Song, Jiaming Song, Jiajun Wu, Jun-Yan
- disable_attention_slicing
- enable_xformers_memory_efficient_attention
- disable_xformers_memory_efficient_attention
- load_textual_inversion
- from_ckpt
- load_lora_weights
- save_lora_weights
[[autodoc]] FlaxStableDiffusionImg2ImgPipeline
- all
- __call__
- __call__

View File

@@ -31,7 +31,10 @@ Available checkpoints are:
- disable_attention_slicing
- enable_xformers_memory_efficient_attention
- disable_xformers_memory_efficient_attention
- load_textual_inversion
- load_lora_weights
- save_lora_weights
[[autodoc]] FlaxStableDiffusionInpaintPipeline
- all
- __call__
- __call__

View File

@@ -68,3 +68,6 @@ images[0].save("snowy_mountains.png")
[[autodoc]] StableDiffusionInstructPix2PixPipeline
- __call__
- all
- load_textual_inversion
- load_lora_weights
- save_lora_weights

View File

@@ -39,6 +39,10 @@ Available Checkpoints are:
- disable_xformers_memory_efficient_attention
- enable_vae_tiling
- disable_vae_tiling
- load_textual_inversion
- from_ckpt
- load_lora_weights
- save_lora_weights
[[autodoc]] FlaxStableDiffusionPipeline
- all

View File

@@ -109,7 +109,6 @@ try:
except OptionalDependencyNotAvailable:
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .loaders import TextualInversionLoaderMixin
from .pipelines import (
AltDiffusionImg2ImgPipeline,
AltDiffusionPipeline,

View File

@@ -13,9 +13,11 @@
# limitations under the License.
import os
from collections import defaultdict
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union
import torch
from huggingface_hub import hf_hub_download
from .models.attention_processor import LoRAAttnProcessor
from .utils import (
@@ -431,6 +433,7 @@ class TextualInversionLoaderMixin:
Example:
To load a textual inversion embedding vector in `diffusers` format:
```py
from diffusers import StableDiffusionPipeline
import torch
@@ -463,6 +466,7 @@ class TextualInversionLoaderMixin:
image = pipe(prompt, num_inference_steps=50).images[0]
image.save("character.png")
```
"""
if not hasattr(self, "tokenizer") or not isinstance(self.tokenizer, PreTrainedTokenizer):
raise ValueError(
@@ -1051,3 +1055,197 @@ class LoraLoaderMixin:
save_function(state_dict, os.path.join(save_directory, weight_name))
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
class FromCkptMixin:
"""This helper class allows to directly load .ckpt stable diffusion file_extension
into the respective classes."""
@classmethod
def from_ckpt(cls, pretrained_model_link_or_path, **kwargs):
r"""
Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights saved in the original .ckpt format.
The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated).
Parameters:
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
Can be either:
- A link to the .ckpt file on the Hub. Should be in the format
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>"`
- A path to a *file* containing all pipeline weights.
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
will be automatically derived from the model's weights.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
file exists.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether or not to only look at local files (i.e., do not try to download the model).
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `huggingface-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
use_safetensors (`bool`, *optional* ):
If set to `True`, the pipeline will be loaded from `safetensors` weights. If set to `None` (the
default). The pipeline will load using `safetensors` if the safetensors weights are available *and* if
`safetensors` is installed. If the to `False` the pipeline will *not* use `safetensors`.
extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for
checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults
to `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for
inference. Non-EMA weights are usually better to continue fine-tuning.
upcast_attention (`bool`, *optional*, defaults to `None`):
Whether the attention computation should always be upcasted. This is necessary when running stable
image_size (`int`, *optional*, defaults to 512):
The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2
Base. Use 768 for Stable Diffusion v2.
prediction_type (`str`, *optional*):
The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion v1.X and Stable
Diffusion v2 Base. Use `'v_prediction'` for Stable Diffusion v2.
num_in_channels (`int`, *optional*, defaults to None):
The number of input channels. If `None`, it will be automatically inferred.
scheduler_type (`str`, *optional*, defaults to 'pndm'):
Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm",
"ddim"]`.
load_safety_checker (`bool`, *optional*, defaults to `True`):
Whether to load the safety checker or not. Defaults to `True`.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
specific pipeline class. The overwritten components are then directly passed to the pipelines
`__init__` method. See example below for more information.
Examples:
```py
>>> from diffusers import StableDiffusionPipeline
>>> # Download pipeline from huggingface.co and cache.
>>> pipeline = StableDiffusionPipeline.from_ckpt(
... "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors"
... )
>>> # Download pipeline from local file
>>> # file is downloaded under ./v1-5-pruned-emaonly.ckpt
>>> pipeline = StableDiffusionPipeline.from_ckpt("./v1-5-pruned-emaonly")
>>> # Enable float16 and move to GPU
>>> pipeline = StableDiffusionPipeline.from_ckpt(
... "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt",
... torch_dtype=torch.float16,
... )
>>> pipeline.to("cuda")
```
"""
# import here to avoid circular dependency
from .pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
extract_ema = kwargs.pop("extract_ema", False)
image_size = kwargs.pop("image_size", 512)
scheduler_type = kwargs.pop("scheduler_type", "pndm")
num_in_channels = kwargs.pop("num_in_channels", None)
upcast_attention = kwargs.pop("upcast_attention", None)
load_safety_checker = kwargs.pop("load_safety_checker", True)
prediction_type = kwargs.pop("prediction_type", None)
torch_dtype = kwargs.pop("torch_dtype", None)
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
pipeline_name = cls.__name__
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
from_safetensors = file_extension == "safetensors"
if from_safetensors and use_safetensors is True:
raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
# TODO: For now we only support stable diffusion
stable_unclip = None
controlnet = False
if pipeline_name == "StableDiffusionControlNetPipeline":
model_type = "FrozenCLIPEmbedder"
controlnet = True
elif "StableDiffusion" in pipeline_name:
model_type = "FrozenCLIPEmbedder"
elif pipeline_name == "StableUnCLIPPipeline":
model_type == "FrozenOpenCLIPEmbedder"
stable_unclip = "txt2img"
elif pipeline_name == "StableUnCLIPImg2ImgPipeline":
model_type == "FrozenOpenCLIPEmbedder"
stable_unclip = "img2img"
elif pipeline_name == "PaintByExamplePipeline":
model_type == "PaintByExample"
elif pipeline_name == "LDMTextToImagePipeline":
model_type == "LDMTextToImage"
else:
raise ValueError(f"Unhandled pipeline class: {pipeline_name}")
# remove huggingface url
for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
if pretrained_model_link_or_path.startswith(prefix):
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
# Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
ckpt_path = Path(pretrained_model_link_or_path)
if not ckpt_path.is_file():
# get repo_id and (potentially nested) file path of ckpt in repo
repo_id = str(Path().joinpath(*ckpt_path.parts[:2]))
file_path = str(Path().joinpath(*ckpt_path.parts[2:]))
if file_path.startswith("blob/"):
file_path = file_path[len("blob/") :]
if file_path.startswith("main/"):
file_path = file_path[len("main/") :]
pretrained_model_link_or_path = hf_hub_download(
repo_id,
filename=file_path,
cache_dir=cache_dir,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
force_download=force_download,
)
pipe = download_from_original_stable_diffusion_ckpt(
pretrained_model_link_or_path,
pipeline_class=cls,
model_type=model_type,
stable_unclip=stable_unclip,
controlnet=controlnet,
from_safetensors=from_safetensors,
extract_ema=extract_ema,
image_size=image_size,
scheduler_type=scheduler_type,
num_in_channels=num_in_channels,
upcast_attention=upcast_attention,
load_safety_checker=load_safety_checker,
prediction_type=prediction_type,
)
if torch_dtype is not None:
pipe.to(torch_dtype=torch_dtype)
return pipe

View File

@@ -57,6 +57,14 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
In addition the pipeline inherits the following loading methods:
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
as well as the following saving methods:
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.

View File

@@ -96,6 +96,14 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
In addition the pipeline inherits the following loading methods:
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
as well as the following saving methods:
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.

View File

@@ -31,35 +31,30 @@ from transformers import (
CLIPVisionModelWithProjection,
)
from diffusers import (
from ...models import (
AutoencoderKL,
ControlNetModel,
PriorTransformer,
UNet2DConditionModel,
)
from ...schedulers import (
DDIMScheduler,
DDPMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
HeunDiscreteScheduler,
LDMTextToImagePipeline,
LMSDiscreteScheduler,
PNDMScheduler,
PriorTransformer,
StableDiffusionControlNetPipeline,
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionPipeline,
StableUnCLIPImg2ImgPipeline,
StableUnCLIPPipeline,
UnCLIPScheduler,
UNet2DConditionModel,
)
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder, PaintByExamplePipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
from ...utils import is_omegaconf_available, is_safetensors_available, logging
from ...utils.import_utils import BACKENDS_MAPPING
from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
from ..paint_by_example import PaintByExampleImageEncoder
from ..pipeline_utils import DiffusionPipeline
from .safety_checker import StableDiffusionSafetyChecker
from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -981,7 +976,6 @@ def download_from_original_stable_diffusion_ckpt(
image_size: int = 512,
prediction_type: str = None,
model_type: str = None,
is_img2img: bool = False,
extract_ema: bool = False,
scheduler_type: str = "pndm",
num_in_channels: Optional[int] = None,
@@ -993,7 +987,8 @@ def download_from_original_stable_diffusion_ckpt(
clip_stats_path: Optional[str] = None,
controlnet: Optional[bool] = None,
load_safety_checker: bool = True,
) -> StableDiffusionPipeline:
pipeline_class: DiffusionPipeline = None,
) -> DiffusionPipeline:
"""
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
config file.
@@ -1031,12 +1026,29 @@ def download_from_original_stable_diffusion_ckpt(
Whether the attention computation should always be upcasted. This is necessary when running stable
diffusion 2.1.
device (`str`, *optional*, defaults to `None`):
The device to use. Pass `None` to determine automatically. :param from_safetensors: If `checkpoint_path` is
in `safetensors` format, load checkpoint with safetensors instead of PyTorch. :return: A
StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
The device to use. Pass `None` to determine automatically.
from_safetensors (`str`, *optional*, defaults to `False`):
If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.
load_safety_checker (`bool`, *optional*, defaults to `True`):
Whether to load the safety checker or not. Defaults to `True`.
pipeline_class (`str`, *optional*, defaults to `None`):
The pipeline class to use. Pass `None` to determine automatically.
return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
"""
# import pipelines here to avoid circular import error when using from_ckpt method
from diffusers import (
LDMTextToImagePipeline,
PaintByExamplePipeline,
StableDiffusionControlNetPipeline,
StableDiffusionPipeline,
StableUnCLIPImg2ImgPipeline,
StableUnCLIPPipeline,
)
if pipeline_class is None:
pipeline_class = StableDiffusionPipeline
if prediction_type == "v-prediction":
prediction_type = "v_prediction"
@@ -1198,44 +1210,16 @@ def download_from_original_stable_diffusion_ckpt(
requires_safety_checker=False,
)
else:
if (
hasattr(original_config, "model")
and hasattr(original_config.model, "target")
and "LatentInpaintDiffusion" in original_config.model.target
):
pipe = StableDiffusionInpaintPipeline(
vae=vae,
text_encoder=text_model,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
else:
if is_img2img:
pipe = StableDiffusionImg2ImgPipeline(
vae=vae,
text_encoder=text_model,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
else:
pipe = StableDiffusionPipeline(
vae=vae,
text_encoder=text_model,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
pipe = pipeline_class(
vae=vae,
text_encoder=text_model,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
else:
image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components(
original_config, clip_stats_path=clip_stats_path, device=device
@@ -1326,41 +1310,15 @@ def download_from_original_stable_diffusion_ckpt(
feature_extractor=feature_extractor,
)
else:
if (
hasattr(original_config, "model")
and hasattr(original_config.model, "target")
and "LatentInpaintDiffusion" in original_config.model.target
):
pipe = StableDiffusionInpaintPipeline(
vae=vae,
text_encoder=text_model,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
else:
if is_img2img:
pipe = StableDiffusionImg2ImgPipeline(
vae=vae,
text_encoder=text_model,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
else:
pipe = StableDiffusionPipeline(
vae=vae,
text_encoder=text_model,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
pipe = pipeline_class(
vae=vae,
text_encoder=text_model,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
else:
text_config = create_ldm_bert_config(original_config)
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
@@ -1379,7 +1337,7 @@ def download_controlnet_from_original_ckpt(
upcast_attention: Optional[bool] = None,
device: str = None,
from_safetensors: bool = False,
) -> StableDiffusionPipeline:
) -> DiffusionPipeline:
if not is_omegaconf_available():
raise ValueError(BACKENDS_MAPPING["omegaconf"][1])

View File

@@ -20,7 +20,7 @@ from packaging import version
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
@@ -53,13 +53,21 @@ EXAMPLE_DOC_STRING = """
"""
class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin):
r"""
Pipeline for text-to-image generation using Stable Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
In addition the pipeline inherits the following loading methods:
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
as well as the following saving methods:
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.

View File

@@ -156,6 +156,9 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
In addition the pipeline inherits the following loading methods:
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.

View File

@@ -23,7 +23,7 @@ from packaging import version
from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTForDepthEstimation
from ...configuration_utils import FrozenDict
from ...loaders import TextualInversionLoaderMixin
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging, randn_tensor
@@ -55,13 +55,20 @@ def preprocess(image):
return image
class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
r"""
Pipeline for text-guided image to image generation using Stable Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
In addition the pipeline inherits the following loading methods:
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
as well as the following saving methods:
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.

View File

@@ -23,7 +23,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor
from ...loaders import TextualInversionLoaderMixin
from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
@@ -92,13 +92,21 @@ def preprocess(image):
return image
class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin):
r"""
Pipeline for text-guided image to image generation using Stable Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
In addition the pipeline inherits the following loading methods:
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
as well as the following saving methods:
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.

View File

@@ -22,7 +22,7 @@ from packaging import version
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
from ...loaders import TextualInversionLoaderMixin
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor
@@ -138,13 +138,20 @@ def prepare_mask_and_masked_image(image, mask):
return mask, masked_image
class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
r"""
Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
In addition the pipeline inherits the following loading methods:
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
as well as the following saving methods:
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.

View File

@@ -22,7 +22,7 @@ from packaging import version
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
from ...loaders import TextualInversionLoaderMixin
from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
@@ -82,13 +82,23 @@ def preprocess_mask(mask, scale_factor=8):
return mask
class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline, TextualInversionLoaderMixin):
class StableDiffusionInpaintPipelineLegacy(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin
):
r"""
Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
In addition the pipeline inherits the following loading methods:
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
as well as the following saving methods:
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.

View File

@@ -20,7 +20,7 @@ import PIL
import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...loaders import TextualInversionLoaderMixin
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
@@ -61,13 +61,20 @@ def preprocess(image):
return image
class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
r"""
Pipeline for pixel-level image editing by following text instructions. Based on Stable Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
In addition the pipeline inherits the following loading methods:
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
as well as the following saving methods:
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.

View File

@@ -2,21 +2,6 @@
from ..utils import DummyObject, requires_backends
class TextualInversionLoaderMixin(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class AltDiffusionImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -36,6 +36,7 @@ from diffusers import (
UNet2DConditionModel,
logging,
)
from diffusers.models.attention_processor import AttnProcessor
from diffusers.utils import load_numpy, nightly, slow, torch_device
from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu
@@ -865,6 +866,62 @@ class StableDiffusionPipelineSlowTests(unittest.TestCase):
assert max_diff < 5e-2
@slow
@require_torch_gpu
class StableDiffusionPipelineCkptTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_download_from_hub(self):
ckpt_paths = [
"https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt",
"https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix_base.ckpt",
]
for ckpt_path in ckpt_paths:
pipe = StableDiffusionPipeline.from_ckpt(ckpt_path, torch_dtype=torch.float16)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.to("cuda")
image_out = pipe("test", num_inference_steps=1, output_type="np").images[0]
assert image_out.shape == (512, 512, 3)
def test_download_local(self):
filename = hf_hub_download("runwayml/stable-diffusion-v1-5", filename="v1-5-pruned-emaonly.ckpt")
pipe = StableDiffusionPipeline.from_ckpt(filename, torch_dtype=torch.float16)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.to("cuda")
image_out = pipe("test", num_inference_steps=1, output_type="np").images[0]
assert image_out.shape == (512, 512, 3)
def test_download_ckpt_diff_format_is_same(self):
ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt"
pipe = StableDiffusionPipeline.from_ckpt(ckpt_path)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.unet.set_attn_processor(AttnProcessor())
pipe.to("cuda")
generator = torch.Generator(device="cpu").manual_seed(0)
image_ckpt = pipe("a turtle", num_inference_steps=5, generator=generator, output_type="np").images[0]
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.unet.set_attn_processor(AttnProcessor())
pipe.to("cuda")
generator = torch.Generator(device="cpu").manual_seed(0)
image = pipe("a turtle", num_inference_steps=5, generator=generator, output_type="np").images[0]
assert np.max(np.abs(image - image_ckpt)) < 1e-4
@nightly
@require_torch_gpu
class StableDiffusionPipelineNightlyTests(unittest.TestCase):