Skip to content
Snippets Groups Projects
Commit df5062d4 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[trainers] added the functionality to only keep models every N epochs, plus the latest one

parent 853568d1
Branches
Tags
No related merge requests found
......@@ -201,7 +201,7 @@ class DRGANTrainer(object):
return start_epoch
def train(self, dataloader, n_epochs=10, learning_rate=0.0002, beta1=0.5, output_dir='out', plot=False, save_sample=10, pose_random=True):
def train(self, dataloader, n_epochs=10, learning_rate=0.0002, beta1=0.5, output_dir='out', plot=False, save_sample=10, pose_random=True, keep_model=1):
"""
Function that performs the training.
......@@ -230,7 +230,13 @@ class DRGANTrainer(object):
pose_random: boolean
To assign random pose labels to generated images (necessary actually)
keep_model: int
To keep model every X epochs (and the last one)
"""
if not pose_random:
logger.warn("Generating same poses as in training examples")
# create directories
images_dir = output_dir + "/images"
models_dir = output_dir + "/models"
......@@ -271,6 +277,7 @@ class DRGANTrainer(object):
for epoch in range(start_epoch, n_epochs):
for i, data in enumerate(dataloader, 0):
start = time.time()
# get the batch data, pose and id labels
......@@ -425,7 +432,7 @@ class DRGANTrainer(object):
# save losses
filename = log_dir + '/losses_{}_{}.hdf5'.format(epoch, i)
f = bob.io.base.HDF5File(filename, 'w')
f = bob.io.base.HDF5File(filename, 'a')
f.set('d_loss', discriminator_loss)
f.set('g_loss', generator_loss)
del f
......@@ -461,6 +468,15 @@ class DRGANTrainer(object):
torch.save(self.encoder.state_dict(), '%s/encoder_epoch_%d.pth' % (models_dir, epoch))
torch.save(self.decoder.state_dict(), '%s/decoder_epoch_%d.pth' % (models_dir, epoch))
torch.save(self.discriminator.state_dict(), '%s/discriminator_epoch_%d.pth' % (models_dir, epoch))
# remove models that we don't want to keep
import glob, os
model_files = glob.glob(models_dir + '/*.pth')
for model_file in model_files:
model_epoch = int(model_file.split('_')[-1].split('.')[0])
if ((model_epoch % keep_model) != 0) and (model_epoch != epoch):
os.remove(model_file)
logger.info("{} removed !".format(model_file))
def check_batch_statistics(self, output_real, output_fake, output_generator,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment