Commit 06d5b2df authored by Olegs NIKISINS's avatar Olegs NIKISINS
Browse files

Added dataset class, Conv-AE model, config to train on CelebA, and train script

parent f979eb7d
Pipeline #26220 failed with stage
in 7 minutes and 15 seconds
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
@author: Olegs Nikisins
# Import here:
from torch import nn
# Define the network:
class ConvAutoencoder(nn.Module):
def __init__(self):
super(ConvAutoencoder, self).__init__()
self.encoder = nn.Sequential(nn.Conv2d(3, 16, 5, padding=2),
nn.Conv2d(16, 16, 5, padding=2),
nn.Conv2d(16, 16, 3, padding=2),
nn.Conv2d(16, 16, 3, padding=2),
self.decoder = nn.Sequential(nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1),
nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1),
nn.ConvTranspose2d(16, 16, 5, stride=2, padding=2),
nn.ConvTranspose2d(16, 3, 5, stride=2, padding=2),
nn.ConvTranspose2d(3, 3, 2, stride=1, padding=1),
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
......@@ -7,6 +7,7 @@ from .DCGAN import DCGAN_discriminator
from .ConditionalGAN import ConditionalGAN_generator
from .ConditionalGAN import ConditionalGAN_discriminator
from .ConvAutoencoder import ConvAutoencoder
from .utils import weights_init
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
@author: Olegs Nikisins
# Import here:
from torchvision import transforms
from bob.pad.face.database import CELEBAPadDatabase
from torch import nn
# Define parameters here:
Note: do not change names of the below constants.
NUM_EPOCHS = 70 # Maximum number of epochs
BATCH_SIZE = 32 # Size of the batch
LEARNING_RATE = 1e-3 # Learning rate
NUM_WORKERS = 8 # The number of workers for the DataLoader
Transformations to be applied sequentially to the input PIL image.
Note: the variable name ``transform`` must be the same in all configuration files.
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
Set the parameters of the DataFolder dataset class.
Note: do not change the name ``kwargs``.
bob_hldi_instance = CELEBAPadDatabase(original_directory = "", original_extension = "")
kwargs = {}
kwargs["transform"] = transform
kwargs["extension"] = '.hdf5'
kwargs["bob_hldi_instance"] = bob_hldi_instance
kwargs["hldi_type"] = "pad"
kwargs["groups"] = ['train']
kwargs["protocol"] = 'grandtest'
kwargs["purposes"] = ['real']
kwargs["allow_missing_files"] = True
Define the network to be trained as a class, named ``Network``.
Note: Do not change the name of the below class.
from bob.learn.pytorch.architectures import ConvAutoencoder as Network
Define the loss to be used for training.
Note: do not change the name of the below variable.
loss_type = nn.MSELoss()
OPTIONAL: if not defined loss will be computed in the training script.
See training script for details
Define the function to compute the loss. Don't change the signature of this
# we don't define the loss_function for this configuration
#def loss_function(output, img, target):
from .casia_webface import CasiaDataset
from .casia_webface import CasiaWebFaceDataset
from .data_folder import DataFolder
# transforms
from .utils import FaceCropper
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
@author: Olegs Nikisins
# Import what is needed here:
import as data
import os
import random
random.seed( a = 7 )
import PIL
import numpy as np
from torchvision import transforms
import torch
import h5py
def get_file_names_and_labels(files, data_folder, extension = ".hdf5", hldi_type = "pad"):
Get absolute names of the corresponding file objects and their class labels,
as well as keys defining name of the frame to load the data from.
``files`` : [File]
A list of files objects defined in the High Level Database Interface
of the particular datbase.
``data_folder`` : str
A directory containing the training data.
``extension`` : str
Extension of the data files. Default: ".hdf5" .
``hldi_type`` : str
Type of the high level database interface. Default: "pad".
Note: this is the only type supported at the moment.
``file_names_labels_keys`` : [(str, int, str)]
A list of tuples, where each tuple contain an absolute filename,
a corresponding label of the class, and a key defining the name of the
frame to extract the data from.
file_names_labels_keys = []
if hldi_type == "pad":
for f in files:
if f.attack_type is None:
label = 0
label = 1
file_name = os.path.join(data_folder, f.path + extension)
if os.path.isfile(file_name): # if file is available:
with h5py.File(file_name, "r") as f_h5py:
file_keys = list(f_h5py.keys())
# elements of tuples in the below list are as follows:
# a filename a key is extracted from,
# a label corresponding to the file,
# a key defining a frame from the file.
file_names_labels_keys = file_names_labels_keys + [(file_name, label, key) for file_name, label, key in
zip([file_name]*len(file_keys), [label]*len(file_keys), file_keys)]
return file_names_labels_keys
class DataFolder(data.Dataset):
A generic data loader compatible with Bob High Level Database Interfaces
(HLDI). Only HLDI's of bob.pad.face are currently supported.
def __init__(self, data_folder,
transform = None,
extension = '.hdf5',
bob_hldi_instance = None,
hldi_type = "pad",
groups = ['train', 'dev', 'eval'],
protocol = 'grandtest',
purposes=['real', 'attack'],
allow_missing_files = True,
``data_folder`` : str
A directory containing the training data.
``transform`` : callable
A function/transform that takes in a PIL image, and returns a
transformed version. E.g, ``transforms.RandomCrop``. Default: None.
``extension`` : str
Extension of the data files. Default: ".hdf5".
Note: this is the only extension supported at the moment.
``bob_hldi_instance`` : object
An instance of the HLDI interface. Only HLDI's of bob.pad.face
are currently supported.
``hldi_type`` : str
String defining the type of the HLDI. Default: "pad".
Note: this is the only option currently supported.
``groups`` : str or [str]
The groups for which the clients should be returned.
Usually, groups are one or more elements of ['train', 'dev', 'eval'].
Default: ['train', 'dev', 'eval'].
``protocol`` : str
The protocol for which the clients should be retrieved.
Default: 'grandtest'.
``purposes`` : str or [str]
The purposes for which File objects should be retrieved.
Usually it is either 'real' or 'attack'.
Default: ['real', 'attack'].
``allow_missing_files`` : str or [str]
The missing files in the ``data_folder`` will not break the
execution if set to True.
Default: True.
self.data_folder = data_folder
self.transform = transform
self.extension = extension
self.bob_hldi_instance = bob_hldi_instance
self.hldi_type = hldi_type
self.groups = groups
self.protocol = protocol
self.purposes = purposes
self.allow_missing_files = allow_missing_files
if bob_hldi_instance is not None:
files = bob_hldi_instance.objects(groups = self.groups,
protocol = self.protocol,
purposes = self.purposes,
file_names_labels_keys = get_file_names_and_labels(files = files,
data_folder = self.data_folder,
extension = self.extension,
hldi_type = self.hldi_type)
if self.allow_missing_files: # return only existing files
file_names_labels_keys = [f for f in file_names_labels_keys if os.path.isfile(f[0])]
# TODO - add behaviour similar to image folder
file_names_labels_keys = []
self.file_names_labels_keys = file_names_labels_keys
def __getitem__(self, index):
Returns an image, possibly transformed, and a target class given index.
``index`` : int.
An index of the sample to return.
``pil_img`` : Tensor or PIL Image
If ``self.transform`` is defined the output is the torch.Tensor,
otherwise the output is an instance of the PIL.Image.Image class.
``target`` : int
Index of the class.
path, target, key = self.file_names_labels_keys[index]
with h5py.File(path, "r") as f_h5py:
img_array = np.array(f_h5py.get(key+'/array')) # The size now is (3 x W x H)
if isinstance(self.transform, transforms.Compose): # if an instance of torchvision composed transformation
if len(img_array.shape) == 3: # for color images
img_array_tr = np.swapaxes(img_array, 1, 2)
img_array_tr = np.swapaxes(img_array_tr, 0, 2)
pil_img = PIL.Image.fromarray( img_array_tr ) # convert to PIL from array of size (H x W x 3)
else: # for gray-scale images
pil_img = PIL.Image.fromarray( img_array, 'L' ) # convert to PIL from array of size (H x W)
if self.transform is not None:
pil_img = self.transform(pil_img)
else: # if custom transformation function is given
img_array_transformed = self.transform(img_array)
return torch.Tensor(img_array_transformed).unsqueeze(0), target # convert array to Tensor, also return target
return pil_img, target
def __len__(self):
``len`` : int
The length of the file list.
return len(self.file_names_labels_keys)
......@@ -31,7 +31,7 @@ import argparse
import importlib
import os
from bob.pad.face.database.pytorch import DataFolder
from bob.learn.pytorch.datasets import DataFolder
import torch
from import DataLoader
......@@ -91,10 +91,10 @@ def parse_arguments(cmd_params=None):
parser.add_argument("-c", "--config-file", type=str, help="Relative name of the config file defining "
"the network, training data, and training parameters.",
default = "autoencoder/")
default = "autoencoder/")
parser.add_argument("-cg", "--config-group", type=str, help="Name of the group, where config file is stored.",
default = "bob.pad.face.config.pytorch")
default = "bob.learn.pytorch.config")
parser.add_argument("-p", "--pretrained-model-path", type=str, help="Absolute name of the file, containing pre-trained Network "
"model, to de used for Network initialization before training.",
......@@ -72,7 +72,7 @@ setup(
' = bob.learn.pytorch.scripts.train_cnn:main',
' = bob.learn.pytorch.scripts.train_dcgan:main',
' = bob.learn.pytorch.scripts.train_conditionalgan:main',
' = bob.pad.face.script.pytorch.pytorch_train:main',
' = bob.learn.pytorch.scripts.pytorch_train:main',
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment