mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Flax: add shape argument to set_timesteps (#690)
* Flax: add shape argument to set_timesteps * style
This commit is contained in:
@@ -156,7 +156,7 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return variance
|
||||
|
||||
def set_timesteps(self, state: DDIMSchedulerState, num_inference_steps: int) -> DDIMSchedulerState:
|
||||
def set_timesteps(self, state: DDIMSchedulerState, num_inference_steps: int, shape: Tuple) -> DDIMSchedulerState:
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
|
||||
@@ -133,7 +133,7 @@ class FlaxDDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
self.variance_type = variance_type
|
||||
|
||||
def set_timesteps(self, state: DDPMSchedulerState, num_inference_steps: int) -> DDPMSchedulerState:
|
||||
def set_timesteps(self, state: DDPMSchedulerState, num_inference_steps: int, shape: Tuple) -> DDPMSchedulerState:
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
|
||||
@@ -99,7 +99,9 @@ class FlaxKarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
):
|
||||
self.state = KarrasVeSchedulerState.create()
|
||||
|
||||
def set_timesteps(self, state: KarrasVeSchedulerState, num_inference_steps: int) -> KarrasVeSchedulerState:
|
||||
def set_timesteps(
|
||||
self, state: KarrasVeSchedulerState, num_inference_steps: int, shape: Tuple
|
||||
) -> KarrasVeSchedulerState:
|
||||
"""
|
||||
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
|
||||
@@ -111,7 +111,9 @@ class FlaxLMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return integrated_coeff
|
||||
|
||||
def set_timesteps(self, state: LMSDiscreteSchedulerState, num_inference_steps: int) -> LMSDiscreteSchedulerState:
|
||||
def set_timesteps(
|
||||
self, state: LMSDiscreteSchedulerState, num_inference_steps: int, shape: Tuple
|
||||
) -> LMSDiscreteSchedulerState:
|
||||
"""
|
||||
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
|
||||
@@ -156,7 +156,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
def create_state(self):
|
||||
return PNDMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps)
|
||||
|
||||
def set_timesteps(self, state: PNDMSchedulerState, shape: Tuple, num_inference_steps: int) -> PNDMSchedulerState:
|
||||
def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int, shape: Tuple) -> PNDMSchedulerState:
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
|
||||
@@ -95,7 +95,7 @@ class FlaxScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.state = self.set_sigmas(state, num_train_timesteps, sigma_min, sigma_max, sampling_eps)
|
||||
|
||||
def set_timesteps(
|
||||
self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, sampling_eps: float = None
|
||||
self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, shape: Tuple, sampling_eps: float = None
|
||||
) -> ScoreSdeVeSchedulerState:
|
||||
"""
|
||||
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
Reference in New Issue
Block a user