show_training_stats.py 3.32 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#!/usr/bin/env python
# encoding: utf-8

""" Read data saved during the training of a DR-GAN

Usage:
  %(prog)s [--logdir=<path>] [--verbose ...] 

Options:
  -h, --help                    Show this screen.
  -V, --version                 Show version.
  -d, --logdir=<path>           The dir where the training data reside
  -v, --verbose                 Increase the verbosity (may appear multiple times).


Example:

  To read and display the training data:

    $ %(prog)s --logdir ./drgan/logs

See '%(prog)s --help' for more information.

"""

import os, sys
import pkg_resources

import bob.core
logger = bob.core.log.setup("bob.learn.pytorch")

from docopt import docopt

version = pkg_resources.require('bob.learn.pytorch')[0].version

import numpy
import bob.io.base

from matplotlib import pyplot

def main(user_input=None):
  
  # Parse the command-line arguments
  if user_input is not None:
      arguments = user_input
  else:
      arguments = sys.argv[1:]

  prog = os.path.basename(sys.argv[0])
  completions = dict(prog=prog, version=version,)
  args = docopt(__doc__ % completions,argv=arguments,version='Train DR-GAN (%s)' % version,)

  # verbosity
  verbosity_level = args['--verbose']
  bob.core.log.set_verbosity_level(logger, verbosity_level)
 
  # get the arguments
  logdir = args['--logdir']


  # === LOSSES ===
  # get the last losses file
  import glob
  losses_files = glob.glob(logdir + '/losses_*') # * means all if need specific format then *.csv
  loss_filename = max(losses_files, key=os.path.getctime)
  print loss_filename

  #fl = bob.io.base.HDF5File(loss_filename)
  #d_loss = fl.read('d_loss')
  #g_loss = fl.read('g_loss')

  #pyplot.title("Losses")
  #pyplot.xlabel("# of iterations")
  #pyplot.plot(d_loss, 'b', label="discriminator")
  #pyplot.plot(g_loss, 'r', label="generator")
  #pyplot.legend()
  #pyplot.show()

  #del fl

  # === DISCRIMINATOR ===
  fdr = bob.io.base.HDF5File(logdir + '/discriminator_real_stats.hdf5')
  real_id_acc = fdr.read('r_id_accuracy')
  real_pose_acc = fdr.read('r_pose_accuracy')
  real_gan_acc = fdr.read('r_real_accuracy')
  
  fdf = bob.io.base.HDF5File(logdir + '/discriminator_fake_stats.hdf5')
  fake_id_acc = fdf.read('f_id_accuracy')
  fake_pose_acc = fdf.read('f_pose_accuracy')
  fake_gan_acc = fdf.read('f_fake_accuracy')

  f, axarr = pyplot.subplots(3, sharex=True)
  f.suptitle("Discriminator stats")
  axarr[0].set_title("Identity")
  axarr[0].plot(real_id_acc, label="real")
  axarr[0].plot(fake_id_acc, 'r', label="fake")
  axarr[0].legend()
  axarr[1].set_title("Pose")
  axarr[1].plot(real_pose_acc, label="real")
  axarr[1].plot(fake_pose_acc, 'r', label="fake")
  axarr[1].legend()
  axarr[2].set_title("Real / fake")
  axarr[2].plot(real_gan_acc, label="real (recognized as real)")
  axarr[2].plot(fake_gan_acc, 'r', label="fake (recognized as fake")
  axarr[2].legend()
  pyplot.show()
  del fdr
  del fdf

  fdg = bob.io.base.HDF5File(logdir + '/generator_stats.hdf5')
  gen_id_acc = fdg.read('g_id_accuracy')
  gen_pose_acc = fdg.read('g_pose_accuracy')
  gen_gan_acc = fdg.read('g_fake_accuracy')

  f, axarr = pyplot.subplots(3, sharex=True)
  f.suptitle("Generator stats")
  axarr[0].set_title("Identity")
  axarr[0].plot(gen_id_acc)
  axarr[1].set_title("Pose")
  axarr[1].plot(gen_pose_acc)
  axarr[2].set_title("Real / fake")
  axarr[2].plot(gen_gan_acc)
  pyplot.show()
  del fdg