From 187ea539aed54872675cd27f6a58b27084a173ab Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 6 Jul 2023 18:11:20 +0200 Subject: [PATCH] Improve SD XL (#3968) * improve sd xl * correct more * finish * make style * fix more --- .github/workflows/push_tests.yml | 1 + .github/workflows/push_tests_fast.yml | 2 +- docs/source/en/_toctree.yml | 2 + docs/source/en/api/loaders.mdx | 4 +- .../pipelines/stable_diffusion/img2img.mdx | 2 +- .../pipelines/stable_diffusion/text2img.mdx | 2 +- .../en/using-diffusers/other-formats.mdx | 2 +- .../en/using-diffusers/using_safetensors.mdx | 4 +- examples/community/lpw_stable_diffusion.py | 4 +- src/diffusers/loaders.py | 16 ++-- .../alt_diffusion/pipeline_alt_diffusion.py | 2 +- .../pipeline_alt_diffusion_img2img.py | 8 +- .../stable_diffusion/convert_from_ckpt.py | 85 +++++++++++++------ .../pipeline_stable_diffusion.py | 6 +- .../pipeline_stable_diffusion_img2img.py | 8 +- ...ipeline_stable_diffusion_inpaint_legacy.py | 6 +- .../pipeline_stable_diffusion_ldm3d.py | 8 +- .../pipeline_stable_diffusion_paradigms.py | 8 +- .../pipeline_stable_diffusion_xl.py | 13 +-- .../pipeline_stable_diffusion_xl_img2img.py | 18 ++-- .../stable_diffusion/test_stable_diffusion.py | 6 +- 21 files changed, 132 insertions(+), 75 deletions(-) diff --git a/.github/workflows/push_tests.yml b/.github/workflows/push_tests.yml index 567cd5f5b0..5ec8dbdc40 100644 --- a/.github/workflows/push_tests.yml +++ b/.github/workflows/push_tests.yml @@ -61,6 +61,7 @@ jobs: - name: Install dependencies run: | + apt-get update && apt-get install libsndfile1-dev libgl1 -y python -m pip install -e .[quality,test] - name: Environment diff --git a/.github/workflows/push_tests_fast.yml b/.github/workflows/push_tests_fast.yml index adf4fc8a87..acd59ef80d 100644 --- a/.github/workflows/push_tests_fast.yml +++ b/.github/workflows/push_tests_fast.yml @@ -60,7 +60,7 @@ jobs: - name: Install dependencies run: | - apt-get update && apt-get install libsndfile1-dev -y + apt-get update && apt-get install libsndfile1-dev libgl1 -y python -m pip install -e .[quality,test] - name: Environment diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 470d8c5c18..ad1c7c0aab 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -247,6 +247,8 @@ title: Safe Stable Diffusion - local: api/pipelines/stable_diffusion/stable_diffusion_2 title: Stable Diffusion 2 + - local: api/pipelines/stable_diffusion/stable_diffusion_xl + title: Stable Diffusion XL - local: api/pipelines/stable_diffusion/latent_upscale title: Stable-Diffusion-Latent-Upscaler - local: api/pipelines/stable_diffusion/upscale diff --git a/docs/source/en/api/loaders.mdx b/docs/source/en/api/loaders.mdx index a236a6c70b..57891d23de 100644 --- a/docs/source/en/api/loaders.mdx +++ b/docs/source/en/api/loaders.mdx @@ -32,6 +32,6 @@ Adapters (textual inversion, LoRA, hypernetworks) allow you to modify a diffusio [[autodoc]] loaders.LoraLoaderMixin -## FromCkptMixin +## FromSingleFileMixin -[[autodoc]] loaders.FromCkptMixin +[[autodoc]] loaders.FromSingleFileMixin diff --git a/docs/source/en/api/pipelines/stable_diffusion/img2img.mdx b/docs/source/en/api/pipelines/stable_diffusion/img2img.mdx index 7959c58860..c70f9ac9dc 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/img2img.mdx +++ b/docs/source/en/api/pipelines/stable_diffusion/img2img.mdx @@ -31,7 +31,7 @@ proposed by Chenlin Meng, Yutong He, Yang Song, Jiaming Song, Jiajun Wu, Jun-Yan - enable_xformers_memory_efficient_attention - disable_xformers_memory_efficient_attention - load_textual_inversion - - from_ckpt + - from_single_file - load_lora_weights - save_lora_weights diff --git a/docs/source/en/api/pipelines/stable_diffusion/text2img.mdx b/docs/source/en/api/pipelines/stable_diffusion/text2img.mdx index ce78434fdb..0e3f511175 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/text2img.mdx +++ b/docs/source/en/api/pipelines/stable_diffusion/text2img.mdx @@ -40,7 +40,7 @@ Available Checkpoints are: - enable_vae_tiling - disable_vae_tiling - load_textual_inversion - - from_ckpt + - from_single_file - load_lora_weights - save_lora_weights diff --git a/docs/source/en/using-diffusers/other-formats.mdx b/docs/source/en/using-diffusers/other-formats.mdx index 2aeb9f3ae2..b58d00fce1 100644 --- a/docs/source/en/using-diffusers/other-formats.mdx +++ b/docs/source/en/using-diffusers/other-formats.mdx @@ -26,7 +26,7 @@ This guide will show you how to convert other Stable Diffusion formats to be com ## PyTorch .ckpt -The checkpoint - or `.ckpt` - format is commonly used to store and save models. The `.ckpt` file contains the entire model and is typically several GBs in size. While you can load and use a `.ckpt` file directly with the [`~StableDiffusionPipeline.from_ckpt`] method, it is generally better to convert the `.ckpt` file to 🤗 Diffusers so both formats are available. +The checkpoint - or `.ckpt` - format is commonly used to store and save models. The `.ckpt` file contains the entire model and is typically several GBs in size. While you can load and use a `.ckpt` file directly with the [`~StableDiffusionPipeline.from_single_file`] method, it is generally better to convert the `.ckpt` file to 🤗 Diffusers so both formats are available. There are two options for converting a `.ckpt` file; use a Space to convert the checkpoint or convert the `.ckpt` file with a script. diff --git a/docs/source/en/using-diffusers/using_safetensors.mdx b/docs/source/en/using-diffusers/using_safetensors.mdx index c312ab5970..a7bc0a7c9c 100644 --- a/docs/source/en/using-diffusers/using_safetensors.mdx +++ b/docs/source/en/using-diffusers/using_safetensors.mdx @@ -21,12 +21,12 @@ from diffusers import DiffusionPipeline pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_safetensors=True) ``` -However, model weights are not necessarily stored in separate subfolders like in the example above. Sometimes, all the weights are stored in a single `.safetensors` file. In this case, if the weights are Stable Diffusion weights, you can load the file directly with the [`~diffusers.loaders.FromCkptMixin.from_ckpt`] method: +However, model weights are not necessarily stored in separate subfolders like in the example above. Sometimes, all the weights are stored in a single `.safetensors` file. In this case, if the weights are Stable Diffusion weights, you can load the file directly with the [`~diffusers.loaders.FromSingleFileMixin.from_single_file`] method: ```py from diffusers import StableDiffusionPipeline -pipeline = StableDiffusionPipeline.from_ckpt( +pipeline = StableDiffusionPipeline.from_single_file( "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors" ) ``` diff --git a/examples/community/lpw_stable_diffusion.py b/examples/community/lpw_stable_diffusion.py index 56fb903c71..2970aae4b1 100644 --- a/examples/community/lpw_stable_diffusion.py +++ b/examples/community/lpw_stable_diffusion.py @@ -11,7 +11,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from diffusers import DiffusionPipeline from diffusers.configuration_utils import FrozenDict from diffusers.image_processor import VaeImageProcessor -from diffusers.loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker from diffusers.schedulers import KarrasDiffusionSchedulers @@ -410,7 +410,7 @@ def preprocess_mask(mask, batch_size, scale_factor=8): class StableDiffusionLongPromptWeightingPipeline( - DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin ): r""" Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 1bdd33fa80..a0be20c543 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1276,13 +1276,19 @@ class LoraLoaderMixin: return new_state_dict, network_alpha -class FromCkptMixin: +class FromSingleFileMixin: """ Load model weights saved in the `.ckpt` format into a [`DiffusionPipeline`]. """ @classmethod - def from_ckpt(cls, pretrained_model_link_or_path, **kwargs): + def from_ckpt(cls, *args, **kwargs): + deprecation_message = "The function `from_ckpt` is deprecated in favor of `from_single_file` and will be removed in diffusers v.0.21. Please make sure to use `StableDiffusionPipeline.from_single_file(...)` instead." + deprecate("from_ckpt", "0.21.0", deprecation_message, standard_warn=False) + return cls.from_single_file(*args, **kwargs) + + @classmethod + def from_single_file(cls, pretrained_model_link_or_path, **kwargs): r""" Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` format. The pipeline is set in evaluation mode (`model.eval()`) by default. @@ -1361,16 +1367,16 @@ class FromCkptMixin: >>> from diffusers import StableDiffusionPipeline >>> # Download pipeline from huggingface.co and cache. - >>> pipeline = StableDiffusionPipeline.from_ckpt( + >>> pipeline = StableDiffusionPipeline.from_single_file( ... "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors" ... ) >>> # Download pipeline from local file >>> # file is downloaded under ./v1-5-pruned-emaonly.ckpt - >>> pipeline = StableDiffusionPipeline.from_ckpt("./v1-5-pruned-emaonly") + >>> pipeline = StableDiffusionPipeline.from_single_file("./v1-5-pruned-emaonly") >>> # Enable float16 and move to GPU - >>> pipeline = StableDiffusionPipeline.from_ckpt( + >>> pipeline = StableDiffusionPipeline.from_single_file( ... "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt", ... torch_dtype=torch.float16, ... ) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index b79e4f7214..5a4746d24e 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -77,7 +77,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] - - *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`] + - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] as well as the following saving methods: - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index 5903f97aca..21c1f0591a 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -26,7 +26,7 @@ from diffusers.utils import is_accelerate_available, is_accelerate_version from ...configuration_utils import FrozenDict from ...image_processor import VaeImageProcessor -from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor, replace_example_docstring @@ -95,7 +95,9 @@ def preprocess(image): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker -class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): +class AltDiffusionImg2ImgPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin +): r""" Pipeline for text-guided image to image generation using Alt Diffusion. @@ -105,7 +107,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] - - *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`] + - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] as well as the following saving methods: - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index 99cfcb8067..ef2333a18d 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -233,7 +233,7 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa if controlnet: unet_params = original_config.model.params.control_stage_config.params else: - if original_config.model.params.unet_config is not None: + if "unet_config" in original_config.model.params and original_config.model.params.unet_config is not None: unet_params = original_config.model.params.unet_config.params else: unet_params = original_config.model.params.network_config.params @@ -1139,7 +1139,7 @@ def download_from_original_stable_diffusion_ckpt( return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file. """ - # import pipelines here to avoid circular import error when using from_ckpt method + # import pipelines here to avoid circular import error when using from_single_file method from diffusers import ( LDMTextToImagePipeline, PaintByExamplePipeline, @@ -1192,23 +1192,45 @@ def download_from_original_stable_diffusion_ckpt( checkpoint = checkpoint["state_dict"] if original_config_file is None: - key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" + key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" + key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias" + key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias" # model_type = "v1" config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" - if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024: + if key_name_v2_1 in checkpoint and checkpoint[key_name_v2_1].shape[-1] == 1024: # model_type = "v2" config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml" if global_step == 110000: # v2.1 needs to upcast attention upcast_attention = True + elif key_name_sd_xl_base in checkpoint: + # only base xl has two text embedders + config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml" + elif key_name_sd_xl_refiner in checkpoint: + # only refiner xl has embedder and one text embedders + config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml" original_config_file = BytesIO(requests.get(config_url).content) original_config = OmegaConf.load(original_config_file) + # Convert the text model. + if ( + model_type is None + and "cond_stage_config" in original_config.model.params + and original_config.model.params.cond_stage_config is not None + ): + model_type = original_config.model.params.cond_stage_config.target.split(".")[-1] + logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}") + elif model_type is None and original_config.model.params.network_config is not None: + if original_config.model.params.network_config.params.context_dim == 2048: + model_type = "SDXL" + else: + model_type = "SDXL-Refiner" + if num_in_channels is not None: original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels @@ -1238,20 +1260,39 @@ def download_from_original_stable_diffusion_ckpt( checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema ) - num_train_timesteps = original_config.model.params.timesteps or 1000 - beta_start = original_config.model.params.linear_start or 0.02 - beta_end = original_config.model.params.linear_end or 0.085 + num_train_timesteps = getattr(original_config.model.params, "timesteps", None) or 1000 - scheduler = DDIMScheduler( - beta_end=beta_end, - beta_schedule="scaled_linear", - beta_start=beta_start, - num_train_timesteps=num_train_timesteps, - steps_offset=1, - clip_sample=False, - set_alpha_to_one=False, - prediction_type=prediction_type, - ) + if model_type in ["SDXL", "SDXL-Refiner"]: + image_size = 1024 + scheduler_dict = { + "beta_schedule": "scaled_linear", + "beta_start": 0.00085, + "beta_end": 0.012, + "interpolation_type": "linear", + "num_train_timesteps": num_train_timesteps, + "prediction_type": "epsilon", + "sample_max_value": 1.0, + "set_alpha_to_one": False, + "skip_prk_steps": True, + "steps_offset": 1, + "timestep_spacing": "leading", + } + scheduler = EulerDiscreteScheduler.from_config(scheduler_dict) + scheduler_type = "euler" + vae_path = "stabilityai/sdxl-vae" + else: + beta_start = getattr(original_config.model.params, "linear_start", None) or 0.02 + beta_end = getattr(original_config.model.params, "linear_end", None) or 0.085 + scheduler = DDIMScheduler( + beta_end=beta_end, + beta_schedule="scaled_linear", + beta_start=beta_start, + num_train_timesteps=num_train_timesteps, + steps_offset=1, + clip_sample=False, + set_alpha_to_one=False, + prediction_type=prediction_type, + ) # make sure scheduler works correctly with DDIM scheduler.register_to_config(clip_sample=False) @@ -1294,16 +1335,6 @@ def download_from_original_stable_diffusion_ckpt( else: vae = AutoencoderKL.from_pretrained(vae_path) - # Convert the text model. - if model_type is None and original_config.model.params.cond_stage_config is not None: - model_type = original_config.model.params.cond_stage_config.target.split(".")[-1] - logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}") - elif model_type is None and original_config.model.params.network_config is not None: - if original_config.model.params.network_config.params.context_dim == 2048: - model_type = "SDXL" - else: - model_type = "SDXL-Refiner" - if model_type == "FrozenOpenCLIPEmbedder": text_model = convert_open_clip_checkpoint(checkpoint) tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer") diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 8368668ebe..9ad4d404fd 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -22,7 +22,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict from ...image_processor import VaeImageProcessor -from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -69,7 +69,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): return noise_cfg -class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin): +class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin): r""" Pipeline for text-to-image generation using Stable Diffusion. @@ -79,7 +79,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] - - *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`] + - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] as well as the following saving methods: - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index e9e91b646e..f8874ba2cf 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -24,7 +24,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict from ...image_processor import VaeImageProcessor -from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -98,7 +98,9 @@ def preprocess(image): return image -class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin): +class StableDiffusionImg2ImgPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin +): r""" Pipeline for text-guided image to image generation using Stable Diffusion. @@ -108,7 +110,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] - - *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`] + - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] as well as the following saving methods: - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py index 55d571ab09..483f27ae39 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py @@ -24,7 +24,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict from ...image_processor import VaeImageProcessor -from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -85,7 +85,7 @@ def preprocess_mask(mask, batch_size, scale_factor=8): class StableDiffusionInpaintPipelineLegacy( - DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin ): r""" Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*. @@ -96,7 +96,7 @@ class StableDiffusionInpaintPipelineLegacy( In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] - - *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`] + - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] as well as the following saving methods: - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py index 2df9c46f0b..85f628ca82 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py @@ -22,7 +22,7 @@ import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...image_processor import VaeImageProcessorLDM3D -from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -74,7 +74,9 @@ class LDM3DPipelineOutput(BaseOutput): nsfw_content_detected: Optional[List[bool]] -class StableDiffusionLDM3DPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin): +class StableDiffusionLDM3DPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin +): r""" Pipeline for text-to-image and 3d generation using LDM3D. LDM3D: Latent Diffusion Model for 3D: https://arxiv.org/abs/2305.10853 @@ -85,7 +87,7 @@ class StableDiffusionLDM3DPipeline(DiffusionPipeline, TextualInversionLoaderMixi In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] - - *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`] + - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] as well as the following saving methods: - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py index 33549ebb0e..2239e3853a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py @@ -19,7 +19,7 @@ import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...image_processor import VaeImageProcessor -from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -59,7 +59,9 @@ EXAMPLE_DOC_STRING = """ """ -class StableDiffusionParadigmsPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin): +class StableDiffusionParadigmsPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin +): r""" Parallelized version of StableDiffusionPipeline, based on the paper https://arxiv.org/abs/2305.16317 This pipeline parallelizes the denoising steps to generate a single image faster (more akin to model parallelism). @@ -72,7 +74,7 @@ class StableDiffusionParadigmsPipeline(DiffusionPipeline, TextualInversionLoader In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] - - *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`] + - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] as well as the following saving methods: - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index c50381c2eb..142aac94b9 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -19,7 +19,7 @@ import torch from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from ...image_processor import VaeImageProcessor -from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, @@ -73,7 +73,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): return noise_cfg -class StableDiffusionXLPipeline(DiffusionPipeline): +class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin): r""" Pipeline for text-to-image generation using Stable Diffusion. @@ -83,7 +83,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline): In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] - - *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`] + - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] as well as the following saving methods: - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] @@ -541,9 +541,9 @@ class StableDiffusionXLPipeline(DiffusionPipeline): callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guidance_rescale: float = 0.0, - original_size: Tuple[int, int] = (1024, 1024), + original_size: Optional[Tuple[int, int]] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), - target_size: Tuple[int, int] = (1024, 1024), + target_size: Optional[Tuple[int, int]] = None, ): r""" Function invoked when calling the pipeline for generation. @@ -629,6 +629,9 @@ class StableDiffusionXLPipeline(DiffusionPipeline): height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor + original_size = original_size or (height, width) + target_size = target_size or (height, width) + # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 329a626ada..f699e23310 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -21,7 +21,7 @@ import torch from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from ...image_processor import VaeImageProcessor -from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, @@ -78,7 +78,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): return noise_cfg -class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline): +class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin): r""" Pipeline for text-to-image generation using Stable Diffusion. @@ -88,7 +88,7 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline): In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] - - *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`] + - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] as well as the following saving methods: - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] @@ -136,7 +136,6 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline): self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.vae_scale_factor = 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.watermark = StableDiffusionXLWatermarker() @@ -631,9 +630,9 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline): callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guidance_rescale: float = 0.0, - original_size: Tuple[int, int] = (1024, 1024), + original_size: Tuple[int, int] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), - target_size: Tuple[int, int] = (1024, 1024), + target_size: Tuple[int, int] = None, aesthetic_score: float = 6.0, negative_aesthetic_score: float = 2.5, ): @@ -778,6 +777,13 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline): # 7. Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + height, width = latents.shape[-2:] + height = height * self.vae_scale_factor + width = width * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + # 8. Prepare added time ids & embeddings add_text_embeds = pooled_prompt_embeds add_time_ids, add_neg_time_ids = self._get_add_time_ids( diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index 7daf3fcda4..a10462a345 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -1029,7 +1029,7 @@ class StableDiffusionPipelineCkptTests(unittest.TestCase): ] for ckpt_path in ckpt_paths: - pipe = StableDiffusionPipeline.from_ckpt(ckpt_path, torch_dtype=torch.float16) + pipe = StableDiffusionPipeline.from_single_file(ckpt_path, torch_dtype=torch.float16) pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) pipe.to("cuda") @@ -1040,7 +1040,7 @@ class StableDiffusionPipelineCkptTests(unittest.TestCase): def test_download_local(self): filename = hf_hub_download("runwayml/stable-diffusion-v1-5", filename="v1-5-pruned-emaonly.ckpt") - pipe = StableDiffusionPipeline.from_ckpt(filename, torch_dtype=torch.float16) + pipe = StableDiffusionPipeline.from_single_file(filename, torch_dtype=torch.float16) pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) pipe.to("cuda") @@ -1051,7 +1051,7 @@ class StableDiffusionPipelineCkptTests(unittest.TestCase): def test_download_ckpt_diff_format_is_same(self): ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt" - pipe = StableDiffusionPipeline.from_ckpt(ckpt_path) + pipe = StableDiffusionPipeline.from_single_file(ckpt_path) pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) pipe.unet.set_attn_processor(AttnProcessor()) pipe.to("cuda")