#!/usr/bin/env python2 # -*- coding: utf-8 -*- """ The following steps are performed in this script: 1. The command line arguments are first parsed. 2. Folder to save the results to is created. 3. Configuration file specifying the Network and learning parameters is loaded. 4. A generic data loader compatible with Bob High Level Database Interfaces, namely DataFolder, is initialized. 5. The Network is initialized, can also be initialized with pre-trained model. 6. The training is performed. Verbosity flag can be used to see and save training related outputs. See ``process_verbosity`` function for more details. 7. The model is saved after each 1 epochs. @author: Olegs Nikisins """ #============================================================================== # Import here: import argparse import importlib import os from bob.learn.pytorch.datasets import DataFolder import torch from torch.utils.data import DataLoader from torch.autograd import Variable from torchvision.utils import save_image import logging logger = logging.getLogger("bob.learn.pytorch") import numpy as np import time #============================================================================== def parse_arguments(cmd_params=None): """ Parse command line arguments. **Parameters:** ``cmd_params``: [] An optional list of command line arguments. Default: None. **Returns:** ``data_folder``: py:class:`string` A directory containing the training data. ``save_folder``: py:class:`string` A directory to save the results of training to. ``relative_mod_name``: py:class:`string` Relative name of the module to import configurations from. ``config_group``: py:class:`string` Group/package name containing the configuration file. ``pretrained_model_path``: py:class:`string` Absolute name of the file, containing pre-trained Network model, to de used for Network initialization before training. ``cross_validate``: bool Cross-validate the current model on the dev set of the database used for training. Cross validation is done after each training epoch, using entire development set of the database. ``verbosity``: py:class:`int` Verbosity level. """ parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("data_folder", type=str, help="A directory containing the training data.") parser.add_argument("save_folder", type=str, help="A directory to save the results of training to.") parser.add_argument("-c", "--config-file", type=str, help="Relative name of the config file defining " "the network, training data, and training parameters.", default = "autoencoder/net1_celeba.py") parser.add_argument("-cg", "--config-group", type=str, help="Name of the group, where config file is stored.", default = "bob.learn.pytorch.config") parser.add_argument("-p", "--pretrained-model-path", type=str, help="Absolute name of the file, containing pre-trained Network " "model, to de used for Network initialization before training.", default = "") parser.add_argument("-cv", "--cross-validate", action="store_true", help="Cross validate the current model on the dev set of the database.", default = False) parser.add_argument("-gpu", "--use-gpu", action="store_true", help="Use the GPU for model training, if GPU is available in your system.", default = False) parser.add_argument("-v", "--verbosity", action="count", default=0, help="Increase output verbosity. For -v loss is printed. For -vv output images are saved.") if cmd_params is not None: args = parser.parse_args(cmd_params) else: args = parser.parse_args() data_folder = args.data_folder save_folder = args.save_folder config_file = args.config_file config_group = args.config_group pretrained_model_path = args.pretrained_model_path cross_validate = args.cross_validate use_gpu = args.use_gpu verbosity = args.verbosity relative_mod_name = '.' + os.path.splitext(config_file)[0].replace(os.path.sep, '.') return data_folder, save_folder, relative_mod_name, config_group, pretrained_model_path, cross_validate, use_gpu, verbosity #============================================================================== def to_img(batch): """ Normalize the images in the batch to [0, 1] range for plotting. **Parameters:** ``batch`` : Tensor A tensor containing a batch of images. The size of the tensor: (num_imgs x num_color_channels x H x W). **Returns:** ``batch`` : Tensor A tensor containing a normalized batch of images. The size of the tensor: (num_imgs x num_color_channels x H x W). """ batch = (batch - batch.min()) batch = batch / batch.max() batch = batch.clamp(0, 1) return batch #============================================================================== def process_verbosity(verbosity, epoch, num_epochs, loss_value, epoch_step, batch_tensor, save_folder): """ Report results based on the verbose level. 1. If verbosity level is 1: loss is printed for each epoch. 2. If verbosity levle is greater than 1: both loss is printed and a reconstructed image is saved efter each ``epoch_step`` epochs. **Parameters:** ``verbosity``: py:class:`int` Verbosity level. ``epoch``: py:class:`int` Current epoch number. ``num_epochs``: py:class:`int` Total number of epochs. ``loss_value``: py:class:`float` Loss value for the current epoch. ``epoch_step``: py:class:`int` Plot the images after each ``epoch_step`` epochs. ``batch_tensor`` : Tensor A tensor containing a batch of NN output images. The size of the tensor: (num_imgs x num_color_channels x H x W). ``save_folder``: py:class:`str` Folder to save images to. """ if verbosity > 0: logger.info ('epoch [{}/{}], loss:{:.6f}'.format(epoch, num_epochs, loss_value)) if verbosity > 1: if epoch % epoch_step == 0: pic = to_img(batch_tensor) save_image( pic, os.path.join(save_folder, 'image_{}.png'.format(epoch)) ) #============================================================================== def main(cmd_params=None): """ The following steps are performed in this function: 1. The command line arguments are first parsed. 2. Folder to save the results to is created. 3. Configuration file specifying the Network and learning parameters is loaded. 4. A generic data loader compatible with Bob High Level Database Interfaces, namely DataFolder, is initialized. 5. The Network is initialized, can also be initialized with pre-trained model. 6. The training is performed. Verbosity flag can be used to see and save training related outputs. See ``process_verbosity`` function for more details. 7. The model is saved after each 1 epochs. """ epoch_step = 1 # save images and trained model after each ``epoch_step`` epoch data_folder, save_folder, relative_mod_name, config_group, pretrained_model_path, cross_validate, use_gpu, verbosity = \ parse_arguments(cmd_params = cmd_params) if not os.path.exists(save_folder): os.mkdir(save_folder) config_module = importlib.import_module(relative_mod_name, config_group) # ========================================================================= # handle the GPU usage in the training: if use_gpu: # if GPU usage is enabled by the user # check if GPU is available in the system: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if verbosity > 0: logger.info ("The number of GPUs available in the system and used for training: {}".format( torch.cuda.device_count())) else: device = torch.device("cpu") # ========================================================================= # Handle the "dataset" initialization: if "dataset" in dir(config_module): # if dataset is initialized in config_module use it dataset = config_module.dataset else: # otherwise initialize the dataset dataset_kwargs = config_module.kwargs dataset_kwargs["data_folder"] = data_folder # set the datafolder from command line arguments if "dataset_classdict" in dir(config_module): # if dataset should be initialized from non "DataFolder" class dataset = config_module.dataset_classdict[config_module.dataset_class_name](**dataset_kwargs) else: # else initialize the DataFolder from kwargs dataset = DataFolder(**dataset_kwargs) if cross_validate: # if cross validation is enabled: if "dataset_dev" in dir(config_module): # if dataset_dev is initialized in config_module use it dataset_dev = config_module.dataset_dev else: # otherwise initialize the dataset_dev dataset_kwargs_dev = dataset_kwargs.copy() # copy the kwargs for dataset initialization dataset_kwargs_dev['groups'] = ['dev'] # select the data for the "dev" set if "dataset_classdict" in dir(config_module): # if dataset should be initialized from non "DataFolder" class dataset_dev = config_module.dataset_classdict[config_module.dataset_class_name](**dataset_kwargs_dev) else: # else initialize the DataFolder from kwargs dataset_dev = DataFolder(**dataset_kwargs_dev) # ========================================================================= # Handle the "dataloader" initialization: if verbosity > 0: logger.info ( "The number of training samples: {}".format( dataset.__len__() ) ) if cross_validate: # if cross validation is enabled: logger.info ( "The number of cross-validation samples: {}".format( dataset_dev.__len__() ) ) dataloader = DataLoader(dataset, batch_size = config_module.BATCH_SIZE, shuffle = True) if cross_validate: # if cross validation is enabled: dataloader_dev = DataLoader(dataset_dev, batch_size = config_module.BATCH_SIZE, shuffle = False) # shuffling is not needed in cross-validation UNUSED = dataset.__getitem__(0) # call a dataset __getitem__ once, to **possibly** compute normalization parameters, after that num_workers can be set for dataloader if "NUM_WORKERS" in dir(config_module) and dataloader.num_workers == 0: # set the number of workers for the DataLoader dataloader.num_workers = config_module.NUM_WORKERS if verbosity > 0: logger.info ( "The number of workers for the DataLoader is: {}".format(dataloader.num_workers) ) # ========================================================================= # Handle the initialization of the networks to be used for training and cross-validation: if "network_kwargs" in dir(config_module): network_kwargs = config_module.network_kwargs model = config_module.Network(**network_kwargs) if cross_validate: # if cross validation is enabled: model_dev = config_module.Network(**network_kwargs) # the network to be used for cross-validation model_dev.train(False) # Model is used for evaluation only else: model = config_module.Network() if cross_validate: # if cross validation is enabled: model_dev = config_module.Network() # the network to be used for cross-validation model_dev.train(False) # Model is used for evaluation only # ========================================================================= # Load pre-trained model if given: if pretrained_model_path: # initialize with pre-trained model if given if verbosity > 0: logger.info ("Initializing the Network with pre-trained model from file: " + pretrained_model_path) model_state=torch.load(pretrained_model_path,map_location=lambda storage,loc:storage) # Initialize the state of the model: model.load_state_dict(model_state) loss_type = config_module.loss_type # ========================================================================= # Set the optimizer: if "param_idx_that_requires_grad" in dir(config_module): # select the parameters to be updated for idx, param in enumerate(model.parameters()): if idx not in config_module.param_idx_that_requires_grad: logger.info ("Parameter {} in the network is not updated".format(idx)) param.requires_grad = False optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=config_module.LEARNING_RATE, weight_decay=1e-5) # ========================================================================= # handle the GPU usage in the training: model.to(device) if cross_validate: model_dev.to(device) # ========================================================================= # Training and cross-validation for ``NUM_EPOCHS`` epochs: mean_losses_dev_set = [] # list to accumulate mean losses computed on dev set for epoch in range(config_module.NUM_EPOCHS): batch_num = 0 start = time.time() for data in dataloader: batch_num = batch_num + 1 img, target = data # if function to preprocess the network data is defined, do preprocessing: if "data_preproc_function" in dir(config_module): # img = config_module.data_preproc_function(img) img = Variable(img) img = img.to(device) target = target.to(device) #===================forward======================================== output = model(img) if "loss_function" in dir(config_module): loss = config_module.loss_function(output, img, target) else: if isinstance(output, tuple): # if network returns 2 parameters loss = loss_type(output[0], output[1]) else: loss = loss_type(output, img) #===================backward======================================= optimizer.zero_grad() loss.backward() optimizer.step() # does the update if batch_num == len(dataloader) - 1: # process verbosity using penultimate batch, because the # last batch can be smaller than BATCH_SIZE. process_verbosity(verbosity = verbosity, epoch = epoch+1, num_epochs = config_module.NUM_EPOCHS, loss_value = loss.item(), epoch_step = epoch_step, batch_tensor = output.data.cpu().data if not isinstance(output, tuple) else output[0].data.cpu().data, save_folder = save_folder) end = time.time() if verbosity > 0: logger.info ('Time taken by current epoch, excluding cross-validation: {:.6f} (seconds)'.format(end-start)) # ===================================================================== # handle the cross-validation loss: if cross_validate: # if cross validation is enabled: # initialize the dev model with current state of the training network: model_dev.load_state_dict(model.state_dict()) losses_dev_set = [] # list to accumulate batch losses computed on dev set for data_dev in dataloader_dev: # get a training data for dev set img_dev, target_dev = data_dev img_dev = Variable(img_dev) img_dev = img_dev.to(device) target_dev = target_dev.to(device) output_dev = model_dev(img_dev) if "loss_function" in dir(config_module): loss_dev = config_module.loss_function(output_dev, img_dev, target_dev) else: if isinstance(output_dev, tuple): # if network returns 2 parameters loss_dev = loss_type(output_dev[0], output_dev[1]) else: loss_dev = loss_type(output_dev, img_dev) # print (loss_dev.item()) losses_dev_set.append(loss_dev.item()) mean_loss_dev_set = np.mean(losses_dev_set) # mean loss across all batches of dev set for current epoch mean_losses_dev_set.append(mean_loss_dev_set) if verbosity > 0: logger.info ('epoch [{}/{}], dev loss:{:.6f}'.format(epoch, config_module.NUM_EPOCHS, mean_loss_dev_set)) # ===================================================================== if (epoch+1) % epoch_step == 0: torch.save(model.state_dict(), os.path.join(save_folder, 'model_{}.pth'.format(epoch+1)))