mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Refactor] Better align from_single_file logic with from_pretrained (#7496)
* refactor unet single file loading a bit. * retrieve the unet from create_diffusers_unet_model_from_ldm * update * update * updae * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * tests * update * update * update * Update docs/source/en/api/single_file.md Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * Update docs/source/en/api/single_file.md Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * update * update * update * update * update * update * update * update * update * update * update * update * update * Update docs/source/en/api/loaders/single_file.md Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/loaders/single_file.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update docs/source/en/api/loaders/single_file.md Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * Update docs/source/en/api/loaders/single_file.md Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * Update docs/source/en/api/loaders/single_file.md Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * Update docs/source/en/api/loaders/single_file.md Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update --------- Co-authored-by: sayakpaul <spsayakpaul@gmail.com> Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
2
.github/workflows/push_tests.yml
vendored
2
.github/workflows/push_tests.yml
vendored
@@ -124,7 +124,7 @@ jobs:
|
||||
shell: bash
|
||||
strategy:
|
||||
matrix:
|
||||
module: [models, schedulers, lora, others]
|
||||
module: [models, schedulers, lora, others, single_file]
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
|
||||
@@ -10,13 +10,124 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Single files
|
||||
# Loading Pipelines and Models via `from_single_file`
|
||||
|
||||
Diffusers supports loading pretrained pipeline (or model) weights stored in a single file, such as a `ckpt` or `safetensors` file. These single file types are typically produced from community trained models. There are three classes for loading single file weights:
|
||||
The `from_single_file` method allows you to load supported pipelines using a single checkpoint file as opposed to the folder format used by Diffusers. This is useful if you are working with many of the Stable Diffusion Web UI's (such as A1111) that extensively rely on a single file to distribute all the components of a diffusion model.
|
||||
|
||||
- [`FromSingleFileMixin`] supports loading pretrained pipeline weights stored in a single file, which can either be a `ckpt` or `safetensors` file.
|
||||
- [`FromOriginalVAEMixin`] supports loading a pretrained [`AutoencoderKL`] from pretrained ControlNet weights stored in a single file, which can either be a `ckpt` or `safetensors` file.
|
||||
- [`FromOriginalControlnetMixin`] supports loading pretrained ControlNet weights stored in a single file, which can either be a `ckpt` or `safetensors` file.
|
||||
The `from_single_file` method also supports loading models in their originally distributed format. This means that supported models that have been finetuned with other services can be loaded directly into supported Diffusers model objects and pipelines.
|
||||
|
||||
## Pipelines that currently support `from_single_file` loading
|
||||
|
||||
- [`StableDiffusionPipeline`]
|
||||
- [`StableDiffusionImg2ImgPipeline`]
|
||||
- [`StableDiffusionInpaintPipeline`]
|
||||
- [`StableDiffusionControlNetPipeline`]
|
||||
- [`StableDiffusionControlNetImg2ImgPipeline`]
|
||||
- [`StableDiffusionControlNetInpaintPipeline`]
|
||||
- [`StableDiffusionUpscalePipeline`]
|
||||
- [`StableDiffusionXLPipeline`]
|
||||
- [`StableDiffusionXLImg2ImgPipeline`]
|
||||
- [`StableDiffusionXLInpaintPipeline`]
|
||||
- [`StableDiffusionXLInstructPix2PixPipeline`]
|
||||
- [`StableDiffusionXLControlNetPipeline`]
|
||||
- [`StableDiffusionXLKDiffusionPipeline`]
|
||||
- [`LatentConsistencyModelPipeline`]
|
||||
- [`LatentConsistencyModelImg2ImgPipeline`]
|
||||
- [`StableDiffusionControlNetXSPipeline`]
|
||||
- [`StableDiffusionXLControlNetXSPipeline`]
|
||||
- [`LEditsPPPipelineStableDiffusion`]
|
||||
- [`LEditsPPPipelineStableDiffusionXL`]
|
||||
- [`PIAPipeline`]
|
||||
|
||||
## Models that currently support `from_single_file` loading
|
||||
|
||||
- [`UNet2DConditionModel`]
|
||||
- [`StableCascadeUNet`]
|
||||
- [`AutoencoderKL`]
|
||||
- [`ControlNetModel`]
|
||||
|
||||
## Usage Examples
|
||||
|
||||
## Loading a Pipeline using `from_single_file`
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
|
||||
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors"
|
||||
pipe = StableDiffusionXLPipeline.from_single_file(ckpt_path)
|
||||
```
|
||||
|
||||
## Setting components in a Pipeline using `from_single_file`
|
||||
|
||||
Swap components of the pipeline by passing them directly to the `from_single_file` method. e.g If you would like use a different scheduler than the pipeline default.
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionXLPipeline, DDIMScheduler
|
||||
|
||||
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors"
|
||||
|
||||
scheduler = DDIMScheduler()
|
||||
pipe = StableDiffusionXLPipeline.from_single_file(ckpt_path, scheduler=scheduler)
|
||||
|
||||
```
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionPipeline, ControlNetModel
|
||||
|
||||
ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors"
|
||||
|
||||
controlnet = ControlNetModel.from_pretrained("https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors")
|
||||
pipe = StableDiffusionPipeline.from_single_file(ckpt_path, controlnet=controlnet)
|
||||
|
||||
```
|
||||
|
||||
## Loading a Model using `from_single_file`
|
||||
|
||||
```python
|
||||
from diffusers import StableCascadeUNet
|
||||
|
||||
ckpt_path = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_b_lite.safetensors"
|
||||
model = StableCascadeUNet.from_single_file(ckpt_path)
|
||||
|
||||
```
|
||||
|
||||
## Using a Diffusers model repository to configure single file loading
|
||||
|
||||
Under the hood, `from_single_file` will try to determine a model repository to use to configure the components of the pipeline. You can also pass in a repository id to the `config` argument of the `from_single_file` method to explicitly set the repository to use.
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
|
||||
ckpt_path = "https://huggingface.co/segmind/SSD-1B/blob/main/SSD-1B.safetensors"
|
||||
repo_id = "segmind/SSD-1B"
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_single_file(ckpt_path, config=repo_id)
|
||||
|
||||
```
|
||||
|
||||
## Override configuration options when using single file loading
|
||||
|
||||
Override the default model or pipeline configuration options when using `from_single_file` by passing in the relevant arguments directly to the `from_single_file` method. Any argument that is supported by the model or pipeline class can be configured in this way:
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionXLInstructPix2PixPipeline
|
||||
|
||||
ckpt_path = "https://huggingface.co/stabilityai/cosxl/blob/main/cosxl_edit.safetensors"
|
||||
pipe = StableDiffusionXLInstructPix2PixPipeline.from_single_file(ckpt_path, config="diffusers/sdxl-instructpix2pix-768", is_cosxl_edit=True)
|
||||
|
||||
```
|
||||
|
||||
```python
|
||||
from diffusers import UNet2DConditionModel
|
||||
|
||||
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors"
|
||||
model = UNet2DConditionModel.from_single_file(ckpt_path, upcast_attention=True)
|
||||
|
||||
```
|
||||
|
||||
In the example above, since we explicitly passed `repo_id="segmind/SSD-1B"`, it will use this [configuration file](https://huggingface.co/segmind/SSD-1B/blob/main/unet/config.json) from the "unet" subfolder in `"segmind/SSD-1B"` to configure the unet component included in the checkpoint; Similarly, it will use the `config.json` file from `"vae"` subfolder to configure the vae model, `config.json` file from text_encoder folder to configure text_encoder and so on.
|
||||
|
||||
Note that most of the time you do not need to explicitly a `config` argument, `from_single_file` will automatically map the checkpoint to a repo id (we will discuss this in more details in next section). However, this can be useful in cases where model components might have been changed from what was originally distributed or in cases where a checkpoint file might not have the necessary metadata to correctly determine the configuration to use for the pipeline.
|
||||
|
||||
<Tip>
|
||||
|
||||
@@ -24,14 +135,114 @@ To learn more about how to load single file weights, see the [Load different Sta
|
||||
|
||||
</Tip>
|
||||
|
||||
## Working with local files
|
||||
|
||||
As of `diffusers>=0.28.0` the `from_single_file` method will attempt to configure a pipeline or model by first inferring the model type from the checkpoint file and then using the model type to determine the appropriate model repo configuration to use from the Hugging Face Hub. For example, any single file checkpoint based on the Stable Diffusion XL base model will use the [`stabilityai/stable-diffusion-xl-base-1.0`](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) model repo to configure the pipeline.
|
||||
|
||||
If you are working in an environment with restricted internet access, it is recommended to download the config files and checkpoints for the model to your preferred directory and pass the local paths to the `pretrained_model_link_or_path` and `config` arguments of the `from_single_file` method.
|
||||
|
||||
```python
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
|
||||
my_local_checkpoint_path = hf_hub_download(
|
||||
repo_id="segmind/SSD-1B",
|
||||
filename="SSD-1B.safetensors"
|
||||
)
|
||||
|
||||
my_local_config_path = snapshot_download(
|
||||
repo_id="segmind/SSD-1B",
|
||||
allowed_patterns=["*.json", "**/*.json", "*.txt", "**/*.txt"]
|
||||
)
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_single_file(my_local_checkpoint_path, config=my_local_config_path, local_files_only=True)
|
||||
|
||||
```
|
||||
|
||||
By default this will download the checkpoints and config files to the [Hugging Face Hub cache directory](https://huggingface.co/docs/huggingface_hub/en/guides/manage-cache). You can also specify a local directory to download the files to by passing the `local_dir` argument to the `hf_hub_download` and `snapshot_download` functions.
|
||||
|
||||
```python
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
|
||||
my_local_checkpoint_path = hf_hub_download(
|
||||
repo_id="segmind/SSD-1B",
|
||||
filename="SSD-1B.safetensors"
|
||||
local_dir="my_local_checkpoints"
|
||||
)
|
||||
|
||||
my_local_config_path = snapshot_download(
|
||||
repo_id="segmind/SSD-1B",
|
||||
allowed_patterns=["*.json", "**/*.json", "*.txt", "**/*.txt"]
|
||||
local_dir="my_local_config"
|
||||
)
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_single_file(my_local_checkpoint_path, config=my_local_config_path, local_files_only=True)
|
||||
|
||||
```
|
||||
|
||||
## Working with local files on file systems that do not support symlinking
|
||||
|
||||
By default the `from_single_file` method relies on the `huggingface_hub` caching mechanism to fetch and store checkpoints and config files for models and pipelines. If you are working with a file system that does not support symlinking, it is recommended that you first download the checkpoint file to a local directory and disable symlinking by passing the `local_dir_use_symlink=False` argument to the `hf_hub_download` and `snapshot_download` functions.
|
||||
|
||||
```python
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
|
||||
my_local_checkpoint_path = hf_hub_download(
|
||||
repo_id="segmind/SSD-1B",
|
||||
filename="SSD-1B.safetensors"
|
||||
local_dir="my_local_checkpoints",
|
||||
local_dir_use_symlinks=False
|
||||
)
|
||||
print("My local checkpoint: ", my_local_checkpoint_path)
|
||||
|
||||
my_local_config_path = snapshot_download(
|
||||
repo_id="segmind/SSD-1B",
|
||||
allowed_patterns=["*.json", "**/*.json", "*.txt", "**/*.txt"]
|
||||
local_dir_use_symlinks=False,
|
||||
)
|
||||
print("My local config: ", my_local_config_path)
|
||||
|
||||
```
|
||||
|
||||
Then pass the local paths to the `pretrained_model_link_or_path` and `config` arguments of the `from_single_file` method.
|
||||
|
||||
```python
|
||||
pipe = StableDiffusionXLPipeline.from_single_file(my_local_checkpoint_path, config=my_local_config_path, local_files_only=True)
|
||||
|
||||
```
|
||||
|
||||
<Tip>
|
||||
Disabling symlinking means that the `huggingface_hub` caching mechanism has no way to determine whether a file has already been downloaded to the local directory. This means that the `hf_hub_download` and `snapshot_download` functions will download files to the local directory each time they are executed. If you are disabling symlinking, it is recommended that you separate the model download and loading steps to avoid downloading the same file multiple times.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Using the original configuration file of a model
|
||||
|
||||
If you would like to configure the parameters of the model components in the pipeline using the orignal YAML configuration file, you can pass a local path or url to the original configuration file to the `original_config` argument of the `from_single_file` method.
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
|
||||
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors"
|
||||
repo_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
original_config = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_single_file(ckpt_path, original_config=original_config)
|
||||
```
|
||||
|
||||
In the example above, the `original_config` file is only used to configure the parameters of the individual model components of the pipeline. For example it will be used to configure parameters such as the `in_channels` of the `vae` model and `unet` model. It is not used to determine the type of component objects in the pipeline.
|
||||
|
||||
|
||||
<Tip>
|
||||
When using `original_config` with local_files_only=True`, Diffusers will attempt to infer the components based on the type signatures of pipeline class, rather than attempting to fetch the pipeline config from the Hugging Face Hub. This is to prevent backwards breaking changes in existing code that might not be able to connect to the internet to fetch the necessary pipeline config files.
|
||||
|
||||
This is not as reliable as providing a path to a local config repo and might lead to errors when configuring the pipeline. To avoid this, please run the pipeline with `local_files_only=False` once to download the appropriate pipeline config files to the local cache.
|
||||
</Tip>
|
||||
|
||||
|
||||
## FromSingleFileMixin
|
||||
|
||||
[[autodoc]] loaders.single_file.FromSingleFileMixin
|
||||
|
||||
## FromOriginalVAEMixin
|
||||
## FromOriginalModelMixin
|
||||
|
||||
[[autodoc]] loaders.autoencoder.FromOriginalVAEMixin
|
||||
|
||||
## FromOriginalControlnetMixin
|
||||
|
||||
[[autodoc]] loaders.controlnet.FromOriginalControlNetMixin
|
||||
[[autodoc]] loaders.single_file_model.FromOriginalModelMixin
|
||||
|
||||
@@ -27,6 +27,7 @@ from .utils import (
|
||||
|
||||
_import_structure = {
|
||||
"configuration_utils": ["ConfigMixin"],
|
||||
"loaders": ["FromOriginalModelMixin"],
|
||||
"models": [],
|
||||
"pipelines": [],
|
||||
"schedulers": [],
|
||||
|
||||
@@ -340,6 +340,8 @@ class ConfigMixin:
|
||||
|
||||
"""
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
local_dir = kwargs.pop("local_dir", None)
|
||||
local_dir_use_symlinks = kwargs.pop("local_dir_use_symlinks", "auto")
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", None)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
@@ -364,13 +366,13 @@ class ConfigMixin:
|
||||
if os.path.isfile(pretrained_model_name_or_path):
|
||||
config_file = pretrained_model_name_or_path
|
||||
elif os.path.isdir(pretrained_model_name_or_path):
|
||||
if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
|
||||
# Load from a PyTorch checkpoint
|
||||
config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
|
||||
elif subfolder is not None and os.path.isfile(
|
||||
if subfolder is not None and os.path.isfile(
|
||||
os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
|
||||
):
|
||||
config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
|
||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
|
||||
# Load from a PyTorch checkpoint
|
||||
config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
|
||||
else:
|
||||
raise EnvironmentError(
|
||||
f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
|
||||
@@ -390,6 +392,8 @@ class ConfigMixin:
|
||||
user_agent=user_agent,
|
||||
subfolder=subfolder,
|
||||
revision=revision,
|
||||
local_dir=local_dir,
|
||||
local_dir_use_symlinks=local_dir_use_symlinks,
|
||||
)
|
||||
except RepositoryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
|
||||
@@ -54,9 +54,7 @@ if is_transformers_available():
|
||||
_import_structure = {}
|
||||
|
||||
if is_torch_available():
|
||||
_import_structure["autoencoder"] = ["FromOriginalVAEMixin"]
|
||||
|
||||
_import_structure["controlnet"] = ["FromOriginalControlNetMixin"]
|
||||
_import_structure["single_file_model"] = ["FromOriginalModelMixin"]
|
||||
_import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
|
||||
_import_structure["utils"] = ["AttnProcsLayers"]
|
||||
if is_transformers_available():
|
||||
@@ -70,8 +68,7 @@ _import_structure["peft"] = ["PeftAdapterMixin"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
if is_torch_available():
|
||||
from .autoencoder import FromOriginalVAEMixin
|
||||
from .controlnet import FromOriginalControlNetMixin
|
||||
from .single_file_model import FromOriginalModelMixin
|
||||
from .unet import UNet2DConditionLoadersMixin
|
||||
from .utils import AttnProcsLayers
|
||||
|
||||
|
||||
@@ -11,144 +11,243 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils import LocalEntryNotFoundError, validate_hf_hub_args
|
||||
from packaging import version
|
||||
|
||||
from ..utils import is_transformers_available, logging
|
||||
from ..utils import deprecate, is_transformers_available, logging
|
||||
from .single_file_utils import (
|
||||
create_diffusers_unet_model_from_ldm,
|
||||
create_diffusers_vae_model_from_ldm,
|
||||
create_scheduler_from_ldm,
|
||||
create_text_encoders_and_tokenizers_from_ldm,
|
||||
fetch_ldm_config_and_checkpoint,
|
||||
infer_model_type,
|
||||
SingleFileComponentError,
|
||||
_is_model_weights_in_cached_folder,
|
||||
_legacy_load_clip_tokenizer,
|
||||
_legacy_load_safety_checker,
|
||||
_legacy_load_scheduler,
|
||||
create_diffusers_clip_model_from_ldm,
|
||||
fetch_diffusers_config,
|
||||
fetch_original_config,
|
||||
is_clip_model_in_single_file,
|
||||
load_single_file_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# Pipelines that support the SDXL Refiner checkpoint
|
||||
REFINER_PIPELINES = [
|
||||
"StableDiffusionXLImg2ImgPipeline",
|
||||
"StableDiffusionXLInpaintPipeline",
|
||||
"StableDiffusionXLControlNetImg2ImgPipeline",
|
||||
]
|
||||
# Legacy behaviour. `from_single_file` does not load the safety checker unless explicitly provided
|
||||
SINGLE_FILE_OPTIONAL_COMPONENTS = ["safety_checker"]
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
from transformers import AutoFeatureExtractor
|
||||
import transformers
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
|
||||
|
||||
def build_sub_model_components(
|
||||
pipeline_components,
|
||||
pipeline_class_name,
|
||||
component_name,
|
||||
original_config,
|
||||
def load_single_file_sub_model(
|
||||
library_name,
|
||||
class_name,
|
||||
name,
|
||||
checkpoint,
|
||||
pipelines,
|
||||
is_pipeline_module,
|
||||
cached_model_config_path,
|
||||
original_config=None,
|
||||
local_files_only=False,
|
||||
load_safety_checker=False,
|
||||
model_type=None,
|
||||
image_size=None,
|
||||
torch_dtype=None,
|
||||
is_legacy_loading=False,
|
||||
**kwargs,
|
||||
):
|
||||
if component_name in pipeline_components:
|
||||
return {}
|
||||
if is_pipeline_module:
|
||||
pipeline_module = getattr(pipelines, library_name)
|
||||
class_obj = getattr(pipeline_module, class_name)
|
||||
else:
|
||||
# else we just import it from the library.
|
||||
library = importlib.import_module(library_name)
|
||||
class_obj = getattr(library, class_name)
|
||||
|
||||
if component_name == "unet":
|
||||
num_in_channels = kwargs.pop("num_in_channels", None)
|
||||
upcast_attention = kwargs.pop("upcast_attention", None)
|
||||
if is_transformers_available():
|
||||
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
|
||||
else:
|
||||
transformers_version = "N/A"
|
||||
|
||||
unet_components = create_diffusers_unet_model_from_ldm(
|
||||
pipeline_class_name,
|
||||
original_config,
|
||||
checkpoint,
|
||||
num_in_channels=num_in_channels,
|
||||
image_size=image_size,
|
||||
is_transformers_model = (
|
||||
is_transformers_available()
|
||||
and issubclass(class_obj, PreTrainedModel)
|
||||
and transformers_version >= version.parse("4.20.0")
|
||||
)
|
||||
is_tokenizer = (
|
||||
is_transformers_available()
|
||||
and issubclass(class_obj, PreTrainedTokenizer)
|
||||
and transformers_version >= version.parse("4.20.0")
|
||||
)
|
||||
|
||||
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
||||
is_diffusers_single_file_model = issubclass(class_obj, diffusers_module.FromOriginalModelMixin)
|
||||
is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin)
|
||||
is_diffusers_scheduler = issubclass(class_obj, diffusers_module.SchedulerMixin)
|
||||
|
||||
if is_diffusers_single_file_model:
|
||||
load_method = getattr(class_obj, "from_single_file")
|
||||
|
||||
# We cannot provide two different config options to the `from_single_file` method
|
||||
# Here we have to ignore loading the config from `cached_model_config_path` if `original_config` is provided
|
||||
if original_config:
|
||||
cached_model_config_path = None
|
||||
|
||||
loaded_sub_model = load_method(
|
||||
pretrained_model_link_or_path_or_dict=checkpoint,
|
||||
original_config=original_config,
|
||||
config=cached_model_config_path,
|
||||
subfolder=name,
|
||||
torch_dtype=torch_dtype,
|
||||
model_type=model_type,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
return unet_components
|
||||
|
||||
if component_name == "vae":
|
||||
scaling_factor = kwargs.get("scaling_factor", None)
|
||||
vae_components = create_diffusers_vae_model_from_ldm(
|
||||
pipeline_class_name,
|
||||
original_config,
|
||||
checkpoint,
|
||||
image_size,
|
||||
scaling_factor,
|
||||
torch_dtype,
|
||||
model_type=model_type,
|
||||
)
|
||||
return vae_components
|
||||
|
||||
if component_name == "scheduler":
|
||||
scheduler_type = kwargs.get("scheduler_type", "ddim")
|
||||
prediction_type = kwargs.get("prediction_type", None)
|
||||
|
||||
scheduler_components = create_scheduler_from_ldm(
|
||||
pipeline_class_name,
|
||||
original_config,
|
||||
checkpoint,
|
||||
scheduler_type=scheduler_type,
|
||||
prediction_type=prediction_type,
|
||||
model_type=model_type,
|
||||
)
|
||||
|
||||
return scheduler_components
|
||||
|
||||
if component_name in ["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"]:
|
||||
text_encoder_components = create_text_encoders_and_tokenizers_from_ldm(
|
||||
original_config,
|
||||
checkpoint,
|
||||
model_type=model_type,
|
||||
local_files_only=local_files_only,
|
||||
torch_dtype=torch_dtype,
|
||||
**kwargs,
|
||||
)
|
||||
return text_encoder_components
|
||||
|
||||
if component_name == "safety_checker":
|
||||
if load_safety_checker:
|
||||
from ..pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
elif is_transformers_model and is_clip_model_in_single_file(class_obj, checkpoint):
|
||||
loaded_sub_model = create_diffusers_clip_model_from_ldm(
|
||||
class_obj,
|
||||
checkpoint=checkpoint,
|
||||
config=cached_model_config_path,
|
||||
subfolder=name,
|
||||
torch_dtype=torch_dtype,
|
||||
local_files_only=local_files_only,
|
||||
is_legacy_loading=is_legacy_loading,
|
||||
)
|
||||
|
||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
||||
"CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only, torch_dtype=torch_dtype
|
||||
elif is_tokenizer and is_legacy_loading:
|
||||
loaded_sub_model = _legacy_load_clip_tokenizer(
|
||||
class_obj, checkpoint=checkpoint, config=cached_model_config_path, local_files_only=local_files_only
|
||||
)
|
||||
|
||||
elif is_diffusers_scheduler and is_legacy_loading:
|
||||
loaded_sub_model = _legacy_load_scheduler(
|
||||
class_obj, checkpoint=checkpoint, component_name=name, original_config=original_config, **kwargs
|
||||
)
|
||||
|
||||
else:
|
||||
if not hasattr(class_obj, "from_pretrained"):
|
||||
raise ValueError(
|
||||
(
|
||||
f"The component {class_obj.__name__} cannot be loaded as it does not seem to have"
|
||||
" a supported loading method."
|
||||
)
|
||||
)
|
||||
else:
|
||||
safety_checker = None
|
||||
return {"safety_checker": safety_checker}
|
||||
|
||||
if component_name == "feature_extractor":
|
||||
if load_safety_checker:
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||
"CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only
|
||||
)
|
||||
else:
|
||||
feature_extractor = None
|
||||
return {"feature_extractor": feature_extractor}
|
||||
|
||||
return
|
||||
|
||||
|
||||
def set_additional_components(
|
||||
pipeline_class_name,
|
||||
original_config,
|
||||
checkpoint=None,
|
||||
model_type=None,
|
||||
):
|
||||
components = {}
|
||||
if pipeline_class_name in REFINER_PIPELINES:
|
||||
model_type = infer_model_type(original_config, checkpoint=checkpoint, model_type=model_type)
|
||||
is_refiner = model_type == "SDXL-Refiner"
|
||||
components.update(
|
||||
loading_kwargs = {}
|
||||
loading_kwargs.update(
|
||||
{
|
||||
"requires_aesthetics_score": is_refiner,
|
||||
"force_zeros_for_empty_prompt": False if is_refiner else True,
|
||||
"pretrained_model_name_or_path": cached_model_config_path,
|
||||
"subfolder": name,
|
||||
"local_files_only": local_files_only,
|
||||
}
|
||||
)
|
||||
|
||||
return components
|
||||
# Schedulers and Tokenizers don't make use of torch_dtype
|
||||
# Skip passing it to those objects
|
||||
if issubclass(class_obj, torch.nn.Module):
|
||||
loading_kwargs.update({"torch_dtype": torch_dtype})
|
||||
|
||||
if is_diffusers_model or is_transformers_model:
|
||||
if not _is_model_weights_in_cached_folder(cached_model_config_path, name):
|
||||
raise SingleFileComponentError(
|
||||
f"Failed to load {class_name}. Weights for this component appear to be missing in the checkpoint."
|
||||
)
|
||||
|
||||
load_method = getattr(class_obj, "from_pretrained")
|
||||
loaded_sub_model = load_method(**loading_kwargs)
|
||||
|
||||
return loaded_sub_model
|
||||
|
||||
|
||||
def _map_component_types_to_config_dict(component_types):
|
||||
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
||||
config_dict = {}
|
||||
component_types.pop("self", None)
|
||||
|
||||
if is_transformers_available():
|
||||
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
|
||||
else:
|
||||
transformers_version = "N/A"
|
||||
|
||||
for component_name, component_value in component_types.items():
|
||||
is_diffusers_model = issubclass(component_value[0], diffusers_module.ModelMixin)
|
||||
is_scheduler_enum = component_value[0].__name__ == "KarrasDiffusionSchedulers"
|
||||
is_scheduler = issubclass(component_value[0], diffusers_module.SchedulerMixin)
|
||||
|
||||
is_transformers_model = (
|
||||
is_transformers_available()
|
||||
and issubclass(component_value[0], PreTrainedModel)
|
||||
and transformers_version >= version.parse("4.20.0")
|
||||
)
|
||||
is_transformers_tokenizer = (
|
||||
is_transformers_available()
|
||||
and issubclass(component_value[0], PreTrainedTokenizer)
|
||||
and transformers_version >= version.parse("4.20.0")
|
||||
)
|
||||
|
||||
if is_diffusers_model and component_name not in SINGLE_FILE_OPTIONAL_COMPONENTS:
|
||||
config_dict[component_name] = ["diffusers", component_value[0].__name__]
|
||||
|
||||
elif is_scheduler_enum or is_scheduler:
|
||||
if is_scheduler_enum:
|
||||
# Since we cannot fetch a scheduler config from the hub, we default to DDIMScheduler
|
||||
# if the type hint is a KarrassDiffusionSchedulers enum
|
||||
config_dict[component_name] = ["diffusers", "DDIMScheduler"]
|
||||
|
||||
elif is_scheduler:
|
||||
config_dict[component_name] = ["diffusers", component_value[0].__name__]
|
||||
|
||||
elif (
|
||||
is_transformers_model or is_transformers_tokenizer
|
||||
) and component_name not in SINGLE_FILE_OPTIONAL_COMPONENTS:
|
||||
config_dict[component_name] = ["transformers", component_value[0].__name__]
|
||||
|
||||
else:
|
||||
config_dict[component_name] = [None, None]
|
||||
|
||||
return config_dict
|
||||
|
||||
|
||||
def _infer_pipeline_config_dict(pipeline_class):
|
||||
parameters = inspect.signature(pipeline_class.__init__).parameters
|
||||
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
|
||||
component_types = pipeline_class._get_signature_types()
|
||||
|
||||
# Ignore parameters that are not required for the pipeline
|
||||
component_types = {k: v for k, v in component_types.items() if k in required_parameters}
|
||||
config_dict = _map_component_types_to_config_dict(component_types)
|
||||
|
||||
return config_dict
|
||||
|
||||
|
||||
def _download_diffusers_model_config_from_hub(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir,
|
||||
revision,
|
||||
proxies,
|
||||
force_download=None,
|
||||
resume_download=None,
|
||||
local_files_only=None,
|
||||
token=None,
|
||||
):
|
||||
allow_patterns = ["**/*.json", "*.json", "*.txt", "**/*.txt"]
|
||||
cached_model_path = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
revision=revision,
|
||||
proxies=proxies,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
allow_patterns=allow_patterns,
|
||||
)
|
||||
|
||||
return cached_model_path
|
||||
|
||||
|
||||
class FromSingleFileMixin:
|
||||
@@ -195,27 +294,12 @@ class FromSingleFileMixin:
|
||||
original_config_file (`str`, *optional*):
|
||||
The path to the original config file that was used to train the model. If not provided, the config file
|
||||
will be inferred from the checkpoint file.
|
||||
model_type (`str`, *optional*):
|
||||
The type of model to load. If not provided, the model type will be inferred from the checkpoint file.
|
||||
image_size (`int`, *optional*):
|
||||
The size of the image output. It's used to configure the `sample_size` parameter of the UNet and VAE
|
||||
model.
|
||||
load_safety_checker (`bool`, *optional*, defaults to `False`):
|
||||
Whether to load the safety checker model or not. By default, the safety checker is not loaded unless a
|
||||
`safety_checker` component is passed to the `kwargs`.
|
||||
num_in_channels (`int`, *optional*):
|
||||
Specify the number of input channels for the UNet model. Read more about how to configure UNet model
|
||||
with this parameter
|
||||
[here](https://huggingface.co/docs/diffusers/training/adapt_a_model#configure-unet2dconditionmodel-parameters).
|
||||
scaling_factor (`float`, *optional*):
|
||||
The scaling factor to use for the VAE model. If not provided, it is inferred from the config file
|
||||
first. If the scaling factor is not found in the config file, the default value 0.18215 is used.
|
||||
scheduler_type (`str`, *optional*):
|
||||
The type of scheduler to load. If not provided, the scheduler type will be inferred from the checkpoint
|
||||
file.
|
||||
prediction_type (`str`, *optional*):
|
||||
The type of prediction to load. If not provided, the prediction type will be inferred from the
|
||||
checkpoint file.
|
||||
config (`str`, *optional*):
|
||||
Can be either:
|
||||
- A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline
|
||||
hosted on the Hub.
|
||||
- A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline
|
||||
component configs in Diffusers format.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
|
||||
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
|
||||
@@ -233,7 +317,7 @@ class FromSingleFileMixin:
|
||||
|
||||
>>> # Download pipeline from local file
|
||||
>>> # file is downloaded under ./v1-5-pruned-emaonly.ckpt
|
||||
>>> pipeline = StableDiffusionPipeline.from_single_file("./v1-5-pruned-emaonly")
|
||||
>>> pipeline = StableDiffusionPipeline.from_single_file("./v1-5-pruned-emaonly.ckpt")
|
||||
|
||||
>>> # Enable float16 and move to GPU
|
||||
>>> pipeline = StableDiffusionPipeline.from_single_file(
|
||||
@@ -242,9 +326,21 @@ class FromSingleFileMixin:
|
||||
... )
|
||||
>>> pipeline.to("cuda")
|
||||
```
|
||||
|
||||
"""
|
||||
original_config_file = kwargs.pop("original_config_file", None)
|
||||
resume_download = kwargs.pop("resume_download", None)
|
||||
config = kwargs.pop("config", None)
|
||||
original_config = kwargs.pop("original_config", None)
|
||||
|
||||
if original_config_file is not None:
|
||||
deprecation_message = (
|
||||
"`original_config_file` argument is deprecated and will be removed in future versions."
|
||||
"please use the `original_config` argument instead."
|
||||
)
|
||||
deprecate("original_config_file", "1.0.0", deprecation_message)
|
||||
original_config = original_config_file
|
||||
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
token = kwargs.pop("token", None)
|
||||
@@ -253,68 +349,198 @@ class FromSingleFileMixin:
|
||||
revision = kwargs.pop("revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
|
||||
class_name = cls.__name__
|
||||
is_legacy_loading = False
|
||||
|
||||
original_config, checkpoint = fetch_ldm_config_and_checkpoint(
|
||||
pretrained_model_link_or_path=pretrained_model_link_or_path,
|
||||
class_name=class_name,
|
||||
original_config_file=original_config_file,
|
||||
# We shouldn't allow configuring individual models components through a Pipeline creation method
|
||||
# These model kwargs should be deprecated
|
||||
scaling_factor = kwargs.get("scaling_factor", None)
|
||||
if scaling_factor is not None:
|
||||
deprecation_message = (
|
||||
"Passing the `scaling_factor` argument to `from_single_file is deprecated "
|
||||
"and will be ignored in future versions."
|
||||
)
|
||||
deprecate("scaling_factor", "1.0.0", deprecation_message)
|
||||
|
||||
if original_config is not None:
|
||||
original_config = fetch_original_config(original_config, local_files_only=local_files_only)
|
||||
|
||||
from ..pipelines.pipeline_utils import _get_pipeline_class
|
||||
|
||||
pipeline_class = _get_pipeline_class(cls, config=None)
|
||||
|
||||
checkpoint = load_single_file_checkpoint(
|
||||
pretrained_model_link_or_path,
|
||||
resume_download=resume_download,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
cache_dir=cache_dir,
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
from ..pipelines.pipeline_utils import _get_pipeline_class
|
||||
if config is None:
|
||||
config = fetch_diffusers_config(checkpoint)
|
||||
default_pretrained_model_config_name = config["pretrained_model_name_or_path"]
|
||||
else:
|
||||
default_pretrained_model_config_name = config
|
||||
|
||||
pipeline_class = _get_pipeline_class(
|
||||
cls,
|
||||
config=None,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
if not os.path.isdir(default_pretrained_model_config_name):
|
||||
# Provided config is a repo_id
|
||||
if default_pretrained_model_config_name.count("/") > 1:
|
||||
raise ValueError(
|
||||
f'The provided config "{config}"'
|
||||
" is neither a valid local path nor a valid repo id. Please check the parameter."
|
||||
)
|
||||
try:
|
||||
# Attempt to download the config files for the pipeline
|
||||
cached_model_config_path = _download_diffusers_model_config_from_hub(
|
||||
default_pretrained_model_config_name,
|
||||
cache_dir=cache_dir,
|
||||
revision=revision,
|
||||
proxies=proxies,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
)
|
||||
config_dict = pipeline_class.load_config(cached_model_config_path)
|
||||
|
||||
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
|
||||
except LocalEntryNotFoundError:
|
||||
# `local_files_only=True` but a local diffusers format model config is not available in the cache
|
||||
# If `original_config` is not provided, we need override `local_files_only` to False
|
||||
# to fetch the config files from the hub so that we have a way
|
||||
# to configure the pipeline components.
|
||||
|
||||
if original_config is None:
|
||||
logger.warning(
|
||||
"`local_files_only` is True but no local configs were found for this checkpoint.\n"
|
||||
"Attempting to download the necessary config files for this pipeline.\n"
|
||||
)
|
||||
cached_model_config_path = _download_diffusers_model_config_from_hub(
|
||||
default_pretrained_model_config_name,
|
||||
cache_dir=cache_dir,
|
||||
revision=revision,
|
||||
proxies=proxies,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
local_files_only=False,
|
||||
token=token,
|
||||
)
|
||||
config_dict = pipeline_class.load_config(cached_model_config_path)
|
||||
|
||||
else:
|
||||
# For backwards compatibility
|
||||
# If `original_config` is provided, then we need to assume we are using legacy loading for pipeline components
|
||||
logger.warning(
|
||||
"Detected legacy `from_single_file` loading behavior. Attempting to create the pipeline based on inferred components.\n"
|
||||
"This may lead to errors if the model components are not correctly inferred. \n"
|
||||
"To avoid this warning, please explicity pass the `config` argument to `from_single_file` with a path to a local diffusers model repo \n"
|
||||
"e.g. `from_single_file(<my model checkpoint path>, config=<path to local diffusers model repo>) \n"
|
||||
"or run `from_single_file` with `local_files_only=False` first to update the local cache directory with "
|
||||
"the necessary config files.\n"
|
||||
)
|
||||
is_legacy_loading = True
|
||||
cached_model_config_path = None
|
||||
|
||||
config_dict = _infer_pipeline_config_dict(pipeline_class)
|
||||
config_dict["_class_name"] = pipeline_class.__name__
|
||||
|
||||
else:
|
||||
# Provided config is a path to a local directory attempt to load directly.
|
||||
cached_model_config_path = default_pretrained_model_config_name
|
||||
config_dict = pipeline_class.load_config(cached_model_config_path)
|
||||
|
||||
# pop out "_ignore_files" as it is only needed for download
|
||||
config_dict.pop("_ignore_files", None)
|
||||
|
||||
expected_modules, optional_kwargs = pipeline_class._get_signature_keys(cls)
|
||||
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
|
||||
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
|
||||
|
||||
model_type = kwargs.pop("model_type", None)
|
||||
image_size = kwargs.pop("image_size", None)
|
||||
load_safety_checker = (kwargs.pop("load_safety_checker", False)) or (
|
||||
passed_class_obj.get("safety_checker", None) is not None
|
||||
)
|
||||
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
||||
init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict}
|
||||
init_kwargs = {**init_kwargs, **passed_pipe_kwargs}
|
||||
|
||||
from diffusers import pipelines
|
||||
|
||||
# remove `null` components
|
||||
def load_module(name, value):
|
||||
if value[0] is None:
|
||||
return False
|
||||
if name in passed_class_obj and passed_class_obj[name] is None:
|
||||
return False
|
||||
if name in SINGLE_FILE_OPTIONAL_COMPONENTS:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
|
||||
|
||||
for name, (library_name, class_name) in logging.tqdm(
|
||||
sorted(init_dict.items()), desc="Loading pipeline components..."
|
||||
):
|
||||
loaded_sub_model = None
|
||||
is_pipeline_module = hasattr(pipelines, library_name)
|
||||
|
||||
init_kwargs = {}
|
||||
for name in expected_modules:
|
||||
if name in passed_class_obj:
|
||||
init_kwargs[name] = passed_class_obj[name]
|
||||
loaded_sub_model = passed_class_obj[name]
|
||||
|
||||
else:
|
||||
components = build_sub_model_components(
|
||||
init_kwargs,
|
||||
class_name,
|
||||
name,
|
||||
original_config,
|
||||
checkpoint,
|
||||
model_type=model_type,
|
||||
image_size=image_size,
|
||||
load_safety_checker=load_safety_checker,
|
||||
local_files_only=local_files_only,
|
||||
torch_dtype=torch_dtype,
|
||||
**kwargs,
|
||||
)
|
||||
if not components:
|
||||
continue
|
||||
init_kwargs.update(components)
|
||||
try:
|
||||
loaded_sub_model = load_single_file_sub_model(
|
||||
library_name=library_name,
|
||||
class_name=class_name,
|
||||
name=name,
|
||||
checkpoint=checkpoint,
|
||||
is_pipeline_module=is_pipeline_module,
|
||||
cached_model_config_path=cached_model_config_path,
|
||||
pipelines=pipelines,
|
||||
torch_dtype=torch_dtype,
|
||||
original_config=original_config,
|
||||
local_files_only=local_files_only,
|
||||
is_legacy_loading=is_legacy_loading,
|
||||
**kwargs,
|
||||
)
|
||||
except SingleFileComponentError as e:
|
||||
raise SingleFileComponentError(
|
||||
(
|
||||
f"{e.message}\n"
|
||||
f"Please load the component before passing it in as an argument to `from_single_file`.\n"
|
||||
f"\n"
|
||||
f"{name} = {class_name}.from_pretrained('...')\n"
|
||||
f"pipe = {pipeline_class.__name__}.from_single_file(<checkpoint path>, {name}={name})\n"
|
||||
f"\n"
|
||||
)
|
||||
)
|
||||
|
||||
additional_components = set_additional_components(
|
||||
class_name, original_config, checkpoint=checkpoint, model_type=model_type
|
||||
)
|
||||
if additional_components:
|
||||
init_kwargs.update(additional_components)
|
||||
init_kwargs[name] = loaded_sub_model
|
||||
|
||||
missing_modules = set(expected_modules) - set(init_kwargs.keys())
|
||||
passed_modules = list(passed_class_obj.keys())
|
||||
optional_modules = pipeline_class._optional_components
|
||||
|
||||
if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules):
|
||||
for module in missing_modules:
|
||||
init_kwargs[module] = passed_class_obj.get(module, None)
|
||||
elif len(missing_modules) > 0:
|
||||
passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs
|
||||
raise ValueError(
|
||||
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
|
||||
)
|
||||
|
||||
# deprecated kwargs
|
||||
load_safety_checker = kwargs.pop("load_safety_checker", None)
|
||||
if load_safety_checker is not None:
|
||||
deprecation_message = (
|
||||
"Please pass instances of `StableDiffusionSafetyChecker` and `AutoImageProcessor`"
|
||||
"using the `safety_checker` and `feature_extractor` arguments in `from_single_file`"
|
||||
)
|
||||
deprecate("load_safety_checker", "1.0.0", deprecation_message)
|
||||
|
||||
safety_checker_components = _legacy_load_safety_checker(local_files_only, torch_dtype)
|
||||
init_kwargs.update(safety_checker_components)
|
||||
|
||||
init_kwargs.update(passed_pipe_kwargs)
|
||||
pipe = pipeline_class(**init_kwargs)
|
||||
|
||||
if torch_dtype is not None:
|
||||
|
||||
290
src/diffusers/loaders/single_file_model.py
Normal file
290
src/diffusers/loaders/single_file_model.py
Normal file
@@ -0,0 +1,290 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
import re
|
||||
from contextlib import nullcontext
|
||||
from typing import Optional
|
||||
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from ..utils import deprecate, is_accelerate_available, logging
|
||||
from .single_file_utils import (
|
||||
SingleFileComponentError,
|
||||
convert_controlnet_checkpoint,
|
||||
convert_ldm_unet_checkpoint,
|
||||
convert_ldm_vae_checkpoint,
|
||||
convert_stable_cascade_unet_single_file_to_diffusers,
|
||||
create_controlnet_diffusers_config_from_ldm,
|
||||
create_unet_diffusers_config_from_ldm,
|
||||
create_vae_diffusers_config_from_ldm,
|
||||
fetch_diffusers_config,
|
||||
fetch_original_config,
|
||||
load_single_file_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
from ..models.modeling_utils import load_model_dict_into_meta
|
||||
|
||||
|
||||
SINGLE_FILE_LOADABLE_CLASSES = {
|
||||
"StableCascadeUNet": {
|
||||
"checkpoint_mapping_fn": convert_stable_cascade_unet_single_file_to_diffusers,
|
||||
},
|
||||
"UNet2DConditionModel": {
|
||||
"checkpoint_mapping_fn": convert_ldm_unet_checkpoint,
|
||||
"config_mapping_fn": create_unet_diffusers_config_from_ldm,
|
||||
"default_subfolder": "unet",
|
||||
"legacy_kwargs": {
|
||||
"num_in_channels": "in_channels", # Legacy kwargs supported by `from_single_file` mapped to new args
|
||||
},
|
||||
},
|
||||
"AutoencoderKL": {
|
||||
"checkpoint_mapping_fn": convert_ldm_vae_checkpoint,
|
||||
"config_mapping_fn": create_vae_diffusers_config_from_ldm,
|
||||
"default_subfolder": "vae",
|
||||
},
|
||||
"ControlNetModel": {
|
||||
"checkpoint_mapping_fn": convert_controlnet_checkpoint,
|
||||
"config_mapping_fn": create_controlnet_diffusers_config_from_ldm,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _get_mapping_function_kwargs(mapping_fn, **kwargs):
|
||||
parameters = inspect.signature(mapping_fn).parameters
|
||||
|
||||
mapping_kwargs = {}
|
||||
for parameter in parameters:
|
||||
if parameter in kwargs:
|
||||
mapping_kwargs[parameter] = kwargs[parameter]
|
||||
|
||||
return mapping_kwargs
|
||||
|
||||
|
||||
class FromOriginalModelMixin:
|
||||
"""
|
||||
Load pretrained weights saved in the `.ckpt` or `.safetensors` format into a model.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = None, **kwargs):
|
||||
r"""
|
||||
Instantiate a model from pretrained weights saved in the original `.ckpt` or `.safetensors` format. The model
|
||||
is set in evaluation mode (`model.eval()`) by default.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_link_or_path_or_dict (`str`, *optional*):
|
||||
Can be either:
|
||||
- A link to the `.safetensors` or `.ckpt` file (for example
|
||||
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.safetensors"`) on the Hub.
|
||||
- A path to a local *file* containing the weights of the component model.
|
||||
- A state dict containing the component model weights.
|
||||
config (`str`, *optional*):
|
||||
- A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline hosted
|
||||
on the Hub.
|
||||
- A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline component
|
||||
configs in Diffusers format.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
||||
original_config (`str`, *optional*):
|
||||
Dict or path to a yaml file containing the configuration for the model in its original format.
|
||||
If a dict is provided, it will be used to initialize the model configuration.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
resume_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
||||
incompletely downloaded files are deleted.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to True, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables (for example the pipeline components of the
|
||||
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
|
||||
method. See example below for more information.
|
||||
|
||||
```py
|
||||
>>> from diffusers import StableCascadeUNet
|
||||
|
||||
>>> ckpt_path = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_b_lite.safetensors"
|
||||
>>> model = StableCascadeUNet.from_single_file(ckpt_path)
|
||||
```
|
||||
"""
|
||||
|
||||
class_name = cls.__name__
|
||||
if class_name not in SINGLE_FILE_LOADABLE_CLASSES:
|
||||
raise ValueError(
|
||||
f"FromOriginalModelMixin is currently only compatible with {', '.join(SINGLE_FILE_LOADABLE_CLASSES.keys())}"
|
||||
)
|
||||
|
||||
pretrained_model_link_or_path = kwargs.get("pretrained_model_link_or_path", None)
|
||||
if pretrained_model_link_or_path is not None:
|
||||
deprecation_message = (
|
||||
"Please use `pretrained_model_link_or_path_or_dict` argument instead for model classes"
|
||||
)
|
||||
deprecate("pretrained_model_link_or_path", "1.0.0", deprecation_message)
|
||||
pretrained_model_link_or_path_or_dict = pretrained_model_link_or_path
|
||||
|
||||
config = kwargs.pop("config", None)
|
||||
original_config = kwargs.pop("original_config", None)
|
||||
|
||||
if config is not None and original_config is not None:
|
||||
raise ValueError(
|
||||
"`from_single_file` cannot accept both `config` and `original_config` arguments. Please provide only one of these arguments"
|
||||
)
|
||||
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
token = kwargs.pop("token", None)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
|
||||
if isinstance(pretrained_model_link_or_path_or_dict, dict):
|
||||
checkpoint = pretrained_model_link_or_path_or_dict
|
||||
else:
|
||||
checkpoint = load_single_file_checkpoint(
|
||||
pretrained_model_link_or_path_or_dict,
|
||||
resume_download=resume_download,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[class_name]
|
||||
|
||||
checkpoint_mapping_fn = mapping_functions["checkpoint_mapping_fn"]
|
||||
if original_config:
|
||||
if "config_mapping_fn" in mapping_functions:
|
||||
config_mapping_fn = mapping_functions["config_mapping_fn"]
|
||||
else:
|
||||
config_mapping_fn = None
|
||||
|
||||
if config_mapping_fn is None:
|
||||
raise ValueError(
|
||||
(
|
||||
f"`original_config` has been provided for {class_name} but no mapping function"
|
||||
"was found to convert the original config to a Diffusers config in"
|
||||
"`diffusers.loaders.single_file_utils`"
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(original_config, str):
|
||||
# If original_config is a URL or filepath fetch the original_config dict
|
||||
original_config = fetch_original_config(original_config, local_files_only=local_files_only)
|
||||
|
||||
config_mapping_kwargs = _get_mapping_function_kwargs(config_mapping_fn, **kwargs)
|
||||
diffusers_model_config = config_mapping_fn(
|
||||
original_config=original_config, checkpoint=checkpoint, **config_mapping_kwargs
|
||||
)
|
||||
else:
|
||||
if config:
|
||||
if isinstance(config, str):
|
||||
default_pretrained_model_config_name = config
|
||||
else:
|
||||
raise ValueError(
|
||||
(
|
||||
"Invalid `config` argument. Please provide a string representing a repo id"
|
||||
"or path to a local Diffusers model repo."
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
config = fetch_diffusers_config(checkpoint)
|
||||
default_pretrained_model_config_name = config["pretrained_model_name_or_path"]
|
||||
|
||||
if "default_subfolder" in mapping_functions:
|
||||
subfolder = mapping_functions["default_subfolder"]
|
||||
|
||||
subfolder = subfolder or config.pop(
|
||||
"subfolder", None
|
||||
) # some configs contain a subfolder key, e.g. StableCascadeUNet
|
||||
|
||||
diffusers_model_config = cls.load_config(
|
||||
pretrained_model_name_or_path=default_pretrained_model_config_name,
|
||||
subfolder=subfolder,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
expected_kwargs, optional_kwargs = cls._get_signature_keys(cls)
|
||||
|
||||
# Map legacy kwargs to new kwargs
|
||||
if "legacy_kwargs" in mapping_functions:
|
||||
legacy_kwargs = mapping_functions["legacy_kwargs"]
|
||||
for legacy_key, new_key in legacy_kwargs.items():
|
||||
if legacy_key in kwargs:
|
||||
kwargs[new_key] = kwargs.pop(legacy_key)
|
||||
|
||||
model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs}
|
||||
diffusers_model_config.update(model_kwargs)
|
||||
|
||||
checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
|
||||
diffusers_format_checkpoint = checkpoint_mapping_fn(
|
||||
config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
|
||||
)
|
||||
if not diffusers_format_checkpoint:
|
||||
raise SingleFileComponentError(
|
||||
f"Failed to load {class_name}. Weights for this component appear to be missing in the checkpoint."
|
||||
)
|
||||
|
||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
with ctx():
|
||||
model = cls.from_config(diffusers_model_config)
|
||||
|
||||
if is_accelerate_available():
|
||||
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
|
||||
if model._keys_to_ignore_on_load_unexpected is not None:
|
||||
for pat in model._keys_to_ignore_on_load_unexpected:
|
||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warning(
|
||||
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
||||
)
|
||||
else:
|
||||
model.load_state_dict(diffusers_format_checkpoint)
|
||||
|
||||
if torch_dtype is not None:
|
||||
model.to(torch_dtype)
|
||||
|
||||
model.eval()
|
||||
|
||||
return model
|
||||
File diff suppressed because it is too large
Load Diff
@@ -44,11 +44,6 @@ from ..utils import (
|
||||
set_adapter_layers,
|
||||
set_weights_and_activate_adapters,
|
||||
)
|
||||
from .single_file_utils import (
|
||||
convert_stable_cascade_unet_single_file_to_diffusers,
|
||||
infer_stable_cascade_single_file_config,
|
||||
load_single_file_model_checkpoint,
|
||||
)
|
||||
from .unet_loader_utils import _maybe_expand_lora_scales
|
||||
from .utils import AttnProcsLayers
|
||||
|
||||
@@ -1059,103 +1054,3 @@ class UNet2DConditionLoadersMixin:
|
||||
}
|
||||
)
|
||||
return lora_dicts
|
||||
|
||||
|
||||
class FromOriginalUNetMixin:
|
||||
"""
|
||||
Load pretrained UNet model weights saved in the `.ckpt` or `.safetensors` format into a [`StableCascadeUNet`].
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
|
||||
r"""
|
||||
Instantiate a [`StableCascadeUNet`] from pretrained StableCascadeUNet weights saved in the original `.ckpt` or
|
||||
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
- A link to the `.ckpt` file (for example
|
||||
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
|
||||
- A path to a *file* containing all pipeline weights.
|
||||
config: (`dict`, *optional*):
|
||||
Dictionary containing the configuration of the model:
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
resume_download:
|
||||
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
||||
of Diffusers.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to True, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables of the model.
|
||||
|
||||
"""
|
||||
class_name = cls.__name__
|
||||
if class_name != "StableCascadeUNet":
|
||||
raise ValueError("FromOriginalUNetMixin is currently only compatible with StableCascadeUNet")
|
||||
|
||||
config = kwargs.pop("config", None)
|
||||
resume_download = kwargs.pop("resume_download", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
token = kwargs.pop("token", None)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
|
||||
checkpoint = load_single_file_model_checkpoint(
|
||||
pretrained_model_link_or_path,
|
||||
resume_download=resume_download,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
if config is None:
|
||||
config = infer_stable_cascade_single_file_config(checkpoint)
|
||||
model_config = cls.load_config(**config, **kwargs)
|
||||
else:
|
||||
model_config = config
|
||||
|
||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
with ctx():
|
||||
model = cls.from_config(model_config, **kwargs)
|
||||
|
||||
diffusers_format_checkpoint = convert_stable_cascade_unet_single_file_to_diffusers(checkpoint)
|
||||
if is_accelerate_available():
|
||||
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warning(
|
||||
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
||||
)
|
||||
|
||||
else:
|
||||
model.load_state_dict(diffusers_format_checkpoint)
|
||||
|
||||
if torch_dtype is not None:
|
||||
model.to(torch_dtype)
|
||||
|
||||
return model
|
||||
|
||||
@@ -17,7 +17,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalVAEMixin
|
||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
@@ -32,7 +32,7 @@ from ..modeling_utils import ModelMixin
|
||||
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
|
||||
|
||||
|
||||
class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
||||
class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
r"""
|
||||
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import FromOriginalControlNetMixin
|
||||
from ..loaders.single_file_model import FromOriginalModelMixin
|
||||
from ..utils import BaseOutput, logging
|
||||
from .attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
@@ -108,7 +108,7 @@ class ControlNetConditioningEmbedding(nn.Module):
|
||||
return embedding
|
||||
|
||||
|
||||
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
|
||||
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
"""
|
||||
A ControlNet model.
|
||||
|
||||
|
||||
@@ -963,6 +963,15 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
|
||||
return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
|
||||
|
||||
@classmethod
|
||||
def _get_signature_keys(cls, obj):
|
||||
parameters = inspect.signature(obj.__init__).parameters
|
||||
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
|
||||
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
|
||||
expected_modules = set(required_parameters.keys()) - {"self"}
|
||||
|
||||
return expected_modules, optional_parameters
|
||||
|
||||
# Adapted from `transformers` modeling_utils.py
|
||||
def _get_no_split_modules(self, device_map: str):
|
||||
"""
|
||||
|
||||
@@ -20,6 +20,7 @@ import torch.utils.checkpoint
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
|
||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
||||
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..activations import get_activation
|
||||
from ..attention_processor import (
|
||||
@@ -66,7 +67,9 @@ class UNet2DConditionOutput(BaseOutput):
|
||||
sample: torch.FloatTensor = None
|
||||
|
||||
|
||||
class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
|
||||
class UNet2DConditionModel(
|
||||
ModelMixin, ConfigMixin, FromOriginalModelMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin
|
||||
):
|
||||
r"""
|
||||
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
|
||||
shaped output.
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders.unet import FromOriginalUNetMixin
|
||||
from ...loaders import FromOriginalModelMixin
|
||||
from ...utils import BaseOutput
|
||||
from ..attention_processor import Attention
|
||||
from ..modeling_utils import ModelMixin
|
||||
@@ -134,7 +134,7 @@ class StableCascadeUNetOutput(BaseOutput):
|
||||
sample: torch.FloatTensor = None
|
||||
|
||||
|
||||
class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalUNetMixin):
|
||||
class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
|
||||
@@ -609,6 +609,7 @@ def load_sub_model(
|
||||
):
|
||||
"""Helper method to load the module `name` from `library_name` and `class_name`"""
|
||||
# retrieve class candidates
|
||||
|
||||
class_obj, class_candidates = get_class_obj_and_candidates(
|
||||
library_name,
|
||||
class_name,
|
||||
|
||||
@@ -791,62 +791,6 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
tolerance = 3e-3 if torch_device != "mps" else 1e-2
|
||||
assert torch_all_close(output_slice, expected_output_slice, atol=tolerance)
|
||||
|
||||
def test_stable_diffusion_model_local(self):
|
||||
model_id = "stabilityai/sd-vae-ft-mse"
|
||||
model_1 = AutoencoderKL.from_pretrained(model_id).to(torch_device)
|
||||
|
||||
url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors"
|
||||
model_2 = AutoencoderKL.from_single_file(url).to(torch_device)
|
||||
image = self.get_sd_image(33)
|
||||
|
||||
with torch.no_grad():
|
||||
sample_1 = model_1(image).sample
|
||||
sample_2 = model_2(image).sample
|
||||
|
||||
assert sample_1.shape == sample_2.shape
|
||||
|
||||
output_slice_1 = sample_1[-1, -2:, -2:, :2].flatten().float().cpu()
|
||||
output_slice_2 = sample_2[-1, -2:, -2:, :2].flatten().float().cpu()
|
||||
|
||||
assert torch_all_close(output_slice_1, output_slice_2, atol=3e-3)
|
||||
|
||||
def test_single_file_component_configs(self):
|
||||
vae_single_file = AutoencoderKL.from_single_file(
|
||||
"https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors"
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")
|
||||
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values"]
|
||||
for param_name, param_value in vae_single_file.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert (
|
||||
vae.config[param_name] == param_value
|
||||
), f"{param_name} differs between single file loading and pretrained loading"
|
||||
|
||||
def test_single_file_arguments(self):
|
||||
vae_default = AutoencoderKL.from_single_file(
|
||||
"https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors",
|
||||
)
|
||||
|
||||
assert vae_default.config.scaling_factor == 0.18215
|
||||
assert vae_default.config.sample_size == 512
|
||||
assert vae_default.dtype == torch.float32
|
||||
|
||||
scaling_factor = 2.0
|
||||
image_size = 256
|
||||
torch_dtype = torch.float16
|
||||
|
||||
vae = AutoencoderKL.from_single_file(
|
||||
"https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors",
|
||||
image_size=image_size,
|
||||
scaling_factor=scaling_factor,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
assert vae.config.scaling_factor == scaling_factor
|
||||
assert vae.config.sample_size == image_size
|
||||
assert vae.dtype == torch_dtype
|
||||
|
||||
|
||||
@slow
|
||||
class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
|
||||
@@ -56,7 +56,7 @@ class StableCascadeUNetModelSlowTests(unittest.TestCase):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values"]
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
|
||||
for param_name, param_value in single_file_unet_config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
@@ -78,7 +78,7 @@ class StableCascadeUNetModelSlowTests(unittest.TestCase):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values"]
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
|
||||
for param_name, param_value in single_file_unet_config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
@@ -97,7 +97,7 @@ class StableCascadeUNetModelSlowTests(unittest.TestCase):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values"]
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
|
||||
for param_name, param_value in config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
|
||||
@@ -38,7 +38,6 @@ from diffusers.utils.testing_utils import (
|
||||
get_python_version,
|
||||
load_image,
|
||||
load_numpy,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_python39_or_higher,
|
||||
require_torch_2,
|
||||
require_torch_gpu,
|
||||
@@ -1063,97 +1062,6 @@ class ControlNetPipelineSlowTests(unittest.TestCase):
|
||||
expected_slice = np.array([0.1338, 0.1597, 0.1202, 0.1687, 0.1377, 0.1017, 0.2070, 0.1574, 0.1348])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_load_local(self):
|
||||
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
|
||||
pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
|
||||
)
|
||||
pipe.unet.set_default_attn_processor()
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
controlnet = ControlNetModel.from_single_file(
|
||||
"https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth"
|
||||
)
|
||||
pipe_sf = StableDiffusionControlNetPipeline.from_single_file(
|
||||
"https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors",
|
||||
safety_checker=None,
|
||||
controlnet=controlnet,
|
||||
scheduler_type="pndm",
|
||||
)
|
||||
pipe_sf.unet.set_default_attn_processor()
|
||||
pipe_sf.enable_model_cpu_offload()
|
||||
|
||||
control_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
|
||||
).resize((512, 512))
|
||||
prompt = "bird"
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
output = pipe(
|
||||
prompt,
|
||||
image=control_image,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
num_inference_steps=3,
|
||||
).images[0]
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
output_sf = pipe_sf(
|
||||
prompt,
|
||||
image=control_image,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
num_inference_steps=3,
|
||||
).images[0]
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(output_sf.flatten(), output.flatten())
|
||||
assert max_diff < 1e-3
|
||||
|
||||
def test_single_file_component_configs(self):
|
||||
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny", variant="fp16")
|
||||
pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", variant="fp16", safety_checker=None, controlnet=controlnet
|
||||
)
|
||||
|
||||
controlnet_single_file = ControlNetModel.from_single_file(
|
||||
"https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth"
|
||||
)
|
||||
single_file_pipe = StableDiffusionControlNetPipeline.from_single_file(
|
||||
"https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors",
|
||||
safety_checker=None,
|
||||
controlnet=controlnet_single_file,
|
||||
scheduler_type="pndm",
|
||||
)
|
||||
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "architectures", "_use_default_values"]
|
||||
for param_name, param_value in single_file_pipe.controlnet.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
|
||||
# This parameter doesn't appear to be loaded from the config.
|
||||
# So when it is registered to config, it remains a tuple as this is the default in the class definition
|
||||
# from_pretrained, does load from config and converts to a list when registering to config
|
||||
if param_name == "conditioning_embedding_out_channels" and isinstance(param_value, tuple):
|
||||
param_value = list(param_value)
|
||||
|
||||
assert (
|
||||
pipe.controlnet.config[param_name] == param_value
|
||||
), f"{param_name} differs between single file loading and pretrained loading"
|
||||
|
||||
for param_name, param_value in single_file_pipe.unet.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert (
|
||||
pipe.unet.config[param_name] == param_value
|
||||
), f"{param_name} differs between single file loading and pretrained loading"
|
||||
|
||||
for param_name, param_value in single_file_pipe.vae.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert (
|
||||
pipe.vae.config[param_name] == param_value
|
||||
), f"{param_name} differs between single file loading and pretrained loading"
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -39,7 +39,6 @@ from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
load_numpy,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
@@ -441,56 +440,3 @@ class ControlNetImg2ImgPipelineSlowTests(unittest.TestCase):
|
||||
)
|
||||
|
||||
assert np.abs(expected_image - image).max() < 9e-2
|
||||
|
||||
def test_load_local(self):
|
||||
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
|
||||
pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
|
||||
)
|
||||
pipe.unet.set_default_attn_processor()
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
controlnet = ControlNetModel.from_single_file(
|
||||
"https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth"
|
||||
)
|
||||
pipe_sf = StableDiffusionControlNetImg2ImgPipeline.from_single_file(
|
||||
"https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors",
|
||||
safety_checker=None,
|
||||
controlnet=controlnet,
|
||||
scheduler_type="pndm",
|
||||
)
|
||||
pipe_sf.unet.set_default_attn_processor()
|
||||
pipe_sf.enable_model_cpu_offload()
|
||||
|
||||
control_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
|
||||
).resize((512, 512))
|
||||
image = load_image(
|
||||
"https://huggingface.co/lllyasviel/sd-controlnet-canny/resolve/main/images/bird.png"
|
||||
).resize((512, 512))
|
||||
prompt = "bird"
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
output = pipe(
|
||||
prompt,
|
||||
image=image,
|
||||
control_image=control_image,
|
||||
strength=0.9,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
num_inference_steps=3,
|
||||
).images[0]
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
output_sf = pipe_sf(
|
||||
prompt,
|
||||
image=image,
|
||||
control_image=control_image,
|
||||
strength=0.9,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
num_inference_steps=3,
|
||||
).images[0]
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(output_sf.flatten(), output.flatten())
|
||||
assert max_diff < 1e-3
|
||||
|
||||
@@ -556,55 +556,3 @@ class ControlNetInpaintPipelineSlowTests(unittest.TestCase):
|
||||
)
|
||||
|
||||
assert numpy_cosine_similarity_distance(expected_image.flatten(), image.flatten()) < 1e-2
|
||||
|
||||
def test_load_local(self):
|
||||
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
|
||||
pipe_1 = StableDiffusionControlNetInpaintPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting", safety_checker=None, controlnet=controlnet
|
||||
)
|
||||
|
||||
controlnet = ControlNetModel.from_single_file(
|
||||
"https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth"
|
||||
)
|
||||
pipe_2 = StableDiffusionControlNetInpaintPipeline.from_single_file(
|
||||
"https://huggingface.co/runwayml/stable-diffusion-inpainting/blob/main/sd-v1-5-inpainting.ckpt",
|
||||
safety_checker=None,
|
||||
controlnet=controlnet,
|
||||
)
|
||||
control_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
|
||||
).resize((512, 512))
|
||||
image = load_image(
|
||||
"https://huggingface.co/lllyasviel/sd-controlnet-canny/resolve/main/images/bird.png"
|
||||
).resize((512, 512))
|
||||
mask_image = load_image(
|
||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
||||
"/stable_diffusion_inpaint/input_bench_mask.png"
|
||||
).resize((512, 512))
|
||||
|
||||
pipes = [pipe_1, pipe_2]
|
||||
images = []
|
||||
for pipe in pipes:
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
prompt = "bird"
|
||||
output = pipe(
|
||||
prompt,
|
||||
image=image,
|
||||
control_image=control_image,
|
||||
mask_image=mask_image,
|
||||
strength=0.9,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
num_inference_steps=3,
|
||||
)
|
||||
images.append(output.images[0])
|
||||
|
||||
del pipe
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(images[0].flatten(), images[1].flatten())
|
||||
assert max_diff < 1e-3
|
||||
|
||||
@@ -37,7 +37,6 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
load_image,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
@@ -949,89 +948,6 @@ class ControlNetSDXLPipelineSlowTests(unittest.TestCase):
|
||||
expected_image = np.array([0.4399, 0.5112, 0.5478, 0.4314, 0.472, 0.4823, 0.4647, 0.4957, 0.4853])
|
||||
assert np.allclose(original_image, expected_image, atol=1e-04)
|
||||
|
||||
def test_download_ckpt_diff_format_is_same(self):
|
||||
controlnet = ControlNetModel.from_pretrained("diffusers/controlnet-depth-sdxl-1.0", torch_dtype=torch.float16)
|
||||
single_file_url = (
|
||||
"https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors"
|
||||
)
|
||||
pipe_single_file = StableDiffusionXLControlNetPipeline.from_single_file(
|
||||
single_file_url, controlnet=controlnet, torch_dtype=torch.float16
|
||||
)
|
||||
pipe_single_file.unet.set_default_attn_processor()
|
||||
pipe_single_file.enable_model_cpu_offload()
|
||||
pipe_single_file.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
prompt = "Stormtrooper's lecture"
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/stormtrooper_depth.png"
|
||||
)
|
||||
single_file_images = pipe_single_file(
|
||||
prompt, image=image, generator=generator, output_type="np", num_inference_steps=2
|
||||
).images
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.unet.set_default_attn_processor()
|
||||
pipe.enable_model_cpu_offload()
|
||||
images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=2).images
|
||||
|
||||
assert images[0].shape == (512, 512, 3)
|
||||
assert single_file_images[0].shape == (512, 512, 3)
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(images[0].flatten(), single_file_images[0].flatten())
|
||||
assert max_diff < 5e-2
|
||||
|
||||
def test_single_file_component_configs(self):
|
||||
controlnet = ControlNetModel.from_pretrained(
|
||||
"diffusers/controlnet-depth-sdxl-1.0", torch_dtype=torch.float16, variant="fp16"
|
||||
)
|
||||
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
variant="fp16",
|
||||
controlnet=controlnet,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
single_file_url = (
|
||||
"https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors"
|
||||
)
|
||||
single_file_pipe = StableDiffusionXLControlNetPipeline.from_single_file(
|
||||
single_file_url, controlnet=controlnet, torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
for param_name, param_value in single_file_pipe.text_encoder.config.to_dict().items():
|
||||
if param_name in ["torch_dtype", "architectures", "_name_or_path"]:
|
||||
continue
|
||||
assert pipe.text_encoder.config.to_dict()[param_name] == param_value
|
||||
|
||||
for param_name, param_value in single_file_pipe.text_encoder_2.config.to_dict().items():
|
||||
if param_name in ["torch_dtype", "architectures", "_name_or_path"]:
|
||||
continue
|
||||
assert pipe.text_encoder_2.config.to_dict()[param_name] == param_value
|
||||
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "architectures", "_use_default_values"]
|
||||
for param_name, param_value in single_file_pipe.unet.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
|
||||
# Upcast attention might be set to None in a config file, which is incorrect. It should default to False in the model
|
||||
if param_name == "upcast_attention" and pipe.unet.config[param_name] is None:
|
||||
pipe.unet.config[param_name] = False
|
||||
|
||||
assert (
|
||||
pipe.unet.config[param_name] == param_value
|
||||
), f"{param_name} differs between single file loading and pretrained loading"
|
||||
|
||||
for param_name, param_value in single_file_pipe.vae.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert (
|
||||
pipe.vae.config[param_name] == param_value
|
||||
), f"{param_name} differs between single file loading and pretrained loading"
|
||||
|
||||
|
||||
class StableDiffusionSSD1BControlNetPipelineFastTests(StableDiffusionXLControlNetPipelineFastTests):
|
||||
def test_controlnet_sdxl_guess(self):
|
||||
|
||||
@@ -42,7 +42,6 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
logging,
|
||||
)
|
||||
from diffusers.models.attention_processor import AttnProcessor
|
||||
from diffusers.utils.testing_utils import (
|
||||
CaptureLogger,
|
||||
enable_full_determinism,
|
||||
@@ -1284,62 +1283,6 @@ class StableDiffusionPipelineCkptTests(unittest.TestCase):
|
||||
|
||||
assert image_out.shape == (512, 512, 3)
|
||||
|
||||
def test_download_ckpt_diff_format_is_same(self):
|
||||
ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors"
|
||||
|
||||
sf_pipe = StableDiffusionPipeline.from_single_file(ckpt_path)
|
||||
sf_pipe.scheduler = DDIMScheduler.from_config(sf_pipe.scheduler.config)
|
||||
sf_pipe.unet.set_attn_processor(AttnProcessor())
|
||||
sf_pipe.to("cuda")
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
image_single_file = sf_pipe("a turtle", num_inference_steps=2, generator=generator, output_type="np").images[0]
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.unet.set_attn_processor(AttnProcessor())
|
||||
pipe.to("cuda")
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
image = pipe("a turtle", num_inference_steps=2, generator=generator, output_type="np").images[0]
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(image.flatten(), image_single_file.flatten())
|
||||
|
||||
assert max_diff < 1e-3
|
||||
|
||||
def test_single_file_component_configs(self):
|
||||
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||
|
||||
ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors"
|
||||
single_file_pipe = StableDiffusionPipeline.from_single_file(ckpt_path, load_safety_checker=True)
|
||||
|
||||
for param_name, param_value in single_file_pipe.text_encoder.config.to_dict().items():
|
||||
if param_name in ["torch_dtype", "architectures", "_name_or_path"]:
|
||||
continue
|
||||
assert pipe.text_encoder.config.to_dict()[param_name] == param_value
|
||||
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "architectures", "_use_default_values"]
|
||||
for param_name, param_value in single_file_pipe.unet.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert (
|
||||
pipe.unet.config[param_name] == param_value
|
||||
), f"{param_name} differs between single file loading and pretrained loading"
|
||||
|
||||
for param_name, param_value in single_file_pipe.vae.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert (
|
||||
pipe.vae.config[param_name] == param_value
|
||||
), f"{param_name} differs between single file loading and pretrained loading"
|
||||
|
||||
for param_name, param_value in single_file_pipe.safety_checker.config.to_dict().items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert (
|
||||
pipe.safety_checker.config.to_dict()[param_name] == param_value
|
||||
), f"{param_name} differs between single file loading and pretrained loading"
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -36,7 +36,6 @@ from diffusers import (
|
||||
StableDiffusionInpaintPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.models.attention_processor import AttnProcessor
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import prepare_mask_and_masked_image
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
@@ -44,7 +43,6 @@ from diffusers.utils.testing_utils import (
|
||||
load_image,
|
||||
load_numpy,
|
||||
nightly,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_python39_or_higher,
|
||||
require_torch_2,
|
||||
require_torch_gpu,
|
||||
@@ -786,77 +784,6 @@ class StableDiffusionInpaintPipelineSlowTests(unittest.TestCase):
|
||||
expected_slice = np.array([0.3757, 0.3875, 0.4445, 0.4353, 0.3780, 0.4513, 0.3965, 0.3984, 0.4362])
|
||||
assert np.abs(expected_slice - image_slice).max() < 1e-3
|
||||
|
||||
def test_download_local(self):
|
||||
filename = hf_hub_download("runwayml/stable-diffusion-inpainting", filename="sd-v1-5-inpainting.ckpt")
|
||||
|
||||
pipe = StableDiffusionInpaintPipeline.from_single_file(filename, torch_dtype=torch.float16)
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.to("cuda")
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
inputs["num_inference_steps"] = 1
|
||||
image_out = pipe(**inputs).images[0]
|
||||
|
||||
assert image_out.shape == (512, 512, 3)
|
||||
|
||||
def test_download_ckpt_diff_format_is_same(self):
|
||||
ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-inpainting/blob/main/sd-v1-5-inpainting.ckpt"
|
||||
|
||||
pipe = StableDiffusionInpaintPipeline.from_single_file(ckpt_path)
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.unet.set_attn_processor(AttnProcessor())
|
||||
pipe.to("cuda")
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
inputs["num_inference_steps"] = 5
|
||||
image_ckpt = pipe(**inputs).images[0]
|
||||
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting")
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.unet.set_attn_processor(AttnProcessor())
|
||||
pipe.to("cuda")
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
inputs["num_inference_steps"] = 5
|
||||
image = pipe(**inputs).images[0]
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(image.flatten(), image_ckpt.flatten())
|
||||
|
||||
assert max_diff < 1e-4
|
||||
|
||||
def test_single_file_component_configs(self):
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting", variant="fp16")
|
||||
|
||||
ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-inpainting/blob/main/sd-v1-5-inpainting.ckpt"
|
||||
single_file_pipe = StableDiffusionInpaintPipeline.from_single_file(ckpt_path, load_safety_checker=True)
|
||||
|
||||
for param_name, param_value in single_file_pipe.text_encoder.config.to_dict().items():
|
||||
if param_name in ["torch_dtype", "architectures", "_name_or_path"]:
|
||||
continue
|
||||
assert pipe.text_encoder.config.to_dict()[param_name] == param_value
|
||||
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "architectures", "_use_default_values"]
|
||||
for param_name, param_value in single_file_pipe.unet.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert (
|
||||
pipe.unet.config[param_name] == param_value
|
||||
), f"{param_name} is differs between single file loading and pretrained loading"
|
||||
|
||||
for param_name, param_value in single_file_pipe.vae.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert (
|
||||
pipe.vae.config[param_name] == param_value
|
||||
), f"{param_name} is differs between single file loading and pretrained loading"
|
||||
|
||||
for param_name, param_value in single_file_pipe.safety_checker.config.to_dict().items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert (
|
||||
pipe.safety_checker.config.to_dict()[param_name] == param_value
|
||||
), f"{param_name} is differs between single file loading and pretrained loading"
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@@ -1082,9 +1009,6 @@ class StableDiffusionInpaintPipelineAsymmetricAutoencoderKLSlowTests(unittest.Te
|
||||
|
||||
assert image_out.shape == (512, 512, 3)
|
||||
|
||||
def test_download_ckpt_diff_format_is_same(self):
|
||||
pass
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -29,7 +29,6 @@ from diffusers.utils.testing_utils import (
|
||||
floats_tensor,
|
||||
load_image,
|
||||
load_numpy,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
@@ -492,73 +491,3 @@ class StableDiffusionUpscalePipelineIntegrationTests(unittest.TestCase):
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
# make sure that less than 2.9 GB is allocated
|
||||
assert mem_bytes < 2.9 * 10**9
|
||||
|
||||
def test_download_ckpt_diff_format_is_same(self):
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/sd2-upscale/low_res_cat.png"
|
||||
)
|
||||
|
||||
prompt = "a cat sitting on a park bench"
|
||||
model_id = "stabilityai/stable-diffusion-x4-upscaler"
|
||||
pipe = StableDiffusionUpscalePipeline.from_pretrained(model_id)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
generator = torch.Generator("cpu").manual_seed(0)
|
||||
output = pipe(prompt=prompt, image=image, generator=generator, output_type="np", num_inference_steps=3)
|
||||
image_from_pretrained = output.images[0]
|
||||
|
||||
single_file_path = (
|
||||
"https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler/blob/main/x4-upscaler-ema.safetensors"
|
||||
)
|
||||
pipe_from_single_file = StableDiffusionUpscalePipeline.from_single_file(single_file_path)
|
||||
pipe_from_single_file.enable_model_cpu_offload()
|
||||
|
||||
generator = torch.Generator("cpu").manual_seed(0)
|
||||
output_from_single_file = pipe_from_single_file(
|
||||
prompt=prompt, image=image, generator=generator, output_type="np", num_inference_steps=3
|
||||
)
|
||||
image_from_single_file = output_from_single_file.images[0]
|
||||
|
||||
assert image_from_pretrained.shape == (512, 512, 3)
|
||||
assert image_from_single_file.shape == (512, 512, 3)
|
||||
assert (
|
||||
numpy_cosine_similarity_distance(image_from_pretrained.flatten(), image_from_single_file.flatten()) < 1e-3
|
||||
)
|
||||
|
||||
def test_single_file_component_configs(self):
|
||||
pipe = StableDiffusionUpscalePipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-x4-upscaler", variant="fp16"
|
||||
)
|
||||
|
||||
ckpt_path = (
|
||||
"https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler/blob/main/x4-upscaler-ema.safetensors"
|
||||
)
|
||||
single_file_pipe = StableDiffusionUpscalePipeline.from_single_file(ckpt_path, load_safety_checker=True)
|
||||
|
||||
for param_name, param_value in single_file_pipe.text_encoder.config.to_dict().items():
|
||||
if param_name in ["torch_dtype", "architectures", "_name_or_path"]:
|
||||
continue
|
||||
assert pipe.text_encoder.config.to_dict()[param_name] == param_value
|
||||
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "architectures", "_use_default_values"]
|
||||
for param_name, param_value in single_file_pipe.unet.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert (
|
||||
pipe.unet.config[param_name] == param_value
|
||||
), f"{param_name} differs between single file loading and pretrained loading"
|
||||
|
||||
for param_name, param_value in single_file_pipe.vae.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert (
|
||||
pipe.vae.config[param_name] == param_value
|
||||
), f"{param_name} differs between single file loading and pretrained loading"
|
||||
|
||||
for param_name, param_value in single_file_pipe.safety_checker.config.to_dict().items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert (
|
||||
pipe.safety_checker.config.to_dict()[param_name] == param_value
|
||||
), f"{param_name} differs between single file loading and pretrained loading"
|
||||
|
||||
@@ -30,7 +30,6 @@ from diffusers import (
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.models.attention_processor import AttnProcessor
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
load_numpy,
|
||||
@@ -473,30 +472,6 @@ class StableDiffusion2VPredictionPipelineIntegrationTests(unittest.TestCase):
|
||||
|
||||
assert image_out.shape == (768, 768, 3)
|
||||
|
||||
def test_download_ckpt_diff_format_is_same(self):
|
||||
single_file_path = (
|
||||
"https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned.safetensors"
|
||||
)
|
||||
|
||||
pipe_single = StableDiffusionPipeline.from_single_file(single_file_path)
|
||||
pipe_single.scheduler = DDIMScheduler.from_config(pipe_single.scheduler.config)
|
||||
pipe_single.unet.set_attn_processor(AttnProcessor())
|
||||
pipe_single.enable_model_cpu_offload()
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
image_ckpt = pipe_single("a turtle", num_inference_steps=2, generator=generator, output_type="np").images[0]
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1")
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.unet.set_attn_processor(AttnProcessor())
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
image = pipe("a turtle", num_inference_steps=2, generator=generator, output_type="np").images[0]
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(image.flatten(), image_ckpt.flatten())
|
||||
assert max_diff < 1e-3
|
||||
|
||||
def test_stable_diffusion_text2img_intermediate_state_v_pred(self):
|
||||
number_of_steps = 0
|
||||
|
||||
|
||||
@@ -1046,68 +1046,3 @@ class StableDiffusionXLPipelineIntegrationTests(unittest.TestCase):
|
||||
max_diff = numpy_cosine_similarity_distance(image.flatten(), expected_image.flatten())
|
||||
|
||||
assert max_diff < 1e-2
|
||||
|
||||
def test_download_ckpt_diff_format_is_same(self):
|
||||
ckpt_path = (
|
||||
"https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors"
|
||||
)
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_single_file(ckpt_path, torch_dtype=torch.float16)
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.unet.set_default_attn_processor()
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
image_ckpt = pipe("a turtle", num_inference_steps=2, generator=generator, output_type="np").images[0]
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
||||
)
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.unet.set_default_attn_processor()
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
image = pipe("a turtle", num_inference_steps=2, generator=generator, output_type="np").images[0]
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(image.flatten(), image_ckpt.flatten())
|
||||
|
||||
assert max_diff < 6e-3
|
||||
|
||||
def test_single_file_component_configs(self):
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
||||
)
|
||||
ckpt_path = (
|
||||
"https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors"
|
||||
)
|
||||
single_file_pipe = StableDiffusionXLPipeline.from_single_file(
|
||||
ckpt_path, variant="fp16", torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
for param_name, param_value in single_file_pipe.text_encoder.config.to_dict().items():
|
||||
if param_name in ["torch_dtype", "architectures", "_name_or_path"]:
|
||||
continue
|
||||
assert pipe.text_encoder.config.to_dict()[param_name] == param_value
|
||||
|
||||
for param_name, param_value in single_file_pipe.text_encoder_2.config.to_dict().items():
|
||||
if param_name in ["torch_dtype", "architectures", "_name_or_path"]:
|
||||
continue
|
||||
assert pipe.text_encoder_2.config.to_dict()[param_name] == param_value
|
||||
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "architectures", "_use_default_values"]
|
||||
for param_name, param_value in single_file_pipe.unet.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
if param_name == "upcast_attention" and pipe.unet.config[param_name] is None:
|
||||
pipe.unet.config[param_name] = False
|
||||
assert (
|
||||
pipe.unet.config[param_name] == param_value
|
||||
), f"{param_name} is differs between single file loading and pretrained loading"
|
||||
|
||||
for param_name, param_value in single_file_pipe.vae.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert (
|
||||
pipe.vae.config[param_name] == param_value
|
||||
), f"{param_name} is differs between single file loading and pretrained loading"
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import random
|
||||
import unittest
|
||||
|
||||
@@ -32,13 +31,10 @@ from diffusers import (
|
||||
T2IAdapter,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.utils import load_image, logging
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
@@ -678,54 +674,3 @@ class StableDiffusionXLMultiAdapterPipelineFastTests(
|
||||
print(",".join(debug))
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class AdapterSDXLPipelineSlowTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_download_ckpt_diff_format_is_same(self):
|
||||
ckpt_path = (
|
||||
"https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors"
|
||||
)
|
||||
adapter = T2IAdapter.from_pretrained("TencentARC/t2i-adapter-lineart-sdxl-1.0", torch_dtype=torch.float16)
|
||||
prompt = "toy"
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/toy_canny.png"
|
||||
)
|
||||
pipe_single_file = StableDiffusionXLAdapterPipeline.from_single_file(
|
||||
ckpt_path,
|
||||
adapter=adapter,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
pipe_single_file.enable_model_cpu_offload()
|
||||
pipe_single_file.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
images_single_file = pipe_single_file(
|
||||
prompt, image=image, generator=generator, output_type="np", num_inference_steps=3
|
||||
).images
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
pipe = StableDiffusionXLAdapterPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
adapter=adapter,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images
|
||||
|
||||
assert images_single_file[0].shape == (768, 512, 3)
|
||||
assert images[0].shape == (768, 512, 3)
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(images[0].flatten(), images_single_file[0].flatten())
|
||||
assert max_diff < 5e-3
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import random
|
||||
import unittest
|
||||
|
||||
@@ -32,19 +31,15 @@ from transformers import (
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
AutoencoderTiny,
|
||||
DDIMScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
LCMScheduler,
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
@@ -781,85 +776,3 @@ class StableDiffusionXLImg2ImgRefinerOnlyPipelineFastTests(
|
||||
|
||||
def test_save_load_optional_components(self):
|
||||
self._test_save_load_optional_components()
|
||||
|
||||
|
||||
@slow
|
||||
class StableDiffusionXLImg2ImgIntegrationTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# clean up the VRAM before each test
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_download_ckpt_diff_format_is_same(self):
|
||||
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/blob/main/sd_xl_refiner_1.0.safetensors"
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
||||
"/stable_diffusion_img2img/sketch-mountains-input.png"
|
||||
)
|
||||
|
||||
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16
|
||||
)
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.unet.set_default_attn_processor()
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
image = pipe(
|
||||
prompt="mountains", image=init_image, num_inference_steps=5, generator=generator, output_type="np"
|
||||
).images[0]
|
||||
|
||||
pipe_single_file = StableDiffusionXLImg2ImgPipeline.from_single_file(ckpt_path, torch_dtype=torch.float16)
|
||||
pipe_single_file.scheduler = DDIMScheduler.from_config(pipe_single_file.scheduler.config)
|
||||
pipe_single_file.unet.set_default_attn_processor()
|
||||
pipe_single_file.enable_model_cpu_offload()
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
image_single_file = pipe_single_file(
|
||||
prompt="mountains", image=init_image, num_inference_steps=5, generator=generator, output_type="np"
|
||||
).images[0]
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(image.flatten(), image_single_file.flatten())
|
||||
|
||||
assert max_diff < 5e-2
|
||||
|
||||
def test_single_file_component_configs(self):
|
||||
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-refiner-1.0",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
)
|
||||
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/blob/main/sd_xl_refiner_1.0.safetensors"
|
||||
single_file_pipe = StableDiffusionXLImg2ImgPipeline.from_single_file(ckpt_path, torch_dtype=torch.float16)
|
||||
|
||||
assert pipe.text_encoder is None
|
||||
assert single_file_pipe.text_encoder is None
|
||||
|
||||
for param_name, param_value in single_file_pipe.text_encoder_2.config.to_dict().items():
|
||||
if param_name in ["torch_dtype", "architectures", "_name_or_path"]:
|
||||
continue
|
||||
assert pipe.text_encoder_2.config.to_dict()[param_name] == param_value
|
||||
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "architectures", "_use_default_values"]
|
||||
for param_name, param_value in single_file_pipe.unet.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
if param_name == "upcast_attention" and pipe.unet.config[param_name] is None:
|
||||
pipe.unet.config[param_name] = False
|
||||
assert (
|
||||
pipe.unet.config[param_name] == param_value
|
||||
), f"{param_name} is differs between single file loading and pretrained loading"
|
||||
|
||||
for param_name, param_value in single_file_pipe.vae.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert (
|
||||
pipe.vae.config[param_name] == param_value
|
||||
), f"{param_name} differs between single file loading and pretrained loading"
|
||||
|
||||
0
tests/single_file/__init__.py
Normal file
0
tests/single_file/__init__.py
Normal file
380
tests/single_file/single_file_testing_utils.py
Normal file
380
tests/single_file/single_file_testing_utils.py
Normal file
@@ -0,0 +1,380 @@
|
||||
import tempfile
|
||||
from io import BytesIO
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
|
||||
from diffusers.models.attention_processor import AttnProcessor
|
||||
from diffusers.utils.testing_utils import (
|
||||
numpy_cosine_similarity_distance,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
def download_single_file_checkpoint(repo_id, filename, tmpdir):
|
||||
path = hf_hub_download(repo_id, filename=filename, local_dir=tmpdir)
|
||||
return path
|
||||
|
||||
|
||||
def download_original_config(config_url, tmpdir):
|
||||
original_config_file = BytesIO(requests.get(config_url).content)
|
||||
path = f"{tmpdir}/config.yaml"
|
||||
with open(path, "wb") as f:
|
||||
f.write(original_config_file.read())
|
||||
|
||||
return path
|
||||
|
||||
|
||||
def download_diffusers_config(repo_id, tmpdir):
|
||||
path = snapshot_download(
|
||||
repo_id,
|
||||
ignore_patterns=[
|
||||
"**/*.ckpt",
|
||||
"*.ckpt",
|
||||
"**/*.bin",
|
||||
"*.bin",
|
||||
"**/*.pt",
|
||||
"*.pt",
|
||||
"**/*.safetensors",
|
||||
"*.safetensors",
|
||||
],
|
||||
allow_patterns=["**/*.json", "*.json", "*.txt", "**/*.txt"],
|
||||
local_dir=tmpdir,
|
||||
)
|
||||
return path
|
||||
|
||||
|
||||
class SDSingleFileTesterMixin:
|
||||
def _compare_component_configs(self, pipe, single_file_pipe):
|
||||
for param_name, param_value in single_file_pipe.text_encoder.config.to_dict().items():
|
||||
if param_name in ["torch_dtype", "architectures", "_name_or_path"]:
|
||||
continue
|
||||
assert pipe.text_encoder.config.to_dict()[param_name] == param_value
|
||||
|
||||
PARAMS_TO_IGNORE = [
|
||||
"torch_dtype",
|
||||
"_name_or_path",
|
||||
"architectures",
|
||||
"_use_default_values",
|
||||
"_diffusers_version",
|
||||
]
|
||||
for component_name, component in single_file_pipe.components.items():
|
||||
if component_name in single_file_pipe._optional_components:
|
||||
continue
|
||||
|
||||
# skip testing transformer based components here
|
||||
# skip text encoders / safety checkers since they have already been tested
|
||||
if component_name in ["text_encoder", "tokenizer", "safety_checker", "feature_extractor"]:
|
||||
continue
|
||||
|
||||
assert component_name in pipe.components, f"single file {component_name} not found in pretrained pipeline"
|
||||
assert isinstance(
|
||||
component, pipe.components[component_name].__class__
|
||||
), f"single file {component.__class__.__name__} and pretrained {pipe.components[component_name].__class__.__name__} are not the same"
|
||||
|
||||
for param_name, param_value in component.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
|
||||
# Some pretrained configs will set upcast attention to None
|
||||
# In single file loading it defaults to the value in the class __init__ which is False
|
||||
if param_name == "upcast_attention" and pipe.components[component_name].config[param_name] is None:
|
||||
pipe.components[component_name].config[param_name] = param_value
|
||||
|
||||
assert (
|
||||
pipe.components[component_name].config[param_name] == param_value
|
||||
), f"single file {param_name}: {param_value} differs from pretrained {pipe.components[component_name].config[param_name]}"
|
||||
|
||||
def test_single_file_components(self, pipe=None, single_file_pipe=None):
|
||||
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
|
||||
self.ckpt_path, safety_checker=None
|
||||
)
|
||||
pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None)
|
||||
|
||||
self._compare_component_configs(pipe, single_file_pipe)
|
||||
|
||||
def test_single_file_components_local_files_only(self, pipe=None, single_file_pipe=None):
|
||||
pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
ckpt_filename = self.ckpt_path.split("/")[-1]
|
||||
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
|
||||
|
||||
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
|
||||
local_ckpt_path, safety_checker=None, local_files_only=True
|
||||
)
|
||||
|
||||
self._compare_component_configs(pipe, single_file_pipe)
|
||||
|
||||
def test_single_file_components_with_original_config(
|
||||
self,
|
||||
pipe=None,
|
||||
single_file_pipe=None,
|
||||
):
|
||||
pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None)
|
||||
# Not possible to infer this value when original config is provided
|
||||
# we just pass it in here otherwise this test will fail
|
||||
upcast_attention = pipe.unet.config.upcast_attention
|
||||
|
||||
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
|
||||
self.ckpt_path,
|
||||
original_config=self.original_config,
|
||||
safety_checker=None,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
|
||||
self._compare_component_configs(pipe, single_file_pipe)
|
||||
|
||||
def test_single_file_components_with_original_config_local_files_only(
|
||||
self,
|
||||
pipe=None,
|
||||
single_file_pipe=None,
|
||||
):
|
||||
pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None)
|
||||
|
||||
# Not possible to infer this value when original config is provided
|
||||
# we just pass it in here otherwise this test will fail
|
||||
upcast_attention = pipe.unet.config.upcast_attention
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
ckpt_filename = self.ckpt_path.split("/")[-1]
|
||||
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
|
||||
local_original_config = download_original_config(self.original_config, tmpdir)
|
||||
|
||||
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
|
||||
local_ckpt_path,
|
||||
original_config=local_original_config,
|
||||
safety_checker=None,
|
||||
upcast_attention=upcast_attention,
|
||||
local_files_only=True,
|
||||
)
|
||||
|
||||
self._compare_component_configs(pipe, single_file_pipe)
|
||||
|
||||
def test_single_file_format_inference_is_same_as_pretrained(self, expected_max_diff=1e-4):
|
||||
sf_pipe = self.pipeline_class.from_single_file(self.ckpt_path, safety_checker=None)
|
||||
sf_pipe.unet.set_attn_processor(AttnProcessor())
|
||||
sf_pipe.enable_model_cpu_offload()
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
image_single_file = sf_pipe(**inputs).images[0]
|
||||
|
||||
pipe = self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None)
|
||||
pipe.unet.set_attn_processor(AttnProcessor())
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
image = pipe(**inputs).images[0]
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(image.flatten(), image_single_file.flatten())
|
||||
|
||||
assert max_diff < expected_max_diff
|
||||
|
||||
def test_single_file_components_with_diffusers_config(
|
||||
self,
|
||||
pipe=None,
|
||||
single_file_pipe=None,
|
||||
):
|
||||
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
|
||||
self.ckpt_path, config=self.repo_id, safety_checker=None
|
||||
)
|
||||
pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None)
|
||||
|
||||
self._compare_component_configs(pipe, single_file_pipe)
|
||||
|
||||
def test_single_file_components_with_diffusers_config_local_files_only(
|
||||
self,
|
||||
pipe=None,
|
||||
single_file_pipe=None,
|
||||
):
|
||||
pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
ckpt_filename = self.ckpt_path.split("/")[-1]
|
||||
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
|
||||
local_diffusers_config = download_diffusers_config(self.repo_id, tmpdir)
|
||||
|
||||
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
|
||||
local_ckpt_path, config=local_diffusers_config, safety_checker=None, local_files_only=True
|
||||
)
|
||||
|
||||
self._compare_component_configs(pipe, single_file_pipe)
|
||||
|
||||
|
||||
class SDXLSingleFileTesterMixin:
|
||||
def _compare_component_configs(self, pipe, single_file_pipe):
|
||||
# Skip testing the text_encoder for Refiner Pipelines
|
||||
if pipe.text_encoder:
|
||||
for param_name, param_value in single_file_pipe.text_encoder.config.to_dict().items():
|
||||
if param_name in ["torch_dtype", "architectures", "_name_or_path"]:
|
||||
continue
|
||||
assert pipe.text_encoder.config.to_dict()[param_name] == param_value
|
||||
|
||||
for param_name, param_value in single_file_pipe.text_encoder_2.config.to_dict().items():
|
||||
if param_name in ["torch_dtype", "architectures", "_name_or_path"]:
|
||||
continue
|
||||
assert pipe.text_encoder_2.config.to_dict()[param_name] == param_value
|
||||
|
||||
PARAMS_TO_IGNORE = [
|
||||
"torch_dtype",
|
||||
"_name_or_path",
|
||||
"architectures",
|
||||
"_use_default_values",
|
||||
"_diffusers_version",
|
||||
]
|
||||
for component_name, component in single_file_pipe.components.items():
|
||||
if component_name in single_file_pipe._optional_components:
|
||||
continue
|
||||
|
||||
# skip text encoders since they have already been tested
|
||||
if component_name in ["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"]:
|
||||
continue
|
||||
|
||||
# skip safety checker if it is not present in the pipeline
|
||||
if component_name in ["safety_checker", "feature_extractor"]:
|
||||
continue
|
||||
|
||||
assert component_name in pipe.components, f"single file {component_name} not found in pretrained pipeline"
|
||||
assert isinstance(
|
||||
component, pipe.components[component_name].__class__
|
||||
), f"single file {component.__class__.__name__} and pretrained {pipe.components[component_name].__class__.__name__} are not the same"
|
||||
|
||||
for param_name, param_value in component.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
|
||||
# Some pretrained configs will set upcast attention to None
|
||||
# In single file loading it defaults to the value in the class __init__ which is False
|
||||
if param_name == "upcast_attention" and pipe.components[component_name].config[param_name] is None:
|
||||
pipe.components[component_name].config[param_name] = param_value
|
||||
|
||||
assert (
|
||||
pipe.components[component_name].config[param_name] == param_value
|
||||
), f"single file {param_name}: {param_value} differs from pretrained {pipe.components[component_name].config[param_name]}"
|
||||
|
||||
def test_single_file_components(self, pipe=None, single_file_pipe=None):
|
||||
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
|
||||
self.ckpt_path, safety_checker=None
|
||||
)
|
||||
pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None)
|
||||
|
||||
self._compare_component_configs(
|
||||
pipe,
|
||||
single_file_pipe,
|
||||
)
|
||||
|
||||
def test_single_file_components_local_files_only(
|
||||
self,
|
||||
pipe=None,
|
||||
single_file_pipe=None,
|
||||
):
|
||||
pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
ckpt_filename = self.ckpt_path.split("/")[-1]
|
||||
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
|
||||
|
||||
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
|
||||
local_ckpt_path, safety_checker=None, local_files_only=True
|
||||
)
|
||||
|
||||
self._compare_component_configs(pipe, single_file_pipe)
|
||||
|
||||
def test_single_file_components_with_original_config(
|
||||
self,
|
||||
pipe=None,
|
||||
single_file_pipe=None,
|
||||
):
|
||||
pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None)
|
||||
# Not possible to infer this value when original config is provided
|
||||
# we just pass it in here otherwise this test will fail
|
||||
upcast_attention = pipe.unet.config.upcast_attention
|
||||
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
|
||||
self.ckpt_path,
|
||||
original_config=self.original_config,
|
||||
safety_checker=None,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
|
||||
self._compare_component_configs(
|
||||
pipe,
|
||||
single_file_pipe,
|
||||
)
|
||||
|
||||
def test_single_file_components_with_original_config_local_files_only(
|
||||
self,
|
||||
pipe=None,
|
||||
single_file_pipe=None,
|
||||
):
|
||||
pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None)
|
||||
# Not possible to infer this value when original config is provided
|
||||
# we just pass it in here otherwise this test will fail
|
||||
upcast_attention = pipe.unet.config.upcast_attention
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
ckpt_filename = self.ckpt_path.split("/")[-1]
|
||||
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
|
||||
local_original_config = download_original_config(self.original_config, tmpdir)
|
||||
|
||||
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
|
||||
local_ckpt_path,
|
||||
original_config=local_original_config,
|
||||
upcast_attention=upcast_attention,
|
||||
safety_checker=None,
|
||||
local_files_only=True,
|
||||
)
|
||||
|
||||
self._compare_component_configs(
|
||||
pipe,
|
||||
single_file_pipe,
|
||||
)
|
||||
|
||||
def test_single_file_components_with_diffusers_config(
|
||||
self,
|
||||
pipe=None,
|
||||
single_file_pipe=None,
|
||||
):
|
||||
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
|
||||
self.ckpt_path, config=self.repo_id, safety_checker=None
|
||||
)
|
||||
pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None)
|
||||
|
||||
self._compare_component_configs(pipe, single_file_pipe)
|
||||
|
||||
def test_single_file_components_with_diffusers_config_local_files_only(
|
||||
self,
|
||||
pipe=None,
|
||||
single_file_pipe=None,
|
||||
):
|
||||
pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
ckpt_filename = self.ckpt_path.split("/")[-1]
|
||||
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
|
||||
local_diffusers_config = download_diffusers_config(self.repo_id, tmpdir)
|
||||
|
||||
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
|
||||
local_ckpt_path, config=local_diffusers_config, safety_checker=None, local_files_only=True
|
||||
)
|
||||
|
||||
self._compare_component_configs(pipe, single_file_pipe)
|
||||
|
||||
def test_single_file_format_inference_is_same_as_pretrained(self, expected_max_diff=1e-4):
|
||||
sf_pipe = self.pipeline_class.from_single_file(self.ckpt_path, torch_dtype=torch.float16, safety_checker=None)
|
||||
sf_pipe.unet.set_default_attn_processor()
|
||||
sf_pipe.enable_model_cpu_offload()
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
image_single_file = sf_pipe(**inputs).images[0]
|
||||
|
||||
pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.float16, safety_checker=None)
|
||||
pipe.unet.set_default_attn_processor()
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
image = pipe(**inputs).images[0]
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(image.flatten(), image_single_file.flatten())
|
||||
|
||||
assert max_diff < expected_max_diff
|
||||
78
tests/single_file/test_model_controlnet_single_file.py
Normal file
78
tests/single_file/test_model_controlnet_single_file.py
Normal file
@@ -0,0 +1,78 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import (
|
||||
ControlNetModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class ControlNetModelSingleFileTests(unittest.TestCase):
|
||||
model_class = ControlNetModel
|
||||
ckpt_path = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth"
|
||||
repo_id = "lllyasviel/control_v11p_sd15_canny"
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_single_file_components(self):
|
||||
model = self.model_class.from_pretrained(self.repo_id)
|
||||
model_single_file = self.model_class.from_single_file(self.ckpt_path)
|
||||
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
|
||||
for param_name, param_value in model_single_file.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert (
|
||||
model.config[param_name] == param_value
|
||||
), f"{param_name} differs between single file loading and pretrained loading"
|
||||
|
||||
def test_single_file_arguments(self):
|
||||
model_default = self.model_class.from_single_file(self.ckpt_path)
|
||||
|
||||
assert model_default.config.upcast_attention is False
|
||||
assert model_default.dtype == torch.float32
|
||||
|
||||
torch_dtype = torch.float16
|
||||
upcast_attention = True
|
||||
|
||||
model = self.model_class.from_single_file(
|
||||
self.ckpt_path,
|
||||
upcast_attention=upcast_attention,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
assert model.config.upcast_attention == upcast_attention
|
||||
assert model.dtype == torch_dtype
|
||||
114
tests/single_file/test_model_sd_cascade_unet_single_file.py
Normal file
114
tests/single_file/test_model_sd_cascade_unet_single_file.py
Normal file
@@ -0,0 +1,114 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import StableCascadeUNet
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class StableCascadeUNetSingleFileTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_single_file_components_stage_b(self):
|
||||
model_single_file = StableCascadeUNet.from_single_file(
|
||||
"https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_b_bf16.safetensors",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
model = StableCascadeUNet.from_pretrained(
|
||||
"stabilityai/stable-cascade", variant="bf16", subfolder="decoder", use_safetensors=True
|
||||
)
|
||||
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
|
||||
for param_name, param_value in model_single_file.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert (
|
||||
model.config[param_name] == param_value
|
||||
), f"{param_name} differs between single file loading and pretrained loading"
|
||||
|
||||
def test_single_file_components_stage_b_lite(self):
|
||||
model_single_file = StableCascadeUNet.from_single_file(
|
||||
"https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_b_lite_bf16.safetensors",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
model = StableCascadeUNet.from_pretrained(
|
||||
"stabilityai/stable-cascade", variant="bf16", subfolder="decoder_lite"
|
||||
)
|
||||
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
|
||||
for param_name, param_value in model_single_file.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert (
|
||||
model.config[param_name] == param_value
|
||||
), f"{param_name} differs between single file loading and pretrained loading"
|
||||
|
||||
def test_single_file_components_stage_c(self):
|
||||
model_single_file = StableCascadeUNet.from_single_file(
|
||||
"https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_c_bf16.safetensors",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
model = StableCascadeUNet.from_pretrained(
|
||||
"stabilityai/stable-cascade-prior", variant="bf16", subfolder="prior"
|
||||
)
|
||||
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
|
||||
for param_name, param_value in model_single_file.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert (
|
||||
model.config[param_name] == param_value
|
||||
), f"{param_name} differs between single file loading and pretrained loading"
|
||||
|
||||
def test_single_file_components_stage_c_lite(self):
|
||||
model_single_file = StableCascadeUNet.from_single_file(
|
||||
"https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_c_lite_bf16.safetensors",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
model = StableCascadeUNet.from_pretrained(
|
||||
"stabilityai/stable-cascade-prior", variant="bf16", subfolder="prior_lite"
|
||||
)
|
||||
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
|
||||
for param_name, param_value in model_single_file.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert (
|
||||
model.config[param_name] == param_value
|
||||
), f"{param_name} differs between single file loading and pretrained loading"
|
||||
117
tests/single_file/test_model_vae_single_file.py
Normal file
117
tests/single_file/test_model_vae_single_file.py
Normal file
@@ -0,0 +1,117 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
load_hf_numpy,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class AutoencoderKLSingleFileTests(unittest.TestCase):
|
||||
model_class = AutoencoderKL
|
||||
ckpt_path = (
|
||||
"https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors"
|
||||
)
|
||||
repo_id = "stabilityai/sd-vae-ft-mse"
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_file_format(self, seed, shape):
|
||||
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
|
||||
|
||||
def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
|
||||
dtype = torch.float16 if fp16 else torch.float32
|
||||
image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
|
||||
return image
|
||||
|
||||
def test_single_file_inference_same_as_pretrained(self):
|
||||
model_1 = self.model_class.from_pretrained(self.repo_id).to(torch_device)
|
||||
model_2 = self.model_class.from_single_file(self.ckpt_path, config=self.repo_id).to(torch_device)
|
||||
|
||||
image = self.get_sd_image(33)
|
||||
|
||||
generator = torch.Generator(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
sample_1 = model_1(image, generator=generator.manual_seed(0)).sample
|
||||
sample_2 = model_2(image, generator=generator.manual_seed(0)).sample
|
||||
|
||||
assert sample_1.shape == sample_2.shape
|
||||
|
||||
output_slice_1 = sample_1.flatten().float().cpu()
|
||||
output_slice_2 = sample_2.flatten().float().cpu()
|
||||
|
||||
assert numpy_cosine_similarity_distance(output_slice_1, output_slice_2) < 1e-4
|
||||
|
||||
def test_single_file_components(self):
|
||||
model = self.model_class.from_pretrained(self.repo_id)
|
||||
model_single_file = self.model_class.from_single_file(self.ckpt_path, config=self.repo_id)
|
||||
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
|
||||
for param_name, param_value in model_single_file.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert (
|
||||
model.config[param_name] == param_value
|
||||
), f"{param_name} differs between pretrained loading and single file loading"
|
||||
|
||||
def test_single_file_arguments(self):
|
||||
model_default = self.model_class.from_single_file(self.ckpt_path, config=self.repo_id)
|
||||
|
||||
assert model_default.config.scaling_factor == 0.18215
|
||||
assert model_default.config.sample_size == 256
|
||||
assert model_default.dtype == torch.float32
|
||||
|
||||
scaling_factor = 2.0
|
||||
sample_size = 512
|
||||
torch_dtype = torch.float16
|
||||
|
||||
model = self.model_class.from_single_file(
|
||||
self.ckpt_path,
|
||||
config=self.repo_id,
|
||||
sample_size=sample_size,
|
||||
scaling_factor=scaling_factor,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
assert model.config.scaling_factor == scaling_factor
|
||||
assert model.config.sample_size == sample_size
|
||||
assert model.dtype == torch_dtype
|
||||
@@ -0,0 +1,182 @@
|
||||
import gc
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import ControlNetModel, StableDiffusionControlNetPipeline
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from .single_file_testing_utils import (
|
||||
SDSingleFileTesterMixin,
|
||||
download_diffusers_config,
|
||||
download_original_config,
|
||||
download_single_file_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
|
||||
pipeline_class = StableDiffusionControlNetPipeline
|
||||
ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors"
|
||||
original_config = (
|
||||
"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
|
||||
)
|
||||
repo_id = "runwayml/stable-diffusion-v1-5"
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
|
||||
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
||||
"/stable_diffusion_img2img/sketch-mountains-input.png"
|
||||
)
|
||||
control_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
|
||||
).resize((512, 512))
|
||||
prompt = "bird"
|
||||
|
||||
inputs = {
|
||||
"prompt": prompt,
|
||||
"image": init_image,
|
||||
"control_image": control_image,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 3,
|
||||
"strength": 0.75,
|
||||
"guidance_scale": 7.5,
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_single_file_format_inference_is_same_as_pretrained(self):
|
||||
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
|
||||
pipe = self.pipeline_class.from_pretrained(self.repo_id, controlnet=controlnet)
|
||||
pipe.unet.set_default_attn_processor()
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
pipe_sf = self.pipeline_class.from_single_file(
|
||||
self.ckpt_path,
|
||||
controlnet=controlnet,
|
||||
)
|
||||
pipe_sf.unet.set_default_attn_processor()
|
||||
pipe_sf.enable_model_cpu_offload()
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
output = pipe(**inputs).images[0]
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
output_sf = pipe_sf(**inputs).images[0]
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(output_sf.flatten(), output.flatten())
|
||||
assert max_diff < 1e-3
|
||||
|
||||
def test_single_file_components(self):
|
||||
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
|
||||
pipe = self.pipeline_class.from_pretrained(
|
||||
self.repo_id, variant="fp16", safety_checker=None, controlnet=controlnet
|
||||
)
|
||||
pipe_single_file = self.pipeline_class.from_single_file(
|
||||
self.ckpt_path,
|
||||
safety_checker=None,
|
||||
controlnet=controlnet,
|
||||
)
|
||||
|
||||
super()._compare_component_configs(pipe, pipe_single_file)
|
||||
|
||||
def test_single_file_components_local_files_only(self):
|
||||
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
|
||||
pipe = self.pipeline_class.from_pretrained(self.repo_id, controlnet=controlnet)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
ckpt_filename = self.ckpt_path.split("/")[-1]
|
||||
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
|
||||
|
||||
pipe_single_file = self.pipeline_class.from_single_file(
|
||||
local_ckpt_path, controlnet=controlnet, safety_checker=None, local_files_only=True
|
||||
)
|
||||
|
||||
super()._compare_component_configs(pipe, pipe_single_file)
|
||||
|
||||
def test_single_file_components_with_original_config(self):
|
||||
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny", variant="fp16")
|
||||
pipe = self.pipeline_class.from_pretrained(self.repo_id, controlnet=controlnet)
|
||||
pipe_single_file = self.pipeline_class.from_single_file(
|
||||
self.ckpt_path, controlnet=controlnet, safety_checker=None, original_config=self.original_config
|
||||
)
|
||||
|
||||
super()._compare_component_configs(pipe, pipe_single_file)
|
||||
|
||||
def test_single_file_components_with_original_config_local_files_only(self):
|
||||
controlnet = ControlNetModel.from_pretrained(
|
||||
"lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16, variant="fp16"
|
||||
)
|
||||
pipe = self.pipeline_class.from_pretrained(
|
||||
self.repo_id,
|
||||
controlnet=controlnet,
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
ckpt_filename = self.ckpt_path.split("/")[-1]
|
||||
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
|
||||
local_original_config = download_original_config(self.original_config, tmpdir)
|
||||
|
||||
pipe_single_file = self.pipeline_class.from_single_file(
|
||||
local_ckpt_path,
|
||||
original_config=local_original_config,
|
||||
controlnet=controlnet,
|
||||
safety_checker=None,
|
||||
local_files_only=True,
|
||||
)
|
||||
super()._compare_component_configs(pipe, pipe_single_file)
|
||||
|
||||
def test_single_file_components_with_diffusers_config(self):
|
||||
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny", variant="fp16")
|
||||
pipe = self.pipeline_class.from_pretrained(self.repo_id, controlnet=controlnet)
|
||||
pipe_single_file = self.pipeline_class.from_single_file(
|
||||
self.ckpt_path, controlnet=controlnet, safety_checker=None, original_config=self.original_config
|
||||
)
|
||||
|
||||
super()._compare_component_configs(pipe, pipe_single_file)
|
||||
|
||||
def test_single_file_components_with_diffusers_config_local_files_only(self):
|
||||
controlnet = ControlNetModel.from_pretrained(
|
||||
"lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16, variant="fp16"
|
||||
)
|
||||
pipe = self.pipeline_class.from_pretrained(
|
||||
self.repo_id,
|
||||
controlnet=controlnet,
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
ckpt_filename = self.ckpt_path.split("/")[-1]
|
||||
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
|
||||
local_diffusers_config = download_diffusers_config(self.repo_id, tmpdir)
|
||||
|
||||
pipe_single_file = self.pipeline_class.from_single_file(
|
||||
local_ckpt_path,
|
||||
config=local_diffusers_config,
|
||||
safety_checker=None,
|
||||
controlnet=controlnet,
|
||||
local_files_only=True,
|
||||
)
|
||||
super()._compare_component_configs(pipe, pipe_single_file)
|
||||
@@ -0,0 +1,183 @@
|
||||
import gc
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import ControlNetModel, StableDiffusionControlNetInpaintPipeline
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
)
|
||||
|
||||
from .single_file_testing_utils import (
|
||||
SDSingleFileTesterMixin,
|
||||
download_diffusers_config,
|
||||
download_original_config,
|
||||
download_single_file_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class StableDiffusionControlNetInpaintPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
|
||||
pipeline_class = StableDiffusionControlNetInpaintPipeline
|
||||
ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-inpainting/blob/main/sd-v1-5-inpainting.ckpt"
|
||||
original_config = "https://raw.githubusercontent.com/runwayml/stable-diffusion/main/configs/stable-diffusion/v1-inpainting-inference.yaml"
|
||||
repo_id = "runwayml/stable-diffusion-inpainting"
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_inputs(self):
|
||||
control_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
|
||||
).resize((512, 512))
|
||||
image = load_image(
|
||||
"https://huggingface.co/lllyasviel/sd-controlnet-canny/resolve/main/images/bird.png"
|
||||
).resize((512, 512))
|
||||
mask_image = load_image(
|
||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
||||
"/stable_diffusion_inpaint/input_bench_mask.png"
|
||||
).resize((512, 512))
|
||||
|
||||
inputs = {
|
||||
"prompt": "bird",
|
||||
"image": image,
|
||||
"control_image": control_image,
|
||||
"mask_image": mask_image,
|
||||
"generator": torch.Generator(device="cpu").manual_seed(0),
|
||||
"num_inference_steps": 3,
|
||||
"output_type": "np",
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_single_file_format_inference_is_same_as_pretrained(self):
|
||||
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
|
||||
pipe = self.pipeline_class.from_pretrained(self.repo_id, controlnet=controlnet, safety_checker=None)
|
||||
pipe.unet.set_default_attn_processor()
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
pipe_sf = self.pipeline_class.from_single_file(self.ckpt_path, controlnet=controlnet, safety_checker=None)
|
||||
pipe_sf.unet.set_default_attn_processor()
|
||||
pipe_sf.enable_model_cpu_offload()
|
||||
|
||||
inputs = self.get_inputs()
|
||||
output = pipe(**inputs).images[0]
|
||||
|
||||
inputs = self.get_inputs()
|
||||
output_sf = pipe_sf(**inputs).images[0]
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(output_sf.flatten(), output.flatten())
|
||||
assert max_diff < 1e-3
|
||||
|
||||
def test_single_file_components(self):
|
||||
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
|
||||
pipe = self.pipeline_class.from_pretrained(
|
||||
self.repo_id, variant="fp16", safety_checker=None, controlnet=controlnet
|
||||
)
|
||||
pipe_single_file = self.pipeline_class.from_single_file(
|
||||
self.ckpt_path,
|
||||
safety_checker=None,
|
||||
controlnet=controlnet,
|
||||
)
|
||||
|
||||
super()._compare_component_configs(pipe, pipe_single_file)
|
||||
|
||||
def test_single_file_components_local_files_only(self):
|
||||
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
|
||||
pipe = self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None, controlnet=controlnet)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
ckpt_filename = self.ckpt_path.split("/")[-1]
|
||||
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
|
||||
|
||||
pipe_single_file = self.pipeline_class.from_single_file(
|
||||
local_ckpt_path, controlnet=controlnet, safety_checker=None, local_files_only=True
|
||||
)
|
||||
|
||||
super()._compare_component_configs(pipe, pipe_single_file)
|
||||
|
||||
def test_single_file_components_with_original_config(self):
|
||||
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny", variant="fp16")
|
||||
pipe = self.pipeline_class.from_pretrained(self.repo_id, controlnet=controlnet)
|
||||
pipe_single_file = self.pipeline_class.from_single_file(
|
||||
self.ckpt_path, controlnet=controlnet, original_config=self.original_config
|
||||
)
|
||||
|
||||
super()._compare_component_configs(pipe, pipe_single_file)
|
||||
|
||||
def test_single_file_components_with_original_config_local_files_only(self):
|
||||
controlnet = ControlNetModel.from_pretrained(
|
||||
"lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16, variant="fp16"
|
||||
)
|
||||
pipe = self.pipeline_class.from_pretrained(
|
||||
self.repo_id,
|
||||
controlnet=controlnet,
|
||||
safety_checker=None,
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
ckpt_filename = self.ckpt_path.split("/")[-1]
|
||||
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
|
||||
local_original_config = download_original_config(self.original_config, tmpdir)
|
||||
|
||||
pipe_single_file = self.pipeline_class.from_single_file(
|
||||
local_ckpt_path,
|
||||
original_config=local_original_config,
|
||||
controlnet=controlnet,
|
||||
safety_checker=None,
|
||||
local_files_only=True,
|
||||
)
|
||||
super()._compare_component_configs(pipe, pipe_single_file)
|
||||
|
||||
def test_single_file_components_with_diffusers_config(self):
|
||||
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny", variant="fp16")
|
||||
pipe = self.pipeline_class.from_pretrained(self.repo_id, controlnet=controlnet)
|
||||
pipe_single_file = self.pipeline_class.from_single_file(
|
||||
self.ckpt_path,
|
||||
controlnet=controlnet,
|
||||
config=self.repo_id,
|
||||
)
|
||||
|
||||
super()._compare_component_configs(pipe, pipe_single_file)
|
||||
|
||||
def test_single_file_components_with_diffusers_config_local_files_only(self):
|
||||
controlnet = ControlNetModel.from_pretrained(
|
||||
"lllyasviel/control_v11p_sd15_canny",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
)
|
||||
pipe = self.pipeline_class.from_pretrained(
|
||||
self.repo_id,
|
||||
controlnet=controlnet,
|
||||
safety_checker=None,
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
ckpt_filename = self.ckpt_path.split("/")[-1]
|
||||
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
|
||||
local_diffusers_config = download_diffusers_config(self.repo_id, tmpdir)
|
||||
|
||||
pipe_single_file = self.pipeline_class.from_single_file(
|
||||
local_ckpt_path,
|
||||
config=local_diffusers_config,
|
||||
controlnet=controlnet,
|
||||
safety_checker=None,
|
||||
local_files_only=True,
|
||||
)
|
||||
super()._compare_component_configs(pipe, pipe_single_file)
|
||||
@@ -0,0 +1,171 @@
|
||||
import gc
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import ControlNetModel, StableDiffusionControlNetPipeline
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
)
|
||||
|
||||
from .single_file_testing_utils import (
|
||||
SDSingleFileTesterMixin,
|
||||
download_diffusers_config,
|
||||
download_original_config,
|
||||
download_single_file_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
|
||||
pipeline_class = StableDiffusionControlNetPipeline
|
||||
ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors"
|
||||
original_config = (
|
||||
"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
|
||||
)
|
||||
repo_id = "runwayml/stable-diffusion-v1-5"
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_inputs(self):
|
||||
control_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
|
||||
).resize((512, 512))
|
||||
inputs = {
|
||||
"prompt": "bird",
|
||||
"image": control_image,
|
||||
"generator": torch.Generator(device="cpu").manual_seed(0),
|
||||
"num_inference_steps": 3,
|
||||
"output_type": "np",
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_single_file_format_inference_is_same_as_pretrained(self):
|
||||
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
|
||||
pipe = self.pipeline_class.from_pretrained(self.repo_id, controlnet=controlnet)
|
||||
pipe.unet.set_default_attn_processor()
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
pipe_sf = self.pipeline_class.from_single_file(
|
||||
self.ckpt_path,
|
||||
controlnet=controlnet,
|
||||
)
|
||||
pipe_sf.unet.set_default_attn_processor()
|
||||
pipe_sf.enable_model_cpu_offload()
|
||||
|
||||
inputs = self.get_inputs()
|
||||
output = pipe(**inputs).images[0]
|
||||
|
||||
inputs = self.get_inputs()
|
||||
output_sf = pipe_sf(**inputs).images[0]
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(output_sf.flatten(), output.flatten())
|
||||
assert max_diff < 1e-3
|
||||
|
||||
def test_single_file_components(self):
|
||||
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
|
||||
pipe = self.pipeline_class.from_pretrained(
|
||||
self.repo_id, variant="fp16", safety_checker=None, controlnet=controlnet
|
||||
)
|
||||
pipe_single_file = self.pipeline_class.from_single_file(
|
||||
self.ckpt_path,
|
||||
safety_checker=None,
|
||||
controlnet=controlnet,
|
||||
)
|
||||
|
||||
super()._compare_component_configs(pipe, pipe_single_file)
|
||||
|
||||
def test_single_file_components_local_files_only(self):
|
||||
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
|
||||
pipe = self.pipeline_class.from_pretrained(self.repo_id, controlnet=controlnet)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
ckpt_filename = self.ckpt_path.split("/")[-1]
|
||||
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
|
||||
|
||||
pipe_single_file = self.pipeline_class.from_single_file(
|
||||
local_ckpt_path, controlnet=controlnet, local_files_only=True
|
||||
)
|
||||
|
||||
super()._compare_component_configs(pipe, pipe_single_file)
|
||||
|
||||
def test_single_file_components_with_original_config(self):
|
||||
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny", variant="fp16")
|
||||
pipe = self.pipeline_class.from_pretrained(self.repo_id, controlnet=controlnet)
|
||||
pipe_single_file = self.pipeline_class.from_single_file(
|
||||
self.ckpt_path, controlnet=controlnet, original_config=self.original_config
|
||||
)
|
||||
|
||||
super()._compare_component_configs(pipe, pipe_single_file)
|
||||
|
||||
def test_single_file_components_with_original_config_local_files_only(self):
|
||||
controlnet = ControlNetModel.from_pretrained(
|
||||
"lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16, variant="fp16"
|
||||
)
|
||||
pipe = self.pipeline_class.from_pretrained(
|
||||
self.repo_id,
|
||||
controlnet=controlnet,
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
ckpt_filename = self.ckpt_path.split("/")[-1]
|
||||
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
|
||||
local_original_config = download_original_config(self.original_config, tmpdir)
|
||||
|
||||
pipe_single_file = self.pipeline_class.from_single_file(
|
||||
local_ckpt_path, original_config=local_original_config, controlnet=controlnet, local_files_only=True
|
||||
)
|
||||
pipe_single_file.scheduler = pipe.scheduler
|
||||
|
||||
super()._compare_component_configs(pipe, pipe_single_file)
|
||||
|
||||
def test_single_file_components_with_diffusers_config(self):
|
||||
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny", variant="fp16")
|
||||
pipe = self.pipeline_class.from_pretrained(self.repo_id, controlnet=controlnet)
|
||||
pipe_single_file = self.pipeline_class.from_single_file(
|
||||
self.ckpt_path, controlnet=controlnet, safety_checker=None, config=self.repo_id
|
||||
)
|
||||
|
||||
super()._compare_component_configs(pipe, pipe_single_file)
|
||||
|
||||
def test_single_file_components_with_diffusers_config_local_files_only(self):
|
||||
controlnet = ControlNetModel.from_pretrained(
|
||||
"lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16, variant="fp16"
|
||||
)
|
||||
pipe = self.pipeline_class.from_pretrained(
|
||||
self.repo_id,
|
||||
controlnet=controlnet,
|
||||
safety_checker=None,
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
ckpt_filename = self.ckpt_path.split("/")[-1]
|
||||
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
|
||||
local_diffusers_config = download_diffusers_config(self.repo_id, tmpdir)
|
||||
|
||||
pipe_single_file = self.pipeline_class.from_single_file(
|
||||
local_ckpt_path,
|
||||
config=local_diffusers_config,
|
||||
controlnet=controlnet,
|
||||
safety_checker=None,
|
||||
local_files_only=True,
|
||||
)
|
||||
super()._compare_component_configs(pipe, pipe_single_file)
|
||||
@@ -0,0 +1,99 @@
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import (
|
||||
StableDiffusionImg2ImgPipeline,
|
||||
)
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
)
|
||||
|
||||
from .single_file_testing_utils import SDSingleFileTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class StableDiffusionImg2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
|
||||
pipeline_class = StableDiffusionImg2ImgPipeline
|
||||
ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors"
|
||||
original_config = (
|
||||
"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
|
||||
)
|
||||
repo_id = "runwayml/stable-diffusion-v1-5"
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
|
||||
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
||||
"/stable_diffusion_img2img/sketch-mountains-input.png"
|
||||
)
|
||||
inputs = {
|
||||
"prompt": "a fantasy landscape, concept art, high resolution",
|
||||
"image": init_image,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 3,
|
||||
"strength": 0.75,
|
||||
"guidance_scale": 7.5,
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_single_file_format_inference_is_same_as_pretrained(self):
|
||||
super().test_single_file_format_inference_is_same_as_pretrained(expected_max_diff=1e-3)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class StableDiffusion21Img2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
|
||||
pipeline_class = StableDiffusionImg2ImgPipeline
|
||||
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned.safetensors"
|
||||
original_config = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
|
||||
repo_id = "stabilityai/stable-diffusion-2-1"
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
|
||||
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
||||
"/stable_diffusion_img2img/sketch-mountains-input.png"
|
||||
)
|
||||
inputs = {
|
||||
"prompt": "a fantasy landscape, concept art, high resolution",
|
||||
"image": init_image,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 3,
|
||||
"strength": 0.75,
|
||||
"guidance_scale": 7.5,
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_single_file_format_inference_is_same_as_pretrained(self):
|
||||
super().test_single_file_format_inference_is_same_as_pretrained(expected_max_diff=1e-3)
|
||||
114
tests/single_file/test_stable_diffusion_inpaint_single_file.py
Normal file
114
tests/single_file/test_stable_diffusion_inpaint_single_file.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import (
|
||||
StableDiffusionInpaintPipeline,
|
||||
)
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
)
|
||||
|
||||
from .single_file_testing_utils import SDSingleFileTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class StableDiffusionInpaintPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
|
||||
pipeline_class = StableDiffusionInpaintPipeline
|
||||
ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-inpainting/blob/main/sd-v1-5-inpainting.ckpt"
|
||||
original_config = "https://raw.githubusercontent.com/runwayml/stable-diffusion/main/configs/stable-diffusion/v1-inpainting-inference.yaml"
|
||||
repo_id = "runwayml/stable-diffusion-inpainting"
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
|
||||
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
||||
"/stable_diffusion_inpaint/input_bench_image.png"
|
||||
)
|
||||
mask_image = load_image(
|
||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
||||
"/stable_diffusion_inpaint/input_bench_mask.png"
|
||||
)
|
||||
inputs = {
|
||||
"prompt": "Face of a yellow cat, high resolution, sitting on a park bench",
|
||||
"image": init_image,
|
||||
"mask_image": mask_image,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 3,
|
||||
"guidance_scale": 7.5,
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_single_file_format_inference_is_same_as_pretrained(self):
|
||||
super().test_single_file_format_inference_is_same_as_pretrained(expected_max_diff=1e-3)
|
||||
|
||||
def test_single_file_loading_4_channel_unet(self):
|
||||
# Test loading single file inpaint with a 4 channel UNet
|
||||
ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors"
|
||||
pipe = self.pipeline_class.from_single_file(ckpt_path)
|
||||
|
||||
assert pipe.unet.config.in_channels == 4
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class StableDiffusion21InpaintPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
|
||||
pipeline_class = StableDiffusionInpaintPipeline
|
||||
ckpt_path = (
|
||||
"https://huggingface.co/stabilityai/stable-diffusion-2-inpainting/blob/main/512-inpainting-ema.safetensors"
|
||||
)
|
||||
original_config = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inpainting-inference.yaml"
|
||||
repo_id = "stabilityai/stable-diffusion-2-inpainting"
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
|
||||
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
||||
"/stable_diffusion_inpaint/input_bench_image.png"
|
||||
)
|
||||
mask_image = load_image(
|
||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
||||
"/stable_diffusion_inpaint/input_bench_mask.png"
|
||||
)
|
||||
inputs = {
|
||||
"prompt": "Face of a yellow cat, high resolution, sitting on a park bench",
|
||||
"image": init_image,
|
||||
"mask_image": mask_image,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 3,
|
||||
"guidance_scale": 7.5,
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_single_file_format_inference_is_same_as_pretrained(self):
|
||||
super().test_single_file_format_inference_is_same_as_pretrained(expected_max_diff=1e-3)
|
||||
114
tests/single_file/test_stable_diffusion_single_file.py
Normal file
114
tests/single_file/test_stable_diffusion_single_file.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import gc
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import EulerDiscreteScheduler, StableDiffusionPipeline
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
)
|
||||
|
||||
from .single_file_testing_utils import (
|
||||
SDSingleFileTesterMixin,
|
||||
download_original_config,
|
||||
download_single_file_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class StableDiffusionPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
|
||||
pipeline_class = StableDiffusionPipeline
|
||||
ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors"
|
||||
original_config = (
|
||||
"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
|
||||
)
|
||||
repo_id = "runwayml/stable-diffusion-v1-5"
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
|
||||
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "a fantasy landscape, concept art, high resolution",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"strength": 0.75,
|
||||
"guidance_scale": 7.5,
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_single_file_format_inference_is_same_as_pretrained(self):
|
||||
super().test_single_file_format_inference_is_same_as_pretrained(expected_max_diff=1e-3)
|
||||
|
||||
def test_single_file_legacy_scheduler_loading(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
ckpt_filename = self.ckpt_path.split("/")[-1]
|
||||
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
|
||||
local_original_config = download_original_config(self.original_config, tmpdir)
|
||||
|
||||
pipe = self.pipeline_class.from_single_file(
|
||||
local_ckpt_path,
|
||||
original_config=local_original_config,
|
||||
cache_dir=tmpdir,
|
||||
local_files_only=True,
|
||||
scheduler_type="euler",
|
||||
)
|
||||
|
||||
# Default is PNDM for this checkpoint
|
||||
assert isinstance(pipe.scheduler, EulerDiscreteScheduler)
|
||||
|
||||
def test_single_file_legacy_scaling_factor(self):
|
||||
new_scaling_factor = 10.0
|
||||
init_pipe = self.pipeline_class.from_single_file(self.ckpt_path)
|
||||
pipe = self.pipeline_class.from_single_file(self.ckpt_path, scaling_factor=new_scaling_factor)
|
||||
|
||||
assert init_pipe.vae.config.scaling_factor != new_scaling_factor
|
||||
assert pipe.vae.config.scaling_factor == new_scaling_factor
|
||||
|
||||
|
||||
class StableDiffusion21PipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
|
||||
pipeline_class = StableDiffusionPipeline
|
||||
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned.safetensors"
|
||||
original_config = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
|
||||
repo_id = "stabilityai/stable-diffusion-2-1"
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
|
||||
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "a fantasy landscape, concept art, high resolution",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"strength": 0.75,
|
||||
"guidance_scale": 7.5,
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_single_file_format_inference_is_same_as_pretrained(self):
|
||||
super().test_single_file_format_inference_is_same_as_pretrained(expected_max_diff=1e-3)
|
||||
@@ -0,0 +1,68 @@
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import (
|
||||
StableDiffusionUpscalePipeline,
|
||||
)
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
)
|
||||
|
||||
from .single_file_testing_utils import SDSingleFileTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class StableDiffusionUpscalePipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
|
||||
pipeline_class = StableDiffusionUpscalePipeline
|
||||
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler/blob/main/x4-upscaler-ema.safetensors"
|
||||
original_config = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/x4-upscaling.yaml"
|
||||
repo_id = "stabilityai/stable-diffusion-x4-upscaler"
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_single_file_format_inference_is_same_as_pretrained(self):
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/sd2-upscale/low_res_cat.png"
|
||||
)
|
||||
|
||||
prompt = "a cat sitting on a park bench"
|
||||
pipe = StableDiffusionUpscalePipeline.from_pretrained(self.repo_id)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
generator = torch.Generator("cpu").manual_seed(0)
|
||||
output = pipe(prompt=prompt, image=image, generator=generator, output_type="np", num_inference_steps=3)
|
||||
image_from_pretrained = output.images[0]
|
||||
|
||||
pipe_from_single_file = StableDiffusionUpscalePipeline.from_single_file(self.ckpt_path)
|
||||
pipe_from_single_file.enable_model_cpu_offload()
|
||||
|
||||
generator = torch.Generator("cpu").manual_seed(0)
|
||||
output_from_single_file = pipe_from_single_file(
|
||||
prompt=prompt, image=image, generator=generator, output_type="np", num_inference_steps=3
|
||||
)
|
||||
image_from_single_file = output_from_single_file.images[0]
|
||||
|
||||
assert image_from_pretrained.shape == (512, 512, 3)
|
||||
assert image_from_single_file.shape == (512, 512, 3)
|
||||
assert (
|
||||
numpy_cosine_similarity_distance(image_from_pretrained.flatten(), image_from_single_file.flatten()) < 1e-3
|
||||
)
|
||||
@@ -0,0 +1,202 @@
|
||||
import gc
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import (
|
||||
StableDiffusionXLAdapterPipeline,
|
||||
T2IAdapter,
|
||||
)
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
)
|
||||
|
||||
from .single_file_testing_utils import (
|
||||
SDXLSingleFileTesterMixin,
|
||||
download_diffusers_config,
|
||||
download_original_config,
|
||||
download_single_file_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class StableDiffusionXLAdapterPipelineSingleFileSlowTests(unittest.TestCase, SDXLSingleFileTesterMixin):
|
||||
pipeline_class = StableDiffusionXLAdapterPipeline
|
||||
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors"
|
||||
repo_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
original_config = (
|
||||
"https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_inputs(self):
|
||||
prompt = "toy"
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/toy_canny.png"
|
||||
)
|
||||
|
||||
inputs = {
|
||||
"prompt": prompt,
|
||||
"image": image,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 7.5,
|
||||
"output_type": "np",
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_single_file_format_inference_is_same_as_pretrained(self):
|
||||
adapter = T2IAdapter.from_pretrained("TencentARC/t2i-adapter-lineart-sdxl-1.0", torch_dtype=torch.float16)
|
||||
pipe_single_file = StableDiffusionXLAdapterPipeline.from_single_file(
|
||||
self.ckpt_path,
|
||||
adapter=adapter,
|
||||
torch_dtype=torch.float16,
|
||||
safety_checker=None,
|
||||
)
|
||||
pipe_single_file.enable_model_cpu_offload()
|
||||
pipe_single_file.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_inputs()
|
||||
images_single_file = pipe_single_file(**inputs).images[0]
|
||||
|
||||
pipe = StableDiffusionXLAdapterPipeline.from_pretrained(
|
||||
self.repo_id,
|
||||
adapter=adapter,
|
||||
torch_dtype=torch.float16,
|
||||
safety_checker=None,
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
inputs = self.get_inputs()
|
||||
images = pipe(**inputs).images[0]
|
||||
|
||||
assert images_single_file.shape == (768, 512, 3)
|
||||
assert images.shape == (768, 512, 3)
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(images.flatten(), images_single_file.flatten())
|
||||
assert max_diff < 5e-3
|
||||
|
||||
def test_single_file_components(self):
|
||||
adapter = T2IAdapter.from_pretrained("TencentARC/t2i-adapter-lineart-sdxl-1.0", torch_dtype=torch.float16)
|
||||
pipe = self.pipeline_class.from_pretrained(
|
||||
self.repo_id,
|
||||
variant="fp16",
|
||||
adapter=adapter,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
pipe_single_file = self.pipeline_class.from_single_file(self.ckpt_path, safety_checker=None, adapter=adapter)
|
||||
super().test_single_file_components(pipe, pipe_single_file)
|
||||
|
||||
def test_single_file_components_local_files_only(self):
|
||||
adapter = T2IAdapter.from_pretrained("TencentARC/t2i-adapter-lineart-sdxl-1.0", torch_dtype=torch.float16)
|
||||
pipe = self.pipeline_class.from_pretrained(
|
||||
self.repo_id,
|
||||
variant="fp16",
|
||||
adapter=adapter,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
ckpt_filename = self.ckpt_path.split("/")[-1]
|
||||
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
|
||||
|
||||
single_file_pipe = self.pipeline_class.from_single_file(
|
||||
local_ckpt_path, adapter=adapter, safety_checker=None, local_files_only=True
|
||||
)
|
||||
|
||||
self._compare_component_configs(pipe, single_file_pipe)
|
||||
|
||||
def test_single_file_components_with_diffusers_config(self):
|
||||
adapter = T2IAdapter.from_pretrained("TencentARC/t2i-adapter-lineart-sdxl-1.0", torch_dtype=torch.float16)
|
||||
pipe = self.pipeline_class.from_pretrained(
|
||||
self.repo_id,
|
||||
variant="fp16",
|
||||
adapter=adapter,
|
||||
torch_dtype=torch.float16,
|
||||
safety_checker=None,
|
||||
)
|
||||
|
||||
pipe_single_file = self.pipeline_class.from_single_file(self.ckpt_path, config=self.repo_id, adapter=adapter)
|
||||
self._compare_component_configs(pipe, pipe_single_file)
|
||||
|
||||
def test_single_file_components_with_diffusers_config_local_files_only(self):
|
||||
adapter = T2IAdapter.from_pretrained("TencentARC/t2i-adapter-lineart-sdxl-1.0", torch_dtype=torch.float16)
|
||||
pipe = self.pipeline_class.from_pretrained(
|
||||
self.repo_id,
|
||||
variant="fp16",
|
||||
adapter=adapter,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
ckpt_filename = self.ckpt_path.split("/")[-1]
|
||||
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
|
||||
local_diffusers_config = download_diffusers_config(self.repo_id, tmpdir)
|
||||
|
||||
pipe_single_file = self.pipeline_class.from_single_file(
|
||||
local_ckpt_path,
|
||||
config=local_diffusers_config,
|
||||
adapter=adapter,
|
||||
safety_checker=None,
|
||||
local_files_only=True,
|
||||
)
|
||||
self._compare_component_configs(pipe, pipe_single_file)
|
||||
|
||||
def test_single_file_components_with_original_config(self):
|
||||
adapter = T2IAdapter.from_pretrained("TencentARC/t2i-adapter-lineart-sdxl-1.0", torch_dtype=torch.float16)
|
||||
pipe = self.pipeline_class.from_pretrained(
|
||||
self.repo_id,
|
||||
variant="fp16",
|
||||
adapter=adapter,
|
||||
torch_dtype=torch.float16,
|
||||
safety_checker=None,
|
||||
)
|
||||
|
||||
pipe_single_file = self.pipeline_class.from_single_file(
|
||||
self.ckpt_path, original_config=self.original_config, adapter=adapter
|
||||
)
|
||||
self._compare_component_configs(pipe, pipe_single_file)
|
||||
|
||||
def test_single_file_components_with_original_config_local_files_only(self):
|
||||
adapter = T2IAdapter.from_pretrained("TencentARC/t2i-adapter-lineart-sdxl-1.0", torch_dtype=torch.float16)
|
||||
pipe = self.pipeline_class.from_pretrained(
|
||||
self.repo_id,
|
||||
variant="fp16",
|
||||
adapter=adapter,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
ckpt_filename = self.ckpt_path.split("/")[-1]
|
||||
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
|
||||
local_original_config = download_original_config(self.original_config, tmpdir)
|
||||
|
||||
pipe_single_file = self.pipeline_class.from_single_file(
|
||||
local_ckpt_path,
|
||||
original_config=local_original_config,
|
||||
adapter=adapter,
|
||||
safety_checker=None,
|
||||
local_files_only=True,
|
||||
)
|
||||
self._compare_component_configs(pipe, pipe_single_file)
|
||||
@@ -0,0 +1,197 @@
|
||||
import gc
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from .single_file_testing_utils import (
|
||||
SDXLSingleFileTesterMixin,
|
||||
download_diffusers_config,
|
||||
download_single_file_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class StableDiffusionXLControlNetPipelineSingleFileSlowTests(unittest.TestCase, SDXLSingleFileTesterMixin):
|
||||
pipeline_class = StableDiffusionXLControlNetPipeline
|
||||
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors"
|
||||
repo_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
original_config = (
|
||||
"https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
|
||||
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/stormtrooper_depth.png"
|
||||
)
|
||||
inputs = {
|
||||
"prompt": "Stormtrooper's lecture",
|
||||
"image": image,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"strength": 0.75,
|
||||
"guidance_scale": 7.5,
|
||||
"output_type": "np",
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_single_file_format_inference_is_same_as_pretrained(self):
|
||||
controlnet = ControlNetModel.from_pretrained("diffusers/controlnet-depth-sdxl-1.0", torch_dtype=torch.float16)
|
||||
pipe_single_file = self.pipeline_class.from_single_file(
|
||||
self.ckpt_path, controlnet=controlnet, torch_dtype=torch.float16
|
||||
)
|
||||
pipe_single_file.unet.set_default_attn_processor()
|
||||
pipe_single_file.enable_model_cpu_offload()
|
||||
pipe_single_file.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
single_file_images = pipe_single_file(**inputs).images[0]
|
||||
|
||||
pipe = self.pipeline_class.from_pretrained(self.repo_id, controlnet=controlnet, torch_dtype=torch.float16)
|
||||
pipe.unet.set_default_attn_processor()
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
images = pipe(**inputs).images[0]
|
||||
|
||||
assert images.shape == (512, 512, 3)
|
||||
assert single_file_images.shape == (512, 512, 3)
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(images[0].flatten(), single_file_images[0].flatten())
|
||||
assert max_diff < 5e-2
|
||||
|
||||
def test_single_file_components(self):
|
||||
controlnet = ControlNetModel.from_pretrained(
|
||||
"diffusers/controlnet-depth-sdxl-1.0", torch_dtype=torch.float16, variant="fp16"
|
||||
)
|
||||
pipe = self.pipeline_class.from_pretrained(
|
||||
self.repo_id,
|
||||
variant="fp16",
|
||||
controlnet=controlnet,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
pipe_single_file = self.pipeline_class.from_single_file(self.ckpt_path, controlnet=controlnet)
|
||||
super().test_single_file_components(pipe, pipe_single_file)
|
||||
|
||||
def test_single_file_components_local_files_only(self):
|
||||
controlnet = ControlNetModel.from_pretrained(
|
||||
"diffusers/controlnet-depth-sdxl-1.0", torch_dtype=torch.float16, variant="fp16"
|
||||
)
|
||||
pipe = self.pipeline_class.from_pretrained(
|
||||
self.repo_id,
|
||||
variant="fp16",
|
||||
controlnet=controlnet,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
ckpt_filename = self.ckpt_path.split("/")[-1]
|
||||
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
|
||||
|
||||
single_file_pipe = self.pipeline_class.from_single_file(
|
||||
local_ckpt_path, controlnet=controlnet, safety_checker=None, local_files_only=True
|
||||
)
|
||||
|
||||
self._compare_component_configs(pipe, single_file_pipe)
|
||||
|
||||
def test_single_file_components_with_original_config(self):
|
||||
controlnet = ControlNetModel.from_pretrained(
|
||||
"diffusers/controlnet-depth-sdxl-1.0", torch_dtype=torch.float16, variant="fp16"
|
||||
)
|
||||
pipe = self.pipeline_class.from_pretrained(
|
||||
self.repo_id,
|
||||
variant="fp16",
|
||||
controlnet=controlnet,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
pipe_single_file = self.pipeline_class.from_single_file(
|
||||
self.ckpt_path,
|
||||
original_config=self.original_config,
|
||||
controlnet=controlnet,
|
||||
)
|
||||
self._compare_component_configs(pipe, pipe_single_file)
|
||||
|
||||
def test_single_file_components_with_original_config_local_files_only(self):
|
||||
controlnet = ControlNetModel.from_pretrained(
|
||||
"diffusers/controlnet-depth-sdxl-1.0", torch_dtype=torch.float16, variant="fp16"
|
||||
)
|
||||
pipe = self.pipeline_class.from_pretrained(
|
||||
self.repo_id,
|
||||
variant="fp16",
|
||||
controlnet=controlnet,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
ckpt_filename = self.ckpt_path.split("/")[-1]
|
||||
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
|
||||
|
||||
pipe_single_file = self.pipeline_class.from_single_file(
|
||||
local_ckpt_path,
|
||||
safety_checker=None,
|
||||
controlnet=controlnet,
|
||||
local_files_only=True,
|
||||
)
|
||||
self._compare_component_configs(pipe, pipe_single_file)
|
||||
|
||||
def test_single_file_components_with_diffusers_config(self):
|
||||
controlnet = ControlNetModel.from_pretrained(
|
||||
"diffusers/controlnet-depth-sdxl-1.0", torch_dtype=torch.float16, variant="fp16"
|
||||
)
|
||||
pipe = self.pipeline_class.from_pretrained(self.repo_id, controlnet=controlnet)
|
||||
pipe_single_file = self.pipeline_class.from_single_file(
|
||||
self.ckpt_path, controlnet=controlnet, config=self.repo_id
|
||||
)
|
||||
|
||||
super()._compare_component_configs(pipe, pipe_single_file)
|
||||
|
||||
def test_single_file_components_with_diffusers_config_local_files_only(self):
|
||||
controlnet = ControlNetModel.from_pretrained(
|
||||
"diffusers/controlnet-depth-sdxl-1.0", torch_dtype=torch.float16, variant="fp16"
|
||||
)
|
||||
pipe = self.pipeline_class.from_pretrained(
|
||||
self.repo_id,
|
||||
controlnet=controlnet,
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
ckpt_filename = self.ckpt_path.split("/")[-1]
|
||||
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
|
||||
local_diffusers_config = download_diffusers_config(self.repo_id, tmpdir)
|
||||
|
||||
pipe_single_file = self.pipeline_class.from_single_file(
|
||||
local_ckpt_path,
|
||||
config=local_diffusers_config,
|
||||
safety_checker=None,
|
||||
controlnet=controlnet,
|
||||
local_files_only=True,
|
||||
)
|
||||
super()._compare_component_configs(pipe, pipe_single_file)
|
||||
@@ -0,0 +1,105 @@
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
)
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
)
|
||||
|
||||
from .single_file_testing_utils import SDXLSingleFileTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class StableDiffusionXLImg2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDXLSingleFileTesterMixin):
|
||||
pipeline_class = StableDiffusionXLImg2ImgPipeline
|
||||
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors"
|
||||
repo_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
original_config = (
|
||||
"https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
|
||||
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
||||
"/stable_diffusion_img2img/sketch-mountains-input.png"
|
||||
)
|
||||
inputs = {
|
||||
"prompt": "a fantasy landscape, concept art, high resolution",
|
||||
"image": init_image,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 3,
|
||||
"strength": 0.75,
|
||||
"guidance_scale": 7.5,
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_single_file_format_inference_is_same_as_pretrained(self):
|
||||
super().test_single_file_format_inference_is_same_as_pretrained(expected_max_diff=1e-3)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class StableDiffusionXLImg2ImgRefinerPipelineSingleFileSlowTests(unittest.TestCase):
|
||||
pipeline_class = StableDiffusionXLImg2ImgPipeline
|
||||
ckpt_path = (
|
||||
"https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/blob/main/sd_xl_refiner_1.0.safetensors"
|
||||
)
|
||||
repo_id = "stabilityai/stable-diffusion-xl-refiner-1.0"
|
||||
original_config = (
|
||||
"https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml"
|
||||
)
|
||||
|
||||
def test_single_file_format_inference_is_same_as_pretrained(self):
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
||||
"/stable_diffusion_img2img/sketch-mountains-input.png"
|
||||
)
|
||||
|
||||
pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.float16)
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.unet.set_default_attn_processor()
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
image = pipe(
|
||||
prompt="mountains", image=init_image, num_inference_steps=5, generator=generator, output_type="np"
|
||||
).images[0]
|
||||
|
||||
pipe_single_file = self.pipeline_class.from_single_file(self.ckpt_path, torch_dtype=torch.float16)
|
||||
pipe_single_file.scheduler = DDIMScheduler.from_config(pipe_single_file.scheduler.config)
|
||||
pipe_single_file.unet.set_default_attn_processor()
|
||||
pipe_single_file.enable_model_cpu_offload()
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
image_single_file = pipe_single_file(
|
||||
prompt="mountains", image=init_image, num_inference_steps=5, generator=generator, output_type="np"
|
||||
).images[0]
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(image.flatten(), image_single_file.flatten())
|
||||
|
||||
assert max_diff < 5e-4
|
||||
@@ -0,0 +1,50 @@
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import StableDiffusionXLInstructPix2PixPipeline
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class StableDiffusionXLInstructPix2PixPipeline(unittest.TestCase):
|
||||
pipeline_class = StableDiffusionXLInstructPix2PixPipeline
|
||||
ckpt_path = "https://huggingface.co/stabilityai/cosxl/blob/main/cosxl_edit.safetensors"
|
||||
original_config = None
|
||||
repo_id = "diffusers/sdxl-instructpix2pix-768"
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
|
||||
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "a fantasy landscape, concept art, high resolution",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"strength": 0.75,
|
||||
"guidance_scale": 7.5,
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_single_file_setting_cosxl_edit(self):
|
||||
# Default is PNDM for this checkpoint
|
||||
pipe = self.pipeline_class.from_single_file(self.ckpt_path, config=self.repo_id, is_cosxl_edit=True)
|
||||
assert pipe.is_cosxl_edit is True
|
||||
54
tests/single_file/test_stable_diffusion_xl_single_file.py
Normal file
54
tests/single_file/test_stable_diffusion_xl_single_file.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import (
|
||||
StableDiffusionXLPipeline,
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
)
|
||||
|
||||
from .single_file_testing_utils import SDXLSingleFileTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class StableDiffusionXLPipelineSingleFileSlowTests(unittest.TestCase, SDXLSingleFileTesterMixin):
|
||||
pipeline_class = StableDiffusionXLPipeline
|
||||
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors"
|
||||
repo_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
original_config = (
|
||||
"https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
|
||||
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "a fantasy landscape, concept art, high resolution",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"strength": 0.75,
|
||||
"guidance_scale": 7.5,
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_single_file_format_inference_is_same_as_pretrained(self):
|
||||
super().test_single_file_format_inference_is_same_as_pretrained(expected_max_diff=1e-3)
|
||||
Reference in New Issue
Block a user