diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py index 8cc72514f863e1f7d8447ebd964429c6fbeeef55..fd48d4f1fdd82eea05d28dfff7d035e5694a08fd 100644 --- a/src/ptbench/engine/trainer.py +++ b/src/ptbench/engine/trainer.py @@ -65,6 +65,45 @@ def check_gpu(device): ), f"Device set to '{device}', but nvidia-smi is not installed" +def initialize_lowest_validation_loss(logfile_name, arguments): + """Initialize the lowest validation loss from the logfile if it exists and + if the training does not start from epoch 0, which means that a previous + training session is resumed. + + Parameters + ---------- + + logfile_name : str + The logfile_name which is a join between the output_folder and trainlog.csv + + arguments : dict + start and end epochs + """ + + if arguments["epoch"] != 0 and os.path.exists(logfile_name): + # Open the CSV file + with open(logfile_name) as file: + reader = csv.DictReader(file) + column_name = "validation_loss" + + if column_name not in reader.fieldnames: + return sys.float_info.max + + # Get the values of the desired column as a list + values = [float(row[column_name]) for row in reader] + + if not values: + return sys.float_info.max + + lowest_value = min(values) + logger.info( + f"Found lowest validation loss from previous session: {lowest_value}" + ) + return lowest_value + + return sys.float_info.max + + def save_model_summary(output_folder, model): """Save a little summary of the model in a txt file. @@ -569,7 +608,9 @@ def run( # the lowest validation loss obtained so far - this value is updated only # if a validation set is available - lowest_validation_loss = sys.float_info.max + lowest_validation_loss = initialize_lowest_validation_loss( + logfile_name, arguments + ) # set a specific validation criterion if the user has set one criterion_valid = criterion_valid or criterion diff --git a/tests/test_cli.py b/tests/test_cli.py index 31edf50194435d9ba064557c16892ebff467cfd4..2feb5e6c9370374874b959b20bf3cad5dda10283 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -288,7 +288,7 @@ def test_train_pasa_montgomery_from_checkpoint(temporary_basedir): r"^Continuing from epoch 1$": 1, r"^Saving model summary at.*$": 1, r"^Model has.*$": 1, - r"^Saving checkpoint": 2, + r"^Found lowest validation loss from previous session.*$": 1, r"^Total training time:": 1, r"^Z-normalization with mean": 1, } @@ -302,6 +302,11 @@ def test_train_pasa_montgomery_from_checkpoint(temporary_basedir): f"instead of the expected {v}:\nOutput:\n{logging_output}" ) + extra_keyword = "Saving checkpoint" + assert ( + extra_keyword in logging_output + ), f"String '{extra_keyword}' did not appear at least once in the output:\nOutput:\n{logging_output}" + @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") def test_predict_pasa_montgomery(temporary_basedir, datadir):