diff --git a/src/diffusers/schedulers/scheduling_ddim_flax.py b/src/diffusers/schedulers/scheduling_ddim_flax.py index d81d666071..a57f61e070 100644 --- a/src/diffusers/schedulers/scheduling_ddim_flax.py +++ b/src/diffusers/schedulers/scheduling_ddim_flax.py @@ -156,7 +156,7 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin): return variance - def set_timesteps(self, state: DDIMSchedulerState, num_inference_steps: int) -> DDIMSchedulerState: + def set_timesteps(self, state: DDIMSchedulerState, num_inference_steps: int, shape: Tuple) -> DDIMSchedulerState: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. diff --git a/src/diffusers/schedulers/scheduling_ddpm_flax.py b/src/diffusers/schedulers/scheduling_ddpm_flax.py index 7c7b8d29ab..8f631403ff 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_flax.py +++ b/src/diffusers/schedulers/scheduling_ddpm_flax.py @@ -133,7 +133,7 @@ class FlaxDDPMScheduler(SchedulerMixin, ConfigMixin): self.variance_type = variance_type - def set_timesteps(self, state: DDPMSchedulerState, num_inference_steps: int) -> DDPMSchedulerState: + def set_timesteps(self, state: DDPMSchedulerState, num_inference_steps: int, shape: Tuple) -> DDPMSchedulerState: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. diff --git a/src/diffusers/schedulers/scheduling_karras_ve_flax.py b/src/diffusers/schedulers/scheduling_karras_ve_flax.py index c320b79e6d..72ff69da03 100644 --- a/src/diffusers/schedulers/scheduling_karras_ve_flax.py +++ b/src/diffusers/schedulers/scheduling_karras_ve_flax.py @@ -99,7 +99,9 @@ class FlaxKarrasVeScheduler(SchedulerMixin, ConfigMixin): ): self.state = KarrasVeSchedulerState.create() - def set_timesteps(self, state: KarrasVeSchedulerState, num_inference_steps: int) -> KarrasVeSchedulerState: + def set_timesteps( + self, state: KarrasVeSchedulerState, num_inference_steps: int, shape: Tuple + ) -> KarrasVeSchedulerState: """ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. diff --git a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py index 4784e4fafc..51b3dfc72f 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py @@ -111,7 +111,9 @@ class FlaxLMSDiscreteScheduler(SchedulerMixin, ConfigMixin): return integrated_coeff - def set_timesteps(self, state: LMSDiscreteSchedulerState, num_inference_steps: int) -> LMSDiscreteSchedulerState: + def set_timesteps( + self, state: LMSDiscreteSchedulerState, num_inference_steps: int, shape: Tuple + ) -> LMSDiscreteSchedulerState: """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index 9e2b19f013..871a406f68 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -156,7 +156,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): def create_state(self): return PNDMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps) - def set_timesteps(self, state: PNDMSchedulerState, shape: Tuple, num_inference_steps: int) -> PNDMSchedulerState: + def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int, shape: Tuple) -> PNDMSchedulerState: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. diff --git a/src/diffusers/schedulers/scheduling_sde_ve_flax.py b/src/diffusers/schedulers/scheduling_sde_ve_flax.py index 08fbe14732..3c8834d7b1 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve_flax.py +++ b/src/diffusers/schedulers/scheduling_sde_ve_flax.py @@ -95,7 +95,7 @@ class FlaxScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): self.state = self.set_sigmas(state, num_train_timesteps, sigma_min, sigma_max, sampling_eps) def set_timesteps( - self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, sampling_eps: float = None + self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, shape: Tuple, sampling_eps: float = None ) -> ScoreSdeVeSchedulerState: """ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.