1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-29 05:02:09 +03:00
Files
sdnext/scripts/lbm/base/base_model.py
Vladimir Mandic 2b9056179d add lbm background replace with relightining
Signed-off-by: Vladimir Mandic <mandic00@live.com>
2025-07-04 15:33:16 -04:00

65 lines
2.0 KiB
Python

from typing import Any, Dict
import torch
import torch.nn as nn
from .model_config import ModelConfig
class BaseModel(nn.Module):
def __init__(self, config: ModelConfig):
nn.Module.__init__(self)
self.config = config
self.input_key = config.input_key
self.device = torch.device("cpu")
self.dtype = torch.float32
def on_fit_start(self, device: torch.device | None = None, *args, **kwargs):
"""Called when the training starts
Args:
device (Optional[torch.device], optional): The device to use. Usefull to set
relevant parameters on the model and embedder to the right device only
once at the start of the training. Defaults to None.
"""
if device is not None:
self.device = device
self.to(self.device)
def forward(self, batch: Dict[str, Any], *args, **kwargs):
raise NotImplementedError("forward method is not implemented")
def freeze(self):
"""Freeze the model"""
self.eval()
for param in self.parameters():
param.requires_grad = False
def to(self, *args, **kwargs):
device, dtype, non_blocking, _ = torch._C._nn._parse_to(*args, **kwargs)
self = super().to(
device=device,
dtype=dtype,
non_blocking=non_blocking,
)
if device is not None:
self.device = device
if dtype is not None:
self.dtype = dtype
return self
def compute_metrics(self, batch: Dict[str, Any], *args, **kwargs):
"""Compute the metrics"""
return {}
def sample(self, batch: Dict[str, Any], *args, **kwargs):
"""Sample from the model"""
return {}
def log_samples(self, batch: Dict[str, Any], *args, **kwargs):
"""Log the samples"""
return None
def on_train_batch_end(self, batch: Dict[str, Any], *args, **kwargs):
"""Update the model an optimization is perforned on a batch."""
pass