mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
29 lines
859 B
Python
29 lines
859 B
Python
import os
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class Encoder(nn.Module):
|
|
def __init__(self):
|
|
super(Encoder, self).__init__()
|
|
|
|
basemodel_name = 'tf_efficientnet_b5_ap'
|
|
repo_path = os.path.join(os.path.dirname(__file__), 'efficientnet_repo')
|
|
basemodel = torch.hub.load(repo_path, basemodel_name, pretrained=False, source='local')
|
|
|
|
# Remove last layer
|
|
basemodel.global_pool = nn.Identity()
|
|
basemodel.classifier = nn.Identity()
|
|
|
|
self.original_model = basemodel
|
|
|
|
def forward(self, x):
|
|
features = [x]
|
|
for k, v in self.original_model._modules.items():
|
|
if k == 'blocks':
|
|
for _ki, vi in v._modules.items():
|
|
features.append(vi(features[-1]))
|
|
else:
|
|
features.append(v(features[-1]))
|
|
return features
|