1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
Patrick von Platen
2022-06-15 12:35:51 +02:00

View File

@@ -58,12 +58,14 @@ git clone https://github.com/huggingface/diffusers.git
cd diffusers && pip install -e .
```
### 1. `diffusers` as a central modular diffusion and sampler library
### 1. `diffusers` as a toolbox for schedulers and models.
`diffusers` is more modularized than `transformers`. The idea is that researchers and engineers can use only parts of the library easily for the own use cases.
It could become a central place for all kinds of models, schedulers, training utils and processors that one can mix and match for one's own use case.
Both models and schedulers should be load- and saveable from the Hub.
For more examples see [schedulers](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers) and [models](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models)
#### **Example for [DDPM](https://arxiv.org/abs/2006.11239):**
```python
@@ -171,25 +173,35 @@ image_pil = PIL.Image.fromarray(image_processed[0])
image_pil.save("test.png")
```
### 2. `diffusers` as a collection of most important Diffusion systems (GLIDE, Dalle, ...)
`models` directory in repository hosts the complete code necessary for running a diffusion system as well as to train it. A `DiffusionPipeline` class allows to easily run the diffusion model in inference:
### 2. `diffusers` as a collection of popula Diffusion systems (GLIDE, Dalle, ...)
#### **Example image generation with DDPM**
For more examples see [pipelines](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines).
#### **Example image generation with PNDM**
```python
from diffusers import DiffusionPipeline
from diffusers import PNDM, UNetModel, PNDMScheduler
import PIL.Image
import numpy as np
import torch
model_id = "fusing/ddim-celeba-hq"
model = UNetModel.from_pretrained(model_id)
scheduler = PNDMScheduler()
# load model and scheduler
ddpm = DiffusionPipeline.from_pretrained("fusing/ddpm-lsun-bedroom")
ddpm = PNDM(unet=model, noise_scheduler=scheduler)
# run pipeline in inference (sample random noise and denoise)
image = ddpm()
with torch.no_grad():
image = ddpm()
# process image to PIL
image_processed = image.cpu().permute(0, 2, 3, 1)
image_processed = (image_processed + 1.0) * 127.5
image_processed = (image_processed + 1.0) / 2
image_processed = torch.clamp(image_processed, 0.0, 1.0)
image_processed = image_processed * 255
image_processed = image_processed.numpy().astype(np.uint8)
image_pil = PIL.Image.fromarray(image_processed[0])
@@ -255,61 +267,3 @@ from scipy.io.wavfile import write as wavwrite
sampling_rate = 22050
wavwrite("generated_audio.wav", sampling_rate, audio.squeeze().cpu().numpy())
```
## Library structure:
```
β”œβ”€β”€ LICENSE
β”œβ”€β”€ Makefile
β”œβ”€β”€ README.md
β”œβ”€β”€ pyproject.toml
β”œβ”€β”€ setup.cfg
β”œβ”€β”€ setup.py
β”œβ”€β”€ src
β”‚ β”œβ”€β”€ diffusers
β”‚ β”œβ”€β”€ __init__.py
β”‚ β”œβ”€β”€ configuration_utils.py
β”‚ β”œβ”€β”€ dependency_versions_check.py
β”‚ β”œβ”€β”€ dependency_versions_table.py
β”‚ β”œβ”€β”€ dynamic_modules_utils.py
β”‚ β”œβ”€β”€ modeling_utils.py
β”‚ β”œβ”€β”€ models
β”‚ β”‚ β”œβ”€β”€ __init__.py
β”‚ β”‚ β”œβ”€β”€ unet.py
β”‚ β”‚ β”œβ”€β”€ unet_glide.py
β”‚ β”‚ └── unet_ldm.py
β”‚ β”œβ”€β”€ pipeline_utils.py
β”‚ β”œβ”€β”€ pipelines
β”‚ β”‚ β”œβ”€β”€ __init__.py
β”‚ β”‚ β”œβ”€β”€ configuration_ldmbert.py
β”‚ β”‚ β”œβ”€β”€ conversion_glide.py
β”‚ β”‚ β”œβ”€β”€ modeling_vae.py
β”‚ β”‚ β”œβ”€β”€ pipeline_bddm.py
β”‚ β”‚ β”œβ”€β”€ pipeline_ddim.py
β”‚ β”‚ β”œβ”€β”€ pipeline_ddpm.py
β”‚ β”‚ β”œβ”€β”€ pipeline_glide.py
β”‚ β”‚ └── pipeline_latent_diffusion.py
β”‚ β”œβ”€β”€ schedulers
β”‚ β”‚ β”œβ”€β”€ __init__.py
β”‚ β”‚ β”œβ”€β”€ classifier_free_guidance.py
β”‚ β”‚ β”œβ”€β”€ scheduling_ddim.py
β”‚ β”‚ β”œβ”€β”€ scheduling_ddpm.py
β”‚ β”‚ β”œβ”€β”€ scheduling_plms.py
β”‚ β”‚ └── scheduling_utils.py
β”‚ β”œβ”€β”€ testing_utils.py
β”‚ └── utils
β”‚ β”œβ”€β”€ __init__.py
β”‚ └── logging.py
β”œβ”€β”€ tests
β”‚ β”œβ”€β”€ __init__.py
β”‚ β”œβ”€β”€ test_modeling_utils.py
β”‚ └── test_scheduler.py
└── utils
β”œβ”€β”€ check_config_docstrings.py
β”œβ”€β”€ check_copies.py
β”œβ”€β”€ check_dummies.py
β”œβ”€β”€ check_inits.py
β”œβ”€β”€ check_repo.py
β”œβ”€β”€ check_table.py
└── check_tf_ops.py
```