1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Add pred_original_sample to if not return_dict path (#9649)

Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
hlky
2024-10-15 04:56:54 +01:00
committed by GitHub
parent 22ed39f571
commit 1bcd19e4d0
10 changed files with 40 additions and 10 deletions

View File

@@ -463,7 +463,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
prev_sample = prev_sample + variance
if not return_dict:
return (prev_sample,)
return (
prev_sample,
pred_original_sample,
)
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)

View File

@@ -394,7 +394,10 @@ class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin):
prev_sample = a_t * sample + b_t * pred_original_sample
if not return_dict:
return (prev_sample,)
return (
prev_sample,
pred_original_sample,
)
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)

View File

@@ -480,7 +480,10 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
prev_sample = prev_sample + variance
if not return_dict:
return (prev_sample,)
return (
prev_sample,
pred_original_sample,
)
return DDIMParallelSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)

View File

@@ -492,7 +492,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
pred_prev_sample = pred_prev_sample + variance
if not return_dict:
return (pred_prev_sample,)
return (
pred_prev_sample,
pred_original_sample,
)
return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)

View File

@@ -500,7 +500,10 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
pred_prev_sample = pred_prev_sample + variance
if not return_dict:
return (pred_prev_sample,)
return (
pred_prev_sample,
pred_original_sample,
)
return DDPMParallelSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)

View File

@@ -360,7 +360,10 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
self._step_index += 1
if not return_dict:
return (prev_sample,)
return (
prev_sample,
pred_original_sample,
)
return EDMEulerSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)

View File

@@ -435,7 +435,10 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
self._step_index += 1
if not return_dict:
return (prev_sample,)
return (
prev_sample,
pred_original_sample,
)
return EulerAncestralDiscreteSchedulerOutput(
prev_sample=prev_sample, pred_original_sample=pred_original_sample

View File

@@ -677,7 +677,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
self._step_index += 1
if not return_dict:
return (prev_sample,)
return (
prev_sample,
pred_original_sample,
)
return EulerDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)

View File

@@ -507,7 +507,10 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
self._step_index += 1
if not return_dict:
return (prev_sample,)
return (
prev_sample,
pred_original_sample,
)
return LMSDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)

View File

@@ -320,7 +320,10 @@ class UnCLIPScheduler(SchedulerMixin, ConfigMixin):
pred_prev_sample = pred_prev_sample + variance
if not return_dict:
return (pred_prev_sample,)
return (
pred_prev_sample,
pred_original_sample,
)
return UnCLIPSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)