From 03c6aa09aef634e7ed04b9ae26374ecb9f84aeb9 Mon Sep 17 00:00:00 2001 From: "ogueler@idiap.ch" <ogueler@vws110.idiap.ch> Date: Sun, 19 Mar 2023 04:36:31 +0100 Subject: [PATCH] added function to recognize lowest validation error. Closes #2 --- src/ptbench/engine/trainer.py | 35 ++++++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py index 8cc72514..b9d89d93 100644 --- a/src/ptbench/engine/trainer.py +++ b/src/ptbench/engine/trainer.py @@ -64,6 +64,37 @@ def check_gpu(device): gpu_constants() ), f"Device set to '{device}', but nvidia-smi is not installed" +def initialize_lowest_validation_loss(logfile_name, arguments): + """Initialize the lowest validation loss from the logfile if it exists and if the training + does not start from epoch 0, which means that a previous training session is resumed. + + Parameters + ---------- + + logfile_name : str + The logfile_name which is a join between the output_folder and trainlog.csv + + arguments : dict + start and end epochs + """ + + if arguments["epoch"] != 0 and os.path.exists(logfile_name): + # Open the CSV file + with open(logfile_name, 'r') as file: + reader = csv.DictReader(file) + column_name = "validation_loss" + + if column_name not in reader.fieldnames: + return sys.float_info.max + + # Get the values of the desired column as a list + values = [float(row[column_name]) for row in reader] + + lowest_value = min(values) + logger.info(f"Found lowest validation error from previous session: {lowest_value}") + return lowest_value + + return sys.float_info.max def save_model_summary(output_folder, model): """Save a little summary of the model in a txt file. @@ -569,7 +600,9 @@ def run( # the lowest validation loss obtained so far - this value is updated only # if a validation set is available - lowest_validation_loss = sys.float_info.max + lowest_validation_loss = initialize_lowest_validation_loss( + logfile_name, arguments + ) # set a specific validation criterion if the user has set one criterion_valid = criterion_valid or criterion -- GitLab