Skip to content
Snippets Groups Projects
Commit a528e833 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[extractor] added the possibility to have an untrained model (for test purposes)

parent 28f9ae2d
No related branches found
No related tags found
No related merge requests found
import numpy
import sys
import torch
import torch.nn as nn
......@@ -7,6 +8,10 @@ from torch.autograd import Variable
import torchvision.transforms as transforms
import bob.core
logger = bob.core.log.setup("bob.learn.pytorch")
CNN8_CONFIG = [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M']
def make_conv_layers(cfg, input_c = 3):
......@@ -43,14 +48,17 @@ from bob.bio.base.extractor import Extractor
class CNN8Extractor(Extractor):
def __init__(self, model_file, num_classes=10575):
def __init__(self, model_file=None, num_classes=10575):
Extractor.__init__(self, skip_extractor_training=True)
# model
self.network = CNN8(num_classes)
cp = torch.load(model_file)
if 'state_dict' in cp:
if model_file is None:
logger.warning("No model file provided, building network from scratch !")
else:
cp = torch.load(model_file)
if 'state_dict' in cp:
self.network.load_state_dict(cp['state_dict'])
self.network.eval()
......
......@@ -7,6 +7,9 @@ from torch.autograd import Variable
import torchvision.transforms as transforms
import bob.core
logger = bob.core.log.setup("bob.learn.pytorch")
CASIA_CONFIG = [32, 64, 'M', 64, 128, 'M', 96, 192, 'M', 128, 256, 'M', 160, 320]
def make_conv_layers(cfg, input_c = 3):
......@@ -44,14 +47,17 @@ from bob.bio.base.extractor import Extractor
class CasiaNetExtractor(Extractor):
def __init__(self, model_file, num_classes=10575):
def __init__(self, model_file=None, num_classes=10575):
Extractor.__init__(self, skip_extractor_training=True)
# model
self.network = CASIA_NET(num_classes)
cp = torch.load(model_file)
if 'state_dict' in cp:
if model_file is None:
logger.warning("No model file provided, building network from scratch !")
else:
cp = torch.load(model_file)
if 'state_dict' in cp:
self.network.load_state_dict(cp['state_dict'])
self.network.eval()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment