mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Batched Generators] This PR adds generators that are useful to make batched generation fully reproducible (#1718)
* [Batched Generators] all batched generators * up * up * up * up * up * up * up * up * up * up * up * up * up * up * up * up * hey * up again * fix tests * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * correct tests Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
committed by
GitHub
parent
086c7f9ea8
commit
c53a850604
@@ -302,11 +302,8 @@ image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
|
||||
|
||||
### Tweak prompts reusing seeds and latents
|
||||
|
||||
You can generate your own latents to reproduce results, or tweak your prompt on a specific result you liked. [This notebook](https://github.com/pcuenca/diffusers-examples/blob/main/notebooks/stable-diffusion-seeds.ipynb) shows how to do it step by step. You can also run it in Google Colab [](https://colab.research.google.com/github/pcuenca/diffusers-examples/blob/main/notebooks/stable-diffusion-seeds.ipynb).
|
||||
|
||||
|
||||
For more details, check out [the Stable Diffusion notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion.ipynb) [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion.ipynb)
|
||||
and have a look into the [release notes](https://github.com/huggingface/diffusers/releases/tag/v0.2.0).
|
||||
You can generate your own latents to reproduce results, or tweak your prompt on a specific result you liked.
|
||||
Please have a look at [Reusing seeds for deterministic generation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/reusing_seeds).
|
||||
|
||||
## Fine-Tuning Stable Diffusion
|
||||
|
||||
|
||||
@@ -28,6 +28,8 @@
|
||||
title: "Text-Guided Image-Inpainting"
|
||||
- local: using-diffusers/depth2img
|
||||
title: "Text-Guided Depth-to-Image"
|
||||
- local: using-diffusers/reusing_seeds
|
||||
title: "Reusing seeds for deterministic generation"
|
||||
- local: using-diffusers/custom_pipeline_examples
|
||||
title: "Community Pipelines"
|
||||
- local: using-diffusers/contribute_pipeline
|
||||
|
||||
73
docs/source/using-diffusers/reusing_seeds.mdx
Normal file
73
docs/source/using-diffusers/reusing_seeds.mdx
Normal file
@@ -0,0 +1,73 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Re-using seeds for fast prompt engineering
|
||||
|
||||
A common use case when generating images is to generate a batch of images, select one image and improve it with a better, more detailed prompt in a second run.
|
||||
To do this, one needs to make each generated image of the batch deterministic.
|
||||
Images are generated by denoising gaussian random noise which can be instantiated by passing a [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html#generator).
|
||||
|
||||
Now, for batched generation, we need to make sure that every single generated image in the batch is tied exactly to one seed. In 🧨 Diffusers, this can be achieved by not passing one `generator`, but a list
|
||||
of `generators` to the pipeline.
|
||||
|
||||
Let's go through an example using [`runwayml/stable-diffusion-v1-5`](runwayml/stable-diffusion-v1-5).
|
||||
We want to generate several versions of the prompt:
|
||||
|
||||
```py
|
||||
prompt = "Labrador in the style of Vermeer"
|
||||
```
|
||||
|
||||
Let's load the pipeline
|
||||
|
||||
```python
|
||||
>>> from diffusers import DiffusionPipeline
|
||||
|
||||
>>> pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
|
||||
>>> pipe = pipe.to("cuda")
|
||||
```
|
||||
|
||||
Now, let's define 4 different generators, since we would like to reproduce a certain image. We'll use seeds `0` to `3` to create our generators.
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
|
||||
>>> generator = [torch.Generator(device="cuda").manual_seed(i) for i in range(4)]
|
||||
```
|
||||
|
||||
Let's generate 4 images:
|
||||
|
||||
```python
|
||||
>>> images = pipe(prompt, generator=generator, num_images_per_prompt=4).images
|
||||
>>> images
|
||||
```
|
||||
|
||||

|
||||
|
||||
Ok, the last images has some double eyes, but the first image looks good!
|
||||
Let's try to make the prompt a bit better **while keeping the first seed**
|
||||
so that the images are similar to the first image.
|
||||
|
||||
```python
|
||||
prompt = [prompt + t for t in [", highly realistic", ", artsy", ", trending", ", colorful"]]
|
||||
generator = [torch.Generator(device="cuda").manual_seed(0) for i in range(4)]
|
||||
```
|
||||
|
||||
We create 4 generators with seed `0`, which is the first seed we used before.
|
||||
|
||||
Let's run the pipeline again.
|
||||
|
||||
```python
|
||||
>>> images = pipe(prompt, generator=generator).images
|
||||
>>> images
|
||||
```
|
||||
|
||||

|
||||
@@ -379,12 +379,24 @@ class AltDiffusionPipeline(DiffusionPipeline):
|
||||
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
if device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
|
||||
rand_device = "cpu" if device.type == "mps" else device
|
||||
|
||||
if isinstance(generator, list):
|
||||
shape = (1,) + shape[1:]
|
||||
latents = [
|
||||
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
|
||||
for i in range(batch_size)
|
||||
]
|
||||
latents = torch.cat(latents, dim=0).to(device)
|
||||
else:
|
||||
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
||||
else:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||
@@ -405,7 +417,7 @@ class AltDiffusionPipeline(DiffusionPipeline):
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
@@ -440,8 +452,8 @@ class AltDiffusionPipeline(DiffusionPipeline):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
|
||||
@@ -396,8 +396,22 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
|
||||
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
init_latent_dist = self.vae.encode(image).latent_dist
|
||||
init_latents = init_latent_dist.sample(generator=generator)
|
||||
|
||||
batch_size = batch_size * num_images_per_prompt
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if isinstance(generator, list):
|
||||
init_latents = [
|
||||
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
|
||||
]
|
||||
init_latents = torch.cat(init_latents, dim=0)
|
||||
else:
|
||||
init_latents = self.vae.encode(image).latent_dist.sample(generator)
|
||||
|
||||
init_latents = 0.18215 * init_latents
|
||||
|
||||
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
||||
@@ -410,16 +424,24 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
)
|
||||
deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
|
||||
additional_image_per_prompt = batch_size // init_latents.shape[0]
|
||||
init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0)
|
||||
init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
|
||||
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
|
||||
raise ValueError(
|
||||
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
|
||||
)
|
||||
else:
|
||||
init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
|
||||
init_latents = torch.cat([init_latents], dim=0)
|
||||
|
||||
# add noise to latents using the timesteps
|
||||
noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype)
|
||||
rand_device = "cpu" if device.type == "mps" else device
|
||||
shape = init_latents.shape
|
||||
if isinstance(generator, list):
|
||||
shape = (1,) + shape[1:]
|
||||
noise = [
|
||||
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(batch_size)
|
||||
]
|
||||
noise = torch.cat(noise, dim=0).to(device)
|
||||
else:
|
||||
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
||||
|
||||
# get latents
|
||||
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
||||
@@ -438,7 +460,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: Optional[float] = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
@@ -478,8 +500,8 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -45,7 +45,7 @@ class DanceDiffusionPipeline(DiffusionPipeline):
|
||||
self,
|
||||
batch_size: int = 1,
|
||||
num_inference_steps: int = 100,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
audio_length_in_s: Optional[float] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[AudioPipelineOutput, Tuple]:
|
||||
@@ -57,8 +57,8 @@ class DanceDiffusionPipeline(DiffusionPipeline):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality audio sample at
|
||||
the expense of slower inference.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
audio_length_in_s (`float`, *optional*, defaults to `self.unet.config.sample_size/self.unet.config.sample_rate`):
|
||||
The length of the generated audio sample in seconds. Note that the output of the pipeline, *i.e.*
|
||||
`sample_size`, will be `audio_length_in_s` * `self.unet.sample_rate`.
|
||||
@@ -94,9 +94,23 @@ class DanceDiffusionPipeline(DiffusionPipeline):
|
||||
sample_size = int(sample_size)
|
||||
|
||||
dtype = next(iter(self.unet.parameters())).dtype
|
||||
audio = torch.randn(
|
||||
(batch_size, self.unet.in_channels, sample_size), generator=generator, device=self.device, dtype=dtype
|
||||
)
|
||||
shape = (batch_size, self.unet.in_channels, sample_size)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
rand_device = "cpu" if self.device.type == "mps" else self.device
|
||||
if isinstance(generator, list):
|
||||
shape = (1,) + shape[1:]
|
||||
audio = [
|
||||
torch.randn(shape, generator=generator[i], device=rand_device, dtype=self.unet.dtype)
|
||||
for i in range(batch_size)
|
||||
]
|
||||
audio = torch.cat(audio, dim=0).to(self.device)
|
||||
else:
|
||||
audio = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(self.device)
|
||||
|
||||
# set step values
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=audio.device)
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -40,7 +40,7 @@ class DDIMPipeline(DiffusionPipeline):
|
||||
def __call__(
|
||||
self,
|
||||
batch_size: int = 1,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
eta: float = 0.0,
|
||||
num_inference_steps: int = 50,
|
||||
use_clipped_model_output: Optional[bool] = None,
|
||||
@@ -52,8 +52,8 @@ class DDIMPipeline(DiffusionPipeline):
|
||||
batch_size (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
The eta parameter which controls the scale of the variance (0 is DDIM and 1 is one type of DDPM).
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
@@ -74,7 +74,12 @@ class DDIMPipeline(DiffusionPipeline):
|
||||
generated images.
|
||||
"""
|
||||
|
||||
if generator is not None and generator.device.type != self.device.type and self.device.type != "mps":
|
||||
if (
|
||||
generator is not None
|
||||
and isinstance(generator, torch.Generator)
|
||||
and generator.device.type != self.device.type
|
||||
and self.device.type != "mps"
|
||||
):
|
||||
message = (
|
||||
f"The `generator` device is `{generator.device}` and does not match the pipeline "
|
||||
f"device `{self.device}`, so the `generator` will be ignored. "
|
||||
@@ -93,12 +98,23 @@ class DDIMPipeline(DiffusionPipeline):
|
||||
else:
|
||||
image_shape = (batch_size, self.unet.in_channels, *self.unet.sample_size)
|
||||
|
||||
if self.device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
image = torch.randn(image_shape, generator=generator, dtype=self.unet.dtype)
|
||||
image = image.to(self.device)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
rand_device = "cpu" if self.device.type == "mps" else self.device
|
||||
if isinstance(generator, list):
|
||||
shape = (1,) + image_shape[1:]
|
||||
image = [
|
||||
torch.randn(shape, generator=generator[i], device=rand_device, dtype=self.unet.dtype)
|
||||
for i in range(batch_size)
|
||||
]
|
||||
image = torch.cat(image, dim=0).to(self.device)
|
||||
else:
|
||||
image = torch.randn(image_shape, generator=generator, device=self.device, dtype=self.unet.dtype)
|
||||
image = torch.randn(image_shape, generator=generator, device=rand_device, dtype=self.unet.dtype)
|
||||
image = image.to(self.device)
|
||||
|
||||
# set step values
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -42,7 +42,7 @@ class DDPMPipeline(DiffusionPipeline):
|
||||
def __call__(
|
||||
self,
|
||||
batch_size: int = 1,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
num_inference_steps: int = 1000,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
@@ -53,8 +53,8 @@ class DDPMPipeline(DiffusionPipeline):
|
||||
batch_size (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
num_inference_steps (`int`, *optional*, defaults to 1000):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
|
||||
@@ -71,7 +71,7 @@ class LDMTextToImagePipeline(DiffusionPipeline):
|
||||
num_inference_steps: Optional[int] = 50,
|
||||
guidance_scale: Optional[float] = 1.0,
|
||||
eta: Optional[float] = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
@@ -95,8 +95,8 @@ class LDMTextToImagePipeline(DiffusionPipeline):
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt` at
|
||||
the, usually at the expense of lower image quality.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import inspect
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -70,7 +70,7 @@ class LDMSuperResolutionPipeline(DiffusionPipeline):
|
||||
batch_size: Optional[int] = 1,
|
||||
num_inference_steps: Optional[int] = 100,
|
||||
eta: Optional[float] = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
**kwargs,
|
||||
@@ -89,8 +89,8 @@ class LDMSuperResolutionPipeline(DiffusionPipeline):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -43,7 +43,7 @@ class LDMPipeline(DiffusionPipeline):
|
||||
def __call__(
|
||||
self,
|
||||
batch_size: int = 1,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
eta: float = 0.0,
|
||||
num_inference_steps: int = 50,
|
||||
output_type: Optional[str] = "pil",
|
||||
@@ -55,8 +55,8 @@ class LDMPipeline(DiffusionPipeline):
|
||||
batch_size (`int`, *optional*, defaults to 1):
|
||||
Number of images to generate.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
|
||||
@@ -291,12 +291,24 @@ class PaintByExamplePipeline(DiffusionPipeline):
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
if device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
|
||||
rand_device = "cpu" if device.type == "mps" else device
|
||||
|
||||
if isinstance(generator, list):
|
||||
shape = (1,) + shape[1:]
|
||||
latents = [
|
||||
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
|
||||
for i in range(batch_size)
|
||||
]
|
||||
latents = torch.cat(latents, dim=0).to(device)
|
||||
else:
|
||||
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
||||
else:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||
@@ -321,7 +333,14 @@ class PaintByExamplePipeline(DiffusionPipeline):
|
||||
masked_image = masked_image.to(device=device, dtype=dtype)
|
||||
|
||||
# encode the mask image into latents space so we can concatenate it to the latents
|
||||
masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
|
||||
if isinstance(generator, list):
|
||||
masked_image_latents = [
|
||||
self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i])
|
||||
for i in range(batch_size)
|
||||
]
|
||||
masked_image_latents = torch.cat(masked_image_latents, dim=0)
|
||||
else:
|
||||
masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
|
||||
masked_image_latents = 0.18215 * masked_image_latents
|
||||
|
||||
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
|
||||
@@ -390,7 +409,7 @@ class PaintByExamplePipeline(DiffusionPipeline):
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
@@ -433,8 +452,8 @@ class PaintByExamplePipeline(DiffusionPipeline):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -45,7 +45,7 @@ class PNDMPipeline(DiffusionPipeline):
|
||||
self,
|
||||
batch_size: int = 1,
|
||||
num_inference_steps: int = 50,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
**kwargs,
|
||||
|
||||
@@ -88,7 +88,7 @@ class RePaintPipeline(DiffusionPipeline):
|
||||
eta: float = 0.0,
|
||||
jump_length: int = 10,
|
||||
jump_n_sample: int = 10,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
**kwargs,
|
||||
@@ -112,8 +112,8 @@ class RePaintPipeline(DiffusionPipeline):
|
||||
The number of times we will make forward time jump for a given chosen time sample. Take a look at
|
||||
Figure 9 and 10 in https://arxiv.org/pdf/2201.09865.pdf.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
@@ -135,19 +135,34 @@ class RePaintPipeline(DiffusionPipeline):
|
||||
mask_image = _preprocess_mask(mask_image)
|
||||
mask_image = mask_image.to(device=self.device, dtype=self.unet.dtype)
|
||||
|
||||
batch_size = original_image.shape[0]
|
||||
|
||||
# sample gaussian noise to begin the loop
|
||||
image = torch.randn(
|
||||
original_image.shape,
|
||||
generator=generator,
|
||||
device=self.device,
|
||||
)
|
||||
image = image.to(device=self.device, dtype=self.unet.dtype)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
rand_device = "cpu" if self.device.type == "mps" else self.device
|
||||
image_shape = original_image.shape
|
||||
if isinstance(generator, list):
|
||||
shape = (1,) + image_shape[1:]
|
||||
image = [
|
||||
torch.randn(shape, generator=generator[i], device=rand_device, dtype=self.unet.dtype)
|
||||
for i in range(batch_size)
|
||||
]
|
||||
image = torch.cat(image, dim=0).to(self.device)
|
||||
else:
|
||||
image = torch.randn(image_shape, generator=generator, device=rand_device, dtype=self.unet.dtype)
|
||||
image = image.to(self.device)
|
||||
|
||||
# set step values
|
||||
self.scheduler.set_timesteps(num_inference_steps, jump_length, jump_n_sample, self.device)
|
||||
self.scheduler.eta = eta
|
||||
|
||||
t_last = self.scheduler.timesteps[0] + 1
|
||||
generator = generator[0] if isinstance(generator, list) else generator
|
||||
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
|
||||
if t < t_last:
|
||||
# predict the noise residual
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -41,7 +41,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
|
||||
self,
|
||||
batch_size: int = 1,
|
||||
num_inference_steps: int = 2000,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
**kwargs,
|
||||
@@ -51,8 +51,8 @@ class ScoreSdeVePipeline(DiffusionPipeline):
|
||||
batch_size (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
|
||||
@@ -435,8 +435,22 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
||||
|
||||
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
init_latent_dist = self.vae.encode(image).latent_dist
|
||||
init_latents = init_latent_dist.sample(generator=generator)
|
||||
|
||||
batch_size = image.shape[0]
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if isinstance(generator, list):
|
||||
init_latents = [
|
||||
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
|
||||
]
|
||||
init_latents = torch.cat(init_latents, dim=0)
|
||||
else:
|
||||
init_latents = self.vae.encode(image).latent_dist.sample(generator)
|
||||
|
||||
init_latents = 0.18215 * init_latents
|
||||
|
||||
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
||||
@@ -458,7 +472,16 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
||||
init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
|
||||
|
||||
# add noise to latents using the timestep
|
||||
noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype)
|
||||
rand_device = "cpu" if device.type == "mps" else device
|
||||
shape = init_latents.shape
|
||||
if isinstance(generator, list):
|
||||
shape = (1,) + shape[1:]
|
||||
noise = [
|
||||
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(batch_size)
|
||||
]
|
||||
noise = torch.cat(noise, dim=0).to(device)
|
||||
else:
|
||||
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
||||
|
||||
# get latents
|
||||
clean_latents = init_latents
|
||||
@@ -479,7 +502,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
||||
source_guidance_scale: Optional[float] = 1,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: Optional[float] = 0.1,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
@@ -519,8 +542,8 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
|
||||
@@ -309,13 +309,10 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
latents (`jnp.array`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
generation. Can be used to tweak the same generation with different prompts. tensor will ge generated
|
||||
by sampling using the supplied random `generator`.
|
||||
jit (`bool`, defaults to `False`):
|
||||
Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument
|
||||
exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release.
|
||||
|
||||
@@ -378,12 +378,24 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
if device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
|
||||
rand_device = "cpu" if device.type == "mps" else device
|
||||
|
||||
if isinstance(generator, list):
|
||||
shape = (1,) + shape[1:]
|
||||
latents = [
|
||||
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
|
||||
for i in range(batch_size)
|
||||
]
|
||||
latents = torch.cat(latents, dim=0).to(device)
|
||||
else:
|
||||
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
||||
else:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||
@@ -404,7 +416,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
@@ -439,8 +451,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
|
||||
@@ -345,8 +345,22 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.prepare_latents
|
||||
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
init_latent_dist = self.vae.encode(image).latent_dist
|
||||
init_latents = init_latent_dist.sample(generator=generator)
|
||||
|
||||
batch_size = batch_size * num_images_per_prompt
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if isinstance(generator, list):
|
||||
init_latents = [
|
||||
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
|
||||
]
|
||||
init_latents = torch.cat(init_latents, dim=0)
|
||||
else:
|
||||
init_latents = self.vae.encode(image).latent_dist.sample(generator)
|
||||
|
||||
init_latents = 0.18215 * init_latents
|
||||
|
||||
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
||||
@@ -359,16 +373,24 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
|
||||
)
|
||||
deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
|
||||
additional_image_per_prompt = batch_size // init_latents.shape[0]
|
||||
init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0)
|
||||
init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
|
||||
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
|
||||
raise ValueError(
|
||||
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
|
||||
)
|
||||
else:
|
||||
init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
|
||||
init_latents = torch.cat([init_latents], dim=0)
|
||||
|
||||
# add noise to latents using the timesteps
|
||||
noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype)
|
||||
rand_device = "cpu" if device.type == "mps" else device
|
||||
shape = init_latents.shape
|
||||
if isinstance(generator, list):
|
||||
shape = (1,) + shape[1:]
|
||||
noise = [
|
||||
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(batch_size)
|
||||
]
|
||||
noise = torch.cat(noise, dim=0).to(device)
|
||||
else:
|
||||
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
||||
|
||||
# get latents
|
||||
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
||||
@@ -429,7 +451,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: Optional[float] = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
@@ -468,8 +490,8 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
|
||||
@@ -258,12 +258,24 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
if device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
|
||||
rand_device = "cpu" if device.type == "mps" else device
|
||||
|
||||
if isinstance(generator, list):
|
||||
shape = (1,) + shape[1:]
|
||||
latents = [
|
||||
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
|
||||
for i in range(batch_size)
|
||||
]
|
||||
latents = torch.cat(latents, dim=0).to(device)
|
||||
else:
|
||||
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
||||
else:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||
@@ -283,7 +295,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
|
||||
guidance_scale: float = 7.5,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
@@ -318,8 +330,8 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
|
||||
@@ -401,8 +401,22 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
|
||||
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
init_latent_dist = self.vae.encode(image).latent_dist
|
||||
init_latents = init_latent_dist.sample(generator=generator)
|
||||
|
||||
batch_size = batch_size * num_images_per_prompt
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if isinstance(generator, list):
|
||||
init_latents = [
|
||||
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
|
||||
]
|
||||
init_latents = torch.cat(init_latents, dim=0)
|
||||
else:
|
||||
init_latents = self.vae.encode(image).latent_dist.sample(generator)
|
||||
|
||||
init_latents = 0.18215 * init_latents
|
||||
|
||||
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
||||
@@ -415,16 +429,24 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
)
|
||||
deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
|
||||
additional_image_per_prompt = batch_size // init_latents.shape[0]
|
||||
init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0)
|
||||
init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
|
||||
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
|
||||
raise ValueError(
|
||||
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
|
||||
)
|
||||
else:
|
||||
init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
|
||||
init_latents = torch.cat([init_latents], dim=0)
|
||||
|
||||
# add noise to latents using the timesteps
|
||||
noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype)
|
||||
rand_device = "cpu" if device.type == "mps" else device
|
||||
shape = init_latents.shape
|
||||
if isinstance(generator, list):
|
||||
shape = (1,) + shape[1:]
|
||||
noise = [
|
||||
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(batch_size)
|
||||
]
|
||||
noise = torch.cat(noise, dim=0).to(device)
|
||||
else:
|
||||
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
||||
|
||||
# get latents
|
||||
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
||||
@@ -443,7 +465,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: Optional[float] = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
@@ -483,8 +505,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
|
||||
@@ -463,12 +463,24 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
if device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
|
||||
rand_device = "cpu" if device.type == "mps" else device
|
||||
|
||||
if isinstance(generator, list):
|
||||
shape = (1,) + shape[1:]
|
||||
latents = [
|
||||
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
|
||||
for i in range(batch_size)
|
||||
]
|
||||
latents = torch.cat(latents, dim=0).to(device)
|
||||
else:
|
||||
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
||||
else:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||
@@ -492,7 +504,14 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
masked_image = masked_image.to(device=device, dtype=dtype)
|
||||
|
||||
# encode the mask image into latents space so we can concatenate it to the latents
|
||||
masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
|
||||
if isinstance(generator, list):
|
||||
masked_image_latents = [
|
||||
self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i])
|
||||
for i in range(batch_size)
|
||||
]
|
||||
masked_image_latents = torch.cat(masked_image_latents, dim=0)
|
||||
else:
|
||||
masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
|
||||
masked_image_latents = 0.18215 * masked_image_latents
|
||||
|
||||
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
|
||||
@@ -535,7 +554,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
@@ -578,8 +597,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
|
||||
@@ -432,7 +432,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
add_predicted_noise: Optional[bool] = False,
|
||||
eta: Optional[float] = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
@@ -479,8 +479,8 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
|
||||
@@ -332,7 +332,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
@@ -367,8 +367,8 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
|
||||
@@ -338,7 +338,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
@@ -371,8 +371,8 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
|
||||
@@ -422,12 +422,24 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
if device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
|
||||
rand_device = "cpu" if device.type == "mps" else device
|
||||
|
||||
if isinstance(generator, list):
|
||||
shape = (1,) + shape[1:]
|
||||
latents = [
|
||||
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
|
||||
for i in range(batch_size)
|
||||
]
|
||||
latents = torch.cat(latents, dim=0).to(device)
|
||||
else:
|
||||
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
||||
else:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||
@@ -490,7 +502,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
@@ -530,8 +542,8 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -49,7 +49,7 @@ class KarrasVePipeline(DiffusionPipeline):
|
||||
self,
|
||||
batch_size: int = 1,
|
||||
num_inference_steps: int = 50,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
**kwargs,
|
||||
@@ -59,8 +59,8 @@ class KarrasVePipeline(DiffusionPipeline):
|
||||
batch_size (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
|
||||
@@ -91,7 +91,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
@@ -126,8 +126,8 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
@@ -207,7 +207,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
@@ -242,8 +242,8 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
@@ -320,7 +320,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
|
||||
guidance_scale: float = 7.5,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
@@ -355,8 +355,8 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
|
||||
@@ -376,12 +376,24 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
if device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
|
||||
rand_device = "cpu" if device.type == "mps" else device
|
||||
|
||||
if isinstance(generator, list):
|
||||
shape = (1,) + shape[1:]
|
||||
latents = [
|
||||
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
|
||||
for i in range(batch_size)
|
||||
]
|
||||
latents = torch.cat(latents, dim=0).to(device)
|
||||
else:
|
||||
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
||||
else:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||
@@ -416,7 +428,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
|
||||
guidance_scale: float = 7.5,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
@@ -452,8 +464,8 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
|
||||
@@ -241,12 +241,24 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
if device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
|
||||
rand_device = "cpu" if device.type == "mps" else device
|
||||
|
||||
if isinstance(generator, list):
|
||||
shape = (1,) + shape[1:]
|
||||
latents = [
|
||||
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
|
||||
for i in range(batch_size)
|
||||
]
|
||||
latents = torch.cat(latents, dim=0).to(device)
|
||||
else:
|
||||
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
||||
else:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||
@@ -267,7 +279,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
@@ -303,8 +315,8 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
|
||||
@@ -292,12 +292,24 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
if device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
|
||||
rand_device = "cpu" if device.type == "mps" else device
|
||||
|
||||
if isinstance(generator, list):
|
||||
shape = (1,) + shape[1:]
|
||||
latents = [
|
||||
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
|
||||
for i in range(batch_size)
|
||||
]
|
||||
latents = torch.cat(latents, dim=0).to(device)
|
||||
else:
|
||||
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
||||
else:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||
@@ -318,7 +330,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
@@ -354,8 +366,8 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
|
||||
@@ -173,7 +173,7 @@ class VQDiffusionPipeline(DiffusionPipeline):
|
||||
guidance_scale: float = 5.0,
|
||||
truncation_rate: float = 1.0,
|
||||
num_images_per_prompt: int = 1,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
@@ -202,8 +202,8 @@ class VQDiffusionPipeline(DiffusionPipeline):
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor` of shape (batch), *optional*):
|
||||
Pre-generated noisy latents to be used as inputs for image generation. Must be valid embedding indices.
|
||||
Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will
|
||||
|
||||
@@ -44,6 +44,11 @@ class PipelineTesterMixin:
|
||||
test_cpu_offload = True
|
||||
test_xformers_attention = True
|
||||
|
||||
def get_generator(self, seed):
|
||||
device = torch_device if torch_device != "mps" else "cpu"
|
||||
generator = torch.Generator(device).manual_seed(seed)
|
||||
return generator
|
||||
|
||||
@property
|
||||
def pipeline_class(self) -> Union[Callable, DiffusionPipeline]:
|
||||
raise NotImplementedError(
|
||||
@@ -176,6 +181,64 @@ class PipelineTesterMixin:
|
||||
|
||||
logger.setLevel(level=diffusers.logging.WARNING)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
if self.pipeline_class.__name__ in ["CycleDiffusionPipeline", "RePaintPipeline"]:
|
||||
# RePaint can hardly be made deterministic since the scheduler is currently always
|
||||
# indeterministic
|
||||
# CycleDiffusion is also slighly undeterministic
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
logger = logging.get_logger(pipe.__module__)
|
||||
logger.setLevel(level=diffusers.logging.FATAL)
|
||||
|
||||
# batchify inputs
|
||||
batched_inputs = {}
|
||||
batch_size = 3
|
||||
for name, value in inputs.items():
|
||||
if name in ALLOWED_REQUIRED_ARGS:
|
||||
# prompt is string
|
||||
if name == "prompt":
|
||||
len_prompt = len(value)
|
||||
# make unequal batch sizes
|
||||
batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
|
||||
|
||||
# make last batch super long
|
||||
batched_inputs[name][-1] = 2000 * "very long"
|
||||
# or else we have images
|
||||
else:
|
||||
batched_inputs[name] = batch_size * [value]
|
||||
elif name == "batch_size":
|
||||
batched_inputs[name] = batch_size
|
||||
elif name == "generator":
|
||||
batched_inputs[name] = [self.get_generator(i) for i in range(batch_size)]
|
||||
else:
|
||||
batched_inputs[name] = value
|
||||
|
||||
batched_inputs["num_inference_steps"] = inputs["num_inference_steps"]
|
||||
|
||||
if self.pipeline_class.__name__ != "DanceDiffusionPipeline":
|
||||
batched_inputs["output_type"] = "np"
|
||||
|
||||
output_batch = pipe(**batched_inputs)
|
||||
assert output_batch[0].shape[0] == batch_size
|
||||
|
||||
inputs["generator"] = self.get_generator(0)
|
||||
|
||||
output = pipe(**inputs)
|
||||
|
||||
if torch_device != "mps":
|
||||
# TODO(Pedro) - not sure why, but not at all reproducible at the moment it seems
|
||||
# make sure that batched and non-batched is identical
|
||||
assert np.abs(output_batch[0][0] - output[0][0]).max() < 1e-4
|
||||
logger.setLevel(level=diffusers.logging.WARNING)
|
||||
|
||||
def test_dict_tuple_outputs_equivalent(self):
|
||||
if torch_device == "mps" and self.pipeline_class in (
|
||||
DanceDiffusionPipeline,
|
||||
|
||||
Reference in New Issue
Block a user