From 0fa399460a022a575273bd835bba0b374b922476 Mon Sep 17 00:00:00 2001
From: Guillaume HEUSCH <guillaume.heusch@idiap.ch>
Date: Fri, 1 Dec 2017 15:28:27 +0100
Subject: [PATCH] [script] added script to train DR-GAN (original and light
 version) on both Multi-PIE and CASIA

---
 .../pytorch/scripts/train_drgan_mpie_casia.py | 178 ++++++++++++++++++
 setup.py                                      |   1 +
 2 files changed, 179 insertions(+)
 create mode 100644 bob/learn/pytorch/scripts/train_drgan_mpie_casia.py

diff --git a/bob/learn/pytorch/scripts/train_drgan_mpie_casia.py b/bob/learn/pytorch/scripts/train_drgan_mpie_casia.py
new file mode 100644
index 0000000..6b69c9c
--- /dev/null
+++ b/bob/learn/pytorch/scripts/train_drgan_mpie_casia.py
@@ -0,0 +1,178 @@
+#!/usr/bin/env python
+# encoding: utf-8
+
+
+""" Train a DR-GAN 
+
+Usage:
+  %(prog)s [--latent-dim=<int>] [--noise-dim=<int>] [--conditional-dim=<int>] 
+           [--batch-size=<int>] [--epochs=<int>] [--sample=<int>] [--light]
+           [--output-dir=<path>] [--use-gpu] [--seed=<int>] [--verbose ...] [--plot]
+
+Options:
+  -h, --help                    Show this screen.
+  -V, --version                 Show version.
+  -l, --latent-dim=<int>        the dimension of the encoded ID [default: 320]
+  -n, --noise-dim=<int>         the dimension of the noise [default: 50]
+  -c, --conditional-dim=<int>   the dimension of the conditioning variable [default: 13]
+  -b, --batch-size=<int>        The size of your mini-batch [default: 64]
+  -e, --epochs=<int>            The number of training epochs [default: 100] 
+  -s, --sample=<int>            Save generated images at every 'sample' batch iteration [default: 100000000000] 
+  -L, --light                   Use a lighter architecture (similar as DCGAN) 
+  -o, --output-dir=<path>       Dir to save the logs, models and images [default: ./drgan-light-mpie-casia/] 
+  -g, --use-gpu                 Use the GPU 
+  -S, --seed=<int>              The random seed [default: 3] 
+  -v, --verbose                 Increase the verbosity (may appear multiple times).
+  -P, --plot                    Show some image during training process (mainly for debug) 
+
+Example:
+
+  To run the training process 
+
+    $ %(prog)s --batch-size 64 --epochs 25 --output-dir drgan 
+
+See '%(prog)s --help' for more information.
+
+"""
+
+import os, sys
+import pkg_resources
+
+import bob.core
+logger = bob.core.log.setup("bob.learn.pytorch")
+
+from docopt import docopt
+
+version = pkg_resources.require('bob.learn.pytorch')[0].version
+
+import numpy
+import bob.io.base
+
+# torch
+import torch
+import torch.nn as nn
+import torch.optim as optim
+import torchvision.transforms as transforms
+import torchvision.utils as vutils
+from torch.autograd import Variable
+
+# data and architecture from the package
+from bob.learn.pytorch.datasets import MultiPIEDataset
+from bob.learn.pytorch.datasets import CasiaDataset
+
+from torch.utils.data import ConcatDataset
+
+from bob.learn.pytorch.datasets import RollChannels
+from bob.learn.pytorch.datasets import ToTensor
+from bob.learn.pytorch.datasets import Normalize
+
+from bob.learn.pytorch.architectures import weights_init
+
+from bob.learn.pytorch.trainers import DRGANTrainer
+
+
+def main(user_input=None):
+  
+  # Parse the command-line arguments
+  if user_input is not None:
+      arguments = user_input
+  else:
+      arguments = sys.argv[1:]
+
+  prog = os.path.basename(sys.argv[0])
+  completions = dict(prog=prog, version=version,)
+  args = docopt(__doc__ % completions,argv=arguments,version='Train DR-GAN (%s)' % version,)
+
+  # verbosity
+  verbosity_level = args['--verbose']
+  bob.core.log.set_verbosity_level(logger, verbosity_level)
+
+  # get the arguments
+  noise_dim = int(args['--noise-dim'])
+  latent_dim = int(args['--latent-dim'])
+  conditional_dim = int(args['--conditional-dim'])
+  batch_size = int(args['--batch-size'])
+  epochs = int(args['--epochs'])
+  sample = int(args['--sample'])
+  output_dir = str(args['--output-dir'])
+  seed = int(args['--seed'])
+  use_gpu = bool(args['--use-gpu'])
+  plot = bool(args['--plot'])
+
+  if bool(args['--light']):
+    from bob.learn.pytorch.architectures import DRGAN_encoder as drgan_encoder
+    from bob.learn.pytorch.architectures import DRGAN_decoder as drgan_decoder
+    from bob.learn.pytorch.architectures import DRGAN_discriminator as drgan_discriminator
+    multipie_root_dir = '/idiap/temp/heusch/data/multipie-cropped-64x64'
+    casia_root_dir = '/idiap/temp/heusch/data/casia-webface-cropped-64x64-pose-clusters/'
+  else:
+    from bob.learn.pytorch.architectures import DRGANOriginal_encoder as drgan_encoder
+    from bob.learn.pytorch.architectures import DRGANOriginal_decoder as drgan_decoder
+    from bob.learn.pytorch.architectures import DRGANOriginal_discriminator as drgan_discriminator
+    multipie_root_dir = '/idiap/temp/heusch/data/multipie-cropped-96x96/'
+    casia_root_dir = '/idiap/temp/heusch/data/casia-webface-96x96-cluster-color/'
+
+  # process on the arguments / options
+  torch.manual_seed(seed)
+  if use_gpu:
+    torch.cuda.manual_seed_all(seed)
+  if torch.cuda.is_available() and not use_gpu:
+    logger.warn("You have a CUDA device, so you should probably run with --use-gpu")
+  
+  bob.io.base.create_directories_safe(output_dir)
+
+  # ============
+  # === DATA ===
+  # ============
+ 
+  data_transform = transforms.Compose([RollChannels(), ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
+
+  # Multi-PIE
+  face_dataset_1 = MultiPIEDataset(root_dir=multipie_root_dir, 
+                                   transform=data_transform)
+  
+  # get the number of ids
+  number_of_ids = numpy.max(face_dataset_1.id_labels) + 1
+  logger.info("There are {} images from {} different identities in Multi-PIE".format(len(face_dataset_1), number_of_ids))
+  
+  # CASIA Webface
+  face_dataset_2 = CasiaDataset(root_dir=casia_root_dir,
+                                transform=data_transform)
+
+  min_index_casia = numpy.min(face_dataset_2.id_labels)
+  max_index_casia = numpy.max(face_dataset_2.id_labels)
+  logger.info("There are {} images from {} different identities in CASIA Webface".format(len(face_dataset_2), (max_index_casia - min_index_casia)))
+  
+  # Total
+  number_of_ids = max_index_casia + 1 
+  face_dataset = ConcatDataset([face_dataset_1, face_dataset_2])
+  logger.info("There are {} images from {} different identities in TOTAL".format(len(face_dataset), number_of_ids))
+
+  # DataLoader
+  dataloader = torch.utils.data.DataLoader(face_dataset, batch_size=batch_size, shuffle=True)
+
+  # get the image size
+  image_size = face_dataset[0]['image'].numpy().shape
+
+  # ===============
+  # === NETWORK ===
+  # ===============
+  encoder = drgan_encoder(image_size, latent_dim)
+  encoder.apply(weights_init)
+  logger.info("Encoder architecture: {}".format(encoder))
+
+  decoder = drgan_decoder(image_size, noise_dim, latent_dim, conditional_dim)
+  decoder.apply(weights_init)
+  logger.info("Generator architecture: {}".format(decoder))
+
+  discriminator = drgan_discriminator(image_size, number_of_ids, conditional_dim)
+  discriminator.apply(weights_init)
+  logger.info("Discriminator architecture: {}".format(discriminator))
+
+
+  # ===============
+  # === TRAINER ===
+  # ===============
+  trainer = DRGANTrainer(encoder, decoder, discriminator, image_size, batch_size=batch_size,
+                         noise_dim=noise_dim, conditional_dim=conditional_dim, latent_dim=latent_dim, use_gpu=use_gpu, verbosity_level=verbosity_level)
+  trainer.train(dataloader, n_epochs=epochs, output_dir=output_dir, plot=plot)
diff --git a/setup.py b/setup.py
index 128fb58..f772114 100644
--- a/setup.py
+++ b/setup.py
@@ -78,6 +78,7 @@ setup(
         'train_conditionalgan_casia.py = bob.learn.pytorch.scripts.train_conditionalgan_casia:main', 
         'train_wcgan_multipie.py = bob.learn.pytorch.scripts.train_wcgan_multipie:main', 
         'train_drgan_multipie.py = bob.learn.pytorch.scripts.train_drgan_multipie:main', 
+        'train_drgan_mpie_casia.py = bob.learn.pytorch.scripts.train_drgan_mpie_casia:main', 
         'read_training_hdf5.py = bob.learn.pytorch.scripts.read_training_hdf5:main', 
       ],
 
-- 
GitLab