Skip to content
Snippets Groups Projects
Commit 49195707 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[trainer] corrected stuff in the CNN trainer (saving models)

parent 1f61fbf1
No related branches found
No related tags found
No related merge requests found
...@@ -13,7 +13,7 @@ import bob.core ...@@ -13,7 +13,7 @@ import bob.core
logger = bob.core.log.setup("bob.learn.pytorch") logger = bob.core.log.setup("bob.learn.pytorch")
import time import time
import os
class CNNTrainer(object): class CNNTrainer(object):
""" """
...@@ -75,7 +75,7 @@ class CNNTrainer(object): ...@@ -75,7 +75,7 @@ class CNNTrainer(object):
return start_epoch, start_iter, losses return start_epoch, start_iter, losses
def save_model(output_dir, epoch=0, iteration=0, losses=None): def save_model(self, output_dir, epoch=0, iteration=0, losses=None):
""" """
Save the trained network Save the trained network
...@@ -96,7 +96,7 @@ class CNNTrainer(object): ...@@ -96,7 +96,7 @@ class CNNTrainer(object):
saved_filename = 'model_{}_{}.pth'.format(epoch, iteration) saved_filename = 'model_{}_{}.pth'.format(epoch, iteration)
saved_path = os.path.join(output_dir, saved_filename) saved_path = os.path.join(output_dir, saved_filename)
logger.info('Saving model to {}'.format(save_path)) logger.info('Saving model to {}'.format(saved_path))
cp = {'epoch': epoch, cp = {'epoch': epoch,
'iteration': iteration, 'iteration': iteration,
...@@ -131,16 +131,19 @@ class CNNTrainer(object): ...@@ -131,16 +131,19 @@ class CNNTrainer(object):
# if model exists, load it # if model exists, load it
if model is not None: if model is not None:
start_epoch, start_iter, losses = self.load_model(model) start_epoch, start_iter, losses = self.load_model(model)
logger.info('Starting training at epoch {}, iteration {} - last loss value is {}'.format(start_epoch, start_iter, losses[-1]))
else: else:
start_epoch = 0 start_epoch = 0
start_iter = 0 start_iter = 0
losses = [] losses = []
logger.info('Starting training from scratch')
# setup optimizer # setup optimizer
optimizer = optim.SGD(self.network.parameters(), learning_rate, momentum = 0.9, weight_decay = 0.0005) optimizer = optim.SGD(self.network.parameters(), learning_rate, momentum = 0.9, weight_decay = 0.0005)
# let's go # let's go
for epoch in range((start_epoch, n_epochs): for epoch in range(start_epoch, n_epochs):
for i, data in enumerate(dataloader, 0): for i, data in enumerate(dataloader, 0):
if i >= start_iter: if i >= start_iter:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment