Commit abf14850 authored by Anjith GEORGE's avatar Anjith GEORGE

Adds more unit tests and updated doc for mccnn

parent 93599179
Pipeline #27518 failed with stage
in 62 minutes and 5 seconds
......@@ -220,6 +220,7 @@ class DummyDataSetMCCNN(Dataset):
sample = data, label
return sample
def test_MCCNNtrainer():
from ..architectures import MCCNN
......@@ -229,13 +230,44 @@ def test_MCCNNtrainer():
dataloader['train'] = torch.utils.data.DataLoader(DummyDataSetMCCNN(), batch_size=32, shuffle=True)
from ..trainers import MCCNNTrainer
trainer = MCCNNTrainer(net, verbosity_level=3)
trainer = MCCNNTrainer(net, verbosity_level=3, do_crossvalidation=False)
trainer.train(dataloader, n_epochs=1, output_dir='.')
import os
assert os.path.isfile('model_1_0.pth')
os.remove('model_1_0.pth')
class DummyDataSetFASNet(Dataset):
def __init__(self):
pass
def __len__(self):
return 100
def __getitem__(self, idx):
data = numpy.random.rand(3, 224,224).astype("float32")
label = numpy.random.randint(2)
sample = data, label
return sample
def test_FASNettrainer():
from ..architectures import FASNet
net = FASNet()
dataloader={}
dataloader['train'] = torch.utils.data.DataLoader(DummyDataSetFASNet(), batch_size=32, shuffle=True)
from ..trainers import FASNetTrainer
trainer = FASNetTrainer(net, verbosity_level=3,do_crossvalidation=False)
trainer.train(dataloader, n_epochs=1, output_dir='.')
import os
assert os.path.isfile('model_1_0.pth')
os.remove('model_1_0.pth')
class DummyDataSetGAN(Dataset):
def __init__(self):
pass
......@@ -357,6 +389,14 @@ def test_extractors():
data = numpy.random.rand(4, 128, 128).astype("float32")
output = extractor(data)
assert output.shape[0] == 1
# FASNet
from ..extractor.image import FASNetExtractor
extractor = FASNetExtractor(num_channels_used=4)
# this architecture expects RGB images of size 3x224x224 channel images
data = numpy.random.rand(3, 224, 224).astype("float32")
output = extractor(data)
assert output.shape[0] == 1
def test_two_layer_mlp():
"""
......
......@@ -37,7 +37,6 @@ An example configuration file to train MCCNN with WMCA dataset is shown below
.. code-block:: python
from torchvision import transforms
from bob.learn.pytorch.architectures import MCCNN
......@@ -46,66 +45,119 @@ An example configuration file to train MCCNN with WMCA dataset is shown below
from bob.pad.face.database import BatlPadDatabase
from bob.learn.pytorch.datasets import ChannelSelect
from bob.learn.pytorch.datasets import ChannelSelect, RandomHorizontalFlipImage
#==============================================================================
# Load the dataset
""" The steps are as follows
### initialize bob database instance ###
1. Initialize a databae instance, with the protocol, groups and number of frames
(currently for the ones in 'bob.pad.face', and point 'data_folder_train' to the preprocessed directory )
Note: Here we assume that we have already preprocessed the with `spoof.py` script and dumped it to location
pointed to by 'data_folder_train'.
data_folder_train='<PREPROCESSED_FOLDER>'
2. Specify the transform to be used on the images. It can be instances of `torchvision.transforms.Compose` or custom functions.
frames=50
3. Initialize the `data_folder` class with the database instance and all other parameters. This dataset instance is used in
the trainer class
4. Initialize the network architecture with required arguments.
5. Define the parameters for the trainer.
"""
#==============================================================================
# Initialize the bob database instance
data_folder_train= <PREPROCESSED_FOLDER>
output_base_path= <OUTPUT_PATH>
extension='.h5'
train_groups=['train'] # only 'train' group is used for training the network
protocols="grandtest-color*depth*infrared*thermal-{}".format(frames) # makeup is excluded anyway here
val_groups=['dev']
do_crossvalidation=True
#=======================
if do_crossvalidation:
phases=['train','val']
else:
phases=['train']
groups={"train":['train'],"val":['dev']}
protocols="grandtest-color-50"
exlude_attacks_list=["makeup"]
bob_hldi_instance_train = BatlPadDatabase(
bob_hldi_instance = BatlPadDatabase(
protocol=protocols,
original_directory=data_folder_train,
original_extension=extension,
landmark_detect_method="mtcnn", # detect annotations using mtcnn
exclude_attacks_list=exlude_attacks_list,
exclude_pai_all_sets=True, # exclude makeup from all the sets, which is the default behavior for grandtest protocol
exclude_pai_all_sets=True,
append_color_face_roi_annot=False)
#==============================================================================
# Initialize the torch dataset, subselect channels from the pretrained files if needed.
SELECTED_CHANNELS = [0,1,2] # selects only color, depth and infrared
SELECTED_CHANNELS = [0,1,2,3]
img_transform={}
img_transform_train = transforms.Compose([ChannelSelect(selected_channels = SELECTED_CHANNELS),transforms.ToPILImage(),transforms.RandomHorizontalFlip(),transforms.ToTensor()])# Add p=0.5 later
img_transform['train'] = transforms.Compose([ChannelSelect(selected_channels = SELECTED_CHANNELS),RandomHorizontalFlipImage(p=0.5),transforms.ToTensor()])
dataset = DataFolder(data_folder=data_folder_train,
transform=img_transform_train,
extension='.hdf5',
bob_hldi_instance=bob_hldi_instance_train,
groups=train_groups,
protocol=protocols,
purposes=['real', 'attack'],
allow_missing_files=True)
img_transform['val'] = transforms.Compose([ChannelSelect(selected_channels = SELECTED_CHANNELS),transforms.ToTensor()])
#==============================================================================
# Load the architecture
dataset={}
NUM_CHANNELS = 3
for phase in phases:
dataset[phase] = DataFolder(data_folder=data_folder_train,
transform=img_transform[phase],
extension='.hdf5',
bob_hldi_instance=bob_hldi_instance,
groups=groups[phase],
protocol=protocols,
purposes=['real', 'attack'],
allow_missing_files=True)
network=MCCNN(num_channels = NUM_CHANNELS)
#==============================================================================
# Specify other training parameters
batch_size = 64
NUM_CHANNELS = len(SELECTED_CHANNELS)
ADAPTED_LAYERS = 'conv1-block1-group1-ffc'
ADAPT_REF_CHANNEL = False
batch_size = 32
num_workers = 0
epochs=25
learning_rate=0.0001
seed = 3
output_dir = 'training_mccn'
use_gpu = False
adapted_layers = 'conv1-ffc'
adapt_reference_channel = False
adapted_layers = ADAPTED_LAYERS
adapt_reference_channel = ADAPT_REF_CHANNEL
verbose = 2
UID = "_".join([str(i) for i in SELECTED_CHANNELS])+"_"+str(ADAPT_REF_CHANNEL)+"_"+ADAPTED_LAYERS+"_"+str(NUM_CHANNELS)+"_"+protocols
training_logs= output_base_path+UID+'/train_log_dir/'
output_dir = output_base_path+UID
#==============================================================================
# Load the architecture
assert(len(SELECTED_CHANNELS)==NUM_CHANNELS)
network=MCCNN(num_channels = NUM_CHANNELS)
#==============================================================================
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment