mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
49 lines
1.3 KiB
Python
49 lines
1.3 KiB
Python
import gc
|
|
import logging
|
|
import os
|
|
import tempfile
|
|
import uuid
|
|
|
|
import torch
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class Utils:
|
|
def __init__(self, host: str = "0.0.0.0", port: int = 8500):
|
|
self.service_url = f"http://{host}:{port}"
|
|
self.image_dir = os.path.join(tempfile.gettempdir(), "images")
|
|
if not os.path.exists(self.image_dir):
|
|
os.makedirs(self.image_dir)
|
|
|
|
self.video_dir = os.path.join(tempfile.gettempdir(), "videos")
|
|
if not os.path.exists(self.video_dir):
|
|
os.makedirs(self.video_dir)
|
|
|
|
def save_image(self, image):
|
|
if hasattr(image, "to"):
|
|
try:
|
|
image = image.to("cpu")
|
|
except Exception:
|
|
pass
|
|
|
|
if isinstance(image, torch.Tensor):
|
|
from torchvision import transforms
|
|
|
|
to_pil = transforms.ToPILImage()
|
|
image = to_pil(image.squeeze(0).clamp(0, 1))
|
|
|
|
filename = "img" + str(uuid.uuid4()).split("-")[0] + ".png"
|
|
image_path = os.path.join(self.image_dir, filename)
|
|
logger.info(f"Saving image to {image_path}")
|
|
|
|
image.save(image_path, format="PNG", optimize=True)
|
|
|
|
del image
|
|
gc.collect()
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
|
|
return os.path.join(self.service_url, "images", filename)
|