Skip to content
Snippets Groups Projects

Light cnn

Merged Guillaume HEUSCH requested to merge lightCNN into master
1 file
+ 92
24
Compare changes
  • Side-by-side
  • Inline
@@ -13,6 +13,7 @@ logger = bob.core.log.setup("bob.learn.pytorch")
@@ -13,6 +13,7 @@ logger = bob.core.log.setup("bob.learn.pytorch")
import time
import time
import os
import os
 
import numpy
class CNNTrainer(object):
class CNNTrainer(object):
"""
"""
@@ -31,7 +32,7 @@ class CNNTrainer(object):
@@ -31,7 +32,7 @@ class CNNTrainer(object):
"""
"""
def __init__(self, network, batch_size=64, use_gpu=False, verbosity_level=2):
def __init__(self, network, batch_size=64, use_gpu=False, verbosity_level=2, num_classes=2):
""" Init function
""" Init function
Parameters
Parameters
@@ -44,9 +45,12 @@ class CNNTrainer(object):
@@ -44,9 +45,12 @@ class CNNTrainer(object):
If you would like to use the gpu
If you would like to use the gpu
verbosity_level: int
verbosity_level: int
The level of verbosity output to stdout
The level of verbosity output to stdout
 
num_classes: int
 
The number of classes
"""
"""
self.network = network
self.network = network
 
self.num_classes = num_classes
self.batch_size = batch_size
self.batch_size = batch_size
self.use_gpu = use_gpu
self.use_gpu = use_gpu
self.criterion = nn.CrossEntropyLoss()
self.criterion = nn.CrossEntropyLoss()
@@ -56,31 +60,88 @@ class CNNTrainer(object):
@@ -56,31 +60,88 @@ class CNNTrainer(object):
bob.core.log.set_verbosity_level(logger, verbosity_level)
bob.core.log.set_verbosity_level(logger, verbosity_level)
def load_and_initialize_model(self, model_filename):
def load_model(self, model_filename):
""" Loads and initialize a model
"""Loads an existing model
Parameters
Parameters
----------
----------
model_file: str
model_filename: str
The filename of the model to load
Returns
-------
start_epoch: int
The epoch to start with
start_iteration: int
The iteration to start with
losses: list(float)
The list of losses from previous training
"""
"""
try:
cp = torch.load(model_filename)
cp = torch.load(model_filename)
self.network.load_state_dict(cp['state_dict'])
#self.network.load_state_dict(cp['state_dict'])
start_epoch = cp['epoch']
except RuntimeError:
start_iter = cp['iteration']
# pre-trained model was probably saved using nn.DataParallel ...
losses = cp['loss']
cp = torch.load(model_filename, map_location='cpu')
 
if 'state_dict' in cp:
 
from collections import OrderedDict
 
new_state_dict = OrderedDict()
 
for k, v in cp['state_dict'].items():
 
name = k[7:]
 
new_state_dict[name] = v
 
cp['state_dict'] = new_state_dict
 
 
print(type(self.network))
 
 
###########################################################################################################
 
### for each defined architecture, get the output size in pre-trained model, and change it if necessary ###
 
 
# LightCNN9
 
if isinstance(self.network, bob.learn.pytorch.architectures.LightCNN.LightCNN9):
 
 
num_classes_pretrained = cp['state_dict']['fc2.weight'].shape[0]
 
 
if num_classes_pretrained == self.num_classes:
 
self.network.load_state_dict(cp['state_dict'])
 
else:
 
var = 1.0 / (cp['state_dict']['fc2.weight'].shape[0])
 
np_weights = numpy.random.normal(loc=0.0, scale=var, size=((self.num_classes+1), cp['state_dict']['fc2.weight'].shape[1]))
 
cp['state_dict']['fc2.weight'] = torch.from_numpy(np_weights)
 
cp['state_dict']['fc2.bias'] = torch.zeros(((self.num_classes+1),))
 
#self.network.load_state_dict(cp['state_dict'], strict=False)
 
self.network.load_state_dict(cp['state_dict'], strict=True)
 
 
# CNN8
 
if isinstance(self.network, bob.learn.pytorch.architectures.CNN8):
 
 
num_classes_pretrained = cp['state_dict']['classifier.weight'].shape[0]
 
if num_classes_pretrained == self.num_classes:
 
self.network.load_state_dict(cp['state_dict'])
 
else:
 
var = 1.0 / (cp['state_dict']['classifier.weight'].shape[0])
 
np_weights = numpy.random.normal(loc=0.0, scale=var, size=((self.num_classes+1), cp['state_dict']['classifier.weight'].shape[1]))
 
cp['state_dict']['classifier.weight'] = torch.from_numpy(np_weights)
 
cp['state_dict']['classifier.bias'] = torch.zeros(((self.num_classes+1),))
 
#self.network.load_state_dict(cp['state_dict'], strict=False)
 
self.network.load_state_dict(cp['state_dict'], strict=True)
 
 
# CASIANet
 
if isinstance(self.network, bob.learn.pytorch.architectures.CASIANet):
 
 
num_classes_pretrained = cp['state_dict']['classifier.weight'].shape[0]
 
if num_classes_pretrained == self.num_classes:
 
self.network.load_state_dict(cp['state_dict'])
 
else:
 
var = 1.0 / (cp['state_dict']['classifier.weight'].shape[0])
 
np_weights = numpy.random.normal(loc=0.0, scale=var, size=((self.num_classes+1), cp['state_dict']['classifier.weight'].shape[1]))
 
cp['state_dict']['classifier.weight'] = torch.from_numpy(np_weights)
 
cp['state_dict']['classifier.bias'] = torch.zeros(((self.num_classes+1),))
 
#self.network.load_state_dict(cp['state_dict'], strict=False)
 
self.network.load_state_dict(cp['state_dict'], strict=True)
 
 
###########################################################################################################
 
 
start_epoch = 0
 
start_iter = 0
 
losses = []
 
if 'epoch' in cp.keys():
 
start_epoch = cp['epoch']
 
if 'iteration' in cp.keys():
 
start_iter = cp['iteration']
 
if 'losses' in cp.keys():
 
losses = cp['epoch']
 
return start_epoch, start_iter, losses
return start_epoch, start_iter, losses
@@ -133,8 +194,15 @@ class CNNTrainer(object):
@@ -133,8 +194,15 @@ class CNNTrainer(object):
# if model exists, load it
# if model exists, load it
if model is not None:
if model is not None:
start_epoch, start_iter, losses = self.load_model(model)
logger.info('Starting training at epoch {}, iteration {} - last loss value is {}'.format(start_epoch, start_iter, losses[-1]))
start_epoch, start_iter, losses = self.load_and_initialize_model(model)
 
if start_epoch != 0:
 
logger.info('Previous network was trained up to epoch {}, iteration {}'.format(start_epoch, start_iter, losses[-1]))
 
if losses:
 
logger.info('Last loss = {}'.format(losses[-1]))
 
else:
 
logger.info('Starting training / fine-tuning from pre-trained model')
 
else:
else:
start_epoch = 0
start_epoch = 0
start_iter = 0
start_iter = 0
Loading