Commit c75f9631 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

Merge branch 'autoencoder_pretrain' into 'master'

autoencoders pretraining using RGB faces

See merge request !6
parents afaf968f 7e723289
Pipeline #26283 passed with stages
in 7 minutes and 46 seconds
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
@author: Olegs Nikisins
"""
#==============================================================================
# Import here:
from torch import nn
#==============================================================================
# Define the network:
class ConvAutoencoder(nn.Module):
def __init__(self):
super(ConvAutoencoder, self).__init__()
self.encoder = nn.Sequential(nn.Conv2d(3, 16, 5, padding=2),
nn.ReLU(True),
nn.MaxPool2d(2),
nn.Conv2d(16, 16, 5, padding=2),
nn.ReLU(True),
nn.MaxPool2d(2),
nn.Conv2d(16, 16, 3, padding=2),
nn.ReLU(True),
nn.MaxPool2d(2),
nn.Conv2d(16, 16, 3, padding=2),
nn.ReLU(True),
nn.MaxPool2d(2))
self.decoder = nn.Sequential(nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1),
nn.ReLU(True),
nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1),
nn.ReLU(True),
nn.ConvTranspose2d(16, 16, 5, stride=2, padding=2),
nn.ReLU(True),
nn.ConvTranspose2d(16, 3, 5, stride=2, padding=2),
nn.ReLU(True),
nn.ConvTranspose2d(3, 3, 2, stride=1, padding=1),
nn.Tanh())
def forward(self, x):
"""
The forward method.
"""
x = self.encoder(x)
x = self.decoder(x)
return x
......@@ -7,6 +7,7 @@ from .DCGAN import DCGAN_discriminator
from .ConditionalGAN import ConditionalGAN_generator
from .ConditionalGAN import ConditionalGAN_discriminator
from .ConvAutoencoder import ConvAutoencoder
from .utils import weights_init
......
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
@author: Olegs Nikisins
"""
#==============================================================================
# Import here:
from torchvision import transforms
from bob.pad.face.database import CELEBAPadDatabase
from torch import nn
#==============================================================================
# Define parameters here:
"""
Note: do not change names of the below constants.
"""
NUM_EPOCHS = 70 # Maximum number of epochs
BATCH_SIZE = 32 # Size of the batch
LEARNING_RATE = 1e-3 # Learning rate
NUM_WORKERS = 8 # The number of workers for the DataLoader
"""
Transformations to be applied sequentially to the input PIL image.
Note: the variable name ``transform`` must be the same in all configuration files.
"""
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
"""
Set the parameters of the DataFolder dataset class.
Note: do not change the name ``kwargs``.
"""
bob_hldi_instance = CELEBAPadDatabase(original_directory = "", original_extension = "")
kwargs = {}
kwargs["data_folder"] = "NO NEED TO SET HERE, WILL BE SET IN THE TRAINING SCRIPT"
kwargs["transform"] = transform
kwargs["extension"] = '.hdf5'
kwargs["bob_hldi_instance"] = bob_hldi_instance
kwargs["hldi_type"] = "pad"
kwargs["groups"] = ['train']
kwargs["protocol"] = 'grandtest'
kwargs["purposes"] = ['real']
kwargs["allow_missing_files"] = True
"""
Define the network to be trained as a class, named ``Network``.
Note: Do not change the name of the below class.
"""
from bob.learn.pytorch.architectures import ConvAutoencoder as Network
"""
Define the loss to be used for training.
Note: do not change the name of the below variable.
"""
loss_type = nn.MSELoss()
"""
OPTIONAL: if not defined loss will be computed in the training script.
See training script for details
Define the function to compute the loss. Don't change the signature of this
function.
"""
# we don't define the loss_function for this configuration
#def loss_function(output, img, target):
from .casia_webface import CasiaDataset
from .casia_webface import CasiaWebFaceDataset
from .data_folder import DataFolder
# transforms
from .utils import FaceCropper
......@@ -11,7 +12,5 @@ from .utils import Resize
from .utils import map_labels
from .utils import ConcatDataset
# gets sphinx autodoc done right - don't remove it
__all__ = [_ for _ in dir() if not _.startswith('_')]
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
@author: Olegs Nikisins
"""
#==============================================================================
# Import what is needed here:
import torch.utils.data as data
import os
import random
random.seed( a = 7 )
import PIL
import numpy as np
from torchvision import transforms
import torch
import h5py
#==============================================================================
def get_file_names_and_labels(files, data_folder, extension = ".hdf5", hldi_type = "pad"):
"""
Get absolute names of the corresponding file objects and their class labels,
as well as keys defining name of the frame to load the data from.
Attributes
----------
files : [File]
A list of files objects defined in the High Level Database Interface
of the particular datbase.
data_folder : str
A directory containing the training data.
extension : str
Extension of the data files. Default: ".hdf5" .
hldi_type : str
Type of the high level database interface. Default: "pad".
Note: this is the only type supported at the moment.
Returns
-------
file_names_labels_keys : [(str, int, str)]
A list of tuples, where each tuple contain an absolute filename,
a corresponding label of the class, and a key defining the name of the
frame to extract the data from.
"""
file_names_labels_keys = []
if hldi_type == "pad":
for f in files:
if f.attack_type is None:
label = 0
else:
label = 1
file_name = os.path.join(data_folder, f.path + extension)
if os.path.isfile(file_name): # if file is available:
with h5py.File(file_name, "r") as f_h5py:
file_keys = list(f_h5py.keys())
# elements of tuples in the below list are as follows:
# a filename a key is extracted from,
# a label corresponding to the file,
# a key defining a frame from the file.
file_names_labels_keys = file_names_labels_keys + [(file_name, label, key) for file_name, label, key in
zip([file_name]*len(file_keys), [label]*len(file_keys), file_keys)]
return file_names_labels_keys
#==============================================================================
class DataFolder(data.Dataset):
"""
A generic data loader compatible with Bob High Level Database Interfaces
(HLDI). Only HLDI's of bob.pad.face are currently supported.
Attributes
----------
data_folder : str
A directory containing the training data.
transform : object
A function/transform that takes in a PIL image, and returns a
transformed version. E.g, ``transforms.RandomCrop``. Default: None.
extension : str
Extension of the data files. Default: ".hdf5".
Note: this is the only extension supported at the moment.
bob_hldi_instance : object
An instance of the HLDI interface. Only HLDI's of bob.pad.face
are currently supported.
hldi_type : str
String defining the type of the HLDI. Default: "pad".
Note: this is the only option currently supported.
groups : str or [str]
The groups for which the clients should be returned.
Usually, groups are one or more elements of ['train', 'dev', 'eval'].
Default: ['train', 'dev', 'eval'].
protocol : str
The protocol for which the clients should be retrieved.
Default: 'grandtest'.
purposes : str or [str]
The purposes for which File objects should be retrieved.
Usually it is either 'real' or 'attack'.
Default: ['real', 'attack'].
allow_missing_files : bool
The missing files in the ``data_folder`` will not break the
execution if set to True.
Default: True.
"""
def __init__(self, data_folder,
transform = None,
extension = '.hdf5',
bob_hldi_instance = None,
hldi_type = "pad",
groups = ['train', 'dev', 'eval'],
protocol = 'grandtest',
purposes=['real', 'attack'],
allow_missing_files = True,
**kwargs):
"""
Attributes
----------
data_folder : str
A directory containing the training data.
transform : object
A function/transform that takes in a PIL image, and returns a
transformed version. E.g, ``transforms.RandomCrop``. Default: None.
extension : str
Extension of the data files. Default: ".hdf5".
Note: this is the only extension supported at the moment.
bob_hldi_instance : object
An instance of the HLDI interface. Only HLDI's of bob.pad.face
are currently supported.
hldi_type : str
String defining the type of the HLDI. Default: "pad".
Note: this is the only option currently supported.
groups : str or [str]
The groups for which the clients should be returned.
Usually, groups are one or more elements of ['train', 'dev', 'eval'].
Default: ['train', 'dev', 'eval'].
protocol : str
The protocol for which the clients should be retrieved.
Default: 'grandtest'.
purposes : str or [str]
The purposes for which File objects should be retrieved.
Usually it is either 'real' or 'attack'.
Default: ['real', 'attack'].
allow_missing_files : bool
The missing files in the ``data_folder`` will not break the
execution if set to True.
Default: True.
"""
self.data_folder = data_folder
self.transform = transform
self.extension = extension
self.bob_hldi_instance = bob_hldi_instance
self.hldi_type = hldi_type
self.groups = groups
self.protocol = protocol
self.purposes = purposes
self.allow_missing_files = allow_missing_files
if bob_hldi_instance is not None:
files = bob_hldi_instance.objects(groups = self.groups,
protocol = self.protocol,
purposes = self.purposes,
**kwargs)
file_names_labels_keys = get_file_names_and_labels(files = files,
data_folder = self.data_folder,
extension = self.extension,
hldi_type = self.hldi_type)
if self.allow_missing_files: # return only existing files
file_names_labels_keys = [f for f in file_names_labels_keys if os.path.isfile(f[0])]
else:
# TODO - add behaviour similar to image folder
file_names_labels_keys = []
self.file_names_labels_keys = file_names_labels_keys
#==========================================================================
def __getitem__(self, index):
"""
Returns an image, possibly transformed, and a target class given index.
Attributes
----------
index : int.
An index of the sample to return.
Returns
-------
pil_img : Tensor or PIL Image
If ``self.transform`` is defined the output is the torch.Tensor,
otherwise the output is an instance of the PIL.Image.Image class.
target : int
Index of the class.
"""
path, target, key = self.file_names_labels_keys[index]
with h5py.File(path, "r") as f_h5py:
img_array = np.array(f_h5py.get(key+'/array')) # The size now is (3 x W x H)
if isinstance(self.transform, transforms.Compose): # if an instance of torchvision composed transformation
if len(img_array.shape) == 3: # for color images
img_array_tr = np.swapaxes(img_array, 1, 2)
img_array_tr = np.swapaxes(img_array_tr, 0, 2)
pil_img = PIL.Image.fromarray( img_array_tr ) # convert to PIL from array of size (H x W x 3)
else: # for gray-scale images
pil_img = PIL.Image.fromarray( img_array, 'L' ) # convert to PIL from array of size (H x W)
if self.transform is not None:
pil_img = self.transform(pil_img)
else: # if custom transformation function is given
img_array_transformed = self.transform(img_array)
return torch.Tensor(img_array_transformed).unsqueeze(0), target # convert array to Tensor, also return target
return pil_img, target
#==========================================================================
def __len__(self):
"""
Returns
-------
len : int
The length of the file list.
"""
return len(self.file_names_labels_keys)
......@@ -26,7 +26,7 @@ class FaceCropper():
cropped = self.face_cropper(sample['image'], sample['eyes'])
sample['image'] = cropped
return sample
class RollChannels(object):
"""
......@@ -41,7 +41,7 @@ class RollChannels(object):
class ToTensor(object):
def __init__(self):
self.op = transforms.ToTensor()
def __call__(self, sample):
sample['image'] = self.op(sample['image'])
return sample
......@@ -70,14 +70,14 @@ class Resize(object):
def map_labels(raw_labels, start_index=0):
"""
Map the ID label to [0 - # of IDs]
Map the ID label to [0 - # of IDs]
"""
possible_labels = list(set(raw_labels))
labels = numpy.array(raw_labels)
for i in range(len(possible_labels)):
l = possible_labels[i]
labels[numpy.where(labels==l)[0]] = i + start_index
labels[numpy.where(labels==l)[0][0]] = i + start_index
# -----
# map back to native int, resolve the problem with dataset concatenation
......@@ -86,7 +86,7 @@ def map_labels(raw_labels, start_index=0):
labels_int = []
for i in range(len(labels)):
labels_int.append(labels[i].item())
return labels_int
......@@ -105,12 +105,12 @@ class ConcatDataset(Dataset):
The list of datasets (as torch.utils.data.Dataset)
"""
def __init__(self, datasets):
self.transform = datasets[0].transform
self.data_files = sum((d.data_files for d in datasets), [])
self.pose_labels = sum((d.pose_labels for d in datasets), [])
self.id_labels = sum((d.id_labels for d in datasets), [])
def __len__(self):
"""
return the length of the dataset (i.e. nb of examples)
......
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
The following steps are performed in this script:
1. The command line arguments are first parsed.
2. Folder to save the results to is created.
3. Configuration file specifying the Network and learning parameters is
loaded.
4. A generic data loader compatible with Bob High Level Database
Interfaces, namely DataFolder, is initialized.
5. The Network is initialized, can also be initialized with pre-trained
model.
6. The training is performed. Verbosity flag can be used to see and save
training related outputs. See ``process_verbosity`` function for
more details.
7. The model is saved after each 1 epochs.
@author: Olegs Nikisins
"""
#==============================================================================
# Import here:
import argparse
import importlib
import os
from bob.learn.pytorch.datasets import DataFolder
import torch
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torchvision.utils import save_image
import logging
logger = logging.getLogger("bob.learn.pytorch")
import numpy as np
import time
#==============================================================================
def parse_arguments(cmd_params=None):
"""
Parse command line arguments.
**Parameters:**
``cmd_params``: []
An optional list of command line arguments. Default: None.
**Returns:**
``data_folder``: py:class:`string`
A directory containing the training data.
``save_folder``: py:class:`string`
A directory to save the results of training to.
``relative_mod_name``: py:class:`string`
Relative name of the module to import configurations from.
``config_group``: py:class:`string`
Group/package name containing the configuration file.
``pretrained_model_path``: py:class:`string`
Absolute name of the file, containing pre-trained Network
model, to de used for Network initialization before training.
``cross_validate``: bool
Cross-validate the current model on the dev set of the database used
for training. Cross validation is done after each training epoch, using
entire development set of the database.
``verbosity``: py:class:`int`
Verbosity level.
"""
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("data_folder", type=str,
help="A directory containing the training data.")
parser.add_argument("save_folder", type=str,
help="A directory to save the results of training to.")
parser.add_argument("-c", "--config-file", type=str, help="Relative name of the config file defining "
"the network, training data, and training parameters.",
default = "autoencoder/net1_celeba.py")
parser.add_argument("-cg", "--config-group", type=str, help="Name of the group, where config file is stored.",
default = "bob.learn.pytorch.config")