mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add RequestScopedPipeline for safe concurrent inference, tokenizer lock and non-mutating retrieve_timesteps (#12328)
* Basic implementation of request scheduling * Basic editing in SD and Flux Pipelines * Small Fix * Fix * Update for more pipelines * Add examples/server-async * Add examples/server-async * Updated RequestScopedPipeline to handle a single tokenizer lock to avoid race conditions * Fix * Fix _TokenizerLockWrapper * Fix _TokenizerLockWrapper * Delete _TokenizerLockWrapper * Fix tokenizer * Update examples/server-async * Fix server-async * Optimizations in examples/server-async * We keep the implementation simple in examples/server-async * Update examples/server-async/README.md * Update examples/server-async/README.md for changes to tokenizer locks and backward-compatible retrieve_timesteps * The changes to the diffusers core have been undone and all logic is being moved to exmaples/server-async * Update examples/server-async/utils/* * Fix BaseAsyncScheduler * Rollback in the core of the diffusers * Update examples/server-async/README.md * Complete rollback of diffusers core files * Simple implementation of an asynchronous server compatible with SD3-3.5 and Flux Pipelines * Update examples/server-async/README.md * Fixed import errors in 'examples/server-async/serverasync.py' * Flux Pipeline Discard * Update examples/server-async/README.md * Apply style fixes --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
91
examples/server-async/Pipelines.py
Normal file
91
examples/server-async/Pipelines.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel
|
||||
|
||||
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import StableDiffusion3Pipeline
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TextToImageInput(BaseModel):
|
||||
model: str
|
||||
prompt: str
|
||||
size: str | None = None
|
||||
n: int | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class PresetModels:
|
||||
SD3: List[str] = field(default_factory=lambda: ["stabilityai/stable-diffusion-3-medium"])
|
||||
SD3_5: List[str] = field(
|
||||
default_factory=lambda: [
|
||||
"stabilityai/stable-diffusion-3.5-large",
|
||||
"stabilityai/stable-diffusion-3.5-large-turbo",
|
||||
"stabilityai/stable-diffusion-3.5-medium",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class TextToImagePipelineSD3:
|
||||
def __init__(self, model_path: str | None = None):
|
||||
self.model_path = model_path or os.getenv("MODEL_PATH")
|
||||
self.pipeline: StableDiffusion3Pipeline | None = None
|
||||
self.device: str | None = None
|
||||
|
||||
def start(self):
|
||||
if torch.cuda.is_available():
|
||||
model_path = self.model_path or "stabilityai/stable-diffusion-3.5-large"
|
||||
logger.info("Loading CUDA")
|
||||
self.device = "cuda"
|
||||
self.pipeline = StableDiffusion3Pipeline.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch.float16,
|
||||
).to(device=self.device)
|
||||
elif torch.backends.mps.is_available():
|
||||
model_path = self.model_path or "stabilityai/stable-diffusion-3.5-medium"
|
||||
logger.info("Loading MPS for Mac M Series")
|
||||
self.device = "mps"
|
||||
self.pipeline = StableDiffusion3Pipeline.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch.bfloat16,
|
||||
).to(device=self.device)
|
||||
else:
|
||||
raise Exception("No CUDA or MPS device available")
|
||||
|
||||
|
||||
class ModelPipelineInitializer:
|
||||
def __init__(self, model: str = "", type_models: str = "t2im"):
|
||||
self.model = model
|
||||
self.type_models = type_models
|
||||
self.pipeline = None
|
||||
self.device = "cuda" if torch.cuda.is_available() else "mps"
|
||||
self.model_type = None
|
||||
|
||||
def initialize_pipeline(self):
|
||||
if not self.model:
|
||||
raise ValueError("Model name not provided")
|
||||
|
||||
# Check if model exists in PresetModels
|
||||
preset_models = PresetModels()
|
||||
|
||||
# Determine which model type we're dealing with
|
||||
if self.model in preset_models.SD3:
|
||||
self.model_type = "SD3"
|
||||
elif self.model in preset_models.SD3_5:
|
||||
self.model_type = "SD3_5"
|
||||
|
||||
# Create appropriate pipeline based on model type and type_models
|
||||
if self.type_models == "t2im":
|
||||
if self.model_type in ["SD3", "SD3_5"]:
|
||||
self.pipeline = TextToImagePipelineSD3(self.model)
|
||||
else:
|
||||
raise ValueError(f"Model type {self.model_type} not supported for text-to-image")
|
||||
elif self.type_models == "t2v":
|
||||
raise ValueError(f"Unsupported type_models: {self.type_models}")
|
||||
|
||||
return self.pipeline
|
||||
171
examples/server-async/README.md
Normal file
171
examples/server-async/README.md
Normal file
@@ -0,0 +1,171 @@
|
||||
# Asynchronous server and parallel execution of models
|
||||
|
||||
> Example/demo server that keeps a single model in memory while safely running parallel inference requests by creating per-request lightweight views and cloning only small, stateful components (schedulers, RNG state, small mutable attrs). Works with StableDiffusion3 pipelines.
|
||||
> We recommend running 10 to 50 inferences in parallel for optimal performance, averaging between 25 and 30 seconds to 1 minute and 1 minute and 30 seconds. (This is only recommended if you have a GPU with 35GB of VRAM or more; otherwise, keep it to one or two inferences in parallel to avoid decoding or saving errors due to memory shortages.)
|
||||
|
||||
## ⚠️ IMPORTANT
|
||||
|
||||
* The example demonstrates how to run pipelines like `StableDiffusion3-3.5` concurrently while keeping a single copy of the heavy model parameters on GPU.
|
||||
|
||||
## Necessary components
|
||||
|
||||
All the components needed to create the inference server are in the current directory:
|
||||
|
||||
```
|
||||
server-async/
|
||||
├── utils/
|
||||
├─────── __init__.py
|
||||
├─────── scheduler.py # BaseAsyncScheduler wrapper and async_retrieve_timesteps for secure inferences
|
||||
├─────── requestscopedpipeline.py # RequestScoped Pipeline for inference with a single in-memory model
|
||||
├─────── utils.py # Image/video saving utilities and service configuration
|
||||
├── Pipelines.py # pipeline loader classes (SD3)
|
||||
├── serverasync.py # FastAPI app with lifespan management and async inference endpoints
|
||||
├── test.py # Client test script for inference requests
|
||||
├── requirements.txt # Dependencies
|
||||
└── README.md # This documentation
|
||||
```
|
||||
|
||||
## What `diffusers-async` adds / Why we needed it
|
||||
|
||||
Core problem: a naive server that calls `pipe.__call__` concurrently can hit **race conditions** (e.g., `scheduler.set_timesteps` mutates shared state) or explode memory by deep-copying the whole pipeline per-request.
|
||||
|
||||
`diffusers-async` / this example addresses that by:
|
||||
|
||||
* **Request-scoped views**: `RequestScopedPipeline` creates a shallow copy of the pipeline per request so heavy weights (UNet, VAE, text encoder) remain shared and *are not duplicated*.
|
||||
* **Per-request mutable state**: stateful small objects (scheduler, RNG state, small lists/dicts, callbacks) are cloned per request. The system uses `BaseAsyncScheduler.clone_for_request(...)` for scheduler cloning, with fallback to safe `deepcopy` or other heuristics.
|
||||
* **Tokenizer concurrency safety**: `RequestScopedPipeline` now manages an internal tokenizer lock with automatic tokenizer detection and wrapping. This ensures that Rust tokenizers are safe to use under concurrency — race condition errors like `Already borrowed` no longer occur.
|
||||
* **`async_retrieve_timesteps(..., return_scheduler=True)`**: fully retro-compatible helper that returns `(timesteps, num_inference_steps, scheduler)` without mutating the shared scheduler. For users not using `return_scheduler=True`, the behavior is identical to the original API.
|
||||
* **Robust attribute handling**: wrapper avoids writing to read-only properties (e.g., `components`) and auto-detects small mutable attributes to clone while avoiding duplication of large tensors. Configurable tensor size threshold prevents cloning of large tensors.
|
||||
* **Enhanced scheduler wrapping**: `BaseAsyncScheduler` automatically wraps schedulers with improved `__getattr__`, `__setattr__`, and debugging methods (`__repr__`, `__str__`).
|
||||
|
||||
## How the server works (high-level flow)
|
||||
|
||||
1. **Single model instance** is loaded into memory (GPU/MPS) when the server starts.
|
||||
2. On each HTTP inference request:
|
||||
|
||||
* The server uses `RequestScopedPipeline.generate(...)` which:
|
||||
|
||||
* automatically wraps the base scheduler in `BaseAsyncScheduler` (if not already wrapped),
|
||||
* obtains a *local scheduler* (via `clone_for_request()` or `deepcopy`),
|
||||
* does `local_pipe = copy.copy(base_pipe)` (shallow copy),
|
||||
* sets `local_pipe.scheduler = local_scheduler` (if possible),
|
||||
* clones only small mutable attributes (callbacks, rng, small latents) with auto-detection,
|
||||
* wraps tokenizers with thread-safe locks to prevent race conditions,
|
||||
* optionally enters a `model_cpu_offload_context()` for memory offload hooks,
|
||||
* calls the pipeline on the local view (`local_pipe(...)`).
|
||||
3. **Result**: inference completes, images are moved to CPU & saved (if requested), internal buffers freed (GC + `torch.cuda.empty_cache()`).
|
||||
4. Multiple requests can run in parallel while sharing heavy weights and isolating mutable state.
|
||||
|
||||
## How to set up and run the server
|
||||
|
||||
### 1) Install dependencies
|
||||
|
||||
Recommended: create a virtualenv / conda environment.
|
||||
|
||||
```bash
|
||||
pip install diffusers
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### 2) Start the server
|
||||
|
||||
Using the `serverasync.py` file that already has everything you need:
|
||||
|
||||
```bash
|
||||
python serverasync.py
|
||||
```
|
||||
|
||||
The server will start on `http://localhost:8500` by default with the following features:
|
||||
- FastAPI application with async lifespan management
|
||||
- Automatic model loading and pipeline initialization
|
||||
- Request counting and active inference tracking
|
||||
- Memory cleanup after each inference
|
||||
- CORS middleware for cross-origin requests
|
||||
|
||||
### 3) Test the server
|
||||
|
||||
Use the included test script:
|
||||
|
||||
```bash
|
||||
python test.py
|
||||
```
|
||||
|
||||
Or send a manual request:
|
||||
|
||||
`POST /api/diffusers/inference` with JSON body:
|
||||
|
||||
```json
|
||||
{
|
||||
"prompt": "A futuristic cityscape, vibrant colors",
|
||||
"num_inference_steps": 30,
|
||||
"num_images_per_prompt": 1
|
||||
}
|
||||
```
|
||||
|
||||
Response example:
|
||||
|
||||
```json
|
||||
{
|
||||
"response": ["http://localhost:8500/images/img123.png"]
|
||||
}
|
||||
```
|
||||
|
||||
### 4) Server endpoints
|
||||
|
||||
- `GET /` - Welcome message
|
||||
- `POST /api/diffusers/inference` - Main inference endpoint
|
||||
- `GET /images/{filename}` - Serve generated images
|
||||
- `GET /api/status` - Server status and memory info
|
||||
|
||||
## Advanced Configuration
|
||||
|
||||
### RequestScopedPipeline Parameters
|
||||
|
||||
```python
|
||||
RequestScopedPipeline(
|
||||
pipeline, # Base pipeline to wrap
|
||||
mutable_attrs=None, # Custom list of attributes to clone
|
||||
auto_detect_mutables=True, # Enable automatic detection of mutable attributes
|
||||
tensor_numel_threshold=1_000_000, # Tensor size threshold for cloning
|
||||
tokenizer_lock=None, # Custom threading lock for tokenizers
|
||||
wrap_scheduler=True # Auto-wrap scheduler in BaseAsyncScheduler
|
||||
)
|
||||
```
|
||||
|
||||
### BaseAsyncScheduler Features
|
||||
|
||||
* Transparent proxy to the original scheduler with `__getattr__` and `__setattr__`
|
||||
* `clone_for_request()` method for safe per-request scheduler cloning
|
||||
* Enhanced debugging with `__repr__` and `__str__` methods
|
||||
* Full compatibility with existing scheduler APIs
|
||||
|
||||
### Server Configuration
|
||||
|
||||
The server configuration can be modified in `serverasync.py` through the `ServerConfigModels` dataclass:
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class ServerConfigModels:
|
||||
model: str = 'stabilityai/stable-diffusion-3.5-medium'
|
||||
type_models: str = 't2im'
|
||||
host: str = '0.0.0.0'
|
||||
port: int = 8500
|
||||
```
|
||||
|
||||
## Troubleshooting (quick)
|
||||
|
||||
* `Already borrowed` — previously a Rust tokenizer concurrency error.
|
||||
✅ This is now fixed: `RequestScopedPipeline` automatically detects and wraps tokenizers with thread locks, so race conditions no longer happen.
|
||||
|
||||
* `can't set attribute 'components'` — pipeline exposes read-only `components`.
|
||||
✅ The RequestScopedPipeline now detects read-only properties and skips setting them automatically.
|
||||
|
||||
* Scheduler issues:
|
||||
* If the scheduler doesn't implement `clone_for_request` and `deepcopy` fails, we log and fallback — but prefer `async_retrieve_timesteps(..., return_scheduler=True)` to avoid mutating the shared scheduler.
|
||||
✅ Note: `async_retrieve_timesteps` is fully retro-compatible — if you don't pass `return_scheduler=True`, the behavior is unchanged.
|
||||
|
||||
* Memory issues with large tensors:
|
||||
✅ The system now has configurable `tensor_numel_threshold` to prevent cloning of large tensors while still cloning small mutable ones.
|
||||
|
||||
* Automatic tokenizer detection:
|
||||
✅ The system automatically identifies tokenizer components by checking for tokenizer methods, class names, and attributes, then applies thread-safe wrappers.
|
||||
10
examples/server-async/requirements.txt
Normal file
10
examples/server-async/requirements.txt
Normal file
@@ -0,0 +1,10 @@
|
||||
torch
|
||||
torchvision
|
||||
transformers
|
||||
sentencepiece
|
||||
fastapi
|
||||
uvicorn
|
||||
ftfy
|
||||
accelerate
|
||||
xformers
|
||||
protobuf
|
||||
230
examples/server-async/serverasync.py
Normal file
230
examples/server-async/serverasync.py
Normal file
@@ -0,0 +1,230 @@
|
||||
import asyncio
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import threading
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
import torch
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse
|
||||
from Pipelines import ModelPipelineInitializer
|
||||
from pydantic import BaseModel
|
||||
|
||||
from utils import RequestScopedPipeline, Utils
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServerConfigModels:
|
||||
model: str = "stabilityai/stable-diffusion-3.5-medium"
|
||||
type_models: str = "t2im"
|
||||
constructor_pipeline: Optional[Type] = None
|
||||
custom_pipeline: Optional[Type] = None
|
||||
components: Optional[Dict[str, Any]] = None
|
||||
torch_dtype: Optional[torch.dtype] = None
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8500
|
||||
|
||||
|
||||
server_config = ServerConfigModels()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
app.state.logger = logging.getLogger("diffusers-server")
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True"
|
||||
os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
|
||||
|
||||
app.state.total_requests = 0
|
||||
app.state.active_inferences = 0
|
||||
app.state.metrics_lock = asyncio.Lock()
|
||||
app.state.metrics_task = None
|
||||
|
||||
app.state.utils_app = Utils(
|
||||
host=server_config.host,
|
||||
port=server_config.port,
|
||||
)
|
||||
|
||||
async def metrics_loop():
|
||||
try:
|
||||
while True:
|
||||
async with app.state.metrics_lock:
|
||||
total = app.state.total_requests
|
||||
active = app.state.active_inferences
|
||||
app.state.logger.info(f"[METRICS] total_requests={total} active_inferences={active}")
|
||||
await asyncio.sleep(5)
|
||||
except asyncio.CancelledError:
|
||||
app.state.logger.info("Metrics loop cancelled")
|
||||
raise
|
||||
|
||||
app.state.metrics_task = asyncio.create_task(metrics_loop())
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
task = app.state.metrics_task
|
||||
if task:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
try:
|
||||
stop_fn = getattr(model_pipeline, "stop", None) or getattr(model_pipeline, "close", None)
|
||||
if callable(stop_fn):
|
||||
await run_in_threadpool(stop_fn)
|
||||
except Exception as e:
|
||||
app.state.logger.warning(f"Error during pipeline shutdown: {e}")
|
||||
|
||||
app.state.logger.info("Lifespan shutdown complete")
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
logger = logging.getLogger("DiffusersServer.Pipelines")
|
||||
|
||||
|
||||
initializer = ModelPipelineInitializer(
|
||||
model=server_config.model,
|
||||
type_models=server_config.type_models,
|
||||
)
|
||||
model_pipeline = initializer.initialize_pipeline()
|
||||
model_pipeline.start()
|
||||
|
||||
request_pipe = RequestScopedPipeline(model_pipeline.pipeline)
|
||||
pipeline_lock = threading.Lock()
|
||||
|
||||
logger.info(f"Pipeline initialized and ready to receive requests (model ={server_config.model})")
|
||||
|
||||
app.state.MODEL_INITIALIZER = initializer
|
||||
app.state.MODEL_PIPELINE = model_pipeline
|
||||
app.state.REQUEST_PIPE = request_pipe
|
||||
app.state.PIPELINE_LOCK = pipeline_lock
|
||||
|
||||
|
||||
class JSONBodyQueryAPI(BaseModel):
|
||||
model: str | None = None
|
||||
prompt: str
|
||||
negative_prompt: str | None = None
|
||||
num_inference_steps: int = 28
|
||||
num_images_per_prompt: int = 1
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def count_requests_middleware(request: Request, call_next):
|
||||
async with app.state.metrics_lock:
|
||||
app.state.total_requests += 1
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"message": "Welcome to the Diffusers Server"}
|
||||
|
||||
|
||||
@app.post("/api/diffusers/inference")
|
||||
async def api(json: JSONBodyQueryAPI):
|
||||
prompt = json.prompt
|
||||
negative_prompt = json.negative_prompt or ""
|
||||
num_steps = json.num_inference_steps
|
||||
num_images_per_prompt = json.num_images_per_prompt
|
||||
|
||||
wrapper = app.state.MODEL_PIPELINE
|
||||
initializer = app.state.MODEL_INITIALIZER
|
||||
|
||||
utils_app = app.state.utils_app
|
||||
|
||||
if not wrapper or not wrapper.pipeline:
|
||||
raise HTTPException(500, "Model not initialized correctly")
|
||||
if not prompt.strip():
|
||||
raise HTTPException(400, "No prompt provided")
|
||||
|
||||
def make_generator():
|
||||
g = torch.Generator(device=initializer.device)
|
||||
return g.manual_seed(random.randint(0, 10_000_000))
|
||||
|
||||
req_pipe = app.state.REQUEST_PIPE
|
||||
|
||||
def infer():
|
||||
gen = make_generator()
|
||||
return req_pipe.generate(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
generator=gen,
|
||||
num_inference_steps=num_steps,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=initializer.device,
|
||||
output_type="pil",
|
||||
)
|
||||
|
||||
try:
|
||||
async with app.state.metrics_lock:
|
||||
app.state.active_inferences += 1
|
||||
|
||||
output = await run_in_threadpool(infer)
|
||||
|
||||
async with app.state.metrics_lock:
|
||||
app.state.active_inferences = max(0, app.state.active_inferences - 1)
|
||||
|
||||
urls = [utils_app.save_image(img) for img in output.images]
|
||||
return {"response": urls}
|
||||
|
||||
except Exception as e:
|
||||
async with app.state.metrics_lock:
|
||||
app.state.active_inferences = max(0, app.state.active_inferences - 1)
|
||||
logger.error(f"Error during inference: {e}")
|
||||
raise HTTPException(500, f"Error in processing: {e}")
|
||||
|
||||
finally:
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.ipc_collect()
|
||||
gc.collect()
|
||||
|
||||
|
||||
@app.get("/images/{filename}")
|
||||
async def serve_image(filename: str):
|
||||
utils_app = app.state.utils_app
|
||||
file_path = os.path.join(utils_app.image_dir, filename)
|
||||
if not os.path.isfile(file_path):
|
||||
raise HTTPException(status_code=404, detail="Image not found")
|
||||
return FileResponse(file_path, media_type="image/png")
|
||||
|
||||
|
||||
@app.get("/api/status")
|
||||
async def get_status():
|
||||
memory_info = {}
|
||||
if torch.cuda.is_available():
|
||||
memory_allocated = torch.cuda.memory_allocated() / 1024**3 # GB
|
||||
memory_reserved = torch.cuda.memory_reserved() / 1024**3 # GB
|
||||
memory_info = {
|
||||
"memory_allocated_gb": round(memory_allocated, 2),
|
||||
"memory_reserved_gb": round(memory_reserved, 2),
|
||||
"device": torch.cuda.get_device_name(0),
|
||||
}
|
||||
|
||||
return {"current_model": server_config.model, "type_models": server_config.type_models, "memory": memory_info}
|
||||
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host=server_config.host, port=server_config.port)
|
||||
65
examples/server-async/test.py
Normal file
65
examples/server-async/test.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import os
|
||||
import time
|
||||
import urllib.parse
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
SERVER_URL = "http://localhost:8500/api/diffusers/inference"
|
||||
BASE_URL = "http://localhost:8500"
|
||||
DOWNLOAD_FOLDER = "generated_images"
|
||||
WAIT_BEFORE_DOWNLOAD = 2 # seconds
|
||||
|
||||
os.makedirs(DOWNLOAD_FOLDER, exist_ok=True)
|
||||
|
||||
|
||||
def save_from_url(url: str) -> str:
|
||||
"""Download the given URL (relative or absolute) and save it locally."""
|
||||
if url.startswith("/"):
|
||||
direct = BASE_URL.rstrip("/") + url
|
||||
else:
|
||||
direct = url
|
||||
resp = requests.get(direct, timeout=60)
|
||||
resp.raise_for_status()
|
||||
filename = os.path.basename(urllib.parse.urlparse(direct).path) or f"img_{int(time.time())}.png"
|
||||
path = os.path.join(DOWNLOAD_FOLDER, filename)
|
||||
with open(path, "wb") as f:
|
||||
f.write(resp.content)
|
||||
return path
|
||||
|
||||
|
||||
def main():
|
||||
payload = {
|
||||
"prompt": "The T-800 Terminator Robot Returning From The Future, Anime Style",
|
||||
"num_inference_steps": 30,
|
||||
"num_images_per_prompt": 1,
|
||||
}
|
||||
|
||||
print("Sending request...")
|
||||
try:
|
||||
r = requests.post(SERVER_URL, json=payload, timeout=480)
|
||||
r.raise_for_status()
|
||||
except Exception as e:
|
||||
print(f"Request failed: {e}")
|
||||
return
|
||||
|
||||
body = r.json().get("response", [])
|
||||
# Normalize to a list
|
||||
urls = body if isinstance(body, list) else [body] if body else []
|
||||
if not urls:
|
||||
print("No URLs found in the response. Check the server output.")
|
||||
return
|
||||
|
||||
print(f"Received {len(urls)} URL(s). Waiting {WAIT_BEFORE_DOWNLOAD}s before downloading...")
|
||||
time.sleep(WAIT_BEFORE_DOWNLOAD)
|
||||
|
||||
for u in urls:
|
||||
try:
|
||||
path = save_from_url(u)
|
||||
print(f"Image saved to: {path}")
|
||||
except Exception as e:
|
||||
print(f"Error downloading {u}: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
2
examples/server-async/utils/__init__.py
Normal file
2
examples/server-async/utils/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .requestscopedpipeline import RequestScopedPipeline
|
||||
from .utils import Utils
|
||||
296
examples/server-async/utils/requestscopedpipeline.py
Normal file
296
examples/server-async/utils/requestscopedpipeline.py
Normal file
@@ -0,0 +1,296 @@
|
||||
import copy
|
||||
import threading
|
||||
from typing import Any, Iterable, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers.utils import logging
|
||||
|
||||
from .scheduler import BaseAsyncScheduler, async_retrieve_timesteps
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def safe_tokenize(tokenizer, *args, lock, **kwargs):
|
||||
with lock:
|
||||
return tokenizer(*args, **kwargs)
|
||||
|
||||
|
||||
class RequestScopedPipeline:
|
||||
DEFAULT_MUTABLE_ATTRS = [
|
||||
"_all_hooks",
|
||||
"_offload_device",
|
||||
"_progress_bar_config",
|
||||
"_progress_bar",
|
||||
"_rng_state",
|
||||
"_last_seed",
|
||||
"latents",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pipeline: Any,
|
||||
mutable_attrs: Optional[Iterable[str]] = None,
|
||||
auto_detect_mutables: bool = True,
|
||||
tensor_numel_threshold: int = 1_000_000,
|
||||
tokenizer_lock: Optional[threading.Lock] = None,
|
||||
wrap_scheduler: bool = True,
|
||||
):
|
||||
self._base = pipeline
|
||||
self.unet = getattr(pipeline, "unet", None)
|
||||
self.vae = getattr(pipeline, "vae", None)
|
||||
self.text_encoder = getattr(pipeline, "text_encoder", None)
|
||||
self.components = getattr(pipeline, "components", None)
|
||||
|
||||
if wrap_scheduler and hasattr(pipeline, "scheduler") and pipeline.scheduler is not None:
|
||||
if not isinstance(pipeline.scheduler, BaseAsyncScheduler):
|
||||
pipeline.scheduler = BaseAsyncScheduler(pipeline.scheduler)
|
||||
|
||||
self._mutable_attrs = list(mutable_attrs) if mutable_attrs is not None else list(self.DEFAULT_MUTABLE_ATTRS)
|
||||
self._tokenizer_lock = tokenizer_lock if tokenizer_lock is not None else threading.Lock()
|
||||
|
||||
self._auto_detect_mutables = bool(auto_detect_mutables)
|
||||
self._tensor_numel_threshold = int(tensor_numel_threshold)
|
||||
|
||||
self._auto_detected_attrs: List[str] = []
|
||||
|
||||
def _make_local_scheduler(self, num_inference_steps: int, device: Optional[str] = None, **clone_kwargs):
|
||||
base_sched = getattr(self._base, "scheduler", None)
|
||||
if base_sched is None:
|
||||
return None
|
||||
|
||||
if not isinstance(base_sched, BaseAsyncScheduler):
|
||||
wrapped_scheduler = BaseAsyncScheduler(base_sched)
|
||||
else:
|
||||
wrapped_scheduler = base_sched
|
||||
|
||||
try:
|
||||
return wrapped_scheduler.clone_for_request(
|
||||
num_inference_steps=num_inference_steps, device=device, **clone_kwargs
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"clone_for_request failed: {e}; falling back to deepcopy()")
|
||||
try:
|
||||
return copy.deepcopy(wrapped_scheduler)
|
||||
except Exception as e:
|
||||
logger.warning(f"Deepcopy of scheduler failed: {e}. Returning original scheduler (*risky*).")
|
||||
return wrapped_scheduler
|
||||
|
||||
def _autodetect_mutables(self, max_attrs: int = 40):
|
||||
if not self._auto_detect_mutables:
|
||||
return []
|
||||
|
||||
if self._auto_detected_attrs:
|
||||
return self._auto_detected_attrs
|
||||
|
||||
candidates: List[str] = []
|
||||
seen = set()
|
||||
for name in dir(self._base):
|
||||
if name.startswith("__"):
|
||||
continue
|
||||
if name in self._mutable_attrs:
|
||||
continue
|
||||
if name in ("to", "save_pretrained", "from_pretrained"):
|
||||
continue
|
||||
try:
|
||||
val = getattr(self._base, name)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
import types
|
||||
|
||||
# skip callables and modules
|
||||
if callable(val) or isinstance(val, (types.ModuleType, types.FunctionType, types.MethodType)):
|
||||
continue
|
||||
|
||||
# containers -> candidate
|
||||
if isinstance(val, (dict, list, set, tuple, bytearray)):
|
||||
candidates.append(name)
|
||||
seen.add(name)
|
||||
else:
|
||||
# try Tensor detection
|
||||
try:
|
||||
if isinstance(val, torch.Tensor):
|
||||
if val.numel() <= self._tensor_numel_threshold:
|
||||
candidates.append(name)
|
||||
seen.add(name)
|
||||
else:
|
||||
logger.debug(f"Ignoring large tensor attr '{name}', numel={val.numel()}")
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if len(candidates) >= max_attrs:
|
||||
break
|
||||
|
||||
self._auto_detected_attrs = candidates
|
||||
logger.debug(f"Autodetected mutable attrs to clone: {self._auto_detected_attrs}")
|
||||
return self._auto_detected_attrs
|
||||
|
||||
def _is_readonly_property(self, base_obj, attr_name: str) -> bool:
|
||||
try:
|
||||
cls = type(base_obj)
|
||||
descriptor = getattr(cls, attr_name, None)
|
||||
if isinstance(descriptor, property):
|
||||
return descriptor.fset is None
|
||||
if hasattr(descriptor, "__set__") is False and descriptor is not None:
|
||||
return False
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
def _clone_mutable_attrs(self, base, local):
|
||||
attrs_to_clone = list(self._mutable_attrs)
|
||||
attrs_to_clone.extend(self._autodetect_mutables())
|
||||
|
||||
EXCLUDE_ATTRS = {
|
||||
"components",
|
||||
}
|
||||
|
||||
for attr in attrs_to_clone:
|
||||
if attr in EXCLUDE_ATTRS:
|
||||
logger.debug(f"Skipping excluded attr '{attr}'")
|
||||
continue
|
||||
if not hasattr(base, attr):
|
||||
continue
|
||||
if self._is_readonly_property(base, attr):
|
||||
logger.debug(f"Skipping read-only property '{attr}'")
|
||||
continue
|
||||
|
||||
try:
|
||||
val = getattr(base, attr)
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not getattr('{attr}') on base pipeline: {e}")
|
||||
continue
|
||||
|
||||
try:
|
||||
if isinstance(val, dict):
|
||||
setattr(local, attr, dict(val))
|
||||
elif isinstance(val, (list, tuple, set)):
|
||||
setattr(local, attr, list(val))
|
||||
elif isinstance(val, bytearray):
|
||||
setattr(local, attr, bytearray(val))
|
||||
else:
|
||||
# small tensors or atomic values
|
||||
if isinstance(val, torch.Tensor):
|
||||
if val.numel() <= self._tensor_numel_threshold:
|
||||
setattr(local, attr, val.clone())
|
||||
else:
|
||||
# don't clone big tensors, keep reference
|
||||
setattr(local, attr, val)
|
||||
else:
|
||||
try:
|
||||
setattr(local, attr, copy.copy(val))
|
||||
except Exception:
|
||||
setattr(local, attr, val)
|
||||
except (AttributeError, TypeError) as e:
|
||||
logger.debug(f"Skipping cloning attribute '{attr}' because it is not settable: {e}")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.debug(f"Unexpected error cloning attribute '{attr}': {e}")
|
||||
continue
|
||||
|
||||
def _is_tokenizer_component(self, component) -> bool:
|
||||
if component is None:
|
||||
return False
|
||||
|
||||
tokenizer_methods = ["encode", "decode", "tokenize", "__call__"]
|
||||
has_tokenizer_methods = any(hasattr(component, method) for method in tokenizer_methods)
|
||||
|
||||
class_name = component.__class__.__name__.lower()
|
||||
has_tokenizer_in_name = "tokenizer" in class_name
|
||||
|
||||
tokenizer_attrs = ["vocab_size", "pad_token", "eos_token", "bos_token"]
|
||||
has_tokenizer_attrs = any(hasattr(component, attr) for attr in tokenizer_attrs)
|
||||
|
||||
return has_tokenizer_methods and (has_tokenizer_in_name or has_tokenizer_attrs)
|
||||
|
||||
def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] = None, **kwargs):
|
||||
local_scheduler = self._make_local_scheduler(num_inference_steps=num_inference_steps, device=device)
|
||||
|
||||
try:
|
||||
local_pipe = copy.copy(self._base)
|
||||
except Exception as e:
|
||||
logger.warning(f"copy.copy(self._base) failed: {e}. Falling back to deepcopy (may increase memory).")
|
||||
local_pipe = copy.deepcopy(self._base)
|
||||
|
||||
if local_scheduler is not None:
|
||||
try:
|
||||
timesteps, num_steps, configured_scheduler = async_retrieve_timesteps(
|
||||
local_scheduler.scheduler,
|
||||
num_inference_steps=num_inference_steps,
|
||||
device=device,
|
||||
return_scheduler=True,
|
||||
**{k: v for k, v in kwargs.items() if k in ["timesteps", "sigmas"]},
|
||||
)
|
||||
|
||||
final_scheduler = BaseAsyncScheduler(configured_scheduler)
|
||||
setattr(local_pipe, "scheduler", final_scheduler)
|
||||
except Exception:
|
||||
logger.warning("Could not set scheduler on local pipe; proceeding without replacing scheduler.")
|
||||
|
||||
self._clone_mutable_attrs(self._base, local_pipe)
|
||||
|
||||
# 4) wrap tokenizers on the local pipe with the lock wrapper
|
||||
tokenizer_wrappers = {} # name -> original_tokenizer
|
||||
try:
|
||||
# a) wrap direct tokenizer attributes (tokenizer, tokenizer_2, ...)
|
||||
for name in dir(local_pipe):
|
||||
if "tokenizer" in name and not name.startswith("_"):
|
||||
tok = getattr(local_pipe, name, None)
|
||||
if tok is not None and self._is_tokenizer_component(tok):
|
||||
tokenizer_wrappers[name] = tok
|
||||
setattr(
|
||||
local_pipe,
|
||||
name,
|
||||
lambda *args, tok=tok, **kwargs: safe_tokenize(
|
||||
tok, *args, lock=self._tokenizer_lock, **kwargs
|
||||
),
|
||||
)
|
||||
|
||||
# b) wrap tokenizers in components dict
|
||||
if hasattr(local_pipe, "components") and isinstance(local_pipe.components, dict):
|
||||
for key, val in local_pipe.components.items():
|
||||
if val is None:
|
||||
continue
|
||||
|
||||
if self._is_tokenizer_component(val):
|
||||
tokenizer_wrappers[f"components[{key}]"] = val
|
||||
local_pipe.components[key] = lambda *args, tokenizer=val, **kwargs: safe_tokenize(
|
||||
tokenizer, *args, lock=self._tokenizer_lock, **kwargs
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Tokenizer wrapping step encountered an error: {e}")
|
||||
|
||||
result = None
|
||||
cm = getattr(local_pipe, "model_cpu_offload_context", None)
|
||||
try:
|
||||
if callable(cm):
|
||||
try:
|
||||
with cm():
|
||||
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
|
||||
except TypeError:
|
||||
# cm might be a context manager instance rather than callable
|
||||
try:
|
||||
with cm:
|
||||
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
|
||||
except Exception as e:
|
||||
logger.debug(f"model_cpu_offload_context usage failed: {e}. Proceeding without it.")
|
||||
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
|
||||
else:
|
||||
# no offload context available — call directly
|
||||
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
|
||||
|
||||
return result
|
||||
|
||||
finally:
|
||||
try:
|
||||
for name, tok in tokenizer_wrappers.items():
|
||||
if name.startswith("components["):
|
||||
key = name[len("components[") : -1]
|
||||
local_pipe.components[key] = tok
|
||||
else:
|
||||
setattr(local_pipe, name, tok)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error restoring wrapped tokenizers: {e}")
|
||||
141
examples/server-async/utils/scheduler.py
Normal file
141
examples/server-async/utils/scheduler.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import copy
|
||||
import inspect
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class BaseAsyncScheduler:
|
||||
def __init__(self, scheduler: Any):
|
||||
self.scheduler = scheduler
|
||||
|
||||
def __getattr__(self, name: str):
|
||||
if hasattr(self.scheduler, name):
|
||||
return getattr(self.scheduler, name)
|
||||
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
|
||||
|
||||
def __setattr__(self, name: str, value):
|
||||
if name == "scheduler":
|
||||
super().__setattr__(name, value)
|
||||
else:
|
||||
if hasattr(self, "scheduler") and hasattr(self.scheduler, name):
|
||||
setattr(self.scheduler, name, value)
|
||||
else:
|
||||
super().__setattr__(name, value)
|
||||
|
||||
def clone_for_request(self, num_inference_steps: int, device: Union[str, torch.device, None] = None, **kwargs):
|
||||
local = copy.deepcopy(self.scheduler)
|
||||
local.set_timesteps(num_inference_steps=num_inference_steps, device=device, **kwargs)
|
||||
cloned = self.__class__(local)
|
||||
return cloned
|
||||
|
||||
def __repr__(self):
|
||||
return f"BaseAsyncScheduler({repr(self.scheduler)})"
|
||||
|
||||
def __str__(self):
|
||||
return f"BaseAsyncScheduler wrapping: {str(self.scheduler)}"
|
||||
|
||||
|
||||
def async_retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call.
|
||||
Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Backwards compatible: by default the function behaves exactly as before and returns
|
||||
(timesteps_tensor, num_inference_steps)
|
||||
|
||||
If the caller passes `return_scheduler=True` in kwargs, the function will **not** mutate the passed
|
||||
scheduler. Instead it will use a cloned scheduler if available (via `scheduler.clone_for_request`)
|
||||
or a deepcopy fallback, call `set_timesteps` on that cloned scheduler, and return:
|
||||
(timesteps_tensor, num_inference_steps, scheduler_in_use)
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Optional kwargs:
|
||||
return_scheduler (bool, default False): if True, return (timesteps, num_inference_steps, scheduler_in_use)
|
||||
where `scheduler_in_use` is a scheduler instance that already has timesteps set.
|
||||
This mode will prefer `scheduler.clone_for_request(...)` if available, to avoid mutating the original scheduler.
|
||||
|
||||
Returns:
|
||||
`(timesteps_tensor, num_inference_steps)` by default (backwards compatible), or
|
||||
`(timesteps_tensor, num_inference_steps, scheduler_in_use)` if `return_scheduler=True`.
|
||||
"""
|
||||
# pop our optional control kwarg (keeps compatibility)
|
||||
return_scheduler = bool(kwargs.pop("return_scheduler", False))
|
||||
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
|
||||
# choose scheduler to call set_timesteps on
|
||||
scheduler_in_use = scheduler
|
||||
if return_scheduler:
|
||||
# Do not mutate the provided scheduler: prefer to clone if possible
|
||||
if hasattr(scheduler, "clone_for_request"):
|
||||
try:
|
||||
# clone_for_request may accept num_inference_steps or other kwargs; be permissive
|
||||
scheduler_in_use = scheduler.clone_for_request(
|
||||
num_inference_steps=num_inference_steps or 0, device=device
|
||||
)
|
||||
except Exception:
|
||||
scheduler_in_use = copy.deepcopy(scheduler)
|
||||
else:
|
||||
# fallback deepcopy (scheduler tends to be smallish - acceptable)
|
||||
scheduler_in_use = copy.deepcopy(scheduler)
|
||||
|
||||
# helper to test if set_timesteps supports a particular kwarg
|
||||
def _accepts(param_name: str) -> bool:
|
||||
try:
|
||||
return param_name in set(inspect.signature(scheduler_in_use.set_timesteps).parameters.keys())
|
||||
except (ValueError, TypeError):
|
||||
# if signature introspection fails, be permissive and attempt the call later
|
||||
return False
|
||||
|
||||
# now call set_timesteps on the chosen scheduler_in_use (may be original or clone)
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = _accepts("timesteps")
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler_in_use.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler_in_use.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps_out = scheduler_in_use.timesteps
|
||||
num_inference_steps = len(timesteps_out)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = _accepts("sigmas")
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler_in_use.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler_in_use.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps_out = scheduler_in_use.timesteps
|
||||
num_inference_steps = len(timesteps_out)
|
||||
else:
|
||||
# default path
|
||||
scheduler_in_use.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps_out = scheduler_in_use.timesteps
|
||||
|
||||
if return_scheduler:
|
||||
return timesteps_out, num_inference_steps, scheduler_in_use
|
||||
return timesteps_out, num_inference_steps
|
||||
48
examples/server-async/utils/utils.py
Normal file
48
examples/server-async/utils/utils.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import uuid
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Utils:
|
||||
def __init__(self, host: str = "0.0.0.0", port: int = 8500):
|
||||
self.service_url = f"http://{host}:{port}"
|
||||
self.image_dir = os.path.join(tempfile.gettempdir(), "images")
|
||||
if not os.path.exists(self.image_dir):
|
||||
os.makedirs(self.image_dir)
|
||||
|
||||
self.video_dir = os.path.join(tempfile.gettempdir(), "videos")
|
||||
if not os.path.exists(self.video_dir):
|
||||
os.makedirs(self.video_dir)
|
||||
|
||||
def save_image(self, image):
|
||||
if hasattr(image, "to"):
|
||||
try:
|
||||
image = image.to("cpu")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if isinstance(image, torch.Tensor):
|
||||
from torchvision import transforms
|
||||
|
||||
to_pil = transforms.ToPILImage()
|
||||
image = to_pil(image.squeeze(0).clamp(0, 1))
|
||||
|
||||
filename = "img" + str(uuid.uuid4()).split("-")[0] + ".png"
|
||||
image_path = os.path.join(self.image_dir, filename)
|
||||
logger.info(f"Saving image to {image_path}")
|
||||
|
||||
image.save(image_path, format="PNG", optimize=True)
|
||||
|
||||
del image
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return os.path.join(self.service_url, "images", filename)
|
||||
Reference in New Issue
Block a user