diff --git a/bob/learn/pytorch/datasets/__init__.py b/bob/learn/pytorch/datasets/__init__.py index b03b5fb0212744b085f0ef6254fce7c76a18479b..5772fc337c6d9e06bd7c12f78afcf4fbf13cf8d8 100644 --- a/bob/learn/pytorch/datasets/__init__.py +++ b/bob/learn/pytorch/datasets/__init__.py @@ -1,4 +1,5 @@ from .multipie import MultiPIEDataset +from .casia_webface import CasiaDataset # transforms from .utils import RollChannels diff --git a/bob/learn/pytorch/datasets/casia_webface.py b/bob/learn/pytorch/datasets/casia_webface.py new file mode 100644 index 0000000000000000000000000000000000000000..955e0e4e3424c7c009a8c8804a92d56930978b27 --- /dev/null +++ b/bob/learn/pytorch/datasets/casia_webface.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python +# encoding: utf-8 + +import os + +import torch +import numpy + +from torch.utils.data import Dataset, DataLoader + +import bob.db.casia_webface +import bob.io.base +import bob.io.image + +from .utils import map_labels + +class CasiaDataset(Dataset): + """Casia WebFace dataset. + + Class representing the CASIA WebFace dataset + + **Parameters** + + root-dir: path + The path to the data + + frontal_only: boolean + If you want to only use frontal faces + + transform: torchvision.transforms + The transform(s) to apply to the face images + """ + + # TODO: Start from original data and annotations - Guillaume HEUSCH, 06-11-2017 + + def __init__(self, root_dir, frontal_only=False, transform=None): + self.root_dir = root_dir + self.transform = transform + + dir_to_pose_label = {'l90': '0', + 'l75': '1', + 'l60': '2', + 'l45': '3', + 'l30': '4', + 'l15': '5', + '0' : '6', + 'r15': '7', + 'r30': '8', + 'r45': '9', + 'r60': '10', + 'r75': '11', + 'r90': '12', + } + + # get all the needed file, the pose labels, and the id labels + self.data_files = [] + self.pose_labels = [] + id_labels = [] + + for root, dirs, files in os.walk(self.root_dir): + for name in files: + filename = os.path.split(os.path.join(root, name))[-1] + path = root.split(os.sep) + subject = int(path[-1]) + cluster = path[-2] + self.data_files.append(os.path.join(root, name)) + self.pose_labels.append(dir_to_pose_label[cluster]) + id_labels.append(subject) + + self.id_labels = map_labels(id_labels) + + + def __len__(self): + """ + return the length of the dataset (i.e. nb of examples) + """ + return len(self.data_files) + + + def __getitem__(self, idx): + """ + return a sample from the dataset + """ + image = bob.io.base.load(self.data_files[idx]) + identity = self.id_labels[idx] + pose = self.pose_labels[idx] + sample = {'image': image, 'id': identity, 'pose': pose} + + if self.transform: + sample = self.transform(sample) + + return sample diff --git a/bob/learn/pytorch/scripts/train_conditionalgan_casia.py b/bob/learn/pytorch/scripts/train_conditionalgan_casia.py new file mode 100644 index 0000000000000000000000000000000000000000..84b28ee1559b79cf2fe758976d2e6d9424e3fb3a --- /dev/null +++ b/bob/learn/pytorch/scripts/train_conditionalgan_casia.py @@ -0,0 +1,145 @@ +#!/usr/bin/env python +# encoding: utf-8 + + +""" Train a Conditional GAN + +Usage: + %(prog)s [--noise-dim=<int>] [--conditional-dim=<int>] + [--batch-size=<int>] [--epochs=<int>] [--sample=<int>] + [--output-dir=<path>] [--use-gpu] [--seed=<int>] [--verbose ...] + +Options: + -h, --help Show this screen. + -V, --version Show version. + -n, --noise-dim=<int> The dimension of the noise [default: 100] + -c, --conditional-dim=<int> The dimension of the conditional variable [default: 13] + -b, --batch-size=<int> The size of your mini-batch [default: 64] + -e, --epochs=<int> The number of training epochs [default: 100] + -s, --sample=<int> Save generated images at every 'sample' batch iteration [default: 100000000000] + -o, --output-dir=<path> Dir to save the logs, models and images [default: ./cgan-casia] + -g, --use-gpu Use the GPU + -S, --seed=<int> The random seed [default: 3] + -v, --verbose Increase the verbosity (may appear multiple times). + +Example: + + To run the training process + + $ %(prog)s --batch-size 64 --epochs 25 --output-dir drgan + +See '%(prog)s --help' for more information. + +""" + +import os, sys +import pkg_resources + +import bob.core +logger = bob.core.log.setup("bob.learn.pytorch") + +from docopt import docopt + +version = pkg_resources.require('bob.learn.pytorch')[0].version + +import numpy +import bob.io.base + +# torch +import torch +import torch.nn as nn +import torch.optim as optim +import torchvision.transforms as transforms +import torchvision.utils as vutils +from torch.autograd import Variable + +# data and architecture from the package +from bob.learn.pytorch.datasets import CasiaDataset +from bob.learn.pytorch.datasets import RollChannels +from bob.learn.pytorch.datasets import ToTensor +from bob.learn.pytorch.datasets import Normalize + +from bob.learn.pytorch.architectures import ConditionalGAN_generator +from bob.learn.pytorch.architectures import ConditionalGAN_discriminator +from bob.learn.pytorch.architectures import weights_init + +from bob.learn.pytorch.trainers import ConditionalGANTrainer + + +def main(user_input=None): + + # Parse the command-line arguments + if user_input is not None: + arguments = user_input + else: + arguments = sys.argv[1:] + + prog = os.path.basename(sys.argv[0]) + completions = dict(prog=prog, version=version,) + args = docopt(__doc__ % completions,argv=arguments,version='Train DR-GAN (%s)' % version,) + + # verbosity + verbosity_level = args['--verbose'] + bob.core.log.set_verbosity_level(logger, verbosity_level) + + # get the arguments + noise_dim = int(args['--noise-dim']) + conditional_dim = int(args['--conditional-dim']) + batch_size = int(args['--batch-size']) + epochs = int(args['--epochs']) + sample = int(args['--sample']) + output_dir = str(args['--output-dir']) + seed = int(args['--seed']) + use_gpu = bool(args['--use-gpu']) + + images_dir = os.path.join(output_dir, 'samples') + log_dir = os.path.join(output_dir, 'logs') + model_dir = os.path.join(output_dir, 'models') + + # process on the arguments / options + torch.manual_seed(seed) + if use_gpu: + torch.cuda.manual_seed_all(seed) + if torch.cuda.is_available() and not use_gpu: + logger.warn("You have a CUDA device, so you should probably run with --use-gpu") + bob.io.base.create_directories_safe(images_dir) + bob.io.base.create_directories_safe(log_dir) + bob.io.base.create_directories_safe(images_dir) + + # ============ + # === DATA === + # ============ + + # WARNING with the transforms ... act on labels too, at some point, I may have to write my own + # Also, in 'ToTensor', there is a reshape performed from: HxWxC to CxHxW + face_dataset = CasiaDataset(root_dir='/idiap/temp/heusch/data/casia-webface-cropped-64x64-pose-clusters', + frontal_only=False, + transform=transforms.Compose([ + RollChannels(), # bob to skimage: + ToTensor(), + Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + ]) + ) + + dataloader = torch.utils.data.DataLoader(face_dataset, batch_size=batch_size, shuffle=True) + logger.info("There are {} training images".format(len(face_dataset))) + + # =============== + # === NETWORK === + # =============== + ngpu = 1 # usually we don't have more than one GPU + + generator = ConditionalGAN_generator(noise_dim, conditional_dim) + generator.apply(weights_init) + logger.info("Generator architecture: {}".format(generator)) + + discriminator = ConditionalGAN_discriminator(conditional_dim) + discriminator.apply(weights_init) + logger.info("Discriminator architecture: {}".format(discriminator)) + + + # =============== + # === TRAINER === + # =============== + trainer = ConditionalGANTrainer(generator, discriminator, [3, 64, 64], batch_size=batch_size, noise_dim=noise_dim, conditional_dim=conditional_dim, use_gpu=use_gpu, verbosity_level=verbosity_level) + trainer.train(dataloader, n_epochs=epochs, output_dir=output_dir) diff --git a/setup.py b/setup.py index 616399401205ac70585c9918183f92e6b4589d6f..1430ca3bc9561143e2bcfe6753390afa585b5fce 100644 --- a/setup.py +++ b/setup.py @@ -76,6 +76,7 @@ setup( 'console_scripts': [ 'train_dcgan_multipie.py = bob.learn.pytorch.scripts.train_dcgan_multipie:main', 'train_conditionalgan_multipie.py = bob.learn.pytorch.scripts.train_conditionalgan_multipie:main', + 'train_conditionalgan_casia.py = bob.learn.pytorch.scripts.train_conditionalgan_casia:main', ],