From c53a850604c4de6ca385a408854b022bbafadce2 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 17 Dec 2022 11:13:16 +0100 Subject: [PATCH] [Batched Generators] This PR adds generators that are useful to make batched generation fully reproducible (#1718) * [Batched Generators] all batched generators * up * up * up * up * up * up * up * up * up * up * up * up * up * up * up * up * hey * up again * fix tests * Apply suggestions from code review Co-authored-by: Pedro Cuenca * correct tests Co-authored-by: Pedro Cuenca --- README.md | 7 +- docs/source/_toctree.yml | 2 + docs/source/using-diffusers/reusing_seeds.mdx | 73 +++++++++++++++++++ .../alt_diffusion/pipeline_alt_diffusion.py | 26 +++++-- .../pipeline_alt_diffusion_img2img.py | 40 +++++++--- .../pipeline_dance_diffusion.py | 28 +++++-- src/diffusers/pipelines/ddim/pipeline_ddim.py | 36 ++++++--- src/diffusers/pipelines/ddpm/pipeline_ddpm.py | 8 +- .../pipeline_latent_diffusion.py | 6 +- ...peline_latent_diffusion_superresolution.py | 8 +- .../pipeline_latent_diffusion_uncond.py | 8 +- .../pipeline_paint_by_example.py | 35 +++++++-- src/diffusers/pipelines/pndm/pipeline_pndm.py | 4 +- .../pipelines/repaint/pipeline_repaint.py | 33 ++++++--- .../score_sde_ve/pipeline_score_sde_ve.py | 8 +- .../pipeline_cycle_diffusion.py | 35 +++++++-- .../pipeline_flax_stable_diffusion.py | 7 +- .../pipeline_stable_diffusion.py | 26 +++++-- .../pipeline_stable_diffusion_depth2img.py | 40 +++++++--- ...peline_stable_diffusion_image_variation.py | 26 +++++-- .../pipeline_stable_diffusion_img2img.py | 40 +++++++--- .../pipeline_stable_diffusion_inpaint.py | 35 +++++++-- ...ipeline_stable_diffusion_inpaint_legacy.py | 6 +- .../pipeline_stable_diffusion_k_diffusion.py | 6 +- .../pipeline_stable_diffusion_upscale.py | 6 +- .../pipeline_stable_diffusion_safe.py | 26 +++++-- .../pipeline_stochastic_karras_ve.py | 8 +- .../pipeline_versatile_diffusion.py | 18 ++--- ...ipeline_versatile_diffusion_dual_guided.py | 26 +++++-- ...ine_versatile_diffusion_image_variation.py | 26 +++++-- ...eline_versatile_diffusion_text_to_image.py | 26 +++++-- .../vq_diffusion/pipeline_vq_diffusion.py | 6 +- tests/test_pipelines_common.py | 63 ++++++++++++++++ 33 files changed, 568 insertions(+), 180 deletions(-) create mode 100644 docs/source/using-diffusers/reusing_seeds.mdx diff --git a/README.md b/README.md index c320d86b9d..6800dbc22a 100644 --- a/README.md +++ b/README.md @@ -302,11 +302,8 @@ image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0] ### Tweak prompts reusing seeds and latents -You can generate your own latents to reproduce results, or tweak your prompt on a specific result you liked. [This notebook](https://github.com/pcuenca/diffusers-examples/blob/main/notebooks/stable-diffusion-seeds.ipynb) shows how to do it step by step. You can also run it in Google Colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pcuenca/diffusers-examples/blob/main/notebooks/stable-diffusion-seeds.ipynb). - - -For more details, check out [the Stable Diffusion notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion.ipynb) -and have a look into the [release notes](https://github.com/huggingface/diffusers/releases/tag/v0.2.0). +You can generate your own latents to reproduce results, or tweak your prompt on a specific result you liked. +Please have a look at [Reusing seeds for deterministic generation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/reusing_seeds). ## Fine-Tuning Stable Diffusion diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index ec578e1721..b981f5f9fa 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -28,6 +28,8 @@ title: "Text-Guided Image-Inpainting" - local: using-diffusers/depth2img title: "Text-Guided Depth-to-Image" + - local: using-diffusers/reusing_seeds + title: "Reusing seeds for deterministic generation" - local: using-diffusers/custom_pipeline_examples title: "Community Pipelines" - local: using-diffusers/contribute_pipeline diff --git a/docs/source/using-diffusers/reusing_seeds.mdx b/docs/source/using-diffusers/reusing_seeds.mdx new file mode 100644 index 0000000000..54238cdf21 --- /dev/null +++ b/docs/source/using-diffusers/reusing_seeds.mdx @@ -0,0 +1,73 @@ + + +# Re-using seeds for fast prompt engineering + +A common use case when generating images is to generate a batch of images, select one image and improve it with a better, more detailed prompt in a second run. +To do this, one needs to make each generated image of the batch deterministic. +Images are generated by denoising gaussian random noise which can be instantiated by passing a [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html#generator). + +Now, for batched generation, we need to make sure that every single generated image in the batch is tied exactly to one seed. In 🧨 Diffusers, this can be achieved by not passing one `generator`, but a list +of `generators` to the pipeline. + +Let's go through an example using [`runwayml/stable-diffusion-v1-5`](runwayml/stable-diffusion-v1-5). +We want to generate several versions of the prompt: + +```py +prompt = "Labrador in the style of Vermeer" +``` + +Let's load the pipeline + +```python +>>> from diffusers import DiffusionPipeline + +>>> pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) +>>> pipe = pipe.to("cuda") +``` + +Now, let's define 4 different generators, since we would like to reproduce a certain image. We'll use seeds `0` to `3` to create our generators. + +```python +>>> import torch + +>>> generator = [torch.Generator(device="cuda").manual_seed(i) for i in range(4)] +``` + +Let's generate 4 images: + +```python +>>> images = pipe(prompt, generator=generator, num_images_per_prompt=4).images +>>> images +``` + +![img](https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/reusabe_seeds.jpg) + +Ok, the last images has some double eyes, but the first image looks good! +Let's try to make the prompt a bit better **while keeping the first seed** +so that the images are similar to the first image. + +```python +prompt = [prompt + t for t in [", highly realistic", ", artsy", ", trending", ", colorful"]] +generator = [torch.Generator(device="cuda").manual_seed(0) for i in range(4)] +``` + +We create 4 generators with seed `0`, which is the first seed we used before. + +Let's run the pipeline again. + +```python +>>> images = pipe(prompt, generator=generator).images +>>> images +``` + +![img](https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/reusabe_seeds_2.jpg) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index f149e2623f..932c4afc5a 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -379,12 +379,24 @@ class AltDiffusionPipeline(DiffusionPipeline): def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: - if device.type == "mps": - # randn does not work reproducibly on mps - latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + rand_device = "cpu" if device.type == "mps" else device + + if isinstance(generator, list): + shape = (1,) + shape[1:] + latents = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) + for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0).to(device) else: - latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") @@ -405,7 +417,7 @@ class AltDiffusionPipeline(DiffusionPipeline): negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, - generator: Optional[torch.Generator] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, @@ -440,8 +452,8 @@ class AltDiffusionPipeline(DiffusionPipeline): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index 6d7e57c7ca..61ee352d70 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -396,8 +396,22 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): image = image.to(device=device, dtype=dtype) - init_latent_dist = self.vae.encode(image).latent_dist - init_latents = init_latent_dist.sample(generator=generator) + + batch_size = batch_size * num_images_per_prompt + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if isinstance(generator, list): + init_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = self.vae.encode(image).latent_dist.sample(generator) + init_latents = 0.18215 * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: @@ -410,16 +424,24 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ) deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) additional_image_per_prompt = batch_size // init_latents.shape[0] - init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0) + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." ) else: - init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0) + init_latents = torch.cat([init_latents], dim=0) - # add noise to latents using the timesteps - noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype) + rand_device = "cpu" if device.type == "mps" else device + shape = init_latents.shape + if isinstance(generator, list): + shape = (1,) + shape[1:] + noise = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(batch_size) + ] + noise = torch.cat(noise, dim=0).to(device) + else: + noise = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) # get latents init_latents = self.scheduler.add_noise(init_latents, noise, timestep) @@ -438,7 +460,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: Optional[float] = 0.0, - generator: Optional[torch.Generator] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, @@ -478,8 +500,8 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. diff --git a/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py b/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py index 48d16889a0..903b50cf33 100644 --- a/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +++ b/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py @@ -13,7 +13,7 @@ # limitations under the License. -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch @@ -45,7 +45,7 @@ class DanceDiffusionPipeline(DiffusionPipeline): self, batch_size: int = 1, num_inference_steps: int = 100, - generator: Optional[torch.Generator] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, audio_length_in_s: Optional[float] = None, return_dict: bool = True, ) -> Union[AudioPipelineOutput, Tuple]: @@ -57,8 +57,8 @@ class DanceDiffusionPipeline(DiffusionPipeline): The number of denoising steps. More denoising steps usually lead to a higher quality audio sample at the expense of slower inference. generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. audio_length_in_s (`float`, *optional*, defaults to `self.unet.config.sample_size/self.unet.config.sample_rate`): The length of the generated audio sample in seconds. Note that the output of the pipeline, *i.e.* `sample_size`, will be `audio_length_in_s` * `self.unet.sample_rate`. @@ -94,9 +94,23 @@ class DanceDiffusionPipeline(DiffusionPipeline): sample_size = int(sample_size) dtype = next(iter(self.unet.parameters())).dtype - audio = torch.randn( - (batch_size, self.unet.in_channels, sample_size), generator=generator, device=self.device, dtype=dtype - ) + shape = (batch_size, self.unet.in_channels, sample_size) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + rand_device = "cpu" if self.device.type == "mps" else self.device + if isinstance(generator, list): + shape = (1,) + shape[1:] + audio = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=self.unet.dtype) + for i in range(batch_size) + ] + audio = torch.cat(audio, dim=0).to(self.device) + else: + audio = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(self.device) # set step values self.scheduler.set_timesteps(num_inference_steps, device=audio.device) diff --git a/src/diffusers/pipelines/ddim/pipeline_ddim.py b/src/diffusers/pipelines/ddim/pipeline_ddim.py index f39f7f8d1f..5de093db69 100644 --- a/src/diffusers/pipelines/ddim/pipeline_ddim.py +++ b/src/diffusers/pipelines/ddim/pipeline_ddim.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch @@ -40,7 +40,7 @@ class DDIMPipeline(DiffusionPipeline): def __call__( self, batch_size: int = 1, - generator: Optional[torch.Generator] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, eta: float = 0.0, num_inference_steps: int = 50, use_clipped_model_output: Optional[bool] = None, @@ -52,8 +52,8 @@ class DDIMPipeline(DiffusionPipeline): batch_size (`int`, *optional*, defaults to 1): The number of images to generate. generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. eta (`float`, *optional*, defaults to 0.0): The eta parameter which controls the scale of the variance (0 is DDIM and 1 is one type of DDPM). num_inference_steps (`int`, *optional*, defaults to 50): @@ -74,7 +74,12 @@ class DDIMPipeline(DiffusionPipeline): generated images. """ - if generator is not None and generator.device.type != self.device.type and self.device.type != "mps": + if ( + generator is not None + and isinstance(generator, torch.Generator) + and generator.device.type != self.device.type + and self.device.type != "mps" + ): message = ( f"The `generator` device is `{generator.device}` and does not match the pipeline " f"device `{self.device}`, so the `generator` will be ignored. " @@ -93,12 +98,23 @@ class DDIMPipeline(DiffusionPipeline): else: image_shape = (batch_size, self.unet.in_channels, *self.unet.sample_size) - if self.device.type == "mps": - # randn does not work reproducibly on mps - image = torch.randn(image_shape, generator=generator, dtype=self.unet.dtype) - image = image.to(self.device) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + rand_device = "cpu" if self.device.type == "mps" else self.device + if isinstance(generator, list): + shape = (1,) + image_shape[1:] + image = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=self.unet.dtype) + for i in range(batch_size) + ] + image = torch.cat(image, dim=0).to(self.device) else: - image = torch.randn(image_shape, generator=generator, device=self.device, dtype=self.unet.dtype) + image = torch.randn(image_shape, generator=generator, device=rand_device, dtype=self.unet.dtype) + image = image.to(self.device) # set step values self.scheduler.set_timesteps(num_inference_steps) diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py index 191dbdfe43..5af29b0b90 100644 --- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py +++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -13,7 +13,7 @@ # limitations under the License. -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch @@ -42,7 +42,7 @@ class DDPMPipeline(DiffusionPipeline): def __call__( self, batch_size: int = 1, - generator: Optional[torch.Generator] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, num_inference_steps: int = 1000, output_type: Optional[str] = "pil", return_dict: bool = True, @@ -53,8 +53,8 @@ class DDPMPipeline(DiffusionPipeline): batch_size (`int`, *optional*, defaults to 1): The number of images to generate. generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. num_inference_steps (`int`, *optional*, defaults to 1000): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py index 5fa05d8a8b..83f8d355a7 100644 --- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py @@ -71,7 +71,7 @@ class LDMTextToImagePipeline(DiffusionPipeline): num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 1.0, eta: Optional[float] = 0.0, - generator: Optional[torch.Generator] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, @@ -95,8 +95,8 @@ class LDMTextToImagePipeline(DiffusionPipeline): 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt` at the, usually at the expense of lower image quality. generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py index e417424625..6ccb025c41 100644 --- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py @@ -1,5 +1,5 @@ import inspect -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -70,7 +70,7 @@ class LDMSuperResolutionPipeline(DiffusionPipeline): batch_size: Optional[int] = 1, num_inference_steps: Optional[int] = 100, eta: Optional[float] = 0.0, - generator: Optional[torch.Generator] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, **kwargs, @@ -89,8 +89,8 @@ class LDMSuperResolutionPipeline(DiffusionPipeline): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. diff --git a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py index 5345c4e562..19380d3624 100644 --- a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +++ b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py @@ -13,7 +13,7 @@ # limitations under the License. import inspect -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch @@ -43,7 +43,7 @@ class LDMPipeline(DiffusionPipeline): def __call__( self, batch_size: int = 1, - generator: Optional[torch.Generator] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, eta: float = 0.0, num_inference_steps: int = 50, output_type: Optional[str] = "pil", @@ -55,8 +55,8 @@ class LDMPipeline(DiffusionPipeline): batch_size (`int`, *optional*, defaults to 1): Number of images to generate. generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. diff --git a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py index e37f771455..332165e1fa 100644 --- a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +++ b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py @@ -291,12 +291,24 @@ class PaintByExamplePipeline(DiffusionPipeline): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: - if device.type == "mps": - # randn does not work reproducibly on mps - latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + rand_device = "cpu" if device.type == "mps" else device + + if isinstance(generator, list): + shape = (1,) + shape[1:] + latents = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) + for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0).to(device) else: - latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") @@ -321,7 +333,14 @@ class PaintByExamplePipeline(DiffusionPipeline): masked_image = masked_image.to(device=device, dtype=dtype) # encode the mask image into latents space so we can concatenate it to the latents - masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator) + if isinstance(generator, list): + masked_image_latents = [ + self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i]) + for i in range(batch_size) + ] + masked_image_latents = torch.cat(masked_image_latents, dim=0) + else: + masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator) masked_image_latents = 0.18215 * masked_image_latents # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method @@ -390,7 +409,7 @@ class PaintByExamplePipeline(DiffusionPipeline): negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, - generator: Optional[torch.Generator] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, @@ -433,8 +452,8 @@ class PaintByExamplePipeline(DiffusionPipeline): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents diff --git a/src/diffusers/pipelines/pndm/pipeline_pndm.py b/src/diffusers/pipelines/pndm/pipeline_pndm.py index ef7062dea1..020dd7c4ee 100644 --- a/src/diffusers/pipelines/pndm/pipeline_pndm.py +++ b/src/diffusers/pipelines/pndm/pipeline_pndm.py @@ -13,7 +13,7 @@ # limitations under the License. -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch @@ -45,7 +45,7 @@ class PNDMPipeline(DiffusionPipeline): self, batch_size: int = 1, num_inference_steps: int = 50, - generator: Optional[torch.Generator] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, **kwargs, diff --git a/src/diffusers/pipelines/repaint/pipeline_repaint.py b/src/diffusers/pipelines/repaint/pipeline_repaint.py index d100001f63..a316e5bc81 100644 --- a/src/diffusers/pipelines/repaint/pipeline_repaint.py +++ b/src/diffusers/pipelines/repaint/pipeline_repaint.py @@ -88,7 +88,7 @@ class RePaintPipeline(DiffusionPipeline): eta: float = 0.0, jump_length: int = 10, jump_n_sample: int = 10, - generator: Optional[torch.Generator] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, **kwargs, @@ -112,8 +112,8 @@ class RePaintPipeline(DiffusionPipeline): The number of times we will make forward time jump for a given chosen time sample. Take a look at Figure 9 and 10 in https://arxiv.org/pdf/2201.09865.pdf. generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -135,19 +135,34 @@ class RePaintPipeline(DiffusionPipeline): mask_image = _preprocess_mask(mask_image) mask_image = mask_image.to(device=self.device, dtype=self.unet.dtype) + batch_size = original_image.shape[0] + # sample gaussian noise to begin the loop - image = torch.randn( - original_image.shape, - generator=generator, - device=self.device, - ) - image = image.to(device=self.device, dtype=self.unet.dtype) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + rand_device = "cpu" if self.device.type == "mps" else self.device + image_shape = original_image.shape + if isinstance(generator, list): + shape = (1,) + image_shape[1:] + image = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=self.unet.dtype) + for i in range(batch_size) + ] + image = torch.cat(image, dim=0).to(self.device) + else: + image = torch.randn(image_shape, generator=generator, device=rand_device, dtype=self.unet.dtype) + image = image.to(self.device) # set step values self.scheduler.set_timesteps(num_inference_steps, jump_length, jump_n_sample, self.device) self.scheduler.eta = eta t_last = self.scheduler.timesteps[0] + 1 + generator = generator[0] if isinstance(generator, list) else generator for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): if t < t_last: # predict the noise residual diff --git a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py index 7eb6a5d3cb..8dc7001a4a 100644 --- a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py +++ b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch @@ -41,7 +41,7 @@ class ScoreSdeVePipeline(DiffusionPipeline): self, batch_size: int = 1, num_inference_steps: int = 2000, - generator: Optional[torch.Generator] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, **kwargs, @@ -51,8 +51,8 @@ class ScoreSdeVePipeline(DiffusionPipeline): batch_size (`int`, *optional*, defaults to 1): The number of images to generate. generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py index c51ebf2611..db1155c665 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py @@ -435,8 +435,22 @@ class CycleDiffusionPipeline(DiffusionPipeline): def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): image = image.to(device=device, dtype=dtype) - init_latent_dist = self.vae.encode(image).latent_dist - init_latents = init_latent_dist.sample(generator=generator) + + batch_size = image.shape[0] + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if isinstance(generator, list): + init_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = self.vae.encode(image).latent_dist.sample(generator) + init_latents = 0.18215 * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: @@ -458,7 +472,16 @@ class CycleDiffusionPipeline(DiffusionPipeline): init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0) # add noise to latents using the timestep - noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype) + rand_device = "cpu" if device.type == "mps" else device + shape = init_latents.shape + if isinstance(generator, list): + shape = (1,) + shape[1:] + noise = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(batch_size) + ] + noise = torch.cat(noise, dim=0).to(device) + else: + noise = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) # get latents clean_latents = init_latents @@ -479,7 +502,7 @@ class CycleDiffusionPipeline(DiffusionPipeline): source_guidance_scale: Optional[float] = 1, num_images_per_prompt: Optional[int] = 1, eta: Optional[float] = 0.1, - generator: Optional[torch.Generator] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, @@ -519,8 +542,8 @@ class CycleDiffusionPipeline(DiffusionPipeline): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index 912a4381d0..e26d71ffda 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -309,13 +309,10 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. - generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. latents (`jnp.array`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. + generation. Can be used to tweak the same generation with different prompts. tensor will ge generated + by sampling using the supplied random `generator`. jit (`bool`, defaults to `False`): Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 4681ffdc82..37a041bd33 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -378,12 +378,24 @@ class StableDiffusionPipeline(DiffusionPipeline): def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: - if device.type == "mps": - # randn does not work reproducibly on mps - latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + rand_device = "cpu" if device.type == "mps" else device + + if isinstance(generator, list): + shape = (1,) + shape[1:] + latents = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) + for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0).to(device) else: - latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") @@ -404,7 +416,7 @@ class StableDiffusionPipeline(DiffusionPipeline): negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, - generator: Optional[torch.Generator] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, @@ -439,8 +451,8 @@ class StableDiffusionPipeline(DiffusionPipeline): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py index 6bff3e9494..dfd3d23393 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py @@ -345,8 +345,22 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.prepare_latents def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): image = image.to(device=device, dtype=dtype) - init_latent_dist = self.vae.encode(image).latent_dist - init_latents = init_latent_dist.sample(generator=generator) + + batch_size = batch_size * num_images_per_prompt + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if isinstance(generator, list): + init_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = self.vae.encode(image).latent_dist.sample(generator) + init_latents = 0.18215 * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: @@ -359,16 +373,24 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline): ) deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) additional_image_per_prompt = batch_size // init_latents.shape[0] - init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0) + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." ) else: - init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0) + init_latents = torch.cat([init_latents], dim=0) - # add noise to latents using the timesteps - noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype) + rand_device = "cpu" if device.type == "mps" else device + shape = init_latents.shape + if isinstance(generator, list): + shape = (1,) + shape[1:] + noise = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(batch_size) + ] + noise = torch.cat(noise, dim=0).to(device) + else: + noise = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) # get latents init_latents = self.scheduler.add_noise(init_latents, noise, timestep) @@ -429,7 +451,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline): negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: Optional[float] = 0.0, - generator: Optional[torch.Generator] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, @@ -468,8 +490,8 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py index 1b6c8475d8..3bc0a8cfa9 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py @@ -258,12 +258,24 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: - if device.type == "mps": - # randn does not work reproducibly on mps - latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + rand_device = "cpu" if device.type == "mps" else device + + if isinstance(generator, list): + shape = (1,) + shape[1:] + latents = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) + for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0).to(device) else: - latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") @@ -283,7 +295,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline): guidance_scale: float = 7.5, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, - generator: Optional[torch.Generator] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, @@ -318,8 +330,8 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index ffb8f7b602..d337b59c93 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -401,8 +401,22 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): image = image.to(device=device, dtype=dtype) - init_latent_dist = self.vae.encode(image).latent_dist - init_latents = init_latent_dist.sample(generator=generator) + + batch_size = batch_size * num_images_per_prompt + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if isinstance(generator, list): + init_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = self.vae.encode(image).latent_dist.sample(generator) + init_latents = 0.18215 * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: @@ -415,16 +429,24 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ) deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) additional_image_per_prompt = batch_size // init_latents.shape[0] - init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0) + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." ) else: - init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0) + init_latents = torch.cat([init_latents], dim=0) - # add noise to latents using the timesteps - noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype) + rand_device = "cpu" if device.type == "mps" else device + shape = init_latents.shape + if isinstance(generator, list): + shape = (1,) + shape[1:] + noise = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(batch_size) + ] + noise = torch.cat(noise, dim=0).to(device) + else: + noise = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) # get latents init_latents = self.scheduler.add_noise(init_latents, noise, timestep) @@ -443,7 +465,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: Optional[float] = 0.0, - generator: Optional[torch.Generator] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, @@ -483,8 +505,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index a909c10167..deec66a976 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -463,12 +463,24 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: - if device.type == "mps": - # randn does not work reproducibly on mps - latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + rand_device = "cpu" if device.type == "mps" else device + + if isinstance(generator, list): + shape = (1,) + shape[1:] + latents = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) + for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0).to(device) else: - latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") @@ -492,7 +504,14 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): masked_image = masked_image.to(device=device, dtype=dtype) # encode the mask image into latents space so we can concatenate it to the latents - masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator) + if isinstance(generator, list): + masked_image_latents = [ + self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i]) + for i in range(batch_size) + ] + masked_image_latents = torch.cat(masked_image_latents, dim=0) + else: + masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator) masked_image_latents = 0.18215 * masked_image_latents # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method @@ -535,7 +554,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, - generator: Optional[torch.Generator] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, @@ -578,8 +597,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py index 53ba2580a8..797c32c198 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py @@ -432,7 +432,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): num_images_per_prompt: Optional[int] = 1, add_predicted_noise: Optional[bool] = False, eta: Optional[float] = 0.0, - generator: Optional[torch.Generator] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, @@ -479,8 +479,8 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py index 273688aeb7..1bb0cb051a 100755 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py @@ -332,7 +332,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline): negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, - generator: Optional[torch.Generator] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, @@ -367,8 +367,8 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index ce68af1375..cfb05a16b0 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -338,7 +338,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline): negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, - generator: Optional[torch.Generator] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, @@ -371,8 +371,8 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents diff --git a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py index 921befcdf3..d5b163ba46 100644 --- a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +++ b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py @@ -422,12 +422,24 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: - if device.type == "mps": - # randn does not work reproducibly on mps - latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + rand_device = "cpu" if device.type == "mps" else device + + if isinstance(generator, list): + shape = (1,) + shape[1:] + latents = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) + for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0).to(device) else: - latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") @@ -490,7 +502,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, - generator: Optional[torch.Generator] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, @@ -530,8 +542,8 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents diff --git a/src/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py b/src/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py index 739de8ebe6..8da1faf9c6 100644 --- a/src/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +++ b/src/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch @@ -49,7 +49,7 @@ class KarrasVePipeline(DiffusionPipeline): self, batch_size: int = 1, num_inference_steps: int = 50, - generator: Optional[torch.Generator] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, **kwargs, @@ -59,8 +59,8 @@ class KarrasVePipeline(DiffusionPipeline): batch_size (`int`, *optional*, defaults to 1): The number of images to generate. generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py index e9ef505040..d0202922da 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py @@ -91,7 +91,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline): negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, - generator: Optional[torch.Generator] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, @@ -126,8 +126,8 @@ class VersatileDiffusionPipeline(DiffusionPipeline): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents @@ -207,7 +207,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline): negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, - generator: Optional[torch.Generator] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, @@ -242,8 +242,8 @@ class VersatileDiffusionPipeline(DiffusionPipeline): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents @@ -320,7 +320,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline): guidance_scale: float = 7.5, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, - generator: Optional[torch.Generator] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, @@ -355,8 +355,8 @@ class VersatileDiffusionPipeline(DiffusionPipeline): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py index 89b333c9c4..b7c22494df 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py @@ -376,12 +376,24 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: - if device.type == "mps": - # randn does not work reproducibly on mps - latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + rand_device = "cpu" if device.type == "mps" else device + + if isinstance(generator, list): + shape = (1,) + shape[1:] + latents = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) + for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0).to(device) else: - latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") @@ -416,7 +428,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): guidance_scale: float = 7.5, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, - generator: Optional[torch.Generator] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, @@ -452,8 +464,8 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py index 7419d2f37d..122fd7a9f7 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py @@ -241,12 +241,24 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: - if device.type == "mps": - # randn does not work reproducibly on mps - latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + rand_device = "cpu" if device.type == "mps" else device + + if isinstance(generator, list): + shape = (1,) + shape[1:] + latents = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) + for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0).to(device) else: - latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") @@ -267,7 +279,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, - generator: Optional[torch.Generator] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, @@ -303,8 +315,8 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py index 3920955f67..e12eabded7 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py @@ -292,12 +292,24 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: - if device.type == "mps": - # randn does not work reproducibly on mps - latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + rand_device = "cpu" if device.type == "mps" else device + + if isinstance(generator, list): + shape = (1,) + shape[1:] + latents = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) + for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0).to(device) else: - latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") @@ -318,7 +330,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, - generator: Optional[torch.Generator] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, @@ -354,8 +366,8 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents diff --git a/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py b/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py index 333599d7ec..a536f91d9e 100644 --- a/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py +++ b/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py @@ -173,7 +173,7 @@ class VQDiffusionPipeline(DiffusionPipeline): guidance_scale: float = 5.0, truncation_rate: float = 1.0, num_images_per_prompt: int = 1, - generator: Optional[torch.Generator] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, @@ -202,8 +202,8 @@ class VQDiffusionPipeline(DiffusionPipeline): num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. latents (`torch.FloatTensor` of shape (batch), *optional*): Pre-generated noisy latents to be used as inputs for image generation. Must be valid embedding indices. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will diff --git a/tests/test_pipelines_common.py b/tests/test_pipelines_common.py index 6352b4b605..5d5c69c5c7 100644 --- a/tests/test_pipelines_common.py +++ b/tests/test_pipelines_common.py @@ -44,6 +44,11 @@ class PipelineTesterMixin: test_cpu_offload = True test_xformers_attention = True + def get_generator(self, seed): + device = torch_device if torch_device != "mps" else "cpu" + generator = torch.Generator(device).manual_seed(seed) + return generator + @property def pipeline_class(self) -> Union[Callable, DiffusionPipeline]: raise NotImplementedError( @@ -176,6 +181,64 @@ class PipelineTesterMixin: logger.setLevel(level=diffusers.logging.WARNING) + def test_inference_batch_single_identical(self): + if self.pipeline_class.__name__ in ["CycleDiffusionPipeline", "RePaintPipeline"]: + # RePaint can hardly be made deterministic since the scheduler is currently always + # indeterministic + # CycleDiffusion is also slighly undeterministic + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + + logger = logging.get_logger(pipe.__module__) + logger.setLevel(level=diffusers.logging.FATAL) + + # batchify inputs + batched_inputs = {} + batch_size = 3 + for name, value in inputs.items(): + if name in ALLOWED_REQUIRED_ARGS: + # prompt is string + if name == "prompt": + len_prompt = len(value) + # make unequal batch sizes + batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)] + + # make last batch super long + batched_inputs[name][-1] = 2000 * "very long" + # or else we have images + else: + batched_inputs[name] = batch_size * [value] + elif name == "batch_size": + batched_inputs[name] = batch_size + elif name == "generator": + batched_inputs[name] = [self.get_generator(i) for i in range(batch_size)] + else: + batched_inputs[name] = value + + batched_inputs["num_inference_steps"] = inputs["num_inference_steps"] + + if self.pipeline_class.__name__ != "DanceDiffusionPipeline": + batched_inputs["output_type"] = "np" + + output_batch = pipe(**batched_inputs) + assert output_batch[0].shape[0] == batch_size + + inputs["generator"] = self.get_generator(0) + + output = pipe(**inputs) + + if torch_device != "mps": + # TODO(Pedro) - not sure why, but not at all reproducible at the moment it seems + # make sure that batched and non-batched is identical + assert np.abs(output_batch[0][0] - output[0][0]).max() < 1e-4 + logger.setLevel(level=diffusers.logging.WARNING) + def test_dict_tuple_outputs_equivalent(self): if torch_device == "mps" and self.pipeline_class in ( DanceDiffusionPipeline,