Skip to content
Snippets Groups Projects
Commit 11003be4 authored by Anjith GEORGE's avatar Anjith GEORGE
Browse files

Modified config file

parent 5222d95d
No related branches found
No related tags found
1 merge request!18MCCNN trainer
...@@ -7,6 +7,8 @@ from bob.learn.pytorch.datasets import DataFolder ...@@ -7,6 +7,8 @@ from bob.learn.pytorch.datasets import DataFolder
from bob.pad.face.database import BatlPadDatabase from bob.pad.face.database import BatlPadDatabase
from bob.learn.pytorch.datasets import ChannelSelect
#============================================================================== #==============================================================================
# Load the dataset # Load the dataset
...@@ -39,7 +41,9 @@ protocols="grandtest-color*depth*infrared*thermal-{}".format(frames) # makeup is ...@@ -39,7 +41,9 @@ protocols="grandtest-color*depth*infrared*thermal-{}".format(frames) # makeup is
exlude_attacks_list=["makeup"] exlude_attacks_list=["makeup"]
img_transform_train = transforms.Compose([transforms.ToPILImage(),transforms.RandomHorizontalFlip(),transforms.ToTensor()])# Add p=0.5 later SELECTED_CHANNELS = [0] # selects only color, depth and infrared
img_transform_train = transforms.Compose([ChannelSelect(selected_channels = SELECTED_CHANNELS),transforms.ToPILImage(),transforms.RandomHorizontalFlip(),transforms.ToTensor()])# Add p=0.5 later
bob_hldi_instance_train = BatlPadDatabase( bob_hldi_instance_train = BatlPadDatabase(
...@@ -63,7 +67,9 @@ dataset = DataFolder(data_folder=data_folder_train, ...@@ -63,7 +67,9 @@ dataset = DataFolder(data_folder=data_folder_train,
#============================================================================== #==============================================================================
# Load the architecture # Load the architecture
network=MCCNN(num_channels=4) NUM_CHANNELS = 1
network=MCCNN(num_channels = NUM_CHANNELS)
#============================================================================== #==============================================================================
# Specify other training parameters # Specify other training parameters
......
...@@ -15,13 +15,12 @@ import time ...@@ -15,13 +15,12 @@ import time
import os import os
""" """
#TODO: #TODO:
#0. Add class balancing #0. Add class balancing - done with the weights ; import the function from utils
#1. Logging to tensorboardX or a simpler logger #1. Logging to tensorboardX or a simpler logger:
#2. Support for Validation set and validation loss #2. Support for Validation set and validation loss
#3. Use to(device) instead of .cuda()? #3. Use to(device) instead of .cuda()?: Eventually migrate everything to this!
#4. Functionality to select channels from the dataloader: may be move this to the datafolder class #4. Moving more arguments to config?
#5. Moving more arguments to config? #5. Implement the selection of channels as a transform: Added now ChannelSelect
#6. Implement the selection of channels as a transform
""" """
def comp_bce_loss_weights(target): def comp_bce_loss_weights(target):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment