1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

[ONNX] Support Euler schedulers (#1328)

This commit is contained in:
Anton Lozhkov
2022-11-17 16:37:35 +01:00
committed by GitHub
parent 632dacea2f
commit e05ca84f41
3 changed files with 12 additions and 6 deletions

View File

@@ -261,8 +261,10 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, torch.from_numpy(latents), **extra_step_kwargs).prev_sample
latents = np.array(latents)
scheduler_output = self.scheduler.step(
torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs
)
latents = scheduler_output.prev_sample.numpy()
# call the callback, if provided
if callback is not None and i % callback_steps == 0:

View File

@@ -401,8 +401,10 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, torch.from_numpy(latents), **extra_step_kwargs).prev_sample
latents = latents.numpy()
scheduler_output = self.scheduler.step(
torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs
)
latents = scheduler_output.prev_sample.numpy()
# call the callback, if provided
if callback is not None and i % callback_steps == 0:

View File

@@ -424,8 +424,10 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, torch.from_numpy(latents), **extra_step_kwargs).prev_sample
latents = latents.numpy()
scheduler_output = self.scheduler.step(
torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs
)
latents = scheduler_output.prev_sample.numpy()
# call the callback, if provided
if callback is not None and i % callback_steps == 0: