mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add server example (#9918)
* Add server example. * Minor updates to README. * Add fixes after local testing. * Apply suggestions from code review Updates to README from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * More doc updates. * Maybe this will work to build the docs correctly? * Fix style issues. * Fix toc. * Minor reformatting. * Move docs to proper loc. * Fix missing tick. * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Sync docs changes back to README. * Very minor update to docs to add space. --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
This commit is contained in:
@@ -55,6 +55,8 @@
|
||||
- sections:
|
||||
- local: using-diffusers/overview_techniques
|
||||
title: Overview
|
||||
- local: using-diffusers/create_a_server
|
||||
title: Create a server
|
||||
- local: training/distributed_inference
|
||||
title: Distributed inference
|
||||
- local: using-diffusers/merge_loras
|
||||
|
||||
61
docs/source/en/using-diffusers/create_a_server.md
Normal file
61
docs/source/en/using-diffusers/create_a_server.md
Normal file
@@ -0,0 +1,61 @@
|
||||
|
||||
# Create a server
|
||||
|
||||
Diffusers' pipelines can be used as an inference engine for a server. It supports concurrent and multithreaded requests to generate images that may be requested by multiple users at the same time.
|
||||
|
||||
This guide will show you how to use the [`StableDiffusion3Pipeline`] in a server, but feel free to use any pipeline you want.
|
||||
|
||||
|
||||
Start by navigating to the `examples/server` folder and installing all of the dependencies.
|
||||
|
||||
```py
|
||||
pip install .
|
||||
pip install -f requirements.txt
|
||||
```
|
||||
|
||||
Launch the server with the following command.
|
||||
|
||||
```py
|
||||
python server.py
|
||||
```
|
||||
|
||||
The server is accessed at http://localhost:8000. You can curl this model with the following command.
|
||||
```
|
||||
curl -X POST -H "Content-Type: application/json" --data '{"model": "something", "prompt": "a kitten in front of a fireplace"}' http://localhost:8000/v1/images/generations
|
||||
```
|
||||
|
||||
If you need to upgrade some dependencies, you can use either [pip-tools](https://github.com/jazzband/pip-tools) or [uv](https://github.com/astral-sh/uv). For example, upgrade the dependencies with `uv` using the following command.
|
||||
|
||||
```
|
||||
uv pip compile requirements.in -o requirements.txt
|
||||
```
|
||||
|
||||
|
||||
The server is built with [FastAPI](https://fastapi.tiangolo.com/async/). The endpoint for `v1/images/generations` is shown below.
|
||||
```py
|
||||
@app.post("/v1/images/generations")
|
||||
async def generate_image(image_input: TextToImageInput):
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
scheduler = shared_pipeline.pipeline.scheduler.from_config(shared_pipeline.pipeline.scheduler.config)
|
||||
pipeline = StableDiffusion3Pipeline.from_pipe(shared_pipeline.pipeline, scheduler=scheduler)
|
||||
generator = torch.Generator(device="cuda")
|
||||
generator.manual_seed(random.randint(0, 10000000))
|
||||
output = await loop.run_in_executor(None, lambda: pipeline(image_input.prompt, generator = generator))
|
||||
logger.info(f"output: {output}")
|
||||
image_url = save_image(output.images[0])
|
||||
return {"data": [{"url": image_url}]}
|
||||
except Exception as e:
|
||||
if isinstance(e, HTTPException):
|
||||
raise e
|
||||
elif hasattr(e, 'message'):
|
||||
raise HTTPException(status_code=500, detail=e.message + traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e) + traceback.format_exc())
|
||||
```
|
||||
The `generate_image` function is defined as asynchronous with the [async](https://fastapi.tiangolo.com/async/) keyword so that FastAPI knows that whatever is happening in this function won't necessarily return a result right away. Once it hits some point in the function that it needs to await some other [Task](https://docs.python.org/3/library/asyncio-task.html#asyncio.Task), the main thread goes back to answering other HTTP requests. This is shown in the code below with the [await](https://fastapi.tiangolo.com/async/#async-and-await) keyword.
|
||||
```py
|
||||
output = await loop.run_in_executor(None, lambda: pipeline(image_input.prompt, generator = generator))
|
||||
```
|
||||
At this point, the execution of the pipeline function is placed onto a [new thread](https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.run_in_executor), and the main thread performs other things until a result is returned from the `pipeline`.
|
||||
|
||||
Another important aspect of this implementation is creating a `pipeline` from `shared_pipeline`. The goal behind this is to avoid loading the underlying model more than once onto the GPU while still allowing for each new request that is running on a separate thread to have its own generator and scheduler. The scheduler, in particular, is not thread-safe, and it will cause errors like: `IndexError: index 21 is out of bounds for dimension 0 with size 21` if you try to use the same scheduler across multiple threads.
|
||||
61
examples/server/README.md
Normal file
61
examples/server/README.md
Normal file
@@ -0,0 +1,61 @@
|
||||
|
||||
# Create a server
|
||||
|
||||
Diffusers' pipelines can be used as an inference engine for a server. It supports concurrent and multithreaded requests to generate images that may be requested by multiple users at the same time.
|
||||
|
||||
This guide will show you how to use the [`StableDiffusion3Pipeline`] in a server, but feel free to use any pipeline you want.
|
||||
|
||||
|
||||
Start by navigating to the `examples/server` folder and installing all of the dependencies.
|
||||
|
||||
```py
|
||||
pip install .
|
||||
pip install -f requirements.txt
|
||||
```
|
||||
|
||||
Launch the server with the following command.
|
||||
|
||||
```py
|
||||
python server.py
|
||||
```
|
||||
|
||||
The server is accessed at http://localhost:8000. You can curl this model with the following command.
|
||||
```
|
||||
curl -X POST -H "Content-Type: application/json" --data '{"model": "something", "prompt": "a kitten in front of a fireplace"}' http://localhost:8000/v1/images/generations
|
||||
```
|
||||
|
||||
If you need to upgrade some dependencies, you can use either [pip-tools](https://github.com/jazzband/pip-tools) or [uv](https://github.com/astral-sh/uv). For example, upgrade the dependencies with `uv` using the following command.
|
||||
|
||||
```
|
||||
uv pip compile requirements.in -o requirements.txt
|
||||
```
|
||||
|
||||
|
||||
The server is built with [FastAPI](https://fastapi.tiangolo.com/async/). The endpoint for `v1/images/generations` is shown below.
|
||||
```py
|
||||
@app.post("/v1/images/generations")
|
||||
async def generate_image(image_input: TextToImageInput):
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
scheduler = shared_pipeline.pipeline.scheduler.from_config(shared_pipeline.pipeline.scheduler.config)
|
||||
pipeline = StableDiffusion3Pipeline.from_pipe(shared_pipeline.pipeline, scheduler=scheduler)
|
||||
generator = torch.Generator(device="cuda")
|
||||
generator.manual_seed(random.randint(0, 10000000))
|
||||
output = await loop.run_in_executor(None, lambda: pipeline(image_input.prompt, generator = generator))
|
||||
logger.info(f"output: {output}")
|
||||
image_url = save_image(output.images[0])
|
||||
return {"data": [{"url": image_url}]}
|
||||
except Exception as e:
|
||||
if isinstance(e, HTTPException):
|
||||
raise e
|
||||
elif hasattr(e, 'message'):
|
||||
raise HTTPException(status_code=500, detail=e.message + traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e) + traceback.format_exc())
|
||||
```
|
||||
The `generate_image` function is defined as asynchronous with the [async](https://fastapi.tiangolo.com/async/) keyword so that FastAPI knows that whatever is happening in this function won't necessarily return a result right away. Once it hits some point in the function that it needs to await some other [Task](https://docs.python.org/3/library/asyncio-task.html#asyncio.Task), the main thread goes back to answering other HTTP requests. This is shown in the code below with the [await](https://fastapi.tiangolo.com/async/#async-and-await) keyword.
|
||||
```py
|
||||
output = await loop.run_in_executor(None, lambda: pipeline(image_input.prompt, generator = generator))
|
||||
```
|
||||
At this point, the execution of the pipeline function is placed onto a [new thread](https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.run_in_executor), and the main thread performs other things until a result is returned from the `pipeline`.
|
||||
|
||||
Another important aspect of this implementation is creating a `pipeline` from `shared_pipeline`. The goal behind this is to avoid loading the underlying model more than once onto the GPU while still allowing for each new request that is running on a separate thread to have its own generator and scheduler. The scheduler, in particular, is not thread-safe, and it will cause errors like: `IndexError: index 21 is out of bounds for dimension 0 with size 21` if you try to use the same scheduler across multiple threads.
|
||||
9
examples/server/requirements.in
Normal file
9
examples/server/requirements.in
Normal file
@@ -0,0 +1,9 @@
|
||||
torch~=2.4.0
|
||||
transformers==4.46.1
|
||||
sentencepiece
|
||||
aiohttp
|
||||
py-consul
|
||||
prometheus_client >= 0.18.0
|
||||
prometheus-fastapi-instrumentator >= 7.0.0
|
||||
fastapi
|
||||
uvicorn
|
||||
124
examples/server/requirements.txt
Normal file
124
examples/server/requirements.txt
Normal file
@@ -0,0 +1,124 @@
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile requirements.in -o requirements.txt
|
||||
aiohappyeyeballs==2.4.3
|
||||
# via aiohttp
|
||||
aiohttp==3.10.10
|
||||
# via -r requirements.in
|
||||
aiosignal==1.3.1
|
||||
# via aiohttp
|
||||
annotated-types==0.7.0
|
||||
# via pydantic
|
||||
anyio==4.6.2.post1
|
||||
# via starlette
|
||||
attrs==24.2.0
|
||||
# via aiohttp
|
||||
certifi==2024.8.30
|
||||
# via requests
|
||||
charset-normalizer==3.4.0
|
||||
# via requests
|
||||
click==8.1.7
|
||||
# via uvicorn
|
||||
fastapi==0.115.3
|
||||
# via -r requirements.in
|
||||
filelock==3.16.1
|
||||
# via
|
||||
# huggingface-hub
|
||||
# torch
|
||||
# transformers
|
||||
frozenlist==1.5.0
|
||||
# via
|
||||
# aiohttp
|
||||
# aiosignal
|
||||
fsspec==2024.10.0
|
||||
# via
|
||||
# huggingface-hub
|
||||
# torch
|
||||
h11==0.14.0
|
||||
# via uvicorn
|
||||
huggingface-hub==0.26.1
|
||||
# via
|
||||
# tokenizers
|
||||
# transformers
|
||||
idna==3.10
|
||||
# via
|
||||
# anyio
|
||||
# requests
|
||||
# yarl
|
||||
jinja2==3.1.4
|
||||
# via torch
|
||||
markupsafe==3.0.2
|
||||
# via jinja2
|
||||
mpmath==1.3.0
|
||||
# via sympy
|
||||
multidict==6.1.0
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
networkx==3.4.2
|
||||
# via torch
|
||||
numpy==2.1.2
|
||||
# via transformers
|
||||
packaging==24.1
|
||||
# via
|
||||
# huggingface-hub
|
||||
# transformers
|
||||
prometheus-client==0.21.0
|
||||
# via
|
||||
# -r requirements.in
|
||||
# prometheus-fastapi-instrumentator
|
||||
prometheus-fastapi-instrumentator==7.0.0
|
||||
# via -r requirements.in
|
||||
propcache==0.2.0
|
||||
# via yarl
|
||||
py-consul==1.5.3
|
||||
# via -r requirements.in
|
||||
pydantic==2.9.2
|
||||
# via fastapi
|
||||
pydantic-core==2.23.4
|
||||
# via pydantic
|
||||
pyyaml==6.0.2
|
||||
# via
|
||||
# huggingface-hub
|
||||
# transformers
|
||||
regex==2024.9.11
|
||||
# via transformers
|
||||
requests==2.32.3
|
||||
# via
|
||||
# huggingface-hub
|
||||
# py-consul
|
||||
# transformers
|
||||
safetensors==0.4.5
|
||||
# via transformers
|
||||
sentencepiece==0.2.0
|
||||
# via -r requirements.in
|
||||
sniffio==1.3.1
|
||||
# via anyio
|
||||
starlette==0.41.0
|
||||
# via
|
||||
# fastapi
|
||||
# prometheus-fastapi-instrumentator
|
||||
sympy==1.13.3
|
||||
# via torch
|
||||
tokenizers==0.20.1
|
||||
# via transformers
|
||||
torch==2.4.1
|
||||
# via -r requirements.in
|
||||
tqdm==4.66.5
|
||||
# via
|
||||
# huggingface-hub
|
||||
# transformers
|
||||
transformers==4.46.1
|
||||
# via -r requirements.in
|
||||
typing-extensions==4.12.2
|
||||
# via
|
||||
# fastapi
|
||||
# huggingface-hub
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# torch
|
||||
urllib3==2.2.3
|
||||
# via requests
|
||||
uvicorn==0.32.0
|
||||
# via -r requirements.in
|
||||
yarl==1.16.0
|
||||
# via aiohttp
|
||||
133
examples/server/server.py
Normal file
133
examples/server/server.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
import traceback
|
||||
import uuid
|
||||
|
||||
import aiohttp
|
||||
import torch
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from pydantic import BaseModel
|
||||
|
||||
from diffusers.pipelines.stable_diffusion_3 import StableDiffusion3Pipeline
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TextToImageInput(BaseModel):
|
||||
model: str
|
||||
prompt: str
|
||||
size: str | None = None
|
||||
n: int | None = None
|
||||
|
||||
|
||||
class HttpClient:
|
||||
session: aiohttp.ClientSession = None
|
||||
|
||||
def start(self):
|
||||
self.session = aiohttp.ClientSession()
|
||||
|
||||
async def stop(self):
|
||||
await self.session.close()
|
||||
self.session = None
|
||||
|
||||
def __call__(self) -> aiohttp.ClientSession:
|
||||
assert self.session is not None
|
||||
return self.session
|
||||
|
||||
|
||||
class TextToImagePipeline:
|
||||
pipeline: StableDiffusion3Pipeline = None
|
||||
device: str = None
|
||||
|
||||
def start(self):
|
||||
if torch.cuda.is_available():
|
||||
model_path = os.getenv("MODEL_PATH", "stabilityai/stable-diffusion-3.5-large")
|
||||
logger.info("Loading CUDA")
|
||||
self.device = "cuda"
|
||||
self.pipeline = StableDiffusion3Pipeline.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch.bfloat16,
|
||||
).to(device=self.device)
|
||||
elif torch.backends.mps.is_available():
|
||||
model_path = os.getenv("MODEL_PATH", "stabilityai/stable-diffusion-3.5-medium")
|
||||
logger.info("Loading MPS for Mac M Series")
|
||||
self.device = "mps"
|
||||
self.pipeline = StableDiffusion3Pipeline.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch.bfloat16,
|
||||
).to(device=self.device)
|
||||
else:
|
||||
raise Exception("No CUDA or MPS device available")
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
service_url = os.getenv("SERVICE_URL", "http://localhost:8000")
|
||||
image_dir = os.path.join(tempfile.gettempdir(), "images")
|
||||
if not os.path.exists(image_dir):
|
||||
os.makedirs(image_dir)
|
||||
app.mount("/images", StaticFiles(directory=image_dir), name="images")
|
||||
http_client = HttpClient()
|
||||
shared_pipeline = TextToImagePipeline()
|
||||
|
||||
# Configure CORS settings
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # Allows all origins
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"], # Allows all methods, e.g., GET, POST, OPTIONS, etc.
|
||||
allow_headers=["*"], # Allows all headers
|
||||
)
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
def startup():
|
||||
http_client.start()
|
||||
shared_pipeline.start()
|
||||
|
||||
|
||||
def save_image(image):
|
||||
filename = "draw" + str(uuid.uuid4()).split("-")[0] + ".png"
|
||||
image_path = os.path.join(image_dir, filename)
|
||||
# write image to disk at image_path
|
||||
logger.info(f"Saving image to {image_path}")
|
||||
image.save(image_path)
|
||||
return os.path.join(service_url, "images", filename)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
@app.post("/")
|
||||
@app.options("/")
|
||||
async def base():
|
||||
return "Welcome to Diffusers! Where you can use diffusion models to generate images"
|
||||
|
||||
|
||||
@app.post("/v1/images/generations")
|
||||
async def generate_image(image_input: TextToImageInput):
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
scheduler = shared_pipeline.pipeline.scheduler.from_config(shared_pipeline.pipeline.scheduler.config)
|
||||
pipeline = StableDiffusion3Pipeline.from_pipe(shared_pipeline.pipeline, scheduler=scheduler)
|
||||
generator = torch.Generator(device=shared_pipeline.device)
|
||||
generator.manual_seed(random.randint(0, 10000000))
|
||||
output = await loop.run_in_executor(None, lambda: pipeline(image_input.prompt, generator=generator))
|
||||
logger.info(f"output: {output}")
|
||||
image_url = save_image(output.images[0])
|
||||
return {"data": [{"url": image_url}]}
|
||||
except Exception as e:
|
||||
if isinstance(e, HTTPException):
|
||||
raise e
|
||||
elif hasattr(e, "message"):
|
||||
raise HTTPException(status_code=500, detail=e.message + traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e) + traceback.format_exc())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
Reference in New Issue
Block a user