Skip to content
Snippets Groups Projects
Commit 21a3784f authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[script] added very first script

parent 362f15e7
No related branches found
No related tags found
No related merge requests found
#!/usr/bin/env python
# encoding: utf-8
""" Train a DR-GAN
Usage:
%(prog)s [--latent-dim=<int>]
[--batch-size=<int>] [--epochs=<int>] [--sample=<int>]
[--output-dir=<path>] [--verbose ...]
Options:
-h, --help Show this screen.
-V, --version Show version.
-l, --latent-dim=<int> the dimension of the encoded ID [default: 320]
-b, --batch-size=<int> The size of your mini-batch [default: 128]
-e, --epochs=<int> The number of training epochs [default: 100]
-s, --sample=<int> Save generated images at every 'sample' batch iteration [default: 100000000000]
-o, --output-dir=<path> Dir to save the logs, models and images [default: ./drgan-light-mpie-casia/]
-v, --verbose Increase the verbosity (may appear multiple times).
Example:
To run the training process
$ %(prog)s --batch-size 64 --epochs 25 --output-dir drgan
See '%(prog)s --help' for more information.
"""
import os, sys
import pkg_resources
import bob.core
logger = bob.core.log.setup("bob.learn.pytorch")
from docopt import docopt
version = pkg_resources.require('bob.learn.pytorch')[0].version
import numpy
from bob.learn.pytorch.datasets.multipie import MultiPIEDataset
#import bob.learn.pytorch.datasets.multipie
def main(user_input=None):
# Parse the command-line arguments
if user_input is not None:
arguments = user_input
else:
arguments = sys.argv[1:]
prog = os.path.basename(sys.argv[0])
completions = dict(prog=prog, version=version,)
args = docopt(__doc__ % completions,argv=arguments,version='Train DR-GAN (%s)' % version,)
# verbosity
verbosity_level = args['--verbose']
bob.core.log.set_verbosity_level(logger, verbosity_level)
# get the parameters
batch_size = int(args['--batch-size'])
epochs = int(args['--epochs'])
sample = int(args['--sample'])
output_dir = str(args['--output-dir'])
images_dir = os.path.join(output_dir, 'samples')
log_dir = os.path.join(output_dir, 'logs')
model_dir = os.path.join(output_dir, 'models')
#face_dataset = MultiPIEDataset(root_dir='/idiap/resource/database/Multi-Pie/data/')
face_dataset = MultiPIEDataset(root_dir='/idiap/temp/heusch/data/multipie-cropped-64x64')
#print len(face_dataset)
from matplotlib import pyplot
for i in range(len(face_dataset)):
sample = face_dataset[i]
pyplot.title('Sample {}: ID -> {}, pose ->{}'.format(i, sample['id'], sample['pose']))
pyplot.imshow(numpy.rollaxis(numpy.rollaxis(sample['image'], 2),2))
pyplot.show()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment