mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
133 lines
7.5 KiB
Python
133 lines
7.5 KiB
Python
from typing import Optional, List
|
|
from pydantic import BaseModel, Field # pylint: disable=no-name-in-module
|
|
from fastapi.exceptions import HTTPException
|
|
from modules import shared
|
|
|
|
|
|
class ReqFramepack(BaseModel):
|
|
variant: str = Field(default=None, title="Model variant", description="Model variant to use")
|
|
prompt: str = Field(default=None, title="Prompt", description="Prompt for the model")
|
|
init_image: str = Field(default=None, title="Initial image", description="Base64 encoded initial image")
|
|
end_image: Optional[str] = Field(default=None, title="End image", description="Base64 encoded end image")
|
|
start_weight: Optional[float] = Field(default=1.0, title="Start weight", description="Weight of the initial image")
|
|
end_weight: Optional[float] = Field(default=1.0, title="End weight", description="Weight of the end image")
|
|
vision_weight: Optional[float] = Field(default=1.0, title="Vision weight", description="Weight of the vision model")
|
|
system_prompt: Optional[str] = Field(default=None, title="System prompt", description="System prompt for the model")
|
|
optimized_prompt: Optional[bool] = Field(default=True, title="Optimized system prompt", description="Use optimized system prompt for the model")
|
|
section_prompt: Optional[str] = Field(default=None, title="Section prompt", description="Prompt for each section")
|
|
negative_prompt: Optional[str] = Field(default=None, title="Negative prompt", description="Negative prompt for the model")
|
|
styles: Optional[List[str]] = Field(default=None, title="Styles", description="Styles for the model")
|
|
seed: Optional[int] = Field(default=None, title="Seed", description="Seed for the model")
|
|
resolution: Optional[int] = Field(default=640, title="Resolution", description="Resolution of the image")
|
|
duration: Optional[float] = Field(default=4, title="Duration", description="Duration of the video in seconds")
|
|
latent_ws: Optional[int] = Field(default=9, title="Latent window size", description="Size of the latent window")
|
|
steps: Optional[int] = Field(default=25, title="Video steps", description="Number of steps for the video generation")
|
|
cfg_scale: Optional[float] = Field(default=1.0, title="CFG scale", description="CFG scale for the model")
|
|
cfg_distilled: Optional[float] = Field(default=10.0, title="Distilled CFG scale", description="Distilled CFG scale for the model")
|
|
cfg_rescale: Optional[float] = Field(default=0.0, title="CFG re-scale", description="CFG re-scale for the model")
|
|
shift: Optional[float] = Field(default=0, title="Sampler shift", description="Shift for the sampler")
|
|
use_teacache: Optional[bool] = Field(default=True, title="Enable TeaCache", description="Use TeaCache for the model")
|
|
use_cfgzero: Optional[bool] = Field(default=False, title="Enable CFGZero", description="Use CFGZero for the model")
|
|
mp4_fps: Optional[int] = Field(default=30, title="FPS", description="Frames per second for the video")
|
|
mp4_codec: Optional[str] = Field(default="libx264", title="Codec", description="Codec for the video")
|
|
mp4_sf: Optional[bool] = Field(default=False, title="Save SafeTensors", description="Save SafeTensors for the video")
|
|
mp4_video: Optional[bool] = Field(default=True, title="Save Video", description="Save video")
|
|
mp4_frames: Optional[bool] = Field(default=False, title="Save Frames", description="Save frames for the video")
|
|
mp4_opt: Optional[str] = Field(default="crf:16", title="Options", description="Options for the video codec")
|
|
mp4_ext: Optional[str] = Field(default="mp4", title="Format", description="Format for the video")
|
|
mp4_interpolate: Optional[int] = Field(default=0, title="Interpolation", description="Interpolation for the video")
|
|
attention: Optional[str] = Field(default="Default", title="Attention", description="Attention type for the model")
|
|
vae_type: Optional[str] = Field(default="Local", title="VAE", description="VAE type for the model")
|
|
vlm_enhance: Optional[bool] = Field(default=False, title="VLM enhance", description="Enable VLM enhance")
|
|
vlm_model: Optional[str] = Field(default=None, title="VLM model", description="VLM model to use")
|
|
vlm_system_prompt: Optional[str] = Field(default=None, title="VLM system prompt", description="System prompt for the VLM model")
|
|
|
|
|
|
class ResFramepack(BaseModel):
|
|
id: str = Field(title="TaskID", description="Task ID")
|
|
filename: str = Field(title="TaskID", description="Task ID")
|
|
message: str = Field(title="TaskID", description="Task ID")
|
|
|
|
|
|
def framepack_post(request: ReqFramepack):
|
|
import numpy as np
|
|
from modules.api import helpers
|
|
from framepack_wrappers import run_framepack
|
|
task_id = shared.state.get_id()
|
|
|
|
try:
|
|
if request.init_image is not None:
|
|
init_image = np.array(helpers.decode_base64_to_image(request.init_image)) if request.init_image else None
|
|
else:
|
|
init_image = None
|
|
except Exception as e:
|
|
shared.log.error(f"API FramePack: id={task_id} cannot decode init image: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e)) from e
|
|
|
|
try:
|
|
if request.end_image is not None:
|
|
end_image = np.array(helpers.decode_base64_to_image(request.end_image)) if request.end_image else None
|
|
else:
|
|
end_image = None
|
|
except Exception as e:
|
|
shared.log.error(f"API FramePack: id={task_id} cannot decode end image: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e)) from e
|
|
|
|
del request.init_image
|
|
del request.end_image
|
|
shared.log.trace(f"API FramePack: id={task_id} init={init_image.shape} end={end_image.shape if end_image else None} {request}")
|
|
|
|
generator = run_framepack(
|
|
_ui_state=None,
|
|
task_id=f'task({task_id})',
|
|
variant=request.variant,
|
|
init_image=init_image,
|
|
end_image=end_image,
|
|
start_weight=request.start_weight,
|
|
end_weight=request.end_weight,
|
|
vision_weight=request.vision_weight,
|
|
prompt=request.prompt,
|
|
system_prompt=request.system_prompt,
|
|
optimized_prompt=request.optimized_prompt,
|
|
section_prompt=request.section_prompt,
|
|
negative_prompt=request.negative_prompt,
|
|
styles=request.styles,
|
|
seed=request.seed,
|
|
resolution=request.resolution,
|
|
duration=request.duration,
|
|
latent_ws=request.latent_ws,
|
|
steps=request.steps,
|
|
cfg_scale=request.cfg_scale,
|
|
cfg_distilled=request.cfg_distilled,
|
|
cfg_rescale=request.cfg_rescale,
|
|
shift=request.shift,
|
|
use_teacache=request.use_teacache,
|
|
use_cfgzero=request.use_cfgzero,
|
|
use_preview=False,
|
|
mp4_fps=request.mp4_fps,
|
|
mp4_codec=request.mp4_codec,
|
|
mp4_sf=request.mp4_sf,
|
|
mp4_video=request.mp4_video,
|
|
mp4_frames=request.mp4_frames,
|
|
mp4_opt=request.mp4_opt,
|
|
mp4_ext=request.mp4_ext,
|
|
mp4_interpolate=request.mp4_interpolate,
|
|
attention=request.attention,
|
|
vae_type=request.vae_type,
|
|
vlm_enhance=request.vlm_enhance,
|
|
vlm_model=request.vlm_model,
|
|
vlm_system_prompt=request.vlm_system_prompt,
|
|
)
|
|
response = ResFramepack(id=task_id, filename='', message='')
|
|
for message in generator:
|
|
if isinstance(message, tuple) and len(message) == 3:
|
|
if isinstance(message[0], str):
|
|
response.filename = message[0]
|
|
if isinstance(message[2], str):
|
|
response.message = message[2]
|
|
return response
|
|
|
|
|
|
def create_api(_fastapi, _gradioapp):
|
|
shared.api.add_api_route("/sdapi/v1/framepack", framepack_post, methods=["POST"], response_model=ResFramepack)
|