diff --git a/bob/learn/pytorch/scripts/show_training_stats.py b/bob/learn/pytorch/scripts/show_training_stats.py new file mode 100644 index 0000000000000000000000000000000000000000..884c8b594efff577717a78d306437d2350b34569 --- /dev/null +++ b/bob/learn/pytorch/scripts/show_training_stats.py @@ -0,0 +1,124 @@ +#!/usr/bin/env python +# encoding: utf-8 + +""" Read data saved during the training of a DR-GAN + +Usage: + %(prog)s [--logdir=<path>] [--verbose ...] + +Options: + -h, --help Show this screen. + -V, --version Show version. + -d, --logdir=<path> The dir where the training data reside + -v, --verbose Increase the verbosity (may appear multiple times). + + +Example: + + To read and display the training data: + + $ %(prog)s --logdir ./drgan/logs + +See '%(prog)s --help' for more information. + +""" + +import os, sys +import pkg_resources + +import bob.core +logger = bob.core.log.setup("bob.learn.pytorch") + +from docopt import docopt + +version = pkg_resources.require('bob.learn.pytorch')[0].version + +import numpy +import bob.io.base + +from matplotlib import pyplot + +def main(user_input=None): + + # Parse the command-line arguments + if user_input is not None: + arguments = user_input + else: + arguments = sys.argv[1:] + + prog = os.path.basename(sys.argv[0]) + completions = dict(prog=prog, version=version,) + args = docopt(__doc__ % completions,argv=arguments,version='Train DR-GAN (%s)' % version,) + + # verbosity + verbosity_level = args['--verbose'] + bob.core.log.set_verbosity_level(logger, verbosity_level) + + # get the arguments + logdir = args['--logdir'] + + + # === LOSSES === + # get the last losses file + import glob + losses_files = glob.glob(logdir + '/losses_*') # * means all if need specific format then *.csv + loss_filename = max(losses_files, key=os.path.getctime) + print loss_filename + + #fl = bob.io.base.HDF5File(loss_filename) + #d_loss = fl.read('d_loss') + #g_loss = fl.read('g_loss') + + #pyplot.title("Losses") + #pyplot.xlabel("# of iterations") + #pyplot.plot(d_loss, 'b', label="discriminator") + #pyplot.plot(g_loss, 'r', label="generator") + #pyplot.legend() + #pyplot.show() + + #del fl + + # === DISCRIMINATOR === + fdr = bob.io.base.HDF5File(logdir + '/discriminator_real_stats.hdf5') + real_id_acc = fdr.read('r_id_accuracy') + real_pose_acc = fdr.read('r_pose_accuracy') + real_gan_acc = fdr.read('r_real_accuracy') + + fdf = bob.io.base.HDF5File(logdir + '/discriminator_fake_stats.hdf5') + fake_id_acc = fdf.read('f_id_accuracy') + fake_pose_acc = fdf.read('f_pose_accuracy') + fake_gan_acc = fdf.read('f_fake_accuracy') + + f, axarr = pyplot.subplots(3, sharex=True) + f.suptitle("Discriminator stats") + axarr[0].set_title("Identity") + axarr[0].plot(real_id_acc, label="real") + axarr[0].plot(fake_id_acc, 'r', label="fake") + axarr[0].legend() + axarr[1].set_title("Pose") + axarr[1].plot(real_pose_acc, label="real") + axarr[1].plot(fake_pose_acc, 'r', label="fake") + axarr[1].legend() + axarr[2].set_title("Real / fake") + axarr[2].plot(real_gan_acc, label="real (recognized as real)") + axarr[2].plot(fake_gan_acc, 'r', label="fake (recognized as fake") + axarr[2].legend() + pyplot.show() + del fdr + del fdf + + fdg = bob.io.base.HDF5File(logdir + '/generator_stats.hdf5') + gen_id_acc = fdg.read('g_id_accuracy') + gen_pose_acc = fdg.read('g_pose_accuracy') + gen_gan_acc = fdg.read('g_fake_accuracy') + + f, axarr = pyplot.subplots(3, sharex=True) + f.suptitle("Generator stats") + axarr[0].set_title("Identity") + axarr[0].plot(gen_id_acc) + axarr[1].set_title("Pose") + axarr[1].plot(gen_pose_acc) + axarr[2].set_title("Real / fake") + axarr[2].plot(gen_gan_acc) + pyplot.show() + del fdg diff --git a/setup.py b/setup.py index f77211452bfcb159453cb7021778ea6acfbe583f..0086f66ae03614ccf50a1710dbb772ec4383476e 100644 --- a/setup.py +++ b/setup.py @@ -79,7 +79,8 @@ setup( 'train_wcgan_multipie.py = bob.learn.pytorch.scripts.train_wcgan_multipie:main', 'train_drgan_multipie.py = bob.learn.pytorch.scripts.train_drgan_multipie:main', 'train_drgan_mpie_casia.py = bob.learn.pytorch.scripts.train_drgan_mpie_casia:main', - 'read_training_hdf5.py = bob.learn.pytorch.scripts.read_training_hdf5:main', + 'show_training_images.py = bob.learn.pytorch.scripts.show_training_images:main', + 'show_training_stats.py = bob.learn.pytorch.scripts.show_training_stats:main', ],