mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Type hint] Karras VE pipeline (#288)
* [Type hint] Karras VE pipeline * Apply suggestions from code review Co-authored-by: Anton Lozhkov <anton@huggingface.co> Co-authored-by: Anton Lozhkov <anton@huggingface.co>
This commit is contained in:
committed by
GitHub
parent
7e1b202d5e
commit
06bc1daf6c
@@ -1,5 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -21,13 +22,20 @@ class KarrasVePipeline(DiffusionPipeline):
|
||||
unet: UNet2DModel
|
||||
scheduler: KarrasVeScheduler
|
||||
|
||||
def __init__(self, unet, scheduler):
|
||||
def __init__(self, unet: UNet2DModel, scheduler: KarrasVeScheduler):
|
||||
super().__init__()
|
||||
scheduler = scheduler.set_format("pt")
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, batch_size=1, num_inference_steps=50, generator=None, output_type="pil", **kwargs):
|
||||
def __call__(
|
||||
self,
|
||||
batch_size: int = 1,
|
||||
num_inference_steps: int = 50,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
**kwargs,
|
||||
):
|
||||
if "torch_device" in kwargs:
|
||||
device = kwargs.pop("torch_device")
|
||||
warnings.warn(
|
||||
|
||||
Reference in New Issue
Block a user