Commit a26ab66c authored by Guillaume HEUSCH's avatar Guillaume HEUSCH

[trainers] fixed stuff on CNNTrainer, to be able to load model saved on both...

[trainers] fixed stuff on CNNTrainer, to be able to load model saved on both CPU/GPU on both CPU/GPU
parent 6a7a4fed
Pipeline #26438 passed with stage
in 10 minutes and 49 seconds
......@@ -70,35 +70,36 @@ class CNNTrainer(object):
"""
try:
cp = torch.load(model_filename)
#self.network.load_state_dict(cp['state_dict'])
except RuntimeError:
# pre-trained model was probably saved using nn.DataParallel ...
cp = torch.load(model_filename, map_location='cpu')
if 'state_dict' in cp:
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in cp['state_dict'].items():
name = k[7:]
new_state_dict[name] = v
cp['state_dict'] = new_state_dict
print(type(self.network))
if 'state_dict' in cp:
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in cp['state_dict'].items():
name = k[7:]
new_state_dict[name] = v
cp['state_dict'] = new_state_dict
###########################################################################################################
### for each defined architecture, get the output size in pre-trained model, and change it if necessary ###
# LightCNN9
if isinstance(self.network, bob.learn.pytorch.architectures.LightCNN.LightCNN9):
num_classes_pretrained = cp['state_dict']['fc2.weight'].shape[0]
last_layer_weight = 'fc2.weight'
last_layer_bias = 'fc2.bias'
num_classes_pretrained = cp['state_dict'][last_layer_weight].shape[0]
if num_classes_pretrained == self.num_classes:
self.network.load_state_dict(cp['state_dict'])
else:
var = 1.0 / (cp['state_dict']['fc2.weight'].shape[0])
np_weights = numpy.random.normal(loc=0.0, scale=var, size=((self.num_classes+1), cp['state_dict']['fc2.weight'].shape[1]))
cp['state_dict']['fc2.weight'] = torch.from_numpy(np_weights)
cp['state_dict']['fc2.bias'] = torch.zeros(((self.num_classes+1),))
var = 1.0 / (cp['state_dict'][last_layer_weight].shape[0])
np_weights = numpy.random.normal(loc=0.0, scale=var, size=((self.num_classes+1), cp['state_dict'][last_layer_weight].shape[1]))
cp['state_dict'][last_layer_weight] = torch.from_numpy(np_weights)
cp['state_dict'][last_layer_bias] = torch.zeros(((self.num_classes+1),))
#self.network.load_state_dict(cp['state_dict'], strict=False)
self.network.load_state_dict(cp['state_dict'], strict=True)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment