Skip to content

Improve trainer function

Driss KHALIL requested to merge ImproveTrainerFunction into master

Split the trainer run function into different steps in order to make it easier to use the different part of the code again in case of a different type of training (example : Multi task learning) where we will need to work with multiple losses. The function is now split to :

  • check_gpu(device) : Where we check the device type and the availability of GPU.

  • save_model_summary(output_folder, model) : Save summary of the model in a txt file.

  • static_information_to_csv(static_logfile_name, device, n) : Save the static information in a csv file.

  • check_exist_logfile(logfile_name, arguments) : Check existence of logfile (trainlog.csv), If the logfile exist the and the epochs number are still 0, The logfile will be replaced.

  • create_logfile_fields(valid_loader, device) : Creation of the logfile fields that will appear in the logfile.

  • train_sample_process(samples, model, optimizer, losses, device, criterion) : Processing the training inputs (Images, ground truth, masks) and apply the backprogration to update the training losses.

  • valid_sample_process(samples, model, valid_losses, device, criterion) : Processing the validation inputs (Images, ground truth, masks) and update validation losses.

  • checkpointer_process(checkpointer, checkpoint_period, valid_losses, lowest_validation_loss, arguments, epoch, max_epoch,) : Process the checkpointer, save the final model and keep track of the best model.

  • write_log_info(epoch, current_time, eta_seconds, losses, valid_losses, optimizer, logwriter, logfile, device,) : Write log info in trainlog.csv

Edited by Driss KHALIL

Merge request reports