Skip to content
Snippets Groups Projects

autoencoders pretraining using RGB faces

Merged Olegs NIKISINS requested to merge autoencoder_pretrain into master
Files
3
+ 496
0
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
The following steps are performed in this script:
1. The command line arguments are first parsed.
2. Folder to save the results to is created.
3. Configuration file specifying the Network and learning parameters is
loaded.
4. A generic data loader compatible with Bob High Level Database
Interfaces, namely DataFolder, is initialized.
5. The Network is initialized, can also be initialized with pre-trained
model.
6. The training is performed. Verbosity flag can be used to see and save
training related outputs. See ``process_verbosity`` function for
more details.
7. The model is saved after each 1 epochs.
@author: Olegs Nikisins
"""
#==============================================================================
# Import here:
import argparse
import importlib
import os
from bob.pad.face.database.pytorch import DataFolder
import torch
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torchvision.utils import save_image
import logging
logger = logging.getLogger("bob.learn.pytorch")
import numpy as np
import time
#==============================================================================
def parse_arguments(cmd_params=None):
"""
Parse command line arguments.
**Parameters:**
``cmd_params``: []
An optional list of command line arguments. Default: None.
**Returns:**
``data_folder``: py:class:`string`
A directory containing the training data.
``save_folder``: py:class:`string`
A directory to save the results of training to.
``relative_mod_name``: py:class:`string`
Relative name of the module to import configurations from.
``config_group``: py:class:`string`
Group/package name containing the configuration file.
``pretrained_model_path``: py:class:`string`
Absolute name of the file, containing pre-trained Network
model, to de used for Network initialization before training.
``cross_validate``: bool
Cross-validate the current model on the dev set of the database used
for training. Cross validation is done after each training epoch, using
entire development set of the database.
``verbosity``: py:class:`int`
Verbosity level.
"""
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("data_folder", type=str,
help="A directory containing the training data.")
parser.add_argument("save_folder", type=str,
help="A directory to save the results of training to.")
parser.add_argument("-c", "--config-file", type=str, help="Relative name of the config file defining "
"the network, training data, and training parameters.",
default = "autoencoder/autoencoder_config.py")
parser.add_argument("-cg", "--config-group", type=str, help="Name of the group, where config file is stored.",
default = "bob.pad.face.config.pytorch")
parser.add_argument("-p", "--pretrained-model-path", type=str, help="Absolute name of the file, containing pre-trained Network "
"model, to de used for Network initialization before training.",
default = "")
parser.add_argument("-cv", "--cross-validate", action="store_true",
help="Cross validate the current model on the dev set of the database.", default = False)
parser.add_argument("-gpu", "--use-gpu", action="store_true",
help="Use the GPU for model training, if GPU is available in your system.", default = False)
parser.add_argument("-v", "--verbosity", action="count", default=0,
help="Increase output verbosity. For -v loss is printed. For -vv output images are saved.")
if cmd_params is not None:
args = parser.parse_args(cmd_params)
else:
args = parser.parse_args()
data_folder = args.data_folder
save_folder = args.save_folder
config_file = args.config_file
config_group = args.config_group
pretrained_model_path = args.pretrained_model_path
cross_validate = args.cross_validate
use_gpu = args.use_gpu
verbosity = args.verbosity
relative_mod_name = '.' + os.path.splitext(config_file)[0].replace(os.path.sep, '.')
return data_folder, save_folder, relative_mod_name, config_group, pretrained_model_path, cross_validate, use_gpu, verbosity
#==============================================================================
def to_img(batch):
"""
Normalize the images in the batch to [0, 1] range for plotting.
**Parameters:**
``batch`` : Tensor
A tensor containing a batch of images.
The size of the tensor: (num_imgs x num_color_channels x H x W).
**Returns:**
``batch`` : Tensor
A tensor containing a normalized batch of images.
The size of the tensor: (num_imgs x num_color_channels x H x W).
"""
batch = (batch - batch.min())
batch = batch / batch.max()
batch = batch.clamp(0, 1)
return batch
#==============================================================================
def process_verbosity(verbosity,
epoch,
num_epochs,
loss_value,
epoch_step,
batch_tensor,
save_folder):
"""
Report results based on the verbose level.
1. If verbosity level is 1: loss is printed for each epoch.
2. If verbosity levle is greater than 1: both loss is printed and
a reconstructed image is saved efter each ``epoch_step`` epochs.
**Parameters:**
``verbosity``: py:class:`int`
Verbosity level.
``epoch``: py:class:`int`
Current epoch number.
``num_epochs``: py:class:`int`
Total number of epochs.
``loss_value``: py:class:`float`
Loss value for the current epoch.
``epoch_step``: py:class:`int`
Plot the images after each ``epoch_step`` epochs.
``batch_tensor`` : Tensor
A tensor containing a batch of NN output images.
The size of the tensor: (num_imgs x num_color_channels x H x W).
``save_folder``: py:class:`str`
Folder to save images to.
"""
if verbosity > 0:
logger.info ('epoch [{}/{}], loss:{:.6f}'.format(epoch, num_epochs, loss_value))
if verbosity > 1:
if epoch % epoch_step == 0:
pic = to_img(batch_tensor)
save_image( pic, os.path.join(save_folder, 'image_{}.png'.format(epoch)) )
#==============================================================================
def main(cmd_params=None):
"""
The following steps are performed in this function:
1. The command line arguments are first parsed.
2. Folder to save the results to is created.
3. Configuration file specifying the Network and learning parameters is
loaded.
4. A generic data loader compatible with Bob High Level Database
Interfaces, namely DataFolder, is initialized.
5. The Network is initialized, can also be initialized with pre-trained
model.
6. The training is performed. Verbosity flag can be used to see and save
training related outputs. See ``process_verbosity`` function for
more details.
7. The model is saved after each 1 epochs.
"""
epoch_step = 1 # save images and trained model after each ``epoch_step`` epoch
data_folder, save_folder, relative_mod_name, config_group, pretrained_model_path, cross_validate, use_gpu, verbosity = \
parse_arguments(cmd_params = cmd_params)
if not os.path.exists(save_folder):
os.mkdir(save_folder)
config_module = importlib.import_module(relative_mod_name, config_group)
# =========================================================================
# handle the GPU usage in the training:
if use_gpu: # if GPU usage is enabled by the user
# check if GPU is available in the system:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if verbosity > 0:
logger.info ("The number of GPUs available in the system and used for training: {}".format(
torch.cuda.device_count()))
else:
device = torch.device("cpu")
# =========================================================================
# Handle the "dataset" initialization:
if "dataset" in dir(config_module): # if dataset is initialized in config_module use it
dataset = config_module.dataset
else: # otherwise initialize the dataset
dataset_kwargs = config_module.kwargs
dataset_kwargs["data_folder"] = data_folder # set the datafolder from command line arguments
if "dataset_classdict" in dir(config_module): # if dataset should be initialized from non "DataFolder" class
dataset = config_module.dataset_classdict[config_module.dataset_class_name](**dataset_kwargs)
else: # else initialize the DataFolder from kwargs
dataset = DataFolder(**dataset_kwargs)
if cross_validate: # if cross validation is enabled:
if "dataset_dev" in dir(config_module): # if dataset_dev is initialized in config_module use it
dataset_dev = config_module.dataset_dev
else: # otherwise initialize the dataset_dev
dataset_kwargs_dev = dataset_kwargs.copy() # copy the kwargs for dataset initialization
dataset_kwargs_dev['groups'] = ['dev'] # select the data for the "dev" set
if "dataset_classdict" in dir(config_module): # if dataset should be initialized from non "DataFolder" class
dataset_dev = config_module.dataset_classdict[config_module.dataset_class_name](**dataset_kwargs_dev)
else: # else initialize the DataFolder from kwargs
dataset_dev = DataFolder(**dataset_kwargs_dev)
# =========================================================================
# Handle the "dataloader" initialization:
if verbosity > 0:
logger.info ( "The number of training samples: {}".format( dataset.__len__() ) )
if cross_validate: # if cross validation is enabled:
logger.info ( "The number of cross-validation samples: {}".format( dataset_dev.__len__() ) )
dataloader = DataLoader(dataset,
batch_size = config_module.BATCH_SIZE,
shuffle = True)
if cross_validate: # if cross validation is enabled:
dataloader_dev = DataLoader(dataset_dev,
batch_size = config_module.BATCH_SIZE,
shuffle = False) # shuffling is not needed in cross-validation
UNUSED = dataset.__getitem__(0) # call a dataset __getitem__ once, to **possibly** compute normalization parameters, after that num_workers can be set for dataloader
if "NUM_WORKERS" in dir(config_module) and dataloader.num_workers == 0: # set the number of workers for the DataLoader
dataloader.num_workers = config_module.NUM_WORKERS
if verbosity > 0:
logger.info ( "The number of workers for the DataLoader is: {}".format(dataloader.num_workers) )
# =========================================================================
# Handle the initialization of the networks to be used for training and cross-validation:
if "network_kwargs" in dir(config_module):
network_kwargs = config_module.network_kwargs
model = config_module.Network(**network_kwargs)
if cross_validate: # if cross validation is enabled:
model_dev = config_module.Network(**network_kwargs) # the network to be used for cross-validation
model_dev.train(False) # Model is used for evaluation only
else:
model = config_module.Network()
if cross_validate: # if cross validation is enabled:
model_dev = config_module.Network() # the network to be used for cross-validation
model_dev.train(False) # Model is used for evaluation only
# =========================================================================
# Load pre-trained model if given:
if pretrained_model_path: # initialize with pre-trained model if given
if verbosity > 0:
logger.info ("Initializing the Network with pre-trained model from file: " + pretrained_model_path)
model_state=torch.load(pretrained_model_path,map_location=lambda storage,loc:storage)
# Initialize the state of the model:
model.load_state_dict(model_state)
loss_type = config_module.loss_type
# =========================================================================
# Set the optimizer:
if "param_idx_that_requires_grad" in dir(config_module): # select the parameters to be updated
for idx, param in enumerate(model.parameters()):
if idx not in config_module.param_idx_that_requires_grad:
logger.info ("Parameter {} in the network is not updated".format(idx))
param.requires_grad = False
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
lr=config_module.LEARNING_RATE,
weight_decay=1e-5)
# =========================================================================
# handle the GPU usage in the training:
model.to(device)
if cross_validate:
model_dev.to(device)
# =========================================================================
# Training and cross-validation for ``NUM_EPOCHS`` epochs:
mean_losses_dev_set = [] # list to accumulate mean losses computed on dev set
for epoch in range(config_module.NUM_EPOCHS):
batch_num = 0
start = time.time()
for data in dataloader:
batch_num = batch_num + 1
img, target = data
# if function to preprocess the network data is defined, do preprocessing:
if "data_preproc_function" in dir(config_module): #
img = config_module.data_preproc_function(img)
img = Variable(img)
img = img.to(device)
target = target.to(device)
#===================forward========================================
output = model(img)
if "loss_function" in dir(config_module):
loss = config_module.loss_function(output, img, target)
else:
if isinstance(output, tuple): # if network returns 2 parameters
loss = loss_type(output[0], output[1])
else:
loss = loss_type(output, img)
#===================backward=======================================
optimizer.zero_grad()
loss.backward()
optimizer.step() # does the update
if batch_num == len(dataloader) - 1: # process verbosity using penultimate batch, because the
# last batch can be smaller than BATCH_SIZE.
process_verbosity(verbosity = verbosity,
epoch = epoch+1,
num_epochs = config_module.NUM_EPOCHS,
loss_value = loss.item(),
epoch_step = epoch_step,
batch_tensor = output.data.cpu().data if not isinstance(output, tuple) else output[0].data.cpu().data,
save_folder = save_folder)
end = time.time()
if verbosity > 0:
logger.info ('Time taken by current epoch, excluding cross-validation: {:.6f} (seconds)'.format(end-start))
# =====================================================================
# handle the cross-validation loss:
if cross_validate: # if cross validation is enabled:
# initialize the dev model with current state of the training network:
model_dev.load_state_dict(model.state_dict())
losses_dev_set = [] # list to accumulate batch losses computed on dev set
for data_dev in dataloader_dev: # get a training data for dev set
img_dev, target_dev = data_dev
img_dev = Variable(img_dev)
img_dev = img_dev.to(device)
target_dev = target_dev.to(device)
output_dev = model_dev(img_dev)
if "loss_function" in dir(config_module):
loss_dev = config_module.loss_function(output_dev, img_dev, target_dev)
else:
if isinstance(output_dev, tuple): # if network returns 2 parameters
loss_dev = loss_type(output_dev[0], output_dev[1])
else:
loss_dev = loss_type(output_dev, img_dev)
# print (loss_dev.item())
losses_dev_set.append(loss_dev.item())
mean_loss_dev_set = np.mean(losses_dev_set) # mean loss across all batches of dev set for current epoch
mean_losses_dev_set.append(mean_loss_dev_set)
if verbosity > 0:
logger.info ('epoch [{}/{}], dev loss:{:.6f}'.format(epoch, config_module.NUM_EPOCHS, mean_loss_dev_set))
# =====================================================================
if (epoch+1) % epoch_step == 0:
torch.save(model.state_dict(), os.path.join(save_folder, 'model_{}.pth'.format(epoch+1)))
Loading