Skip to content
Snippets Groups Projects
Commit d43f9f16 authored by Anjith GEORGE's avatar Anjith GEORGE
Browse files

Generic WIP

parent ffd5ad52
Branches
Tags
1 merge request!32Generic trainer cleaned
from .casia_webface import CasiaDataset from .casia_webface import CasiaDataset
from .casia_webface import CasiaWebFaceDataset from .casia_webface import CasiaWebFaceDataset
from .data_folder import DataFolder from .data_folder import DataFolder
from .data_folder_generic import DataFolderGeneric
# transforms # transforms
from .utils import FaceCropper from .utils import FaceCropper
......
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
@author: Olegs Nikisins
"""
#==============================================================================
# Import what is needed here:
import torch.utils.data as data
import os
import random
random.seed( a = 7 )
import numpy as np
from torchvision import transforms
import h5py
#==============================================================================
def get_file_names_and_labels(files, data_folder, extension = ".hdf5", hldi_type = "pad"):
"""
Get absolute names of the corresponding file objects and their class labels,
as well as keys defining name of the frame to load the data from.
Attributes
----------
files : [File]
A list of files objects defined in the High Level Database Interface
of the particular database.
data_folder : str
A directory containing the training data.
extension : str
Extension of the data files. Default: ".hdf5" .
hldi_type : str
Type of the high level database interface. Default: "pad".
Note: this is the only type supported at the moment.
Returns
-------
file_names_labels_keys : [(str, int, str)]
A list of tuples, where each tuple contain an absolute filename,
a corresponding label of the class, and a key defining the name of the
frame to extract the data from.
"""
file_names_labels_keys = []
if hldi_type == "pad":
for f in files:
if f.attack_type is None:
label = 1
else:
label = 0
file_name = os.path.join(data_folder, f.path + extension)
if os.path.isfile(file_name): # if file is available:
with h5py.File(file_name, "r") as f_h5py:
file_keys = list(f_h5py.keys())
#removes the 'FrameIndexes' key
file_keys=[f for f in file_keys if f!='FrameIndexes' ]
# elements of tuples in the below list are as follows:
# a filename a key is extracted from,
# a label corresponding to the file,
# a key defining a frame from the file.
file_names_labels_keys = file_names_labels_keys + [(file_name, label, key) for file_name, label, key in
zip([file_name]*len(file_keys), [label]*len(file_keys), file_keys)]
return file_names_labels_keys
#==============================================================================
class DataFolderGeneric(data.Dataset):
"""
A generic data loader compatible with Bob High Level Database Interfaces
(HLDI). Only HLDI's of ``bob.pad.face`` are currently supported.
The basic functionality is composed of two steps: load the data from hdf5
file, and transform it using user defined transformation function.
Two types of user defined transformations are supported:
1. An instance of ``Compose`` transformation class from ``torchvision``
package.
2. A custom transformation function, which takes numpy.ndarray as input,
and returns a transformed Tensor. The dimensionality of the output tensor
must match the format expected by the network to be trained.
Note: if no special transformation is needed, the ``transform``
must at least convert an input numpy array to Tensor.
Attributes
----------
data_folder : str
A directory containing the training data. Note, that the training data
must be stored as a FrameContainers written to the hdf5 files. Other
formats are currently not supported.
transform : object
A function ``transform`` takes an input numpy.ndarray sample/image,
and returns a transformed version as a Tensor. Default: None.
extension : str
Extension of the data files. Default: ".hdf5".
Note: this is the only extension supported at the moment.
bob_hldi_instance : object
An instance of the HLDI interface. Only HLDI's of bob.pad.face
are currently supported.
hldi_type : str
String defining the type of the HLDI. Default: "pad".
Note: this is the only option currently supported.
groups : str or [str]
The groups for which the clients should be returned.
Usually, groups are one or more elements of ['train', 'dev', 'eval'].
Default: ['train', 'dev', 'eval'].
protocol : str
The protocol for which the clients should be retrieved.
Default: 'grandtest'.
purposes : str or [str]
The purposes for which File objects should be retrieved.
Usually it is either 'real' or 'attack'.
Default: ['real', 'attack'].
allow_missing_files : bool
The missing files in the ``data_folder`` will not break the
execution if set to True.
Default: True.
"""
def __init__(self, data_folder,
transform = None,
extension = '.hdf5',
bob_hldi_instance = None,
hldi_type = "pad",
groups = ['train', 'dev', 'eval'],
protocol = 'grandtest',
purposes=['real', 'attack'],
allow_missing_files = True,custom_func=None,
**kwargs):
"""
Attributes
----------
data_folder : str
A directory containing the training data.
transform : object
A function ``transform`` takes an input numpy.ndarray sample/image,
and returns a transformed version as a Tensor. Default: None.
extension : str
Extension of the data files. Default: ".hdf5".
Note: this is the only extension supported at the moment.
bob_hldi_instance : object
An instance of the HLDI interface. Only HLDI's of bob.pad.face
are currently supported.
hldi_type : str
String defining the type of the HLDI. Default: "pad".
Note: this is the only option currently supported.
groups : str or [str]
The groups for which the clients should be returned.
Usually, groups are one or more elements of ['train', 'dev', 'eval'].
Default: ['train', 'dev', 'eval'].
protocol : str
The protocol for which the clients should be retrieved.
Default: 'grandtest'.
purposes : str or [str]
The purposes for which File objects should be retrieved.
Usually it is either 'real' or 'attack'.
Default: ['real', 'attack'].
allow_missing_files : bool
The missing files in the ``data_folder`` will not break the
execution if set to True.
Default: True.
"""
self.data_folder = data_folder
self.transform = transform
self.extension = extension
self.bob_hldi_instance = bob_hldi_instance
self.hldi_type = hldi_type
self.groups = groups
self.protocol = protocol
self.purposes = purposes
self.allow_missing_files = allow_missing_files
self.custom_func = custom_func
if bob_hldi_instance is not None:
files = bob_hldi_instance.objects(groups = self.groups,
protocol = self.protocol,
purposes = self.purposes,
**kwargs)
file_names_labels_keys = get_file_names_and_labels(files = files,
data_folder = self.data_folder,
extension = self.extension,
hldi_type = self.hldi_type)
if self.allow_missing_files: # return only existing files
file_names_labels_keys = [f for f in file_names_labels_keys if os.path.isfile(f[0])]
else:
# TODO - add behaviour similar to image folder
file_names_labels_keys = []
self.file_names_labels_keys = file_names_labels_keys
#==========================================================================
def __getitem__(self, index):
"""
Returns a **transformed** sample/image and a target class, given index.
Two types of transformations are handled, see the doc-string of the
class.
Attributes
----------
index : int
An index of the sample to return.
Returns
-------
np_img : Tensor
Transformed sample.
target : int
Index of the class.
"""
path, target, key = self.file_names_labels_keys[index]
with h5py.File(path, "r") as f_h5py:
img_array = np.array(f_h5py.get(key+'/array')) # The size now is (3 x W x H)
if isinstance(self.transform, transforms.Compose): # if an instance of torchvision composed transformation
if len(img_array.shape) == 3: # for color or multi-channel images
img_array_tr = np.swapaxes(img_array, 1, 2)
img_array_tr = np.swapaxes(img_array_tr, 0, 2)
np_img =img_array_tr.copy() # np_img is numpy.ndarray of shape HxWxC
else: # for gray-scale images
np_img=np.expand_dims(img_array_tr,2) # np_img is numpy.ndarray of size HxWx1
if self.transform is not None:
np_img = self.transform(np_img) # after this transformation np_img should be a tensor
else: # if custom transformation function is given
img_array_transformed = self.transform(img_array)
return img_array_transformed, target
# NOTE: make sure ``img_array_transformed`` converted to Tensor in your custom ``transform`` function.
if self.custom_func is not None: # custom function to change the return to something else
return self.custom_func(np_img,target)
return np_img, target
#==========================================================================
def __len__(self):
"""
Returns
-------
len : int
The length of the file list.
"""
return len(self.file_names_labels_keys)
#!/usr/bin/env python
# encoding: utf-8
""" Train a Generic Net
Usage:
%(prog)s <configuration>
[--model=<string>] [--batch-size=<int>] [--num-workers=<int>][--epochs=<int>] [--save-interval=<int>]
[--learning-rate=<float>][--do-crossvalidation][--seed=<int>]
[--output-dir=<path>] [--use-gpu] [--verbose ...]
Arguments:
<configuration> A configuration file, defining the dataset and the network
Options:
-h, --help Shows this help message and exits
--model=<string> Filename of the model to load (if any).
--batch-size=<int> Batch size [default: 64]
--num-workers=<int> Number subprocesses to use for data loading [default: 0]
--epochs=<int> Number of training epochs [default: 20]
--save-interval=<int> Interval between saving epochs [default: 5]
--learning-rate=<float> Learning rate [default: 0.01]
--do-crossvalidation Whether to perform cross validation [default: False]
-S, --seed=<int> The random seed [default: 3]
-o, --output-dir=<path> Dir to save stuff [default: training]
-g, --use-gpu Use the GPU
-v, --verbose Increase the verbosity (may appear multiple times).
Note that arguments provided directly by command-line will override the ones in the configuration file.
Example:
To run the training process
$ %(prog)s config.py
See '%(prog)s --help' for more information.
"""
import os, sys
import pkg_resources
import torch
import numpy
from docopt import docopt
import bob.core
logger = bob.core.log.setup("bob.learn.pytorch")
from bob.extension.config import load
from bob.learn.pytorch.trainers import GenericTrainer
from bob.learn.pytorch.utils import get_parameter
version = pkg_resources.require('bob.learn.pytorch')[0].version
def main(user_input=None):
# Parse the command-line arguments
if user_input is not None:
arguments = user_input
else:
arguments = sys.argv[1:]
prog = os.path.basename(sys.argv[0])
completions = dict(prog=prog, version=version,)
args = docopt(__doc__ % completions,argv=arguments,version='Train a Generic Network (%s)' % version,)
# load configuration file
configuration = load([os.path.join(args['<configuration>'])])
# get the pre-trained model file, if any
model = args['--model']
if hasattr(configuration, 'model'):
model = configuration.model
# get various parameters, either from config file or command-line
batch_size = get_parameter(args, configuration, 'batch_size', 64)
num_workers = get_parameter(args, configuration, 'num_workers', 0)
epochs = get_parameter(args, configuration, 'epochs', 20)
save_interval = get_parameter(args, configuration, 'save_interval', 5)
learning_rate = get_parameter(args, configuration, 'learning_rate', 0.01)
seed = get_parameter(args, configuration, 'seed', 3)
output_dir = get_parameter(args, configuration, 'output_dir', 'training')
use_gpu = get_parameter(args, configuration, 'use_gpu', False)
verbosity_level = get_parameter(args, configuration, 'verbose', 0)
do_crossvalidation = get_parameter(args, configuration, 'do_crossvalidation', False)
bob.core.log.set_verbosity_level(logger, verbosity_level)
bob.io.base.create_directories_safe(output_dir)
# print parameters
logger.debug("Model file = {}".format(model))
logger.debug("Batch size = {}".format(batch_size))
logger.debug("Num workers = {}".format(num_workers))
logger.debug("Epochs = {}".format(epochs))
logger.debug("Save interval = {}".format(save_interval))
logger.debug("Learning rate = {}".format(learning_rate))
logger.debug("Seed = {}".format(seed))
logger.debug("Output directory = {}".format(output_dir))
logger.debug("Use GPU = {}".format(use_gpu))
logger.debug("Perform cross validation = {}".format(do_crossvalidation))
# use new interface
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# process on the arguments / options
torch.manual_seed(seed)
if use_gpu:
torch.cuda.manual_seed_all(seed)
if torch.cuda.is_available() and not use_gpu:
device="cpu"
logger.warn("You have a CUDA device, so you should probably run with --use-gpu")
logger.debug("Device used for training = {}".format(device))
# Which device to use is figured out at this point, no need to use `use-gpu` flag anymore
# get data
if hasattr(configuration, 'dataset'):
dataloader={}
if not do_crossvalidation:
logger.info("There are {} training samples".format(len(configuration.dataset['train'])))
dataloader['train'] = torch.utils.data.DataLoader(configuration.dataset['train'], batch_size=batch_size, num_workers=num_workers, shuffle=True)
else:
dataloader['train'] = torch.utils.data.DataLoader(configuration.dataset['train'], batch_size=batch_size, num_workers=num_workers, shuffle=True)
dataloader['val'] = torch.utils.data.DataLoader(configuration.dataset['val'], batch_size=batch_size, num_workers=num_workers, shuffle=True)
logger.info("There are {} training samples".format(len(configuration.dataset['train'])))
logger.info("There are {} validation samples".format(len(configuration.dataset['val'])))
else:
logger.error("Please provide a dataset in your configuration file !")
sys.exit()
assert(hasattr(configuration, 'optimizer'))
# train the network
if hasattr(configuration, 'network'):
trainer = GenericTrainer(configuration.network, configuration.optimizer,configuration.compute_loss,learning_rate=learning_rate, batch_size=batch_size, device=device, verbosity_level=verbosity_level,tf_logdir=output_dir+'/tf_logs',do_crossvalidation=do_crossvalidation, save_interval=save_interval)
trainer.train(dataloader, n_epochs=epochs, output_dir=output_dir, model=model)
else:
logger.error("Please provide a network in your configuration file !")
sys.exit()
...@@ -233,6 +233,29 @@ def test_CNNtrainer(): ...@@ -233,6 +233,29 @@ def test_CNNtrainer():
os.remove('model_1_0.pth') os.remove('model_1_0.pth')
def test_Generictrainer():
from ..architectures import LightCNN9
net = LightCNN9(20)
dataloader={}
dataloader['train'] = torch.utils.data.DataLoader(DummyDataSet(), batch_size=32, shuffle=True)
dataloader['val']= torch.utils.data.DataLoader(DummyDataSet(), batch_size=32, shuffle=True)
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()),lr = 0.1)
criterion = torch.nn.BCELoss()
from ..trainers import GenericTrainer
trainer = GenericTrainer(net, verbosity_level=3)
trainer.train(dataloader, n_epochs=1, output_dir='.')
import os
assert os.path.isfile('model_1_0.pth')
os.remove('model_1_0.pth')
class DummyDataSetMCCNN(Dataset): class DummyDataSetMCCNN(Dataset):
def __init__(self): def __init__(self):
pass pass
......
#!/usr/bin/env python
# encoding: utf-8
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from .tflog import Logger
import bob.core
logger = bob.core.log.setup("bob.learn.pytorch")
import time
import os
import copy
class GenericTrainer(object):
"""
Class to train a generic NN; all the parameters are provided in configs
Attributes
----------
network: :py:class:`torch.nn.Module`
The network to train
batch_size: int
The size of your minibatch
optimizer: :py:class:`torch.optim.Optimizer`
Optimizer object to be used. Initialized in the config file.
device: str
Device which will be used for training the model
verbosity_level: int
The level of verbosity output to stdout
"""
def __init__(self, network, optimizer, compute_loss, learning_rate=0.0001, batch_size=64, device='cpu', verbosity_level=2, tf_logdir='tf_logs',do_crossvalidation=False, save_interval=5):
""" Init function . The layers to be adapted in the network is selected and the gradients are set to `True`
for the layers which needs to be adapted.
Parameters
----------
network: :py:class:`torch.nn.Module`
The network to train
batch_size: int
The size of your minibatch
device: str
Device which will be used for training the model
verbosity_level: int
The level of verbosity output to stdout
do_crossvalidation: bool
If set to `True`, performs validation in each epoch and stores the best model based on validation loss.
"""
self.network = network
self.batch_size = batch_size
self.optimizer=optimizer
self.compute_loss=compute_loss
self.device = device
self.learning_rate=learning_rate
self.save_interval=save_interval
self.do_crossvalidation=do_crossvalidation
if self.do_crossvalidation:
phases=['train','val']
else:
phases=['train']
self.phases=phases
# Move the network to device
self.network.to(self.device)
bob.core.log.set_verbosity_level(logger, verbosity_level)
self.tf_logger = Logger(tf_logdir)
# Setting the gradients to true for the layers which needs to be adapted
def load_model(self, model_filename):
"""Loads an existing model
Parameters
----------
model_file: str
The filename of the model to load
Returns
-------
start_epoch: int
The epoch to start with
start_iteration: int
The iteration to start with
losses: list(float)
The list of losses from previous training
"""
cp = torch.load(model_filename)
self.network.load_state_dict(cp['state_dict'])
start_epoch = cp['epoch']
start_iter = cp['iteration']
losses = cp['loss']
return start_epoch, start_iter, losses
def save_model(self, output_dir, epoch=0, iteration=0, losses=None):
"""Save the trained network
Parameters
----------
output_dir: str
The directory to write the models to
epoch: int
the current epoch
iteration: int
the current (last) iteration
losses: list(float)
The list of losses since the beginning of training
"""
saved_filename = 'model_{}_{}.pth'.format(epoch, iteration)
saved_path = os.path.join(output_dir, saved_filename)
logger.info('Saving model to {}'.format(saved_path))
cp = {'epoch': epoch,
'iteration': iteration,
'loss': losses,
'state_dict': self.network.cpu().state_dict()
}
torch.save(cp, saved_path)
self.network.to(self.device)
def train(self, dataloader, n_epochs=25, output_dir='out', model=None):
"""Performs the training.
Parameters
----------
dataloader: :py:class:`torch.utils.data.DataLoader`
The dataloader for your data
n_epochs: int
The number of epochs you would like to train for
learning_rate: float
The learning rate for Adam optimizer.
output_dir: str
The directory where you would like to save models
model: str
The path to a pretrained model file to start training from; this is the PAD model; not the LightCNN model
"""
# if model exists, load it
if model is not None:
start_epoch, start_iter, losses = self.load_model(model)
logger.info('Starting training at epoch {}, iteration {} - last loss value is {}'.format(start_epoch, start_iter, losses[-1]))
else:
start_epoch = 0
start_iter = 0
losses = []
logger.info('Starting training from scratch')
for name, param in self.network.named_parameters():
if param.requires_grad == True:
logger.info('Layer to be adapted from grad check : {}'.format(name))
# setup optimizer
self.network.train(True)
best_model_wts = copy.deepcopy(self.network.state_dict())
best_loss = float("inf")
# let's go
for epoch in range(start_epoch, n_epochs):
# in the epoch
train_loss_history=[]
val_loss_history = []
for phase in self.phases:
if phase == 'train':
self.network.train() # Set model to training mode
else:
self.network.eval() # Set model to evaluate mode
for i, data in enumerate(dataloader[phase], 0):
if i >= start_iter:
start = time.time()
# get data from dataset
img, labels = data
self.optimizer.zero_grad()
with torch.set_grad_enabled(phase == 'train'):
loss = self.compute_loss(self.network, img, labels, self.device)
if phase == 'train':
loss.backward()
self.optimizer.step()
train_loss_history.append(loss.item())
else:
val_loss_history.append(loss.item())
end = time.time()
logger.info("[{}/{}][{}/{}] => Loss = {} (time spent: {}), Phase {}".format(epoch, n_epochs, i, len(dataloader[phase]), loss.item(), (end-start),phase))
losses.append(loss.item())
epoch_train_loss=np.mean(train_loss_history)
logger.info("Train Loss : {} epoch : {}".format(epoch_train_loss,epoch))
if self.do_crossvalidation:
epoch_val_loss=np.mean(val_loss_history)
logger.info("Val Loss : {} epoch : {}".format(epoch_val_loss,epoch))
if phase == 'val' and epoch_val_loss < best_loss:
logger.debug("New val loss : {} is better than old: {}, copying over the new weights".format(epoch_val_loss,best_loss))
best_loss = epoch_val_loss
best_model_wts = copy.deepcopy(self.network.state_dict())
######################################## <Logging> ###################################
if self.do_crossvalidation:
info = {'train_loss':epoch_train_loss,'val_loss':epoch_val_loss}
else:
info = {'train_loss':epoch_train_loss}
# scalar logs
for tag, value in info.items():
self.tf_logger.scalar_summary(tag, value, epoch+1)
# Log values and gradients of the parameters (histogram summary)
for tag, value in self.network.named_parameters():
tag = tag.replace('.', '/')
try:
self.tf_logger.histo_summary(tag, value.data.cpu().numpy(), epoch+1)
self.tf_logger.histo_summary(tag+'/grad', value.grad.data.cpu().numpy(), epoch+1)
except:
pass
######################################## </Logging> ###################################
# do stuff - like saving models
logger.info("EPOCH {} DONE".format(epoch+1))
# comment it out after debugging
if (epoch+1)==n_epochs or epoch%self.save_interval==0: # save the last model, and the ones in the specified interval
self.save_model(output_dir, epoch=(epoch+1), iteration=0, losses=losses)
## load the best weights
self.network.load_state_dict(best_model_wts)
# best epoch is 0
self.save_model(output_dir, epoch=0, iteration=0, losses=losses)
...@@ -3,6 +3,8 @@ from .MCCNNTrainer import MCCNNTrainer ...@@ -3,6 +3,8 @@ from .MCCNNTrainer import MCCNNTrainer
from .DCGANTrainer import DCGANTrainer from .DCGANTrainer import DCGANTrainer
from .ConditionalGANTrainer import ConditionalGANTrainer from .ConditionalGANTrainer import ConditionalGANTrainer
from .FASNetTrainer import FASNetTrainer from .FASNetTrainer import FASNetTrainer
from .GenericTrainer import GenericTrainer
from .GenericTrainer import GenericTrainer
# gets sphinx autodoc done right - don't remove it # gets sphinx autodoc done right - don't remove it
__all__ = [_ for _ in dir() if not _.startswith('_')] __all__ = [_ for _ in dir() if not _.startswith('_')]
......
...@@ -71,6 +71,7 @@ setup( ...@@ -71,6 +71,7 @@ setup(
'console_scripts' : [ 'console_scripts' : [
'train_cnn.py = bob.learn.pytorch.scripts.train_cnn:main', 'train_cnn.py = bob.learn.pytorch.scripts.train_cnn:main',
'train_mccnn.py = bob.learn.pytorch.scripts.train_mccnn:main', 'train_mccnn.py = bob.learn.pytorch.scripts.train_mccnn:main',
'train_generic.py = bob.learn.pytorch.scripts.train_generic:main',
'train_fasnet.py = bob.learn.pytorch.scripts.train_fasnet:main', 'train_fasnet.py = bob.learn.pytorch.scripts.train_fasnet:main',
'train_dcgan.py = bob.learn.pytorch.scripts.train_dcgan:main', 'train_dcgan.py = bob.learn.pytorch.scripts.train_dcgan:main',
'train_conditionalgan.py = bob.learn.pytorch.scripts.train_conditionalgan:main', 'train_conditionalgan.py = bob.learn.pytorch.scripts.train_conditionalgan:main',
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment