From 03c6aa09aef634e7ed04b9ae26374ecb9f84aeb9 Mon Sep 17 00:00:00 2001
From: "ogueler@idiap.ch" <ogueler@vws110.idiap.ch>
Date: Sun, 19 Mar 2023 04:36:31 +0100
Subject: [PATCH] added function to recognize lowest validation error. Closes
 #2

---
 src/ptbench/engine/trainer.py | 35 ++++++++++++++++++++++++++++++++++-
 1 file changed, 34 insertions(+), 1 deletion(-)

diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py
index 8cc72514..b9d89d93 100644
--- a/src/ptbench/engine/trainer.py
+++ b/src/ptbench/engine/trainer.py
@@ -64,6 +64,37 @@ def check_gpu(device):
             gpu_constants()
         ), f"Device set to '{device}', but nvidia-smi is not installed"
 
+def initialize_lowest_validation_loss(logfile_name, arguments):
+    """Initialize the lowest validation loss from the logfile if it exists and if the training
+    does not start from epoch 0, which means that a previous training session is resumed.
+    
+    Parameters
+    ----------
+
+    logfile_name : str
+        The logfile_name which is a join between the output_folder and trainlog.csv
+
+    arguments : dict
+        start and end epochs
+    """
+
+    if arguments["epoch"] != 0 and os.path.exists(logfile_name):
+        # Open the CSV file
+        with open(logfile_name, 'r') as file:
+            reader = csv.DictReader(file)
+            column_name = "validation_loss"
+            
+            if column_name not in reader.fieldnames:
+                return sys.float_info.max
+
+            # Get the values of the desired column as a list
+            values = [float(row[column_name]) for row in reader]
+
+        lowest_value = min(values)
+        logger.info(f"Found lowest validation error from previous session: {lowest_value}")
+        return lowest_value
+
+    return sys.float_info.max
 
 def save_model_summary(output_folder, model):
     """Save a little summary of the model in a txt file.
@@ -569,7 +600,9 @@ def run(
 
     # the lowest validation loss obtained so far - this value is updated only
     # if a validation set is available
-    lowest_validation_loss = sys.float_info.max
+    lowest_validation_loss = initialize_lowest_validation_loss(
+        logfile_name, arguments
+    )
 
     # set a specific validation criterion if the user has set one
     criterion_valid = criterion_valid or criterion
-- 
GitLab