mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Add pred_original_sample to if not return_dict path (#9649)
Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
@@ -463,7 +463,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
prev_sample = prev_sample + variance
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
return (
|
||||
prev_sample,
|
||||
pred_original_sample,
|
||||
)
|
||||
|
||||
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
|
||||
|
||||
|
||||
@@ -394,7 +394,10 @@ class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
prev_sample = a_t * sample + b_t * pred_original_sample
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
return (
|
||||
prev_sample,
|
||||
pred_original_sample,
|
||||
)
|
||||
|
||||
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
|
||||
|
||||
|
||||
@@ -480,7 +480,10 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
prev_sample = prev_sample + variance
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
return (
|
||||
prev_sample,
|
||||
pred_original_sample,
|
||||
)
|
||||
|
||||
return DDIMParallelSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
|
||||
|
||||
|
||||
@@ -492,7 +492,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
pred_prev_sample = pred_prev_sample + variance
|
||||
|
||||
if not return_dict:
|
||||
return (pred_prev_sample,)
|
||||
return (
|
||||
pred_prev_sample,
|
||||
pred_original_sample,
|
||||
)
|
||||
|
||||
return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
|
||||
|
||||
|
||||
@@ -500,7 +500,10 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
pred_prev_sample = pred_prev_sample + variance
|
||||
|
||||
if not return_dict:
|
||||
return (pred_prev_sample,)
|
||||
return (
|
||||
pred_prev_sample,
|
||||
pred_original_sample,
|
||||
)
|
||||
|
||||
return DDPMParallelSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
|
||||
|
||||
|
||||
@@ -360,7 +360,10 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
return (
|
||||
prev_sample,
|
||||
pred_original_sample,
|
||||
)
|
||||
|
||||
return EDMEulerSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
|
||||
|
||||
|
||||
@@ -435,7 +435,10 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
return (
|
||||
prev_sample,
|
||||
pred_original_sample,
|
||||
)
|
||||
|
||||
return EulerAncestralDiscreteSchedulerOutput(
|
||||
prev_sample=prev_sample, pred_original_sample=pred_original_sample
|
||||
|
||||
@@ -677,7 +677,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
return (
|
||||
prev_sample,
|
||||
pred_original_sample,
|
||||
)
|
||||
|
||||
return EulerDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
|
||||
|
||||
|
||||
@@ -507,7 +507,10 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
return (
|
||||
prev_sample,
|
||||
pred_original_sample,
|
||||
)
|
||||
|
||||
return LMSDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
|
||||
|
||||
|
||||
@@ -320,7 +320,10 @@ class UnCLIPScheduler(SchedulerMixin, ConfigMixin):
|
||||
pred_prev_sample = pred_prev_sample + variance
|
||||
|
||||
if not return_dict:
|
||||
return (pred_prev_sample,)
|
||||
return (
|
||||
pred_prev_sample,
|
||||
pred_original_sample,
|
||||
)
|
||||
|
||||
return UnCLIPSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user