1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Add server example (#9918)

* Add server example.

* Minor updates to README.

* Add fixes after local testing.

* Apply suggestions from code review

Updates to README from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* More doc updates.

* Maybe this will work to build the docs correctly?

* Fix style issues.

* Fix toc.

* Minor reformatting.

* Move docs to proper loc.

* Fix missing tick.

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Sync docs changes back to README.

* Very minor update to docs to add space.

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
This commit is contained in:
Grant Sherrick
2024-11-18 11:26:13 -06:00
committed by GitHub
parent 365a938884
commit c3c94fe71b
6 changed files with 390 additions and 0 deletions

View File

@@ -55,6 +55,8 @@
- sections:
- local: using-diffusers/overview_techniques
title: Overview
- local: using-diffusers/create_a_server
title: Create a server
- local: training/distributed_inference
title: Distributed inference
- local: using-diffusers/merge_loras

View File

@@ -0,0 +1,61 @@
# Create a server
Diffusers' pipelines can be used as an inference engine for a server. It supports concurrent and multithreaded requests to generate images that may be requested by multiple users at the same time.
This guide will show you how to use the [`StableDiffusion3Pipeline`] in a server, but feel free to use any pipeline you want.
Start by navigating to the `examples/server` folder and installing all of the dependencies.
```py
pip install .
pip install -f requirements.txt
```
Launch the server with the following command.
```py
python server.py
```
The server is accessed at http://localhost:8000. You can curl this model with the following command.
```
curl -X POST -H "Content-Type: application/json" --data '{"model": "something", "prompt": "a kitten in front of a fireplace"}' http://localhost:8000/v1/images/generations
```
If you need to upgrade some dependencies, you can use either [pip-tools](https://github.com/jazzband/pip-tools) or [uv](https://github.com/astral-sh/uv). For example, upgrade the dependencies with `uv` using the following command.
```
uv pip compile requirements.in -o requirements.txt
```
The server is built with [FastAPI](https://fastapi.tiangolo.com/async/). The endpoint for `v1/images/generations` is shown below.
```py
@app.post("/v1/images/generations")
async def generate_image(image_input: TextToImageInput):
try:
loop = asyncio.get_event_loop()
scheduler = shared_pipeline.pipeline.scheduler.from_config(shared_pipeline.pipeline.scheduler.config)
pipeline = StableDiffusion3Pipeline.from_pipe(shared_pipeline.pipeline, scheduler=scheduler)
generator = torch.Generator(device="cuda")
generator.manual_seed(random.randint(0, 10000000))
output = await loop.run_in_executor(None, lambda: pipeline(image_input.prompt, generator = generator))
logger.info(f"output: {output}")
image_url = save_image(output.images[0])
return {"data": [{"url": image_url}]}
except Exception as e:
if isinstance(e, HTTPException):
raise e
elif hasattr(e, 'message'):
raise HTTPException(status_code=500, detail=e.message + traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e) + traceback.format_exc())
```
The `generate_image` function is defined as asynchronous with the [async](https://fastapi.tiangolo.com/async/) keyword so that FastAPI knows that whatever is happening in this function won't necessarily return a result right away. Once it hits some point in the function that it needs to await some other [Task](https://docs.python.org/3/library/asyncio-task.html#asyncio.Task), the main thread goes back to answering other HTTP requests. This is shown in the code below with the [await](https://fastapi.tiangolo.com/async/#async-and-await) keyword.
```py
output = await loop.run_in_executor(None, lambda: pipeline(image_input.prompt, generator = generator))
```
At this point, the execution of the pipeline function is placed onto a [new thread](https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.run_in_executor), and the main thread performs other things until a result is returned from the `pipeline`.
Another important aspect of this implementation is creating a `pipeline` from `shared_pipeline`. The goal behind this is to avoid loading the underlying model more than once onto the GPU while still allowing for each new request that is running on a separate thread to have its own generator and scheduler. The scheduler, in particular, is not thread-safe, and it will cause errors like: `IndexError: index 21 is out of bounds for dimension 0 with size 21` if you try to use the same scheduler across multiple threads.

61
examples/server/README.md Normal file
View File

@@ -0,0 +1,61 @@
# Create a server
Diffusers' pipelines can be used as an inference engine for a server. It supports concurrent and multithreaded requests to generate images that may be requested by multiple users at the same time.
This guide will show you how to use the [`StableDiffusion3Pipeline`] in a server, but feel free to use any pipeline you want.
Start by navigating to the `examples/server` folder and installing all of the dependencies.
```py
pip install .
pip install -f requirements.txt
```
Launch the server with the following command.
```py
python server.py
```
The server is accessed at http://localhost:8000. You can curl this model with the following command.
```
curl -X POST -H "Content-Type: application/json" --data '{"model": "something", "prompt": "a kitten in front of a fireplace"}' http://localhost:8000/v1/images/generations
```
If you need to upgrade some dependencies, you can use either [pip-tools](https://github.com/jazzband/pip-tools) or [uv](https://github.com/astral-sh/uv). For example, upgrade the dependencies with `uv` using the following command.
```
uv pip compile requirements.in -o requirements.txt
```
The server is built with [FastAPI](https://fastapi.tiangolo.com/async/). The endpoint for `v1/images/generations` is shown below.
```py
@app.post("/v1/images/generations")
async def generate_image(image_input: TextToImageInput):
try:
loop = asyncio.get_event_loop()
scheduler = shared_pipeline.pipeline.scheduler.from_config(shared_pipeline.pipeline.scheduler.config)
pipeline = StableDiffusion3Pipeline.from_pipe(shared_pipeline.pipeline, scheduler=scheduler)
generator = torch.Generator(device="cuda")
generator.manual_seed(random.randint(0, 10000000))
output = await loop.run_in_executor(None, lambda: pipeline(image_input.prompt, generator = generator))
logger.info(f"output: {output}")
image_url = save_image(output.images[0])
return {"data": [{"url": image_url}]}
except Exception as e:
if isinstance(e, HTTPException):
raise e
elif hasattr(e, 'message'):
raise HTTPException(status_code=500, detail=e.message + traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e) + traceback.format_exc())
```
The `generate_image` function is defined as asynchronous with the [async](https://fastapi.tiangolo.com/async/) keyword so that FastAPI knows that whatever is happening in this function won't necessarily return a result right away. Once it hits some point in the function that it needs to await some other [Task](https://docs.python.org/3/library/asyncio-task.html#asyncio.Task), the main thread goes back to answering other HTTP requests. This is shown in the code below with the [await](https://fastapi.tiangolo.com/async/#async-and-await) keyword.
```py
output = await loop.run_in_executor(None, lambda: pipeline(image_input.prompt, generator = generator))
```
At this point, the execution of the pipeline function is placed onto a [new thread](https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.run_in_executor), and the main thread performs other things until a result is returned from the `pipeline`.
Another important aspect of this implementation is creating a `pipeline` from `shared_pipeline`. The goal behind this is to avoid loading the underlying model more than once onto the GPU while still allowing for each new request that is running on a separate thread to have its own generator and scheduler. The scheduler, in particular, is not thread-safe, and it will cause errors like: `IndexError: index 21 is out of bounds for dimension 0 with size 21` if you try to use the same scheduler across multiple threads.

View File

@@ -0,0 +1,9 @@
torch~=2.4.0
transformers==4.46.1
sentencepiece
aiohttp
py-consul
prometheus_client >= 0.18.0
prometheus-fastapi-instrumentator >= 7.0.0
fastapi
uvicorn

View File

@@ -0,0 +1,124 @@
# This file was autogenerated by uv via the following command:
# uv pip compile requirements.in -o requirements.txt
aiohappyeyeballs==2.4.3
# via aiohttp
aiohttp==3.10.10
# via -r requirements.in
aiosignal==1.3.1
# via aiohttp
annotated-types==0.7.0
# via pydantic
anyio==4.6.2.post1
# via starlette
attrs==24.2.0
# via aiohttp
certifi==2024.8.30
# via requests
charset-normalizer==3.4.0
# via requests
click==8.1.7
# via uvicorn
fastapi==0.115.3
# via -r requirements.in
filelock==3.16.1
# via
# huggingface-hub
# torch
# transformers
frozenlist==1.5.0
# via
# aiohttp
# aiosignal
fsspec==2024.10.0
# via
# huggingface-hub
# torch
h11==0.14.0
# via uvicorn
huggingface-hub==0.26.1
# via
# tokenizers
# transformers
idna==3.10
# via
# anyio
# requests
# yarl
jinja2==3.1.4
# via torch
markupsafe==3.0.2
# via jinja2
mpmath==1.3.0
# via sympy
multidict==6.1.0
# via
# aiohttp
# yarl
networkx==3.4.2
# via torch
numpy==2.1.2
# via transformers
packaging==24.1
# via
# huggingface-hub
# transformers
prometheus-client==0.21.0
# via
# -r requirements.in
# prometheus-fastapi-instrumentator
prometheus-fastapi-instrumentator==7.0.0
# via -r requirements.in
propcache==0.2.0
# via yarl
py-consul==1.5.3
# via -r requirements.in
pydantic==2.9.2
# via fastapi
pydantic-core==2.23.4
# via pydantic
pyyaml==6.0.2
# via
# huggingface-hub
# transformers
regex==2024.9.11
# via transformers
requests==2.32.3
# via
# huggingface-hub
# py-consul
# transformers
safetensors==0.4.5
# via transformers
sentencepiece==0.2.0
# via -r requirements.in
sniffio==1.3.1
# via anyio
starlette==0.41.0
# via
# fastapi
# prometheus-fastapi-instrumentator
sympy==1.13.3
# via torch
tokenizers==0.20.1
# via transformers
torch==2.4.1
# via -r requirements.in
tqdm==4.66.5
# via
# huggingface-hub
# transformers
transformers==4.46.1
# via -r requirements.in
typing-extensions==4.12.2
# via
# fastapi
# huggingface-hub
# pydantic
# pydantic-core
# torch
urllib3==2.2.3
# via requests
uvicorn==0.32.0
# via -r requirements.in
yarl==1.16.0
# via aiohttp

133
examples/server/server.py Normal file
View File

@@ -0,0 +1,133 @@
import asyncio
import logging
import os
import random
import tempfile
import traceback
import uuid
import aiohttp
import torch
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from diffusers.pipelines.stable_diffusion_3 import StableDiffusion3Pipeline
logger = logging.getLogger(__name__)
class TextToImageInput(BaseModel):
model: str
prompt: str
size: str | None = None
n: int | None = None
class HttpClient:
session: aiohttp.ClientSession = None
def start(self):
self.session = aiohttp.ClientSession()
async def stop(self):
await self.session.close()
self.session = None
def __call__(self) -> aiohttp.ClientSession:
assert self.session is not None
return self.session
class TextToImagePipeline:
pipeline: StableDiffusion3Pipeline = None
device: str = None
def start(self):
if torch.cuda.is_available():
model_path = os.getenv("MODEL_PATH", "stabilityai/stable-diffusion-3.5-large")
logger.info("Loading CUDA")
self.device = "cuda"
self.pipeline = StableDiffusion3Pipeline.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
).to(device=self.device)
elif torch.backends.mps.is_available():
model_path = os.getenv("MODEL_PATH", "stabilityai/stable-diffusion-3.5-medium")
logger.info("Loading MPS for Mac M Series")
self.device = "mps"
self.pipeline = StableDiffusion3Pipeline.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
).to(device=self.device)
else:
raise Exception("No CUDA or MPS device available")
app = FastAPI()
service_url = os.getenv("SERVICE_URL", "http://localhost:8000")
image_dir = os.path.join(tempfile.gettempdir(), "images")
if not os.path.exists(image_dir):
os.makedirs(image_dir)
app.mount("/images", StaticFiles(directory=image_dir), name="images")
http_client = HttpClient()
shared_pipeline = TextToImagePipeline()
# Configure CORS settings
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins
allow_credentials=True,
allow_methods=["*"], # Allows all methods, e.g., GET, POST, OPTIONS, etc.
allow_headers=["*"], # Allows all headers
)
@app.on_event("startup")
def startup():
http_client.start()
shared_pipeline.start()
def save_image(image):
filename = "draw" + str(uuid.uuid4()).split("-")[0] + ".png"
image_path = os.path.join(image_dir, filename)
# write image to disk at image_path
logger.info(f"Saving image to {image_path}")
image.save(image_path)
return os.path.join(service_url, "images", filename)
@app.get("/")
@app.post("/")
@app.options("/")
async def base():
return "Welcome to Diffusers! Where you can use diffusion models to generate images"
@app.post("/v1/images/generations")
async def generate_image(image_input: TextToImageInput):
try:
loop = asyncio.get_event_loop()
scheduler = shared_pipeline.pipeline.scheduler.from_config(shared_pipeline.pipeline.scheduler.config)
pipeline = StableDiffusion3Pipeline.from_pipe(shared_pipeline.pipeline, scheduler=scheduler)
generator = torch.Generator(device=shared_pipeline.device)
generator.manual_seed(random.randint(0, 10000000))
output = await loop.run_in_executor(None, lambda: pipeline(image_input.prompt, generator=generator))
logger.info(f"output: {output}")
image_url = save_image(output.images[0])
return {"data": [{"url": image_url}]}
except Exception as e:
if isinstance(e, HTTPException):
raise e
elif hasattr(e, "message"):
raise HTTPException(status_code=500, detail=e.message + traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e) + traceback.format_exc())
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)