Skip to content
Snippets Groups Projects
Commit 03c6aa09 authored by ogueler@idiap.ch's avatar ogueler@idiap.ch
Browse files

added function to recognize lowest validation error. Closes #2

parent c7068292
No related branches found
No related tags found
1 merge request!3Lowest validation loss
......@@ -64,6 +64,37 @@ def check_gpu(device):
gpu_constants()
), f"Device set to '{device}', but nvidia-smi is not installed"
def initialize_lowest_validation_loss(logfile_name, arguments):
"""Initialize the lowest validation loss from the logfile if it exists and if the training
does not start from epoch 0, which means that a previous training session is resumed.
Parameters
----------
logfile_name : str
The logfile_name which is a join between the output_folder and trainlog.csv
arguments : dict
start and end epochs
"""
if arguments["epoch"] != 0 and os.path.exists(logfile_name):
# Open the CSV file
with open(logfile_name, 'r') as file:
reader = csv.DictReader(file)
column_name = "validation_loss"
if column_name not in reader.fieldnames:
return sys.float_info.max
# Get the values of the desired column as a list
values = [float(row[column_name]) for row in reader]
lowest_value = min(values)
logger.info(f"Found lowest validation error from previous session: {lowest_value}")
return lowest_value
return sys.float_info.max
def save_model_summary(output_folder, model):
"""Save a little summary of the model in a txt file.
......@@ -569,7 +600,9 @@ def run(
# the lowest validation loss obtained so far - this value is updated only
# if a validation set is available
lowest_validation_loss = sys.float_info.max
lowest_validation_loss = initialize_lowest_validation_loss(
logfile_name, arguments
)
# set a specific validation criterion if the user has set one
criterion_valid = criterion_valid or criterion
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment