diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py index bba476b5fe..6b5ce9530d 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py @@ -227,8 +227,9 @@ class WuerstchenCombinedPipeline(DiffusionPipeline): return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. prior_callback (`Callable`, *optional*): - A function that will be called every `prior_callback_steps` steps during inference. The function will be - called with the following arguments: `prior_callback(step: int, timestep: int, latents: torch.FloatTensor)`. + A function that will be called every `prior_callback_steps` steps during inference. The function will + be called with the following arguments: `prior_callback(step: int, timestep: int, latents: + torch.FloatTensor)`. prior_callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index 2e3147b80e..9d9472a906 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -15,7 +15,7 @@ import importlib import os from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, Optional, Union +from typing import Optional, Union import torch diff --git a/src/diffusers/schedulers/scheduling_utils_flax.py b/src/diffusers/schedulers/scheduling_utils_flax.py index d0bed6ef5f..ccec121d30 100644 --- a/src/diffusers/schedulers/scheduling_utils_flax.py +++ b/src/diffusers/schedulers/scheduling_utils_flax.py @@ -16,7 +16,7 @@ import math import os from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, Optional, Tuple, Union +from typing import Optional, Tuple, Union import flax import jax.numpy as jnp