mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Community] Implementation of the IADB community pipeline (#3996)
* community pipeline: implementation of iadb * iadb.py: reformat using black * iadb.py: linting update
This commit is contained in:
@@ -38,6 +38,7 @@ If a community doesn't work as expected, please open an issue and ping the autho
|
||||
| Stable Diffusion IPEX Pipeline | Accelerate Stable Diffusion inference pipeline with BF16/FP32 precision on Intel Xeon CPUs with [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [Stable Diffusion on IPEX](#stable-diffusion-on-ipex) | - | [Yingjie Han](https://github.com/yingjie-han/) |
|
||||
| CLIP Guided Images Mixing Stable Diffusion Pipeline | Сombine images using usual diffusion models. | [CLIP Guided Images Mixing Using Stable Diffusion](#clip-guided-images-mixing-with-stable-diffusion) | - | [Karachev Denis](https://github.com/TheDenk) |
|
||||
| TensorRT Stable Diffusion Inpainting Pipeline | Accelerates the Stable Diffusion Inpainting Pipeline using TensorRT | [TensorRT Stable Diffusion Inpainting Pipeline](#tensorrt-inpainting-stable-diffusion-pipeline) | - | [Asfiya Baig](https://github.com/asfiyab-nvidia) |
|
||||
| IADB Pipeline | Implementation of [Iterative α-(de)Blending: a Minimalist Deterministic Diffusion Model](https://arxiv.org/abs/2305.03486) | [IADB Pipeline](#iadb-pipeline) | - | [Thomas Chambon](https://github.com/tchambon)
|
||||
|
||||
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
|
||||
```py
|
||||
@@ -1707,3 +1708,62 @@ output = pipeline(
|
||||
```
|
||||

|
||||

|
||||
|
||||
|
||||
### IADB pipeline
|
||||
|
||||
This pipeline is the implementation of the [α-(de)Blending: a Minimalist Deterministic Diffusion Model](https://arxiv.org/abs/2305.03486) paper.
|
||||
It is a simple and minimalist diffusion model.
|
||||
|
||||
The following code shows how to use the IADB pipeline to generate images using a pretrained celebahq-256 model.
|
||||
|
||||
```python
|
||||
|
||||
pipeline_iadb = DiffusionPipeline.from_pretrained("thomasc4/iadb-celebahq-256", custom_pipeline='iadb')
|
||||
|
||||
pipeline_iadb = pipeline_iadb.to('cuda')
|
||||
|
||||
output = pipeline_iadb(batch_size=4,num_inference_steps=128)
|
||||
for i in range(len(output[0])):
|
||||
plt.imshow(output[0][i])
|
||||
plt.show()
|
||||
|
||||
```
|
||||
|
||||
Sampling with the IADB formulation is easy, and can be done in a few lines (the pipeline already implements it):
|
||||
|
||||
```python
|
||||
|
||||
def sample_iadb(model, x0, nb_step):
|
||||
x_alpha = x0
|
||||
for t in range(nb_step):
|
||||
alpha = (t/nb_step)
|
||||
alpha_next =((t+1)/nb_step)
|
||||
|
||||
d = model(x_alpha, torch.tensor(alpha, device=x_alpha.device))['sample']
|
||||
x_alpha = x_alpha + (alpha_next-alpha)*d
|
||||
|
||||
return x_alpha
|
||||
|
||||
```
|
||||
|
||||
The training loop is also straightforward:
|
||||
|
||||
```python
|
||||
|
||||
# Training loop
|
||||
while True:
|
||||
x0 = sample_noise()
|
||||
x1 = sample_dataset()
|
||||
|
||||
alpha = torch.rand(batch_size)
|
||||
|
||||
# Blend
|
||||
x_alpha = (1-alpha) * x0 + alpha * x1
|
||||
|
||||
# Loss
|
||||
loss = torch.sum((D(x_alpha, alpha)- (x1-x0))**2)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
149
examples/community/iadb.py
Normal file
149
examples/community/iadb.py
Normal file
@@ -0,0 +1,149 @@
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
from diffusers.pipeline_utils import ImagePipelineOutput
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
class IADBScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
IADBScheduler is a scheduler for the Iterative α-(de)Blending denoising method. It is simple and minimalist.
|
||||
|
||||
For more details, see the original paper: https://arxiv.org/abs/2305.03486 and the blog post: https://ggx-research.github.io/publication/2023/05/10/publication-iadb.html
|
||||
"""
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: int,
|
||||
x_alpha: torch.FloatTensor,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Predict the sample at the previous timestep by reversing the ODE. Core function to propagate the diffusion
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`): direct output from learned diffusion model. It is the direction from x0 to x1.
|
||||
timestep (`float`): current timestep in the diffusion chain.
|
||||
x_alpha (`torch.FloatTensor`): x_alpha sample for the current timestep
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: the sample at the previous timestep
|
||||
|
||||
"""
|
||||
if self.num_inference_steps is None:
|
||||
raise ValueError(
|
||||
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||||
)
|
||||
|
||||
alpha = timestep / self.num_inference_steps
|
||||
alpha_next = (timestep + 1) / self.num_inference_steps
|
||||
|
||||
d = model_output
|
||||
|
||||
x_alpha = x_alpha + (alpha_next - alpha) * d
|
||||
|
||||
return x_alpha
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int):
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
noise: torch.FloatTensor,
|
||||
alpha: torch.FloatTensor,
|
||||
) -> torch.FloatTensor:
|
||||
return original_samples * alpha + noise * (1 - alpha)
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
|
||||
class IADBPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Parameters:
|
||||
unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
|
||||
[`DDPMScheduler`], or [`DDIMScheduler`].
|
||||
"""
|
||||
|
||||
def __init__(self, unet, scheduler):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
batch_size: int = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
num_inference_steps: int = 50,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
) -> Union[ImagePipelineOutput, Tuple]:
|
||||
r"""
|
||||
Args:
|
||||
batch_size (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is
|
||||
True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
# Sample gaussian noise to begin loop
|
||||
if isinstance(self.unet.config.sample_size, int):
|
||||
image_shape = (
|
||||
batch_size,
|
||||
self.unet.config.in_channels,
|
||||
self.unet.config.sample_size,
|
||||
self.unet.config.sample_size,
|
||||
)
|
||||
else:
|
||||
image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
image = torch.randn(image_shape, generator=generator, device=self.device, dtype=self.unet.dtype)
|
||||
|
||||
# set step values
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
x_alpha = image.clone()
|
||||
for t in self.progress_bar(range(num_inference_steps)):
|
||||
alpha = t / num_inference_steps
|
||||
|
||||
# 1. predict noise model_output
|
||||
model_output = self.unet(x_alpha, torch.tensor(alpha, device=x_alpha.device)).sample
|
||||
|
||||
# 2. step
|
||||
x_alpha = self.scheduler.step(model_output, t, x_alpha)
|
||||
|
||||
image = (x_alpha * 0.5 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return ImagePipelineOutput(images=image)
|
||||
Reference in New Issue
Block a user