1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Outputs] Improve syntax (#423)

* [Outputs] Improve syntax

* improve more

* fix docstring return

* correct all

* uP

Co-authored-by: Mishig Davaadorj <dmishig@gmail.com>
This commit is contained in:
Patrick von Platen
2022-09-08 16:46:38 +02:00
committed by GitHub
parent 1a79969d23
commit f6fb3282b1
26 changed files with 162 additions and 76 deletions

View File

@@ -10,7 +10,7 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
# Diffusion Pipeline
# Pipelines
The [`DiffusionPipeline`] is the easiest way to load any pretrained diffusion pipeline from the [Hub](https://huggingface.co/models?library=diffusers) and to use it in inference.
@@ -28,6 +28,12 @@ pipeline and pass them into the `__init__` function of the pipeline, *e.g.* [`~S
Any pipeline object can be saved locally with [`~DiffusionPipeline.save_pretrained`].
## DiffusionPipeline
[[autodoc]] DiffusionPipeline
- from_pretrained
- save_pretrained
## ImagePipelineOutput
By default diffusion pipelines return an object of class
[[autodoc]] pipeline_utils.ImagePipelineOutput

View File

@@ -17,7 +17,6 @@ The original codebase of this paper can be found [here](https://github.com/ermon
| [pipeline_ddim.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/ddim/pipeline_ddim.py) | *Unconditional Image Generation* | - |
## API
[[autodoc]] pipelines.ddim.pipeline_ddim.DDIMPipeline
## DDIMPipeline
[[autodoc]] DDIMPipeline
- __call__

View File

@@ -19,7 +19,6 @@ The original codebase of this paper can be found [here](https://github.com/hojon
| [pipeline_ddpm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/ddpm/pipeline_ddpm.py) | *Unconditional Image Generation* | - |
## API
[[autodoc]] pipelines.ddpm.pipeline_ddpm.DDPMPipeline
# DDPMPipeline
[[autodoc]] DDPMPipeline
- __call__

View File

@@ -25,7 +25,6 @@ The original codebase can be found [here](https://github.com/CompVis/latent-diff
## Examples:
## API
## LDMTextToImagePipeline
[[autodoc]] pipelines.latent_diffusion.pipeline_latent_diffusion.LDMTextToImagePipeline
- __call__

View File

@@ -24,7 +24,6 @@ The original codebase can be found [here](https://github.com/CompVis/latent-diff
## Examples:
## API
[[autodoc]] pipelines.latent_diffusion_uncond.pipeline_latent_diffusion_uncond.LDMPipeline
## LDMPipeline
[[autodoc]] LDMPipeline
- __call__

View File

@@ -17,8 +17,7 @@ The original codebase can be found [here](https://github.com/luping-liu/PNDM).
| [pipeline_pndm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pndm/pipeline_pndm.py) | *Unconditional Image Generation* | - |
## API
## PNDMPipeline
[[autodoc]] pipelines.pndm.pipeline_pndm.PNDMPipeline
- __call__

View File

@@ -18,8 +18,7 @@ This pipeline implements the Variance Expanding (VE) variant of the method.
|---|---|:---:|
| [pipeline_score_sde_ve.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py) | *Unconditional Image Generation* | - |
## API
[[autodoc]] pipelines.score_sde_ve.pipeline_score_sde_ve.ScoreSdeVePipeline
## ScoreSdeVePipeline
[[autodoc]] ScoreSdeVePipeline
- __call__

View File

@@ -8,29 +8,31 @@ For more details about how Stable Diffusion works and how it differs from the ba
*Tips*:
- To tweak your prompts on a specific result you liked, you can generate your own latents, as demonstrated in the following notebook: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pcuenca/diffusers-examples/blob/main/notebooks/stable-diffusion-seeds.ipynb)
- TODO: some interesting Tips
*Overview*:
| Pipeline | Tasks | Colab | Demo
|---|---|:---:|:---:|
| [pipeline_stable_diffusion.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py) | *Text-to-Image Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion.ipynb) | [🤗 Stable Diffusion](https://huggingface.co/spaces/stabilityai/stable-diffusion)
| [pipeline_stable_diffusion_img2img.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py) | *Image-to-Image Text-Guided Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/image_2_image_using_diffusers.ipynb) | [🤗 Diffuse the Rest](https://huggingface.co/spaces/huggingface/diffuse-the-rest)
| [pipeline_stable_diffusion_inpaint.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py) | **Experimental** *Text-Guided Image Inpainting* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/in_painting_with_stable_diffusion_using_diffusers.ipynb) | Coming soon
## StableDiffusionPipelineOutput
[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
## StableDiffusionPipeline
[[autodoc]] StableDiffusionPipeline
- __init__
- __call__
- enable_attention_slicing
- disable_attention_slicing
- __call__
- enable_attention_slicing
- disable_attention_slicing
## StableDiffusionImg2ImgPipeline
[[autodoc]] StableDiffusionImg2ImgPipeline
- __init__
- __call__
- enable_attention_slicing
- disable_attention_slicing
- __call__
- enable_attention_slicing
- disable_attention_slicing
## StableDiffusionInpaintPipeline
[[autodoc]] StableDiffusionInpaintPipeline
- __init__
- __call__
- enable_attention_slicing
- disable_attention_slicing
- __call__
- enable_attention_slicing
- disable_attention_slicing

View File

@@ -18,7 +18,6 @@ This pipeline implements the Stochastic sampling tailored to the Variance-Expand
| [pipeline_stochastic_karras_ve.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py) | *Unconditional Image Generation* | - |
## API
[[autodoc]] pipelines.stochastic_karras_ve.pipeline_stochastic_karras_ve.KarrasVePipeline
## KarrasVePipeline
[[autodoc]] KarrasVePipeline
- __call__

View File

@@ -52,21 +52,26 @@ class DDIMPipeline(DiffusionPipeline):
) -> Union[ImagePipelineOutput, Tuple]:
r"""
Args:
batch_size (:obj:`int`, *optional*, defaults to 1):
batch_size (`int`, *optional*, defaults to 1):
The number of images to generate.
generator (:obj:`torch.Generator`, *optional*):
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
eta (:obj:`float`, *optional*, defaults to 0.0):
eta (`float`, *optional*, defaults to 0.0):
The eta parameter which controls the scale of the variance (0 is DDIM and 1 is one type of DDPM).
num_inference_steps (:obj:`int`, *optional*, defaults to 50):
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
output_type (:obj:`str`, *optional*, defaults to :obj:`"pil"`):
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
return_dict (:obj:`bool`, *optional*, defaults to :obj:`True`):
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
Returns:
[`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
generated images.
"""
if "torch_device" in kwargs:

View File

@@ -50,16 +50,21 @@ class DDPMPipeline(DiffusionPipeline):
) -> Union[ImagePipelineOutput, Tuple]:
r"""
Args:
batch_size (:obj:`int`, *optional*, defaults to 1):
batch_size (`int`, *optional*, defaults to 1):
The number of images to generate.
generator (:obj:`torch.Generator`, *optional*):
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
output_type (:obj:`str`, *optional*, defaults to :obj:`"pil"`):
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
return_dict (:obj:`bool`, *optional*, defaults to :obj:`True`):
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
Returns:
[`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
generated images.
"""
if "torch_device" in kwargs:
device = kwargs.pop("torch_device")

View File

@@ -88,6 +88,11 @@ class LDMTextToImagePipeline(DiffusionPipeline):
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
return_dict (`bool`, *optional*):
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
Returns:
[`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
generated images.
"""
if "torch_device" in kwargs:
device = kwargs.pop("torch_device")

View File

@@ -54,6 +54,11 @@ class LDMPipeline(DiffusionPipeline):
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
Returns:
[`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
generated images.
"""
if "torch_device" in kwargs:

View File

@@ -30,7 +30,7 @@ class PNDMPipeline(DiffusionPipeline):
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Parameters:
unet (:obj:`UNet2DModel`): U-Net architecture to denoise the encoded image latents.
unet (`UNet2DModel`): U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
The `PNDMScheduler` to be used in combination with `unet` to denoise the encoded image.
"""
@@ -55,20 +55,22 @@ class PNDMPipeline(DiffusionPipeline):
) -> Union[ImagePipelineOutput, Tuple]:
r"""
Args:
batch_size (:obj:`int`, `optional`, defaults to 1): The number of images to generate.
num_inference_steps (:
obj:`int`, `optional`, defaults to 50): The number of denoising steps. More denoising steps usually
lead to a higher quality image at the expense of slower inference.
generator (:
obj:`torch.Generator`, `optional`): A [torch
batch_size (`int`, `optional`, defaults to 1): The number of images to generate.
num_inference_steps (`int`, `optional`, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
generator (`torch.Generator`, `optional`): A [torch
generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
output_type (:
obj:`str`, `optional`, defaults to :obj:`"pil"`): The output format of the generate image. Choose
output_type (`str`, `optional`, defaults to `"pil"`): The output format of the generate image. Choose
between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
return_dict (:
obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not to return a
return_dict (`bool`, `optional`, defaults to `True`): Whether or not to return a
[`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
Returns:
[`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
generated images.
"""
# For more information on the sampling method you can take a look at Algorithm 2 of
# the official paper: https://arxiv.org/pdf/2202.09778.pdf

View File

@@ -36,16 +36,21 @@ class ScoreSdeVePipeline(DiffusionPipeline):
) -> Union[ImagePipelineOutput, Tuple]:
r"""
Args:
batch_size (:obj:`int`, *optional*, defaults to 1):
batch_size (`int`, *optional*, defaults to 1):
The number of images to generate.
generator (:obj:`torch.Generator`, *optional*):
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
output_type (:obj:`str`, *optional*, defaults to :obj:`"pil"`):
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
return_dict (:obj:`bool`, *optional*, defaults to :obj:`True`):
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
Returns:
[`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
generated images.
"""
if "torch_device" in kwargs:

View File

@@ -142,7 +142,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
plain tuple.
Returns:
`~pipelines.stable_diffusion.StableDiffusionPipelineOutput` if `return_dict` is True, otherwise a tuple.
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.

View File

@@ -152,7 +152,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
plain tuple.
Returns:
`~pipelines.stable_diffusion.StableDiffusionPipelineOutput` if `return_dict` is True, otherwise a tuple.
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.

View File

@@ -170,7 +170,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
plain tuple.
Returns:
`~pipelines.stable_diffusion.StableDiffusionPipelineOutput` if `return_dict` is True, otherwise a tuple.
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.

View File

@@ -45,19 +45,24 @@ class KarrasVePipeline(DiffusionPipeline):
) -> Union[Tuple, ImagePipelineOutput]:
r"""
Args:
batch_size (:obj:`int`, *optional*, defaults to 1):
batch_size (`int`, *optional*, defaults to 1):
The number of images to generate.
generator (:obj:`torch.Generator`, *optional*):
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
num_inference_steps (:obj:`int`, *optional*, defaults to 50):
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
output_type (:obj:`str`, *optional*, defaults to :obj:`"pil"`):
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
return_dict (:obj:`bool`, *optional*, defaults to :obj:`True`):
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
Returns:
[`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
generated images.
"""
if "torch_device" in kwargs:
device = kwargs.pop("torch_device")

View File

@@ -59,6 +59,11 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
diffusion probabilistic models (DDPMs) with non-Markovian guidance.
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functios.
For more details, see the original paper: https://arxiv.org/abs/2010.02502
Args:
@@ -171,7 +176,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
`SchedulerOutput`: updated sample in the diffusion chain.
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
if self.num_inference_steps is None:

View File

@@ -58,6 +58,11 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and
Langevin dynamics sampling.
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functios.
For more details, see the original paper: https://arxiv.org/abs/2006.11239
Args:
@@ -189,7 +194,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
`SchedulerOutput`: updated sample in the diffusion chain.
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
t = timestep

View File

@@ -50,6 +50,11 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic
differential equations." https://arxiv.org/abs/2011.13456
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functios.
For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the
optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper.
@@ -147,8 +152,11 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
KarrasVeOutput: updated sample in the diffusion chain and derivative (TODO double check).
Returns:
[`~schedulers.scheduling_karras_ve.KarrasVeOutput`] or `tuple`:
[`~schedulers.scheduling_karras_ve.KarrasVeOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""

View File

@@ -29,6 +29,11 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
Katherine Crowson:
https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functios.
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
beta_start (`float`): the starting `beta` value of inference.
@@ -143,7 +148,9 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
prev_sample (`SchedulerOutput` or `Tuple`): updated sample in the diffusion chain.
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
sigma = self.sigmas[timestep]

View File

@@ -58,6 +58,11 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques,
namely Runge-Kutta method and a linear multi-step method.
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functios.
For more details, see the original paper: https://arxiv.org/abs/2202.09778
Args:
@@ -186,7 +191,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
`SchedulerOutput`: updated sample in the diffusion chain.
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
if self.counter < len(self.prk_timesteps) and not self.config.skip_prk_steps:
@@ -213,7 +220,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
prev_sample (`SchedulerOutput` or `Tuple`): updated sample in the diffusion chain.
[`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is
True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
"""
if self.num_inference_steps is None:
@@ -267,7 +275,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
prev_sample (`SchedulerOutput` or `Tuple`): updated sample in the diffusion chain.
[`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is
True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
"""
if self.num_inference_steps is None:

View File

@@ -49,6 +49,11 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
For more information, see the original paper: https://arxiv.org/abs/2011.13456
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functios.
Args:
snr (`float`):
coefficient weighting the step from the model_output sample (from the network) to the random noise.
@@ -182,7 +187,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
prev_sample (`SchedulerOutput` or `Tuple`): updated sample in the diffusion chain.
[`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`: [`~schedulers.scheduling_sde_ve.SdeVeOutput`] if
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
"""
if "seed" in kwargs and kwargs["seed"] is not None:
@@ -241,7 +247,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
prev_sample (`SchedulerOutput` or `Tuple`): updated sample in the diffusion chain.
[`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`: [`~schedulers.scheduling_sde_ve.SdeVeOutput`] if
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
"""
if "seed" in kwargs and kwargs["seed"] is not None:

View File

@@ -27,6 +27,11 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
"""
The variance preserving stochastic differential equation (SDE) scheduler.
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functios.
For more information, see the original paper: https://arxiv.org/abs/2011.13456
UNDER CONSTRUCTION