Skip to content
Snippets Groups Projects
Commit 7c548570 authored by Anjith GEORGE's avatar Anjith GEORGE Committed by Anjith GEORGE
Browse files

Added docs

parent da36d3a7
No related branches found
No related tags found
1 merge request!18MCCNN trainer
......@@ -28,6 +28,84 @@ All the parameters required to train MCCNN are defined in the configuration file
The ``config.py`` file should contain atleast the network definition and the dataset class to be used for training.
It can also define the transforms, number of channels in mccnn, training parameters such as number of epochs, learning rate and so on.
a. Structure of the config file
An example configuration file to train MCCNN with WMCA dataset is shown below
.. code-block:: python
from torchvision import transforms
from bob.learn.pytorch.architectures import MCCNN
from bob.learn.pytorch.datasets import DataFolder
from bob.pad.face.database import BatlPadDatabase
from bob.learn.pytorch.datasets import ChannelSelect
### initialize bob database instance ###
data_folder_train='<PREPROCESSED_FOLDER>'
frames=50
extension='.h5'
train_groups=['train'] # only 'train' group is used for training the network
protocols="grandtest-color*depth*infrared*thermal-{}".format(frames) # makeup is excluded anyway here
exlude_attacks_list=["makeup"]
bob_hldi_instance_train = BatlPadDatabase(
protocol=protocols,
original_directory=data_folder_train,
original_extension=extension,
landmark_detect_method="mtcnn", # detect annotations using mtcnn
exclude_attacks_list=exlude_attacks_list,
exclude_pai_all_sets=True, # exclude makeup from all the sets, which is the default behavior for grandtest protocol
append_color_face_roi_annot=False)
#==============================================================================
# Initialize the torch dataset, subselect channels from the pretrained files if needed.
SELECTED_CHANNELS = [0,1,2] # selects only color, depth and infrared
img_transform_train = transforms.Compose([ChannelSelect(selected_channels = SELECTED_CHANNELS),transforms.ToPILImage(),transforms.RandomHorizontalFlip(),transforms.ToTensor()])# Add p=0.5 later
dataset = DataFolder(data_folder=data_folder_train,
transform=img_transform_train,
extension='.hdf5',
bob_hldi_instance=bob_hldi_instance_train,
groups=train_groups,
protocol=protocols,
purposes=['real', 'attack'],
allow_missing_files=True)
#==============================================================================
# Load the architecture
NUM_CHANNELS = 3
network=MCCNN(num_channels = NUM_CHANNELS)
#==============================================================================
# Specify other training parameters
batch_size = 64
epochs=25
learning_rate=0.0001
seed = 3
output_dir = 'training_mccn'
use_gpu = False
adapted_layers = 'conv1-ffc'
adapt_reference_channel = False
verbose = 2
Once the config file is defined, training the network can be done with the following code:
.. code-block:: sh
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment