mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
committed by
GitHub
parent
1d7adf1329
commit
e7457b377d
@@ -10,19 +10,24 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Models
|
||||
# Diffusion Pipeline
|
||||
|
||||
Diffusers contains pretrained models for popular algorithms and modules for creating the next set of diffusion models.
|
||||
The primary function of these models is to denoise an input sample, by modeling the distribution $p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)$.
|
||||
The models are built on the base class ['ModelMixin'] that is a `torch.nn.module` with basic functionality for saving and loading models both locally and from the HuggingFace hub.
|
||||
The [`DiffusionPipeline`] is the easiest way to load any pretrained diffusion pipeline from the [Hub](https://huggingface.co/models?library=diffusers) and to use it in inference.
|
||||
|
||||
## API
|
||||
<Tip>
|
||||
|
||||
One should not use the Diffusion Pipeline class for training or fine-tuning a diffusion model. Individual
|
||||
components of diffusion pipelines are usually trained individually, so we suggest to directly work
|
||||
with [`UNetModel`] and [`UNetConditionModel`].
|
||||
|
||||
Models should provide the `def forward` function and initialization of the model.
|
||||
All saving, loading, and utilities should be in the base ['ModelMixin'] class.
|
||||
</Tip>
|
||||
|
||||
## Examples
|
||||
Any diffusion pipeline that is loaded with [`~DiffusionPipeline.from_pretrained`] will automatically
|
||||
detect the pipeline type, *e.g.* [`StableDiffusionPipeline`] and consequently load each component of the
|
||||
pipeline and pass them into the `__init__` function of the pipeline, *e.g.* [`~StableDiffusionPipeline.__init__`].
|
||||
|
||||
- The ['UNetModel'] was proposed in [TODO](https://arxiv.org/) and has been used in paper1, paper2, paper3.
|
||||
- Extensions of the ['UNetModel'] include the ['UNetGlideModel'] that uses attention and timestep embeddings for the [GLIDE](https://arxiv.org/abs/2112.10741) paper, the ['UNetGradTTS'] model from this [paper](https://arxiv.org/abs/2105.06337) for text-to-speech, ['UNetLDMModel'] for latent-diffusion models in this [paper](https://arxiv.org/abs/2112.10752), and the ['TemporalUNet'] used for time-series prediciton in this reinforcement learning [paper](https://arxiv.org/abs/2205.09991).
|
||||
- TODO: mention VAE / SDE score estimation
|
||||
Any pipeline object can be saved locally with [`~DiffusionPipeline.save_pretrained`].
|
||||
|
||||
[[autodoc]] DiffusionPipeline
|
||||
- from_pretrained
|
||||
- save_pretrained
|
||||
|
||||
@@ -1,39 +1,35 @@
|
||||
# Stable diffusion pipelines
|
||||
|
||||
## Overview
|
||||
|
||||
Stable Diffusion is a text-to-image _latent diffusion_ model created by the researchers and engineers from [CompVis](https://github.com/CompVis), [Stability AI](https://stability.ai/) and [LAION](https://laion.ai/). It's trained on 512x512 images from a subset of the [LAION-5B](https://laion.ai/blog/laion-5b/) dataset. This model uses a frozen CLIP ViT-L/14 text encoder to condition the model on text prompts. With its 860M UNet and 123M text encoder, the model is relatively lightweight and can run on consumer GPUs.
|
||||
|
||||
Latent diffusion is the research on top of which Stable Diffusion was built. It was proposed in [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) by Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, Björn Ommer. You can learn more details about it in the [specific pipeline for latent diffusion](pipelines/latent_diffusion) that is part of 🤗 Diffusers.
|
||||
|
||||
For more details about how Stable Diffusion works and how it differs from the base latent diffusion model, please refer to the official [launch announcement post](https://stability.ai/blog/stable-diffusion-announcement) and [this section of our own blog post](https://huggingface.co/blog/stable_diffusion#how-does-stable-diffusion-work).
|
||||
|
||||
## Tips
|
||||
|
||||
*Tips*:
|
||||
- To tweak your prompts on a specific result you liked, you can generate your own latents, as demonstrated in the following notebook: [](https://colab.research.google.com/github/pcuenca/diffusers-examples/blob/main/notebooks/stable-diffusion-seeds.ipynb)
|
||||
- TODO: some interesting Tips
|
||||
|
||||
## Available pipelines
|
||||
|
||||
| Pipeline | Tasks | Colab | Demo
|
||||
|---|---|:---:|:---:|
|
||||
| [pipeline_stable_diffusion.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py) | *Text-to-Image Generation* | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion.ipynb) | [🤗 Stable Diffusion](https://huggingface.co/spaces/stabilityai/stable-diffusion)
|
||||
| [pipeline_stable_diffusion_img2img.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py) | *Image-to-Image Text-Guided Generation* | [](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/image_2_image_using_diffusers.ipynb) | [🤗 Diffuse the Rest](https://huggingface.co/spaces/huggingface/diffuse-the-rest)
|
||||
| [pipeline_stable_diffusion_inpaint.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py) | **Experimental** – *Text-Guided Image Inpainting* | [](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/in_painting_with_stable_diffusion_using_diffusers.ipynb) | Coming soon
|
||||
|
||||
## API
|
||||
|
||||
[[autodoc]] StableDiffusionPipeline
|
||||
- __init__
|
||||
- __call__
|
||||
- enable_attention_slicing
|
||||
- disable_attention_slicing
|
||||
|
||||
[[autodoc]] StableDiffusionImg2ImgPipeline
|
||||
- __init__
|
||||
- __call__
|
||||
- enable_attention_slicing
|
||||
- disable_attention_slicing
|
||||
|
||||
[[autodoc]] StableDiffusionInpaintPipeline
|
||||
- __init__
|
||||
- __call__
|
||||
- enable_attention_slicing
|
||||
- disable_attention_slicing
|
||||
|
||||
@@ -14,7 +14,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
Install Diffusers for with PyTorch. Support for other libraries will come in the future
|
||||
|
||||
🤗 Diffusers is tested on Python 3.6+, and PyTorch 1.4.0+.
|
||||
🤗 Diffusers is tested on Python 3.7+, and PyTorch 1.7.0+.
|
||||
|
||||
## Install with pip
|
||||
|
||||
|
||||
@@ -72,6 +72,20 @@ class ImagePipelineOutput(BaseOutput):
|
||||
|
||||
|
||||
class DiffusionPipeline(ConfigMixin):
|
||||
r"""
|
||||
Base class for all models.
|
||||
|
||||
[`DiffusionPipeline`] takes care of storing all components (models, schedulers, processors) for diffusion pipelines
|
||||
and handles methods for loading, downloading and saving models as well as a few methods common to all pipelines to:
|
||||
|
||||
- move all PyTorch modules to the device of your choice
|
||||
- enabling/disabling the progress bar for the denoising iteration
|
||||
|
||||
Class attributes:
|
||||
|
||||
- **config_name** ([`str`]) -- name of the config file that will store the class and module names of all
|
||||
compenents of the diffusion pipeline.
|
||||
"""
|
||||
config_name = "model_index.json"
|
||||
|
||||
def register_modules(self, **kwargs):
|
||||
@@ -105,6 +119,15 @@ class DiffusionPipeline(ConfigMixin):
|
||||
setattr(self, name, module)
|
||||
|
||||
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
|
||||
"""
|
||||
Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to
|
||||
a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading
|
||||
method. The pipeline can easily be re-loaded using the `[`~DiffusionPipeline.from_pretrained`]` class method.
|
||||
|
||||
Arguments:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory to which to save. Will be created if it doesn't exist.
|
||||
"""
|
||||
self.save_config(save_directory)
|
||||
|
||||
model_index_dict = dict(self.config)
|
||||
@@ -145,6 +168,10 @@ class DiffusionPipeline(ConfigMixin):
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
r"""
|
||||
Returns:
|
||||
`torch.device`: The torch device on which the pipeline is located.
|
||||
"""
|
||||
module_names, _ = self.extract_init_dict(dict(self.config))
|
||||
for name in module_names.keys():
|
||||
module = getattr(self, name)
|
||||
@@ -155,7 +182,94 @@ class DiffusionPipeline(ConfigMixin):
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||
r"""
|
||||
Add docstrings
|
||||
Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights.
|
||||
|
||||
The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated).
|
||||
|
||||
The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
|
||||
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
|
||||
task.
|
||||
|
||||
The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
|
||||
weights are discarded.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
|
||||
- A string, the *repo id* of a pretrained pipeline hosted inside a model repo on
|
||||
https://huggingface.co/ Valid repo ids have to be located under a user or organization name, like
|
||||
`CompVis/ldm-text2im-large-256`.
|
||||
- A path to a *directory* containing pipeline weights saved using
|
||||
[`~DiffusionPipeline.save_pretrained`], e.g., `./my_pipeline_directory/`.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
|
||||
will be automatically derived from the model's weights.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
resume_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
|
||||
file exists.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
output_loading_info(`bool`, *optional*, defaults to `False`):
|
||||
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||
local_files_only(`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to only look at local files (i.e., do not try to download the model).
|
||||
use_auth_token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
||||
when running `huggingface-cli login` (stored in `~/.huggingface`).
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
||||
identifier allowed by git.
|
||||
mirror (`str`, *optional*):
|
||||
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
|
||||
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
|
||||
Please refer to the mirror site for more information. specify the folder name here.
|
||||
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
|
||||
speficic pipeline class. The overritten components are then directly passed to the pipelines `__init__`
|
||||
method. See example below for more information.
|
||||
|
||||
<Tip>
|
||||
|
||||
Passing `use_auth_token=True`` is required when you want to use a private model, *e.g.*
|
||||
`"CompVis/stable-diffusion-v1-4"`
|
||||
|
||||
</Tip>
|
||||
|
||||
<Tip>
|
||||
|
||||
Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
|
||||
this method in a firewalled environment.
|
||||
|
||||
</Tip>
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
>>> from diffusers import DiffusionPipeline
|
||||
|
||||
>>> # Download pipeline from huggingface.co and cache.
|
||||
>>> pipeline = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")
|
||||
|
||||
>>> # Download pipeline that requires an authorization token
|
||||
>>> # For more information on access tokens, please refer to this section
|
||||
>>> # of the documentation](https://huggingface.co/docs/hub/security-tokens)
|
||||
>>> pipeline = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True)
|
||||
|
||||
>>> # Download pipeline, but overwrite scheduler
|
||||
>>> from diffusers import LMSDiscreteScheduler
|
||||
|
||||
>>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
|
||||
>>> pipeline = DiffusionPipeline.from_pretrained(
|
||||
... "CompVis/stable-diffusion-v1-4", scheduler=scheduler, use_auth_token=True
|
||||
... )
|
||||
```
|
||||
"""
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
|
||||
Reference in New Issue
Block a user