1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Type hint] Karras VE pipeline (#288)

* [Type hint] Karras VE pipeline

* Apply suggestions from code review

Co-authored-by: Anton Lozhkov <anton@huggingface.co>

Co-authored-by: Anton Lozhkov <anton@huggingface.co>
This commit is contained in:
Patrick von Platen
2022-08-31 12:50:11 +02:00
committed by GitHub
parent 7e1b202d5e
commit 06bc1daf6c

View File

@@ -1,5 +1,6 @@
#!/usr/bin/env python3
import warnings
from typing import Optional
import torch
@@ -21,13 +22,20 @@ class KarrasVePipeline(DiffusionPipeline):
unet: UNet2DModel
scheduler: KarrasVeScheduler
def __init__(self, unet, scheduler):
def __init__(self, unet: UNet2DModel, scheduler: KarrasVeScheduler):
super().__init__()
scheduler = scheduler.set_format("pt")
self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad()
def __call__(self, batch_size=1, num_inference_steps=50, generator=None, output_type="pil", **kwargs):
def __call__(
self,
batch_size: int = 1,
num_inference_steps: int = 50,
generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil",
**kwargs,
):
if "torch_device" in kwargs:
device = kwargs.pop("torch_device")
warnings.warn(