diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py index 38147f843c6ee0c4e93b241f813081f7f6c51d64..fd48d4f1fdd82eea05d28dfff7d035e5694a08fd 100644 --- a/src/ptbench/engine/trainer.py +++ b/src/ptbench/engine/trainer.py @@ -92,9 +92,12 @@ def initialize_lowest_validation_loss(logfile_name, arguments): # Get the values of the desired column as a list values = [float(row[column_name]) for row in reader] + if not values: + return sys.float_info.max + lowest_value = min(values) logger.info( - f"Found lowest validation error from previous session: {lowest_value}" + f"Found lowest validation loss from previous session: {lowest_value}" ) return lowest_value diff --git a/tests/test_cli.py b/tests/test_cli.py index 31edf50194435d9ba064557c16892ebff467cfd4..aebfb2397241469a2430917c410664a9ec436ca1 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -288,7 +288,8 @@ def test_train_pasa_montgomery_from_checkpoint(temporary_basedir): r"^Continuing from epoch 1$": 1, r"^Saving model summary at.*$": 1, r"^Model has.*$": 1, - r"^Saving checkpoint": 2, + r"^Found lowest validation loss from previous session.*$": 1, + r"^Saving checkpoint": 1, r"^Total training time:": 1, r"^Z-normalization with mean": 1, }