Improve trainer function

Split the trainer run function into different steps in order to make it easier to use the different part of the code again in case of a different type of training (example : Multi task learning) where we will need to work with multiple losses. The function is now split to :

  • check_gpu(device) : Where we check the device type and the availability of GPU.

  • save_model_summary(output_folder, model) : Save summary of the model in a txt file.

  • static_information_to_csv(static_logfile_name, device, n) : Save the static information in a csv file.

  • check_exist_logfile(logfile_name, arguments) : Check existence of logfile (trainlog.csv), If the logfile exist the and the epochs number are still 0, The logfile will be replaced.

  • create_logfile_fields(valid_loader, device) : Creation of the logfile fields that will appear in the logfile.

  • train_sample_process(samples, model, optimizer, losses, device, criterion) : Processing the training inputs (Images, ground truth, masks) and apply the backprogration to update the training losses.

  • valid_sample_process(samples, model, valid_losses, device, criterion) : Processing the validation inputs (Images, ground truth, masks) and update validation losses.

  • checkpointer_process(checkpointer, checkpoint_period, valid_losses, lowest_validation_loss, arguments, epoch, max_epoch,) : Process the checkpointer, save the final model and keep track of the best model.

  • write_log_info(epoch, current_time, eta_seconds, losses, valid_losses, optimizer, logwriter, logfile, device,) : Write log info in trainlog.csv

Edited by Driss KHALIL

Merge request reports

Loading