mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-29 05:02:09 +03:00
91 lines
3.7 KiB
Python
91 lines
3.7 KiB
Python
from modules import shared, sd_hijack, devices
|
|
from modules.api import models
|
|
from modules.textual_inversion.preprocess import preprocess
|
|
|
|
|
|
def post_create_embedding(args: dict):
|
|
from modules.textual_inversion.textual_inversion import create_embedding
|
|
try:
|
|
shared.state.begin('api-embedding')
|
|
filename = create_embedding(**args) # create empty embedding
|
|
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
|
|
shared.state.end()
|
|
return models.CreateResponse(info = f"create embedding filename: {filename}")
|
|
except AssertionError as e:
|
|
shared.state.end()
|
|
return models.TrainResponse(info = f"create embedding error: {e}")
|
|
|
|
def post_create_hypernetwork(args: dict):
|
|
from modules.hypernetworks.hypernetwork import create_hypernetwork
|
|
try:
|
|
shared.state.begin('api-hypernetwork')
|
|
filename = create_hypernetwork(**args) # create empty embedding # pylint: disable=E1111
|
|
shared.state.end()
|
|
return models.CreateResponse(info = f"create hypernetwork filename: {filename}")
|
|
except AssertionError as e:
|
|
shared.state.end()
|
|
return models.TrainResponse(info = f"create hypernetwork error: {e}")
|
|
|
|
def post_preprocess(args: dict):
|
|
try:
|
|
shared.state.begin('api-preprocess')
|
|
preprocess(**args) # quick operation unless blip/booru interrogation is enabled
|
|
shared.state.end()
|
|
return models.PreprocessResponse(info = 'preprocess complete')
|
|
except KeyError as e:
|
|
shared.state.end()
|
|
return models.PreprocessResponse(info = f"preprocess error: invalid token: {e}")
|
|
except AssertionError as e:
|
|
shared.state.end()
|
|
return models.PreprocessResponse(info = f"preprocess error: {e}")
|
|
except FileNotFoundError as e:
|
|
shared.state.end()
|
|
return models.PreprocessResponse(info = f'preprocess error: {e}')
|
|
|
|
def post_train_embedding(args: dict):
|
|
from modules.textual_inversion.textual_inversion import train_embedding
|
|
try:
|
|
shared.state.begin('api-embedding')
|
|
apply_optimizations = False
|
|
error = None
|
|
filename = ''
|
|
if not apply_optimizations:
|
|
sd_hijack.undo_optimizations()
|
|
try:
|
|
_embedding, filename = train_embedding(**args) # can take a long time to complete
|
|
except Exception as e:
|
|
error = e
|
|
finally:
|
|
if not apply_optimizations:
|
|
sd_hijack.apply_optimizations()
|
|
shared.state.end()
|
|
return models.TrainResponse(info = f"train embedding complete: filename: {filename} error: {error}")
|
|
except AssertionError as msg:
|
|
shared.state.end()
|
|
return models.TrainResponse(info = f"train embedding error: {msg}")
|
|
|
|
def post_train_hypernetwork(args: dict):
|
|
from modules.hypernetworks.hypernetwork import train_hypernetwork
|
|
try:
|
|
shared.state.begin('api-hypernetwork')
|
|
shared.loaded_hypernetworks = []
|
|
apply_optimizations = False
|
|
error = None
|
|
filename = ''
|
|
if not apply_optimizations:
|
|
sd_hijack.undo_optimizations()
|
|
try:
|
|
_hypernetwork, filename = train_hypernetwork(**args)
|
|
except Exception as e:
|
|
error = e
|
|
finally:
|
|
shared.sd_model.cond_stage_model.to(devices.device)
|
|
shared.sd_model.first_stage_model.to(devices.device)
|
|
if not apply_optimizations:
|
|
sd_hijack.apply_optimizations()
|
|
shared.state.end()
|
|
return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
|
|
except AssertionError:
|
|
shared.state.end()
|
|
return models.TrainResponse(info=f"train embedding error: {error}")
|