Commit 334b9616 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[data, script] added the CASIA WebFace Dataset interface, and the script to...

[data, script] added the CASIA WebFace Dataset interface, and the script to train a conditional GAN on this set
parent fe20611e
from .multipie import MultiPIEDataset
from .casia_webface import CasiaDataset
# transforms
from .utils import RollChannels
#!/usr/bin/env python
# encoding: utf-8
import os
import torch
import numpy
from import Dataset, DataLoader
import bob.db.casia_webface
from .utils import map_labels
class CasiaDataset(Dataset):
"""Casia WebFace dataset.
Class representing the CASIA WebFace dataset
root-dir: path
The path to the data
frontal_only: boolean
If you want to only use frontal faces
transform: torchvision.transforms
The transform(s) to apply to the face images
# TODO: Start from original data and annotations - Guillaume HEUSCH, 06-11-2017
def __init__(self, root_dir, frontal_only=False, transform=None):
self.root_dir = root_dir
self.transform = transform
dir_to_pose_label = {'l90': '0',
'l75': '1',
'l60': '2',
'l45': '3',
'l30': '4',
'l15': '5',
'0' : '6',
'r15': '7',
'r30': '8',
'r45': '9',
'r60': '10',
'r75': '11',
'r90': '12',
# get all the needed file, the pose labels, and the id labels
self.data_files = []
self.pose_labels = []
id_labels = []
for root, dirs, files in os.walk(self.root_dir):
for name in files:
filename = os.path.split(os.path.join(root, name))[-1]
path = root.split(os.sep)
subject = int(path[-1])
cluster = path[-2]
self.data_files.append(os.path.join(root, name))
self.id_labels = map_labels(id_labels)
def __len__(self):
return the length of the dataset (i.e. nb of examples)
return len(self.data_files)
def __getitem__(self, idx):
return a sample from the dataset
image =[idx])
identity = self.id_labels[idx]
pose = self.pose_labels[idx]
sample = {'image': image, 'id': identity, 'pose': pose}
if self.transform:
sample = self.transform(sample)
return sample
#!/usr/bin/env python
# encoding: utf-8
""" Train a Conditional GAN
%(prog)s [--noise-dim=<int>] [--conditional-dim=<int>]
[--batch-size=<int>] [--epochs=<int>] [--sample=<int>]
[--output-dir=<path>] [--use-gpu] [--seed=<int>] [--verbose ...]
-h, --help Show this screen.
-V, --version Show version.
-n, --noise-dim=<int> The dimension of the noise [default: 100]
-c, --conditional-dim=<int> The dimension of the conditional variable [default: 13]
-b, --batch-size=<int> The size of your mini-batch [default: 64]
-e, --epochs=<int> The number of training epochs [default: 100]
-s, --sample=<int> Save generated images at every 'sample' batch iteration [default: 100000000000]
-o, --output-dir=<path> Dir to save the logs, models and images [default: ./cgan-casia]
-g, --use-gpu Use the GPU
-S, --seed=<int> The random seed [default: 3]
-v, --verbose Increase the verbosity (may appear multiple times).
To run the training process
$ %(prog)s --batch-size 64 --epochs 25 --output-dir drgan
See '%(prog)s --help' for more information.
import os, sys
import pkg_resources
import bob.core
logger = bob.core.log.setup("bob.learn.pytorch")
from docopt import docopt
version = pkg_resources.require('bob.learn.pytorch')[0].version
import numpy
# torch
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
# data and architecture from the package
from bob.learn.pytorch.datasets import CasiaDataset
from bob.learn.pytorch.datasets import RollChannels
from bob.learn.pytorch.datasets import ToTensor
from bob.learn.pytorch.datasets import Normalize
from bob.learn.pytorch.architectures import ConditionalGAN_generator
from bob.learn.pytorch.architectures import ConditionalGAN_discriminator
from bob.learn.pytorch.architectures import weights_init
from bob.learn.pytorch.trainers import ConditionalGANTrainer
def main(user_input=None):
# Parse the command-line arguments
if user_input is not None:
arguments = user_input
arguments = sys.argv[1:]
prog = os.path.basename(sys.argv[0])
completions = dict(prog=prog, version=version,)
args = docopt(__doc__ % completions,argv=arguments,version='Train DR-GAN (%s)' % version,)
# verbosity
verbosity_level = args['--verbose']
bob.core.log.set_verbosity_level(logger, verbosity_level)
# get the arguments
noise_dim = int(args['--noise-dim'])
conditional_dim = int(args['--conditional-dim'])
batch_size = int(args['--batch-size'])
epochs = int(args['--epochs'])
sample = int(args['--sample'])
output_dir = str(args['--output-dir'])
seed = int(args['--seed'])
use_gpu = bool(args['--use-gpu'])
images_dir = os.path.join(output_dir, 'samples')
log_dir = os.path.join(output_dir, 'logs')
model_dir = os.path.join(output_dir, 'models')
# process on the arguments / options
if use_gpu:
if torch.cuda.is_available() and not use_gpu:
logger.warn("You have a CUDA device, so you should probably run with --use-gpu")
# ============
# === DATA ===
# ============
# WARNING with the transforms ... act on labels too, at some point, I may have to write my own
# Also, in 'ToTensor', there is a reshape performed from: HxWxC to CxHxW
face_dataset = CasiaDataset(root_dir='/idiap/temp/heusch/data/casia-webface-cropped-64x64-pose-clusters',
RollChannels(), # bob to skimage:
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
dataloader =, batch_size=batch_size, shuffle=True)"There are {} training images".format(len(face_dataset)))
# ===============
# === NETWORK ===
# ===============
ngpu = 1 # usually we don't have more than one GPU
generator = ConditionalGAN_generator(noise_dim, conditional_dim)
generator.apply(weights_init)"Generator architecture: {}".format(generator))
discriminator = ConditionalGAN_discriminator(conditional_dim)
discriminator.apply(weights_init)"Discriminator architecture: {}".format(discriminator))
# ===============
# === TRAINER ===
# ===============
trainer = ConditionalGANTrainer(generator, discriminator, [3, 64, 64], batch_size=batch_size, noise_dim=noise_dim, conditional_dim=conditional_dim, use_gpu=use_gpu, verbosity_level=verbosity_level)
trainer.train(dataloader, n_epochs=epochs, output_dir=output_dir)
......@@ -76,6 +76,7 @@ setup(
'console_scripts': [
' = bob.learn.pytorch.scripts.train_dcgan_multipie:main',
' = bob.learn.pytorch.scripts.train_conditionalgan_multipie:main',
' = bob.learn.pytorch.scripts.train_conditionalgan_casia:main',
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment