Skip to content
Snippets Groups Projects
Commit 5db8116f authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[scripts] add script to display training stats (loss, accuracy, ...)

parent 63a72e48
No related branches found
No related tags found
No related merge requests found
#!/usr/bin/env python
# encoding: utf-8
""" Read data saved during the training of a DR-GAN
Usage:
%(prog)s [--logdir=<path>] [--verbose ...]
Options:
-h, --help Show this screen.
-V, --version Show version.
-d, --logdir=<path> The dir where the training data reside
-v, --verbose Increase the verbosity (may appear multiple times).
Example:
To read and display the training data:
$ %(prog)s --logdir ./drgan/logs
See '%(prog)s --help' for more information.
"""
import os, sys
import pkg_resources
import bob.core
logger = bob.core.log.setup("bob.learn.pytorch")
from docopt import docopt
version = pkg_resources.require('bob.learn.pytorch')[0].version
import numpy
import bob.io.base
from matplotlib import pyplot
def main(user_input=None):
# Parse the command-line arguments
if user_input is not None:
arguments = user_input
else:
arguments = sys.argv[1:]
prog = os.path.basename(sys.argv[0])
completions = dict(prog=prog, version=version,)
args = docopt(__doc__ % completions,argv=arguments,version='Train DR-GAN (%s)' % version,)
# verbosity
verbosity_level = args['--verbose']
bob.core.log.set_verbosity_level(logger, verbosity_level)
# get the arguments
logdir = args['--logdir']
# === LOSSES ===
# get the last losses file
import glob
losses_files = glob.glob(logdir + '/losses_*') # * means all if need specific format then *.csv
loss_filename = max(losses_files, key=os.path.getctime)
print loss_filename
#fl = bob.io.base.HDF5File(loss_filename)
#d_loss = fl.read('d_loss')
#g_loss = fl.read('g_loss')
#pyplot.title("Losses")
#pyplot.xlabel("# of iterations")
#pyplot.plot(d_loss, 'b', label="discriminator")
#pyplot.plot(g_loss, 'r', label="generator")
#pyplot.legend()
#pyplot.show()
#del fl
# === DISCRIMINATOR ===
fdr = bob.io.base.HDF5File(logdir + '/discriminator_real_stats.hdf5')
real_id_acc = fdr.read('r_id_accuracy')
real_pose_acc = fdr.read('r_pose_accuracy')
real_gan_acc = fdr.read('r_real_accuracy')
fdf = bob.io.base.HDF5File(logdir + '/discriminator_fake_stats.hdf5')
fake_id_acc = fdf.read('f_id_accuracy')
fake_pose_acc = fdf.read('f_pose_accuracy')
fake_gan_acc = fdf.read('f_fake_accuracy')
f, axarr = pyplot.subplots(3, sharex=True)
f.suptitle("Discriminator stats")
axarr[0].set_title("Identity")
axarr[0].plot(real_id_acc, label="real")
axarr[0].plot(fake_id_acc, 'r', label="fake")
axarr[0].legend()
axarr[1].set_title("Pose")
axarr[1].plot(real_pose_acc, label="real")
axarr[1].plot(fake_pose_acc, 'r', label="fake")
axarr[1].legend()
axarr[2].set_title("Real / fake")
axarr[2].plot(real_gan_acc, label="real (recognized as real)")
axarr[2].plot(fake_gan_acc, 'r', label="fake (recognized as fake")
axarr[2].legend()
pyplot.show()
del fdr
del fdf
fdg = bob.io.base.HDF5File(logdir + '/generator_stats.hdf5')
gen_id_acc = fdg.read('g_id_accuracy')
gen_pose_acc = fdg.read('g_pose_accuracy')
gen_gan_acc = fdg.read('g_fake_accuracy')
f, axarr = pyplot.subplots(3, sharex=True)
f.suptitle("Generator stats")
axarr[0].set_title("Identity")
axarr[0].plot(gen_id_acc)
axarr[1].set_title("Pose")
axarr[1].plot(gen_pose_acc)
axarr[2].set_title("Real / fake")
axarr[2].plot(gen_gan_acc)
pyplot.show()
del fdg
......@@ -79,7 +79,8 @@ setup(
'train_wcgan_multipie.py = bob.learn.pytorch.scripts.train_wcgan_multipie:main',
'train_drgan_multipie.py = bob.learn.pytorch.scripts.train_drgan_multipie:main',
'train_drgan_mpie_casia.py = bob.learn.pytorch.scripts.train_drgan_mpie_casia:main',
'read_training_hdf5.py = bob.learn.pytorch.scripts.read_training_hdf5:main',
'show_training_images.py = bob.learn.pytorch.scripts.show_training_images:main',
'show_training_stats.py = bob.learn.pytorch.scripts.show_training_stats:main',
],
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment