diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py index b9d89d93dd241447805a885160d46c658f3bc9c7..38147f843c6ee0c4e93b241f813081f7f6c51d64 100644 --- a/src/ptbench/engine/trainer.py +++ b/src/ptbench/engine/trainer.py @@ -64,10 +64,12 @@ def check_gpu(device): gpu_constants() ), f"Device set to '{device}', but nvidia-smi is not installed" + def initialize_lowest_validation_loss(logfile_name, arguments): - """Initialize the lowest validation loss from the logfile if it exists and if the training - does not start from epoch 0, which means that a previous training session is resumed. - + """Initialize the lowest validation loss from the logfile if it exists and + if the training does not start from epoch 0, which means that a previous + training session is resumed. + Parameters ---------- @@ -80,10 +82,10 @@ def initialize_lowest_validation_loss(logfile_name, arguments): if arguments["epoch"] != 0 and os.path.exists(logfile_name): # Open the CSV file - with open(logfile_name, 'r') as file: + with open(logfile_name) as file: reader = csv.DictReader(file) column_name = "validation_loss" - + if column_name not in reader.fieldnames: return sys.float_info.max @@ -91,11 +93,14 @@ def initialize_lowest_validation_loss(logfile_name, arguments): values = [float(row[column_name]) for row in reader] lowest_value = min(values) - logger.info(f"Found lowest validation error from previous session: {lowest_value}") + logger.info( + f"Found lowest validation error from previous session: {lowest_value}" + ) return lowest_value return sys.float_info.max + def save_model_summary(output_folder, model): """Save a little summary of the model in a txt file.