mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Improve SD XL (#3968)
* improve sd xl * correct more * finish * make style * fix more
This commit is contained in:
committed by
GitHub
parent
8bf80fc8d8
commit
187ea539ae
1
.github/workflows/push_tests.yml
vendored
1
.github/workflows/push_tests.yml
vendored
@@ -61,6 +61,7 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apt-get update && apt-get install libsndfile1-dev libgl1 -y
|
||||
python -m pip install -e .[quality,test]
|
||||
|
||||
- name: Environment
|
||||
|
||||
2
.github/workflows/push_tests_fast.yml
vendored
2
.github/workflows/push_tests_fast.yml
vendored
@@ -60,7 +60,7 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apt-get update && apt-get install libsndfile1-dev -y
|
||||
apt-get update && apt-get install libsndfile1-dev libgl1 -y
|
||||
python -m pip install -e .[quality,test]
|
||||
|
||||
- name: Environment
|
||||
|
||||
@@ -247,6 +247,8 @@
|
||||
title: Safe Stable Diffusion
|
||||
- local: api/pipelines/stable_diffusion/stable_diffusion_2
|
||||
title: Stable Diffusion 2
|
||||
- local: api/pipelines/stable_diffusion/stable_diffusion_xl
|
||||
title: Stable Diffusion XL
|
||||
- local: api/pipelines/stable_diffusion/latent_upscale
|
||||
title: Stable-Diffusion-Latent-Upscaler
|
||||
- local: api/pipelines/stable_diffusion/upscale
|
||||
|
||||
@@ -32,6 +32,6 @@ Adapters (textual inversion, LoRA, hypernetworks) allow you to modify a diffusio
|
||||
|
||||
[[autodoc]] loaders.LoraLoaderMixin
|
||||
|
||||
## FromCkptMixin
|
||||
## FromSingleFileMixin
|
||||
|
||||
[[autodoc]] loaders.FromCkptMixin
|
||||
[[autodoc]] loaders.FromSingleFileMixin
|
||||
|
||||
@@ -31,7 +31,7 @@ proposed by Chenlin Meng, Yutong He, Yang Song, Jiaming Song, Jiajun Wu, Jun-Yan
|
||||
- enable_xformers_memory_efficient_attention
|
||||
- disable_xformers_memory_efficient_attention
|
||||
- load_textual_inversion
|
||||
- from_ckpt
|
||||
- from_single_file
|
||||
- load_lora_weights
|
||||
- save_lora_weights
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ Available Checkpoints are:
|
||||
- enable_vae_tiling
|
||||
- disable_vae_tiling
|
||||
- load_textual_inversion
|
||||
- from_ckpt
|
||||
- from_single_file
|
||||
- load_lora_weights
|
||||
- save_lora_weights
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ This guide will show you how to convert other Stable Diffusion formats to be com
|
||||
|
||||
## PyTorch .ckpt
|
||||
|
||||
The checkpoint - or `.ckpt` - format is commonly used to store and save models. The `.ckpt` file contains the entire model and is typically several GBs in size. While you can load and use a `.ckpt` file directly with the [`~StableDiffusionPipeline.from_ckpt`] method, it is generally better to convert the `.ckpt` file to 🤗 Diffusers so both formats are available.
|
||||
The checkpoint - or `.ckpt` - format is commonly used to store and save models. The `.ckpt` file contains the entire model and is typically several GBs in size. While you can load and use a `.ckpt` file directly with the [`~StableDiffusionPipeline.from_single_file`] method, it is generally better to convert the `.ckpt` file to 🤗 Diffusers so both formats are available.
|
||||
|
||||
There are two options for converting a `.ckpt` file; use a Space to convert the checkpoint or convert the `.ckpt` file with a script.
|
||||
|
||||
|
||||
@@ -21,12 +21,12 @@ from diffusers import DiffusionPipeline
|
||||
pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_safetensors=True)
|
||||
```
|
||||
|
||||
However, model weights are not necessarily stored in separate subfolders like in the example above. Sometimes, all the weights are stored in a single `.safetensors` file. In this case, if the weights are Stable Diffusion weights, you can load the file directly with the [`~diffusers.loaders.FromCkptMixin.from_ckpt`] method:
|
||||
However, model weights are not necessarily stored in separate subfolders like in the example above. Sometimes, all the weights are stored in a single `.safetensors` file. In this case, if the weights are Stable Diffusion weights, you can load the file directly with the [`~diffusers.loaders.FromSingleFileMixin.from_single_file`] method:
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
pipeline = StableDiffusionPipeline.from_ckpt(
|
||||
pipeline = StableDiffusionPipeline.from_single_file(
|
||||
"https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors"
|
||||
)
|
||||
```
|
||||
|
||||
@@ -11,7 +11,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.configuration_utils import FrozenDict
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
@@ -410,7 +410,7 @@ def preprocess_mask(mask, batch_size, scale_factor=8):
|
||||
|
||||
|
||||
class StableDiffusionLongPromptWeightingPipeline(
|
||||
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin
|
||||
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
|
||||
|
||||
@@ -1276,13 +1276,19 @@ class LoraLoaderMixin:
|
||||
return new_state_dict, network_alpha
|
||||
|
||||
|
||||
class FromCkptMixin:
|
||||
class FromSingleFileMixin:
|
||||
"""
|
||||
Load model weights saved in the `.ckpt` format into a [`DiffusionPipeline`].
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_ckpt(cls, pretrained_model_link_or_path, **kwargs):
|
||||
def from_ckpt(cls, *args, **kwargs):
|
||||
deprecation_message = "The function `from_ckpt` is deprecated in favor of `from_single_file` and will be removed in diffusers v.0.21. Please make sure to use `StableDiffusionPipeline.from_single_file(...)` instead."
|
||||
deprecate("from_ckpt", "0.21.0", deprecation_message, standard_warn=False)
|
||||
return cls.from_single_file(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
|
||||
r"""
|
||||
Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` format. The pipeline
|
||||
is set in evaluation mode (`model.eval()`) by default.
|
||||
@@ -1361,16 +1367,16 @@ class FromCkptMixin:
|
||||
>>> from diffusers import StableDiffusionPipeline
|
||||
|
||||
>>> # Download pipeline from huggingface.co and cache.
|
||||
>>> pipeline = StableDiffusionPipeline.from_ckpt(
|
||||
>>> pipeline = StableDiffusionPipeline.from_single_file(
|
||||
... "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors"
|
||||
... )
|
||||
|
||||
>>> # Download pipeline from local file
|
||||
>>> # file is downloaded under ./v1-5-pruned-emaonly.ckpt
|
||||
>>> pipeline = StableDiffusionPipeline.from_ckpt("./v1-5-pruned-emaonly")
|
||||
>>> pipeline = StableDiffusionPipeline.from_single_file("./v1-5-pruned-emaonly")
|
||||
|
||||
>>> # Enable float16 and move to GPU
|
||||
>>> pipeline = StableDiffusionPipeline.from_ckpt(
|
||||
>>> pipeline = StableDiffusionPipeline.from_single_file(
|
||||
... "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt",
|
||||
... torch_dtype=torch.float16,
|
||||
... )
|
||||
|
||||
@@ -77,7 +77,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
||||
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
|
||||
- *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
||||
|
||||
@@ -26,7 +26,7 @@ from diffusers.utils import is_accelerate_available, is_accelerate_version
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor, replace_example_docstring
|
||||
@@ -95,7 +95,9 @@ def preprocess(image):
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker
|
||||
class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
|
||||
class AltDiffusionImg2ImgPipeline(
|
||||
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-guided image to image generation using Alt Diffusion.
|
||||
|
||||
@@ -105,7 +107,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
||||
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
|
||||
- *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
||||
|
||||
@@ -233,7 +233,7 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
|
||||
if controlnet:
|
||||
unet_params = original_config.model.params.control_stage_config.params
|
||||
else:
|
||||
if original_config.model.params.unet_config is not None:
|
||||
if "unet_config" in original_config.model.params and original_config.model.params.unet_config is not None:
|
||||
unet_params = original_config.model.params.unet_config.params
|
||||
else:
|
||||
unet_params = original_config.model.params.network_config.params
|
||||
@@ -1139,7 +1139,7 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
|
||||
"""
|
||||
|
||||
# import pipelines here to avoid circular import error when using from_ckpt method
|
||||
# import pipelines here to avoid circular import error when using from_single_file method
|
||||
from diffusers import (
|
||||
LDMTextToImagePipeline,
|
||||
PaintByExamplePipeline,
|
||||
@@ -1192,23 +1192,45 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
if original_config_file is None:
|
||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias"
|
||||
key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias"
|
||||
|
||||
# model_type = "v1"
|
||||
config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
|
||||
|
||||
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
|
||||
if key_name_v2_1 in checkpoint and checkpoint[key_name_v2_1].shape[-1] == 1024:
|
||||
# model_type = "v2"
|
||||
config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
|
||||
|
||||
if global_step == 110000:
|
||||
# v2.1 needs to upcast attention
|
||||
upcast_attention = True
|
||||
elif key_name_sd_xl_base in checkpoint:
|
||||
# only base xl has two text embedders
|
||||
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
|
||||
elif key_name_sd_xl_refiner in checkpoint:
|
||||
# only refiner xl has embedder and one text embedders
|
||||
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml"
|
||||
|
||||
original_config_file = BytesIO(requests.get(config_url).content)
|
||||
|
||||
original_config = OmegaConf.load(original_config_file)
|
||||
|
||||
# Convert the text model.
|
||||
if (
|
||||
model_type is None
|
||||
and "cond_stage_config" in original_config.model.params
|
||||
and original_config.model.params.cond_stage_config is not None
|
||||
):
|
||||
model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
|
||||
logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}")
|
||||
elif model_type is None and original_config.model.params.network_config is not None:
|
||||
if original_config.model.params.network_config.params.context_dim == 2048:
|
||||
model_type = "SDXL"
|
||||
else:
|
||||
model_type = "SDXL-Refiner"
|
||||
|
||||
if num_in_channels is not None:
|
||||
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
|
||||
|
||||
@@ -1238,20 +1260,39 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
|
||||
)
|
||||
|
||||
num_train_timesteps = original_config.model.params.timesteps or 1000
|
||||
beta_start = original_config.model.params.linear_start or 0.02
|
||||
beta_end = original_config.model.params.linear_end or 0.085
|
||||
num_train_timesteps = getattr(original_config.model.params, "timesteps", None) or 1000
|
||||
|
||||
scheduler = DDIMScheduler(
|
||||
beta_end=beta_end,
|
||||
beta_schedule="scaled_linear",
|
||||
beta_start=beta_start,
|
||||
num_train_timesteps=num_train_timesteps,
|
||||
steps_offset=1,
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
prediction_type=prediction_type,
|
||||
)
|
||||
if model_type in ["SDXL", "SDXL-Refiner"]:
|
||||
image_size = 1024
|
||||
scheduler_dict = {
|
||||
"beta_schedule": "scaled_linear",
|
||||
"beta_start": 0.00085,
|
||||
"beta_end": 0.012,
|
||||
"interpolation_type": "linear",
|
||||
"num_train_timesteps": num_train_timesteps,
|
||||
"prediction_type": "epsilon",
|
||||
"sample_max_value": 1.0,
|
||||
"set_alpha_to_one": False,
|
||||
"skip_prk_steps": True,
|
||||
"steps_offset": 1,
|
||||
"timestep_spacing": "leading",
|
||||
}
|
||||
scheduler = EulerDiscreteScheduler.from_config(scheduler_dict)
|
||||
scheduler_type = "euler"
|
||||
vae_path = "stabilityai/sdxl-vae"
|
||||
else:
|
||||
beta_start = getattr(original_config.model.params, "linear_start", None) or 0.02
|
||||
beta_end = getattr(original_config.model.params, "linear_end", None) or 0.085
|
||||
scheduler = DDIMScheduler(
|
||||
beta_end=beta_end,
|
||||
beta_schedule="scaled_linear",
|
||||
beta_start=beta_start,
|
||||
num_train_timesteps=num_train_timesteps,
|
||||
steps_offset=1,
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
prediction_type=prediction_type,
|
||||
)
|
||||
# make sure scheduler works correctly with DDIM
|
||||
scheduler.register_to_config(clip_sample=False)
|
||||
|
||||
@@ -1294,16 +1335,6 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
else:
|
||||
vae = AutoencoderKL.from_pretrained(vae_path)
|
||||
|
||||
# Convert the text model.
|
||||
if model_type is None and original_config.model.params.cond_stage_config is not None:
|
||||
model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
|
||||
logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}")
|
||||
elif model_type is None and original_config.model.params.network_config is not None:
|
||||
if original_config.model.params.network_config.params.context_dim == 2048:
|
||||
model_type = "SDXL"
|
||||
else:
|
||||
model_type = "SDXL-Refiner"
|
||||
|
||||
if model_type == "FrozenOpenCLIPEmbedder":
|
||||
text_model = convert_open_clip_checkpoint(checkpoint)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer")
|
||||
|
||||
@@ -22,7 +22,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
@@ -69,7 +69,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
return noise_cfg
|
||||
|
||||
|
||||
class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin):
|
||||
class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion.
|
||||
|
||||
@@ -79,7 +79,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
||||
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
|
||||
- *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
||||
|
||||
@@ -24,7 +24,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
@@ -98,7 +98,9 @@ def preprocess(image):
|
||||
return image
|
||||
|
||||
|
||||
class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin):
|
||||
class StableDiffusionImg2ImgPipeline(
|
||||
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-guided image to image generation using Stable Diffusion.
|
||||
|
||||
@@ -108,7 +110,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
||||
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
|
||||
- *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
||||
|
||||
@@ -24,7 +24,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
@@ -85,7 +85,7 @@ def preprocess_mask(mask, batch_size, scale_factor=8):
|
||||
|
||||
|
||||
class StableDiffusionInpaintPipelineLegacy(
|
||||
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin
|
||||
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*.
|
||||
@@ -96,7 +96,7 @@ class StableDiffusionInpaintPipelineLegacy(
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
||||
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
|
||||
- *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
||||
|
||||
@@ -22,7 +22,7 @@ import torch
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...image_processor import VaeImageProcessorLDM3D
|
||||
from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
@@ -74,7 +74,9 @@ class LDM3DPipelineOutput(BaseOutput):
|
||||
nsfw_content_detected: Optional[List[bool]]
|
||||
|
||||
|
||||
class StableDiffusionLDM3DPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin):
|
||||
class StableDiffusionLDM3DPipeline(
|
||||
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-to-image and 3d generation using LDM3D. LDM3D: Latent Diffusion Model for 3D:
|
||||
https://arxiv.org/abs/2305.10853
|
||||
@@ -85,7 +87,7 @@ class StableDiffusionLDM3DPipeline(DiffusionPipeline, TextualInversionLoaderMixi
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
||||
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
|
||||
- *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
||||
|
||||
@@ -19,7 +19,7 @@ import torch
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
@@ -59,7 +59,9 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
class StableDiffusionParadigmsPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin):
|
||||
class StableDiffusionParadigmsPipeline(
|
||||
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
|
||||
):
|
||||
r"""
|
||||
Parallelized version of StableDiffusionPipeline, based on the paper https://arxiv.org/abs/2305.16317 This pipeline
|
||||
parallelizes the denoising steps to generate a single image faster (more akin to model parallelism).
|
||||
@@ -72,7 +74,7 @@ class StableDiffusionParadigmsPipeline(DiffusionPipeline, TextualInversionLoader
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
||||
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
|
||||
- *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
||||
|
||||
@@ -19,7 +19,7 @@ import torch
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
@@ -73,7 +73,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
return noise_cfg
|
||||
|
||||
|
||||
class StableDiffusionXLPipeline(DiffusionPipeline):
|
||||
class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion.
|
||||
|
||||
@@ -83,7 +83,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline):
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
||||
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
|
||||
- *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
||||
@@ -541,9 +541,9 @@ class StableDiffusionXLPipeline(DiffusionPipeline):
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
original_size: Tuple[int, int] = (1024, 1024),
|
||||
original_size: Optional[Tuple[int, int]] = None,
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
target_size: Tuple[int, int] = (1024, 1024),
|
||||
target_size: Optional[Tuple[int, int]] = None,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@@ -629,6 +629,9 @@ class StableDiffusionXLPipeline(DiffusionPipeline):
|
||||
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||
|
||||
original_size = original_size or (height, width)
|
||||
target_size = target_size or (height, width)
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
@@ -78,7 +78,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
return noise_cfg
|
||||
|
||||
|
||||
class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline):
|
||||
class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion.
|
||||
|
||||
@@ -88,7 +88,7 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline):
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
||||
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
|
||||
- *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
||||
@@ -136,7 +136,6 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline):
|
||||
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
||||
self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = 8
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
self.watermark = StableDiffusionXLWatermarker()
|
||||
@@ -631,9 +630,9 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline):
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
original_size: Tuple[int, int] = (1024, 1024),
|
||||
original_size: Tuple[int, int] = None,
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
target_size: Tuple[int, int] = (1024, 1024),
|
||||
target_size: Tuple[int, int] = None,
|
||||
aesthetic_score: float = 6.0,
|
||||
negative_aesthetic_score: float = 2.5,
|
||||
):
|
||||
@@ -778,6 +777,13 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline):
|
||||
# 7. Prepare extra step kwargs.
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
height, width = latents.shape[-2:]
|
||||
height = height * self.vae_scale_factor
|
||||
width = width * self.vae_scale_factor
|
||||
|
||||
original_size = original_size or (height, width)
|
||||
target_size = target_size or (height, width)
|
||||
|
||||
# 8. Prepare added time ids & embeddings
|
||||
add_text_embeds = pooled_prompt_embeds
|
||||
add_time_ids, add_neg_time_ids = self._get_add_time_ids(
|
||||
|
||||
@@ -1029,7 +1029,7 @@ class StableDiffusionPipelineCkptTests(unittest.TestCase):
|
||||
]
|
||||
|
||||
for ckpt_path in ckpt_paths:
|
||||
pipe = StableDiffusionPipeline.from_ckpt(ckpt_path, torch_dtype=torch.float16)
|
||||
pipe = StableDiffusionPipeline.from_single_file(ckpt_path, torch_dtype=torch.float16)
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.to("cuda")
|
||||
|
||||
@@ -1040,7 +1040,7 @@ class StableDiffusionPipelineCkptTests(unittest.TestCase):
|
||||
def test_download_local(self):
|
||||
filename = hf_hub_download("runwayml/stable-diffusion-v1-5", filename="v1-5-pruned-emaonly.ckpt")
|
||||
|
||||
pipe = StableDiffusionPipeline.from_ckpt(filename, torch_dtype=torch.float16)
|
||||
pipe = StableDiffusionPipeline.from_single_file(filename, torch_dtype=torch.float16)
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.to("cuda")
|
||||
|
||||
@@ -1051,7 +1051,7 @@ class StableDiffusionPipelineCkptTests(unittest.TestCase):
|
||||
def test_download_ckpt_diff_format_is_same(self):
|
||||
ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt"
|
||||
|
||||
pipe = StableDiffusionPipeline.from_ckpt(ckpt_path)
|
||||
pipe = StableDiffusionPipeline.from_single_file(ckpt_path)
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.unet.set_attn_processor(AttnProcessor())
|
||||
pipe.to("cuda")
|
||||
|
||||
Reference in New Issue
Block a user