Commit 0fa39946 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[script] added script to train DR-GAN (original and light version) on both Multi-PIE and CASIA

parent b6c9164b
#!/usr/bin/env python
# encoding: utf-8
""" Train a DR-GAN
%(prog)s [--latent-dim=<int>] [--noise-dim=<int>] [--conditional-dim=<int>]
[--batch-size=<int>] [--epochs=<int>] [--sample=<int>] [--light]
[--output-dir=<path>] [--use-gpu] [--seed=<int>] [--verbose ...] [--plot]
-h, --help Show this screen.
-V, --version Show version.
-l, --latent-dim=<int> the dimension of the encoded ID [default: 320]
-n, --noise-dim=<int> the dimension of the noise [default: 50]
-c, --conditional-dim=<int> the dimension of the conditioning variable [default: 13]
-b, --batch-size=<int> The size of your mini-batch [default: 64]
-e, --epochs=<int> The number of training epochs [default: 100]
-s, --sample=<int> Save generated images at every 'sample' batch iteration [default: 100000000000]
-L, --light Use a lighter architecture (similar as DCGAN)
-o, --output-dir=<path> Dir to save the logs, models and images [default: ./drgan-light-mpie-casia/]
-g, --use-gpu Use the GPU
-S, --seed=<int> The random seed [default: 3]
-v, --verbose Increase the verbosity (may appear multiple times).
-P, --plot Show some image during training process (mainly for debug)
To run the training process
$ %(prog)s --batch-size 64 --epochs 25 --output-dir drgan
See '%(prog)s --help' for more information.
import os, sys
import pkg_resources
import bob.core
logger = bob.core.log.setup("bob.learn.pytorch")
from docopt import docopt
version = pkg_resources.require('bob.learn.pytorch')[0].version
import numpy
# torch
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
# data and architecture from the package
from bob.learn.pytorch.datasets import MultiPIEDataset
from bob.learn.pytorch.datasets import CasiaDataset
from import ConcatDataset
from bob.learn.pytorch.datasets import RollChannels
from bob.learn.pytorch.datasets import ToTensor
from bob.learn.pytorch.datasets import Normalize
from bob.learn.pytorch.architectures import weights_init
from bob.learn.pytorch.trainers import DRGANTrainer
def main(user_input=None):
# Parse the command-line arguments
if user_input is not None:
arguments = user_input
arguments = sys.argv[1:]
prog = os.path.basename(sys.argv[0])
completions = dict(prog=prog, version=version,)
args = docopt(__doc__ % completions,argv=arguments,version='Train DR-GAN (%s)' % version,)
# verbosity
verbosity_level = args['--verbose']
bob.core.log.set_verbosity_level(logger, verbosity_level)
# get the arguments
noise_dim = int(args['--noise-dim'])
latent_dim = int(args['--latent-dim'])
conditional_dim = int(args['--conditional-dim'])
batch_size = int(args['--batch-size'])
epochs = int(args['--epochs'])
sample = int(args['--sample'])
output_dir = str(args['--output-dir'])
seed = int(args['--seed'])
use_gpu = bool(args['--use-gpu'])
plot = bool(args['--plot'])
if bool(args['--light']):
from bob.learn.pytorch.architectures import DRGAN_encoder as drgan_encoder
from bob.learn.pytorch.architectures import DRGAN_decoder as drgan_decoder
from bob.learn.pytorch.architectures import DRGAN_discriminator as drgan_discriminator
multipie_root_dir = '/idiap/temp/heusch/data/multipie-cropped-64x64'
casia_root_dir = '/idiap/temp/heusch/data/casia-webface-cropped-64x64-pose-clusters/'
from bob.learn.pytorch.architectures import DRGANOriginal_encoder as drgan_encoder
from bob.learn.pytorch.architectures import DRGANOriginal_decoder as drgan_decoder
from bob.learn.pytorch.architectures import DRGANOriginal_discriminator as drgan_discriminator
multipie_root_dir = '/idiap/temp/heusch/data/multipie-cropped-96x96/'
casia_root_dir = '/idiap/temp/heusch/data/casia-webface-96x96-cluster-color/'
# process on the arguments / options
if use_gpu:
if torch.cuda.is_available() and not use_gpu:
logger.warn("You have a CUDA device, so you should probably run with --use-gpu")
# ============
# === DATA ===
# ============
data_transform = transforms.Compose([RollChannels(), ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# Multi-PIE
face_dataset_1 = MultiPIEDataset(root_dir=multipie_root_dir,
# get the number of ids
number_of_ids = numpy.max(face_dataset_1.id_labels) + 1"There are {} images from {} different identities in Multi-PIE".format(len(face_dataset_1), number_of_ids))
# CASIA Webface
face_dataset_2 = CasiaDataset(root_dir=casia_root_dir,
min_index_casia = numpy.min(face_dataset_2.id_labels)
max_index_casia = numpy.max(face_dataset_2.id_labels)"There are {} images from {} different identities in CASIA Webface".format(len(face_dataset_2), (max_index_casia - min_index_casia)))
# Total
number_of_ids = max_index_casia + 1
face_dataset = ConcatDataset([face_dataset_1, face_dataset_2])"There are {} images from {} different identities in TOTAL".format(len(face_dataset), number_of_ids))
# DataLoader
dataloader =, batch_size=batch_size, shuffle=True)
# get the image size
image_size = face_dataset[0]['image'].numpy().shape
# ===============
# === NETWORK ===
# ===============
encoder = drgan_encoder(image_size, latent_dim)
encoder.apply(weights_init)"Encoder architecture: {}".format(encoder))
decoder = drgan_decoder(image_size, noise_dim, latent_dim, conditional_dim)
decoder.apply(weights_init)"Generator architecture: {}".format(decoder))
discriminator = drgan_discriminator(image_size, number_of_ids, conditional_dim)
discriminator.apply(weights_init)"Discriminator architecture: {}".format(discriminator))
# ===============
# === TRAINER ===
# ===============
trainer = DRGANTrainer(encoder, decoder, discriminator, image_size, batch_size=batch_size,
noise_dim=noise_dim, conditional_dim=conditional_dim, latent_dim=latent_dim, use_gpu=use_gpu, verbosity_level=verbosity_level)
trainer.train(dataloader, n_epochs=epochs, output_dir=output_dir, plot=plot)
......@@ -78,6 +78,7 @@ setup(
' = bob.learn.pytorch.scripts.train_conditionalgan_casia:main',
' = bob.learn.pytorch.scripts.train_wcgan_multipie:main',
' = bob.learn.pytorch.scripts.train_drgan_multipie:main',
' = bob.learn.pytorch.scripts.train_drgan_mpie_casia:main',
' = bob.learn.pytorch.scripts.read_training_hdf5:main',
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment