1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-29 05:02:09 +03:00
Files
Vladimir Mandic 6e04bac7de merge control
2023-12-20 15:27:13 -05:00

20 lines
552 B
Python

import torch.nn as nn
from .submodules.encoder import Encoder
from .submodules.decoder import Decoder
class NNET(nn.Module):
def __init__(self, args):
super(NNET, self).__init__()
self.encoder = Encoder()
self.decoder = Decoder(args)
def get_1x_lr_params(self): # lr/10 learning rate
return self.encoder.parameters()
def get_10x_lr_params(self): # lr learning rate
return self.decoder.parameters()
def forward(self, img, **kwargs):
return self.decoder(self.encoder(img), **kwargs)