1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
Files
diffusers/examples/server-async/serverasync.py
2025-10-03 11:38:16 +05:30

231 lines
6.8 KiB
Python

import asyncio
import gc
import logging
import os
import random
import threading
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any, Dict, Optional, Type
import torch
from fastapi import FastAPI, HTTPException, Request
from fastapi.concurrency import run_in_threadpool
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from Pipelines import ModelPipelineInitializer
from pydantic import BaseModel
from utils import RequestScopedPipeline, Utils
@dataclass
class ServerConfigModels:
model: str = "stabilityai/stable-diffusion-3.5-medium"
type_models: str = "t2im"
constructor_pipeline: Optional[Type] = None
custom_pipeline: Optional[Type] = None
components: Optional[Dict[str, Any]] = None
torch_dtype: Optional[torch.dtype] = None
host: str = "0.0.0.0"
port: int = 8500
server_config = ServerConfigModels()
@asynccontextmanager
async def lifespan(app: FastAPI):
logging.basicConfig(level=logging.INFO)
app.state.logger = logging.getLogger("diffusers-server")
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True"
os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
app.state.total_requests = 0
app.state.active_inferences = 0
app.state.metrics_lock = asyncio.Lock()
app.state.metrics_task = None
app.state.utils_app = Utils(
host=server_config.host,
port=server_config.port,
)
async def metrics_loop():
try:
while True:
async with app.state.metrics_lock:
total = app.state.total_requests
active = app.state.active_inferences
app.state.logger.info(f"[METRICS] total_requests={total} active_inferences={active}")
await asyncio.sleep(5)
except asyncio.CancelledError:
app.state.logger.info("Metrics loop cancelled")
raise
app.state.metrics_task = asyncio.create_task(metrics_loop())
try:
yield
finally:
task = app.state.metrics_task
if task:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
try:
stop_fn = getattr(model_pipeline, "stop", None) or getattr(model_pipeline, "close", None)
if callable(stop_fn):
await run_in_threadpool(stop_fn)
except Exception as e:
app.state.logger.warning(f"Error during pipeline shutdown: {e}")
app.state.logger.info("Lifespan shutdown complete")
app = FastAPI(lifespan=lifespan)
logger = logging.getLogger("DiffusersServer.Pipelines")
initializer = ModelPipelineInitializer(
model=server_config.model,
type_models=server_config.type_models,
)
model_pipeline = initializer.initialize_pipeline()
model_pipeline.start()
request_pipe = RequestScopedPipeline(model_pipeline.pipeline)
pipeline_lock = threading.Lock()
logger.info(f"Pipeline initialized and ready to receive requests (model ={server_config.model})")
app.state.MODEL_INITIALIZER = initializer
app.state.MODEL_PIPELINE = model_pipeline
app.state.REQUEST_PIPE = request_pipe
app.state.PIPELINE_LOCK = pipeline_lock
class JSONBodyQueryAPI(BaseModel):
model: str | None = None
prompt: str
negative_prompt: str | None = None
num_inference_steps: int = 28
num_images_per_prompt: int = 1
@app.middleware("http")
async def count_requests_middleware(request: Request, call_next):
async with app.state.metrics_lock:
app.state.total_requests += 1
response = await call_next(request)
return response
@app.get("/")
async def root():
return {"message": "Welcome to the Diffusers Server"}
@app.post("/api/diffusers/inference")
async def api(json: JSONBodyQueryAPI):
prompt = json.prompt
negative_prompt = json.negative_prompt or ""
num_steps = json.num_inference_steps
num_images_per_prompt = json.num_images_per_prompt
wrapper = app.state.MODEL_PIPELINE
initializer = app.state.MODEL_INITIALIZER
utils_app = app.state.utils_app
if not wrapper or not wrapper.pipeline:
raise HTTPException(500, "Model not initialized correctly")
if not prompt.strip():
raise HTTPException(400, "No prompt provided")
def make_generator():
g = torch.Generator(device=initializer.device)
return g.manual_seed(random.randint(0, 10_000_000))
req_pipe = app.state.REQUEST_PIPE
def infer():
gen = make_generator()
return req_pipe.generate(
prompt=prompt,
negative_prompt=negative_prompt,
generator=gen,
num_inference_steps=num_steps,
num_images_per_prompt=num_images_per_prompt,
device=initializer.device,
output_type="pil",
)
try:
async with app.state.metrics_lock:
app.state.active_inferences += 1
output = await run_in_threadpool(infer)
async with app.state.metrics_lock:
app.state.active_inferences = max(0, app.state.active_inferences - 1)
urls = [utils_app.save_image(img) for img in output.images]
return {"response": urls}
except Exception as e:
async with app.state.metrics_lock:
app.state.active_inferences = max(0, app.state.active_inferences - 1)
logger.error(f"Error during inference: {e}")
raise HTTPException(500, f"Error in processing: {e}")
finally:
if torch.cuda.is_available():
torch.cuda.synchronize()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.ipc_collect()
gc.collect()
@app.get("/images/{filename}")
async def serve_image(filename: str):
utils_app = app.state.utils_app
file_path = os.path.join(utils_app.image_dir, filename)
if not os.path.isfile(file_path):
raise HTTPException(status_code=404, detail="Image not found")
return FileResponse(file_path, media_type="image/png")
@app.get("/api/status")
async def get_status():
memory_info = {}
if torch.cuda.is_available():
memory_allocated = torch.cuda.memory_allocated() / 1024**3 # GB
memory_reserved = torch.cuda.memory_reserved() / 1024**3 # GB
memory_info = {
"memory_allocated_gb": round(memory_allocated, 2),
"memory_reserved_gb": round(memory_reserved, 2),
"device": torch.cuda.get_device_name(0),
}
return {"current_model": server_config.model, "type_models": server_config.type_models, "memory": memory_info}
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host=server_config.host, port=server_config.port)