diff --git a/src/mednet/engine/trainer.py b/src/mednet/engine/trainer.py index e4a2f5eac865ae870207657219b120b21cc9be64..5ea8ccae4cfecb6987fb7affca1274e6bb474a4e 100644 --- a/src/mednet/engine/trainer.py +++ b/src/mednet/engine/trainer.py @@ -93,8 +93,6 @@ def run( main_pid=os.getpid(), ) - monitor_key = "loss/validation" - # This checkpointer will operate at the end of every validation epoch # (which happens at each checkpoint period), it will then save the lowest # validation loss model observed. It will also save the last trained model @@ -102,7 +100,7 @@ def run( dirpath=output_folder, filename=CHECKPOINT_ALIASES["best"], save_last=True, # will (re)create the last trained model, at every iteration - monitor=monitor_key, + monitor="loss/validation", mode="min", save_on_train_epoch_end=True, every_n_epochs=validation_period, # frequency at which it checks the "monitor"