diff --git a/bob/ip/binseg/script/experiment.py b/bob/ip/binseg/script/experiment.py index 30d14e2c7e51ff0de20801b9d8a0bdabc8c2ea27..f43fe5fc7b117d6089a8215906f6d032074b1af1 100644 --- a/bob/ip/binseg/script/experiment.py +++ b/bob/ip/binseg/script/experiment.py @@ -326,7 +326,11 @@ def experiment( from .analyze import analyze - model_file = os.path.join(train_output_folder, "model_final.pth") + # preferably, we use the best model on the validation set + # otherwise, we get the last saved model + model_file = os.path.join(train_output_folder, "model_lowest_valid_loss.pth") + if not os.path.exists(model_file): + model_file = os.path.join(train_output_folder, "model_final.pth") ctx.invoke( analyze,