From 69f3c2bd3e706263afb80c2ad2e9ed99b7078b5f Mon Sep 17 00:00:00 2001 From: "ogueler@idiap.ch" <ogueler@vws110.idiap.ch> Date: Sun, 19 Mar 2023 04:37:57 +0100 Subject: [PATCH] fixed formatting --- src/ptbench/engine/trainer.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py index b9d89d93..38147f84 100644 --- a/src/ptbench/engine/trainer.py +++ b/src/ptbench/engine/trainer.py @@ -64,10 +64,12 @@ def check_gpu(device): gpu_constants() ), f"Device set to '{device}', but nvidia-smi is not installed" + def initialize_lowest_validation_loss(logfile_name, arguments): - """Initialize the lowest validation loss from the logfile if it exists and if the training - does not start from epoch 0, which means that a previous training session is resumed. - + """Initialize the lowest validation loss from the logfile if it exists and + if the training does not start from epoch 0, which means that a previous + training session is resumed. + Parameters ---------- @@ -80,10 +82,10 @@ def initialize_lowest_validation_loss(logfile_name, arguments): if arguments["epoch"] != 0 and os.path.exists(logfile_name): # Open the CSV file - with open(logfile_name, 'r') as file: + with open(logfile_name) as file: reader = csv.DictReader(file) column_name = "validation_loss" - + if column_name not in reader.fieldnames: return sys.float_info.max @@ -91,11 +93,14 @@ def initialize_lowest_validation_loss(logfile_name, arguments): values = [float(row[column_name]) for row in reader] lowest_value = min(values) - logger.info(f"Found lowest validation error from previous session: {lowest_value}") + logger.info( + f"Found lowest validation error from previous session: {lowest_value}" + ) return lowest_value return sys.float_info.max + def save_model_summary(output_folder, model): """Save a little summary of the model in a txt file. -- GitLab