Skip to content
Snippets Groups Projects
Commit 9ae5047d authored by Anjith GEORGE's avatar Anjith GEORGE
Browse files

Clean up code

parent 884dec6d
No related branches found
No related tags found
1 merge request!22Cross validation
......@@ -36,7 +36,7 @@ from bob.learn.pytorch.datasets import ChannelSelect, RandomHorizontalFlipImage
data_folder_train='/idiap/temp/ageorge/WMCA/preprocessed/'
output_base_path='/idiap/temp/ageorge/Pytorch_WMCA/MCCNNv1_new_2/'
output_base_path='/idiap/temp/ageorge/Pytorch_WMCA/MCCNNv1/'
unseen_protocols=['','-LOO_fakehead','-LOO_flexiblemask','-LOO_glasses','-LOO_papermask','-LOO_prints','-LOO_replay','-LOO_rigidmask']
......
from torchvision import transforms
from bob.learn.pytorch.architectures import MCCNN
from bob.learn.pytorch.datasets import DataFolder
from bob.pad.face.database import BatlPadDatabase
from bob.learn.pytorch.datasets import ChannelSelect, RandomHorizontalFlipImage
#==============================================================================
# Load the dataset
""" The steps are as follows
1. Initialize a databae instance, with the protocol, groups and number of frames
(currently for the ones in 'bob.pad.face', and point 'data_folder_train' to the preprocessed directory )
Note: Here we assume that we have already preprocessed the with `spoof.py` script and dumped it to location
pointed to by 'data_folder_train'.
2. Specify the transform to be used on the images. It can be instances of `torchvision.transforms.Compose` or custom functions.
3. Initialize the `data_folder` class with the database instance and all other parameters. This dataset instance is used in
the trainer class
4. Initialize the network architecture with required arguments.
5. Define the parameters for the trainer.
"""
#==============================================================================
# Initialize the bob database instance
data_folder_train='/idiap/temp/ageorge/WMCA/preprocessed/'
output_base_path='/idiap/temp/ageorge/Pytorch_WMCA/MCCNNv1_CV_1/'
unseen_protocols=['','-LOO_fakehead','-LOO_flexiblemask','-LOO_glasses','-LOO_papermask','-LOO_prints','-LOO_replay','-LOO_rigidmask']
PROTOCOL_INDEX=0
####################################################################
frames=50
extension='.h5'
train_groups=['train'] # only 'train' group is used for training the network
val_groups=['dev']
do_crossvalidation=True
####################################################################
if do_crossvalidation:
phases=['train','val']
else:
phases=['train']
groups={"train":['train'],"val":['dev']}
protocols="grandtest-color-50"+unseen_protocols[PROTOCOL_INDEX] # makeup is excluded anyway here
exlude_attacks_list=["makeup"]
bob_hldi_instance = BatlPadDatabase(
protocol=protocols,
original_directory=data_folder_train,
original_extension=extension,
landmark_detect_method="mtcnn", # detect annotations using mtcnn
exclude_attacks_list=exlude_attacks_list,
exclude_pai_all_sets=True, # exclude makeup from all the sets, which is the default behavior for grandtest protocol
append_color_face_roi_annot=False)
#==============================================================================
# Initialize the torch dataset, subselect channels from the pretrained files if needed.
SELECTED_CHANNELS = [0,1,2,3]
####################################################################
img_transform={}
img_transform['train'] = transforms.Compose([ChannelSelect(selected_channels = SELECTED_CHANNELS),RandomHorizontalFlipImage(p=0.5),transforms.ToTensor()])
img_transform['val'] = transforms.Compose([ChannelSelect(selected_channels = SELECTED_CHANNELS),transforms.ToTensor()])
dataset={}
for phase in phases:
dataset[phase] = DataFolder(data_folder=data_folder_train,
transform=img_transform[phase],
extension='.hdf5',
bob_hldi_instance=bob_hldi_instance,
groups=groups[phase],
protocol=protocols,
purposes=['real', 'attack'],
allow_missing_files=True)
#==============================================================================
# Specify other training parameters
NUM_CHANNELS = len(SELECTED_CHANNELS)
ADAPTED_LAYERS = 'conv1-block1-ffc'
####################################################################
ADAPT_REF_CHANNEL = False
####################################################################
batch_size = 32
num_workers = 0
epochs=25
learning_rate=0.0001
seed = 3
use_gpu = False
adapted_layers = ADAPTED_LAYERS
adapt_reference_channel = ADAPT_REF_CHANNEL
verbose = 2
UID = "_".join([str(i) for i in SELECTED_CHANNELS])+"_"+str(ADAPT_REF_CHANNEL)+"_"+ADAPTED_LAYERS+"_"+str(NUM_CHANNELS)+"_"+protocols
training_logs= output_base_path+UID+'/train_log_dir/'
output_dir = output_base_path+UID
#==============================================================================
# Load the architecture
assert(len(SELECTED_CHANNELS)==NUM_CHANNELS)
network=MCCNN(num_channels = NUM_CHANNELS)
#==============================================================================
"""
Note: Running in GPU
jman submit --queue gpu \
--name mccnnv2 \
--log-dir /idiap/temp/ageorge/Pytorch_WMCA/MCCNNv2/logs/ \
--environment="PYTHONUNBUFFERED=1" -- \
./bin/train_mccnn.py \
/idiap/user/ageorge/WORK/COMMON_ENV_PAD_BATL_DB/src/bob.learn.pytorch/bob/learn/pytorch/config/mccnn/wmca_mccnn.py --use-gpu -vvv
Note: Running in cpu
./bin/train_mccnn.py \
/idiap/user/ageorge/WORK/COMMON_ENV_PAD_BATL_DB/src/bob.learn.pytorch/bob/learn/pytorch/config/mccnn/wmca_mccnn.py -vvv
"""
......@@ -54,7 +54,8 @@ class FASNetTrainer(object):
is not adapted, so that it can be used for Face recognition as well, default: `False`.
verbosity_level: int
The level of verbosity output to stdout
do_crossvalidation: bool
If set to `True`, performs validation in each epoch and stores the best model based on validation loss.
"""
self.network = network
self.batch_size = batch_size
......@@ -189,15 +190,12 @@ class FASNetTrainer(object):
# let's go
for epoch in range(start_epoch, n_epochs):
# in the epoch
train_loss_history=[]
val_loss_history = []
for phase in self.phases:
if phase == 'train':
......@@ -296,14 +294,6 @@ class FASNetTrainer(object):
except:
pass
# Log images
# logimg=img.view(-1,img.size()[1]*224, 128)[:10].cpu().numpy()
# info = { 'images': logimg}
# for tag, images in info.items():
# self.tf_logger.image_summary(tag, images, epoch+1)
######################################## </Logging> ###################################
......@@ -311,10 +301,8 @@ class FASNetTrainer(object):
logger.info("EPOCH {} DONE".format(epoch+1))
# comment it out after debugging
if epoch>=24:
self.save_model(output_dir, epoch=(epoch+1), iteration=0, losses=losses)
self.save_model(output_dir, epoch=(epoch+1), iteration=0, losses=losses)
## load the best weights
......
......@@ -54,6 +54,8 @@ class MCCNNTrainer(object):
is not adapted, so that it can be used for Face recognition as well, default: `False`.
verbosity_level: int
The level of verbosity output to stdout
do_crossvalidation: bool
If set to `True`, performs validation in each epoch and stores the best model based on validation loss.
"""
self.network = network
......@@ -220,7 +222,6 @@ class MCCNNTrainer(object):
# let's go
for epoch in range(start_epoch, n_epochs):
# in the epoch
......@@ -228,9 +229,6 @@ class MCCNNTrainer(object):
val_loss_history = []
for phase in self.phases:
if phase == 'train':
......@@ -238,7 +236,6 @@ class MCCNNTrainer(object):
else:
self.network.eval() # Set model to evaluate mode
for i, data in enumerate(dataloader[phase], 0):
......@@ -265,7 +262,6 @@ class MCCNNTrainer(object):
# weights for samples, should help with data imbalance
self.criterion.weight = weights
optimizer.zero_grad()
with torch.set_grad_enabled(phase == 'train'):
......@@ -274,10 +270,9 @@ class MCCNNTrainer(object):
loss = self.criterion(output, labelsv)
if phase == 'train':
loss.backward()
optimizer.step()
train_loss_history.append(loss.item())
train_loss_history.append(loss.item())
else:
val_loss_history.append(loss.item())
......@@ -305,11 +300,8 @@ class MCCNNTrainer(object):
best_model_wts = copy.deepcopy(self.network.state_dict())
######################################## <Logging> ###################################
if self.do_crossvalidation:
info = {'train_loss':epoch_train_loss,'val_loss':epoch_val_loss}
else:
info = {'train_loss':epoch_train_loss}
......@@ -343,16 +335,12 @@ class MCCNNTrainer(object):
# do stuff - like saving models
logger.info("EPOCH {} DONE".format(epoch+1))
# comment it out after debugging
if epoch>=24:
self.save_model(output_dir, epoch=(epoch+1), iteration=0, losses=losses)
self.save_model(output_dir, epoch=(epoch+1), iteration=0, losses=losses)
## load the best weights
self.network.load_state_dict(best_model_wts)
# best epoch is 100
# best epoch is named as best
self.save_model(output_dir, epoch=100, iteration=0, losses=losses)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment