mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
* Basic implementation of request scheduling * Basic editing in SD and Flux Pipelines * Small Fix * Fix * Update for more pipelines * Add examples/server-async * Add examples/server-async * Updated RequestScopedPipeline to handle a single tokenizer lock to avoid race conditions * Fix * Fix _TokenizerLockWrapper * Fix _TokenizerLockWrapper * Delete _TokenizerLockWrapper * Fix tokenizer * Update examples/server-async * Fix server-async * Optimizations in examples/server-async * We keep the implementation simple in examples/server-async * Update examples/server-async/README.md * Update examples/server-async/README.md for changes to tokenizer locks and backward-compatible retrieve_timesteps * The changes to the diffusers core have been undone and all logic is being moved to exmaples/server-async * Update examples/server-async/utils/* * Fix BaseAsyncScheduler * Rollback in the core of the diffusers * Update examples/server-async/README.md * Complete rollback of diffusers core files * Simple implementation of an asynchronous server compatible with SD3-3.5 and Flux Pipelines * Update examples/server-async/README.md * Fixed import errors in 'examples/server-async/serverasync.py' * Flux Pipeline Discard * Update examples/server-async/README.md * Apply style fixes --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
92 lines
3.1 KiB
Python
92 lines
3.1 KiB
Python
import logging
|
|
import os
|
|
from dataclasses import dataclass, field
|
|
from typing import List
|
|
|
|
import torch
|
|
from pydantic import BaseModel
|
|
|
|
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import StableDiffusion3Pipeline
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TextToImageInput(BaseModel):
|
|
model: str
|
|
prompt: str
|
|
size: str | None = None
|
|
n: int | None = None
|
|
|
|
|
|
@dataclass
|
|
class PresetModels:
|
|
SD3: List[str] = field(default_factory=lambda: ["stabilityai/stable-diffusion-3-medium"])
|
|
SD3_5: List[str] = field(
|
|
default_factory=lambda: [
|
|
"stabilityai/stable-diffusion-3.5-large",
|
|
"stabilityai/stable-diffusion-3.5-large-turbo",
|
|
"stabilityai/stable-diffusion-3.5-medium",
|
|
]
|
|
)
|
|
|
|
|
|
class TextToImagePipelineSD3:
|
|
def __init__(self, model_path: str | None = None):
|
|
self.model_path = model_path or os.getenv("MODEL_PATH")
|
|
self.pipeline: StableDiffusion3Pipeline | None = None
|
|
self.device: str | None = None
|
|
|
|
def start(self):
|
|
if torch.cuda.is_available():
|
|
model_path = self.model_path or "stabilityai/stable-diffusion-3.5-large"
|
|
logger.info("Loading CUDA")
|
|
self.device = "cuda"
|
|
self.pipeline = StableDiffusion3Pipeline.from_pretrained(
|
|
model_path,
|
|
torch_dtype=torch.float16,
|
|
).to(device=self.device)
|
|
elif torch.backends.mps.is_available():
|
|
model_path = self.model_path or "stabilityai/stable-diffusion-3.5-medium"
|
|
logger.info("Loading MPS for Mac M Series")
|
|
self.device = "mps"
|
|
self.pipeline = StableDiffusion3Pipeline.from_pretrained(
|
|
model_path,
|
|
torch_dtype=torch.bfloat16,
|
|
).to(device=self.device)
|
|
else:
|
|
raise Exception("No CUDA or MPS device available")
|
|
|
|
|
|
class ModelPipelineInitializer:
|
|
def __init__(self, model: str = "", type_models: str = "t2im"):
|
|
self.model = model
|
|
self.type_models = type_models
|
|
self.pipeline = None
|
|
self.device = "cuda" if torch.cuda.is_available() else "mps"
|
|
self.model_type = None
|
|
|
|
def initialize_pipeline(self):
|
|
if not self.model:
|
|
raise ValueError("Model name not provided")
|
|
|
|
# Check if model exists in PresetModels
|
|
preset_models = PresetModels()
|
|
|
|
# Determine which model type we're dealing with
|
|
if self.model in preset_models.SD3:
|
|
self.model_type = "SD3"
|
|
elif self.model in preset_models.SD3_5:
|
|
self.model_type = "SD3_5"
|
|
|
|
# Create appropriate pipeline based on model type and type_models
|
|
if self.type_models == "t2im":
|
|
if self.model_type in ["SD3", "SD3_5"]:
|
|
self.pipeline = TextToImagePipelineSD3(self.model)
|
|
else:
|
|
raise ValueError(f"Model type {self.model_type} not supported for text-to-image")
|
|
elif self.type_models == "t2v":
|
|
raise ValueError(f"Unsupported type_models: {self.type_models}")
|
|
|
|
return self.pipeline
|