Skip to content
Snippets Groups Projects
Commit 69f3c2bd authored by ogueler@idiap.ch's avatar ogueler@idiap.ch
Browse files

fixed formatting

parent 03c6aa09
No related branches found
No related tags found
1 merge request!3Lowest validation loss
...@@ -64,10 +64,12 @@ def check_gpu(device): ...@@ -64,10 +64,12 @@ def check_gpu(device):
gpu_constants() gpu_constants()
), f"Device set to '{device}', but nvidia-smi is not installed" ), f"Device set to '{device}', but nvidia-smi is not installed"
def initialize_lowest_validation_loss(logfile_name, arguments): def initialize_lowest_validation_loss(logfile_name, arguments):
"""Initialize the lowest validation loss from the logfile if it exists and if the training """Initialize the lowest validation loss from the logfile if it exists and
does not start from epoch 0, which means that a previous training session is resumed. if the training does not start from epoch 0, which means that a previous
training session is resumed.
Parameters Parameters
---------- ----------
...@@ -80,10 +82,10 @@ def initialize_lowest_validation_loss(logfile_name, arguments): ...@@ -80,10 +82,10 @@ def initialize_lowest_validation_loss(logfile_name, arguments):
if arguments["epoch"] != 0 and os.path.exists(logfile_name): if arguments["epoch"] != 0 and os.path.exists(logfile_name):
# Open the CSV file # Open the CSV file
with open(logfile_name, 'r') as file: with open(logfile_name) as file:
reader = csv.DictReader(file) reader = csv.DictReader(file)
column_name = "validation_loss" column_name = "validation_loss"
if column_name not in reader.fieldnames: if column_name not in reader.fieldnames:
return sys.float_info.max return sys.float_info.max
...@@ -91,11 +93,14 @@ def initialize_lowest_validation_loss(logfile_name, arguments): ...@@ -91,11 +93,14 @@ def initialize_lowest_validation_loss(logfile_name, arguments):
values = [float(row[column_name]) for row in reader] values = [float(row[column_name]) for row in reader]
lowest_value = min(values) lowest_value = min(values)
logger.info(f"Found lowest validation error from previous session: {lowest_value}") logger.info(
f"Found lowest validation error from previous session: {lowest_value}"
)
return lowest_value return lowest_value
return sys.float_info.max return sys.float_info.max
def save_model_summary(output_folder, model): def save_model_summary(output_folder, model):
"""Save a little summary of the model in a txt file. """Save a little summary of the model in a txt file.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment