Skip to content
Snippets Groups Projects
Commit 69f3c2bd authored by ogueler@idiap.ch's avatar ogueler@idiap.ch
Browse files

fixed formatting

parent 03c6aa09
No related branches found
No related tags found
1 merge request!3Lowest validation loss
......@@ -64,10 +64,12 @@ def check_gpu(device):
gpu_constants()
), f"Device set to '{device}', but nvidia-smi is not installed"
def initialize_lowest_validation_loss(logfile_name, arguments):
"""Initialize the lowest validation loss from the logfile if it exists and if the training
does not start from epoch 0, which means that a previous training session is resumed.
"""Initialize the lowest validation loss from the logfile if it exists and
if the training does not start from epoch 0, which means that a previous
training session is resumed.
Parameters
----------
......@@ -80,10 +82,10 @@ def initialize_lowest_validation_loss(logfile_name, arguments):
if arguments["epoch"] != 0 and os.path.exists(logfile_name):
# Open the CSV file
with open(logfile_name, 'r') as file:
with open(logfile_name) as file:
reader = csv.DictReader(file)
column_name = "validation_loss"
if column_name not in reader.fieldnames:
return sys.float_info.max
......@@ -91,11 +93,14 @@ def initialize_lowest_validation_loss(logfile_name, arguments):
values = [float(row[column_name]) for row in reader]
lowest_value = min(values)
logger.info(f"Found lowest validation error from previous session: {lowest_value}")
logger.info(
f"Found lowest validation error from previous session: {lowest_value}"
)
return lowest_value
return sys.float_info.max
def save_model_summary(output_folder, model):
"""Save a little summary of the model in a txt file.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment