1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-29 05:02:09 +03:00
Files
sdnext/modules/api/train.py
Vladimir Mandic 7e447222a1 refactor api
2024-01-28 10:13:09 -05:00

91 lines
3.7 KiB
Python

from modules import shared, sd_hijack, devices
from modules.api import models
from modules.textual_inversion.preprocess import preprocess
def post_create_embedding(args: dict):
from modules.textual_inversion.textual_inversion import create_embedding
try:
shared.state.begin('api-embedding')
filename = create_embedding(**args) # create empty embedding
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
shared.state.end()
return models.CreateResponse(info = f"create embedding filename: {filename}")
except AssertionError as e:
shared.state.end()
return models.TrainResponse(info = f"create embedding error: {e}")
def post_create_hypernetwork(args: dict):
from modules.hypernetworks.hypernetwork import create_hypernetwork
try:
shared.state.begin('api-hypernetwork')
filename = create_hypernetwork(**args) # create empty embedding # pylint: disable=E1111
shared.state.end()
return models.CreateResponse(info = f"create hypernetwork filename: {filename}")
except AssertionError as e:
shared.state.end()
return models.TrainResponse(info = f"create hypernetwork error: {e}")
def post_preprocess(args: dict):
try:
shared.state.begin('api-preprocess')
preprocess(**args) # quick operation unless blip/booru interrogation is enabled
shared.state.end()
return models.PreprocessResponse(info = 'preprocess complete')
except KeyError as e:
shared.state.end()
return models.PreprocessResponse(info = f"preprocess error: invalid token: {e}")
except AssertionError as e:
shared.state.end()
return models.PreprocessResponse(info = f"preprocess error: {e}")
except FileNotFoundError as e:
shared.state.end()
return models.PreprocessResponse(info = f'preprocess error: {e}')
def post_train_embedding(args: dict):
from modules.textual_inversion.textual_inversion import train_embedding
try:
shared.state.begin('api-embedding')
apply_optimizations = False
error = None
filename = ''
if not apply_optimizations:
sd_hijack.undo_optimizations()
try:
_embedding, filename = train_embedding(**args) # can take a long time to complete
except Exception as e:
error = e
finally:
if not apply_optimizations:
sd_hijack.apply_optimizations()
shared.state.end()
return models.TrainResponse(info = f"train embedding complete: filename: {filename} error: {error}")
except AssertionError as msg:
shared.state.end()
return models.TrainResponse(info = f"train embedding error: {msg}")
def post_train_hypernetwork(args: dict):
from modules.hypernetworks.hypernetwork import train_hypernetwork
try:
shared.state.begin('api-hypernetwork')
shared.loaded_hypernetworks = []
apply_optimizations = False
error = None
filename = ''
if not apply_optimizations:
sd_hijack.undo_optimizations()
try:
_hypernetwork, filename = train_hypernetwork(**args)
except Exception as e:
error = e
finally:
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)
if not apply_optimizations:
sd_hijack.apply_optimizations()
shared.state.end()
return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
except AssertionError:
shared.state.end()
return models.TrainResponse(info=f"train embedding error: {error}")