From 2ff36c7bac4089a7b33186ab7c844f0bdb62c1ba Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Fri, 26 Jan 2024 13:37:58 +0100 Subject: [PATCH] [tests] Add tests for experiment --- tests/test_cli.py | 64 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/tests/test_cli.py b/tests/test_cli.py index 4828e1d0..f56bc43f 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -436,6 +436,70 @@ def test_evaluate_pasa_montgomery(temporary_basedir): ) +@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") +def test_experiment(temporary_basedir): + from mednet.scripts.experiment import experiment + + runner = CliRunner() + + output_folder = str(temporary_basedir / "experiment") + num_epochs = 2 + result = runner.invoke( + experiment, + [ + "-vv", + "pasa", + "montgomery", + f"--epochs={num_epochs}", + f"--output-folder={output_folder}", + ], + ) + _assert_exit_0(result) + + assert os.path.exists(os.path.join(output_folder, "command.sh")) + assert os.path.exists(os.path.join(output_folder, "predictions.json")) + assert os.path.exists(os.path.join(output_folder, "model", "command.sh")) + assert os.path.exists(os.path.join(output_folder, "model", "constants.csv")) + assert os.path.exists( + os.path.join( + output_folder, "model", f"model-at-epoch={num_epochs-1}.ckpt" + ) + ) + # Need to glob because we cannot be sure of the checkpoint with lowest validation loss + assert ( + len( + glob.glob( + os.path.join( + output_folder, + "model", + "model-at-lowest-validation-loss-epoch=*.ckpt", + ) + ) + ) + == 1 + ) + assert os.path.exists( + os.path.join(output_folder, "model", "model-summary.txt") + ) + assert os.path.exists(os.path.join(output_folder, "model", "trainlog.pdf")) + assert ( + len( + glob.glob( + os.path.join( + output_folder, "model", "logs", "events.out.tfevents.*" + ) + ) + ) + == 1 + ) + assert os.path.exists( + os.path.join(output_folder, "evaluation", "plots.pdf") + ) + assert os.path.exists( + os.path.join(output_folder, "evaluation", "summary.rst") + ) + + # This script does not work anymore, either fix or remove the script + this test # def test_evaluatevis(temporary_basedir): # import pandas as pd -- GitLab