1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
Files
diffusers/examples/community/one_step_unet.py
Patrick von Platen 1c96f82ed9 Update one_step_unet.py
Fix dummy community pipeline
2023-04-09 19:22:18 +01:00

25 lines
721 B
Python
Executable File

#!/usr/bin/env python3
import torch
from diffusers import DiffusionPipeline
class UnetSchedulerOneForwardPipeline(DiffusionPipeline):
def __init__(self, unet, scheduler):
super().__init__()
self.register_modules(unet=unet, scheduler=scheduler)
def __call__(self):
image = torch.randn(
(1, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size),
)
timestep = 1
model_output = self.unet(image, timestep).sample
scheduler_output = self.scheduler.step(model_output, timestep, image).prev_sample
result = scheduler_output - scheduler_output + torch.ones_like(scheduler_output)
return result