Skip to content
Snippets Groups Projects
Commit 23112c39 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[trainer] added code to save various stats

parent 51302e73
No related branches found
No related tags found
No related merge requests found
......@@ -93,7 +93,7 @@ class DRGANTrainer(object):
self.criterion_id.cuda()
def train(self, dataloader, n_epochs=10, learning_rate=0.0002, beta1=0.5, output_dir='out', plot=False, save_sample=100, pose_random=True):
def train(self, dataloader, n_epochs=10, learning_rate=0.0002, beta1=0.5, output_dir='out', plot=False, save_sample=10, pose_random=True):
"""
Function that performs the training.
......@@ -115,8 +115,26 @@ class DRGANTrainer(object):
The directory where you would like to output images and models
plot: boolean
If you want to plot some images during the training process (debug)
If you want to plot some images during the training process (debug)
save_sample: int
To save sample every X iterations
pose_random: boolean
To assign random pose labels to generated images (necessary actually)
"""
# create directories
images_dir = output_dir + "/images"
models_dir = output_dir + "/models"
log_dir = output_dir + "/logs"
bob.io.base.create_directories_safe(images_dir)
bob.io.base.create_directories_safe(models_dir)
bob.io.base.create_directories_safe(log_dir)
# be sure to save samples at each epoch at least
if save_sample < len(dataloader):
save_sample = len(dataloader) - 1
# labels for real/fake
real_label = 1
fake_label = 0
......@@ -126,7 +144,8 @@ class DRGANTrainer(object):
optimizerD = optim.Adam(self.discriminator.parameters(), lr=learning_rate, betas=(beta1, 0.999))
optimizerG = optim.Adam(generator_params, lr=learning_rate, betas=(beta1, 0.999))
# be sure to have a fixed frontal image to sample from
# ==================================================================================================
# TODO: fix this - Guillaume HEUSCH, 04-12-2017
pose = 0
counter = 0
while pose != 6:
......@@ -159,7 +178,13 @@ class DRGANTrainer(object):
for k in range(self.conditional_dim):
fixed_one_hot[k, k] = 1
fixed_one_hot = Variable(fixed_one_hot)
# ==================================================================================================
# statistics
discriminator_loss = []
generator_loss = []
# ================
# === LET'S GO ===
# ================
......@@ -282,9 +307,13 @@ class DRGANTrainer(object):
optimizerG.step()
end = time.time()
logger.info("[{}/{}][{}/{}] => Loss D = {} -- Loss G = {} (time spent: {})".format(epoch, n_epochs, i, len(dataloader), errD.data[0], errG.data[0], (end-start)))
discriminator_loss.append(errD.data[0])
generator_loss.append(errG.data[0])
self.check_batch_statistics(output_real, output_fake, output_generated, ids, poses, random_poses, batch_size)
self.check_batch_statistics(output_real, output_fake, output_generated, ids, poses, random_poses, batch_size, log_dir)
# =====================
# SAVE GENERATED IMAGES
......@@ -306,7 +335,7 @@ class DRGANTrainer(object):
pose_example = random_poses[index]
# save hdf5 data:
filename = output_dir + '/generated_data_{}_{}.hdf5'.format(epoch, i)
filename = images_dir + '/generated_data_{}_{}.hdf5'.format(epoch, i)
f = bob.io.base.HDF5File(filename, 'w')
f.set('id', id_example)
f.set('real_example', real_example)
......@@ -315,6 +344,12 @@ class DRGANTrainer(object):
f.set('target_pose', pose_example)
del f
# save losses
filename = log_dir + '/losses_{}_{}.hdf5'.format(epoch, i)
f = bob.io.base.HDF5File(filename, 'w')
f.set('d_loss', discriminator_loss)
f.set('g_loss', generator_loss)
# save generated images at every epoch
# TODO: model moved to CPU and back and I don't really know why (expected CPU tensor error)
......@@ -340,7 +375,7 @@ class DRGANTrainer(object):
def check_batch_statistics(self, output_real, output_fake, output_generator,
label_id, label_pose, label_random_pose, batch_size):
label_id, label_pose, label_random_pose, batch_size, log_dir):
"""
Compute some performance stats on the current mini-batch
......@@ -366,6 +401,9 @@ class DRGANTrainer(object):
batch_size: int
The size of the current batch
log_dir: path
The directory where to store the logs
"""
# --- REAL ---
......@@ -431,7 +469,27 @@ class DRGANTrainer(object):
generator_pose_correct += 1
if prob_real[0] < 0.5:
generator_gan_correct +=1
fd_real = bob.io.base.HDF5File(log_dir + '/discriminator_real_stats.hdf5', 'a')
fd_fake = bob.io.base.HDF5File(log_dir + '/discriminator_fake_stats.hdf5', 'a')
fg = bob.io.base.HDF5File(log_dir + '/generator_stats.hdf5', 'a')
fd_real.append('r_id_accuracy', real_id_correct/float(batch_size))
fd_real.append('r_pose_accuracy', real_pose_correct/float(batch_size))
fd_real.append('r_real_accuracy', real_gan_correct/float(batch_size))
fd_fake.append('f_id_accuracy', fake_id_correct/float(batch_size))
fd_fake.append('f_pose_accuracy', fake_pose_correct/float(batch_size))
fd_fake.append('f_fake_accuracy', fake_gan_correct/float(batch_size))
fg.append('g_id_accuracy', fake_id_correct/float(batch_size))
fg.append('g_pose_accuracy', fake_pose_correct/float(batch_size))
fg.append('g_fake_accuracy', real_gan_correct/float(batch_size))
del fd_real
del fd_fake
del fg
logger.debug("[REAL] accuracy on ID = {} ({}/{})".format(real_id_correct/float(batch_size), real_id_correct, batch_size))
logger.debug("[REAL] accuracy on pose = {} ({}/{})".format(real_pose_correct/float(batch_size), real_pose_correct, batch_size))
logger.debug("[REAL] accuracy on real/fake = {} ({}/{})".format(real_gan_correct/float(batch_size), real_gan_correct, batch_size))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment