Skip to content
Snippets Groups Projects

Fine-tuning of autoencoders on multi-channel data

Merged Olegs NIKISINS requested to merge ae_mc_finetune into master
4 files
+ 252
1
Compare changes
  • Side-by-side
  • Inline
Files
4
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
@author: Olegs Nikisins
"""
#==============================================================================
# Import here:
from torchvision import transforms
from bob.pad.face.database import BatlPadDatabase
from torch import nn
#==============================================================================
# Define parameters here:
"""
Note: do not change names of the below constants.
"""
NUM_EPOCHS = 50 # Maximum number of epochs
BATCH_SIZE = 32 # Size of the batch
LEARNING_RATE = 1e-3 # Learning rate
NUM_WORKERS = 8 # The number of workers for the DataLoader
"""
Transformations to be applied sequentially to the input PIL image.
Note: the variable name ``transform`` must be the same in all configuration files.
"""
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
"""
Set the parameters of the DataFolder dataset class.
Note: do not change the name ``kwargs``.
"""
ORIGINAL_DIRECTORY = ""
ORIGINAL_EXTENSION = ".h5" # extension of the data files
PROTOCOL = 'grandtest-color*infrared*depth-10' # use 10 frames for PAD experiments
annotations_temp_dir = ""
bob_hldi_instance = BatlPadDatabase(protocol=PROTOCOL,
original_directory=ORIGINAL_DIRECTORY,
original_extension=ORIGINAL_EXTENSION,
annotations_temp_dir = annotations_temp_dir,
landmark_detect_method="mtcnn", # detect annotations using mtcnn
exclude_attacks_list=['makeup'],
exclude_pai_all_sets=True, # exclude makeup from all the sets, which is the default behavior for grandtest protocol
append_color_face_roi_annot=False)
kwargs = {}
kwargs["data_folder"] = "NO NEED TO SET HERE, WILL BE SET IN THE TRAINING SCRIPT"
kwargs["transform"] = transform
kwargs["extension"] = '.hdf5'
kwargs["bob_hldi_instance"] = bob_hldi_instance
kwargs["hldi_type"] = "pad"
kwargs["groups"] = ['train']
kwargs["protocol"] = 'grandtest'
kwargs["purposes"] = ['real']
kwargs["allow_missing_files"] = True
"""
Define the network to be trained as a class, named ``Network``.
Note: Do not change the name of the below class.
"""
from bob.learn.pytorch.architectures import ConvAutoencoder as Network
"""
Define the loss to be used for training.
Note: do not change the name of the below variable.
"""
loss_type = nn.MSELoss()
"""
OPTIONAL: if not defined loss will be computed in the training script.
See training script for details
Define the function to compute the loss. Don't change the signature of this
function.
"""
# we don't define the loss_function for this configuration
#def loss_function(output, img, target):
Loading