diff --git a/tests/test_cli.py b/tests/test_cli.py index 4828e1d09ca51d1b06c5cb3c29fb1f8b62ddccd0..f56bc43f77a33a9a4e5d998676ae70299cbfac01 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -436,6 +436,70 @@ def test_evaluate_pasa_montgomery(temporary_basedir): ) +@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") +def test_experiment(temporary_basedir): + from mednet.scripts.experiment import experiment + + runner = CliRunner() + + output_folder = str(temporary_basedir / "experiment") + num_epochs = 2 + result = runner.invoke( + experiment, + [ + "-vv", + "pasa", + "montgomery", + f"--epochs={num_epochs}", + f"--output-folder={output_folder}", + ], + ) + _assert_exit_0(result) + + assert os.path.exists(os.path.join(output_folder, "command.sh")) + assert os.path.exists(os.path.join(output_folder, "predictions.json")) + assert os.path.exists(os.path.join(output_folder, "model", "command.sh")) + assert os.path.exists(os.path.join(output_folder, "model", "constants.csv")) + assert os.path.exists( + os.path.join( + output_folder, "model", f"model-at-epoch={num_epochs-1}.ckpt" + ) + ) + # Need to glob because we cannot be sure of the checkpoint with lowest validation loss + assert ( + len( + glob.glob( + os.path.join( + output_folder, + "model", + "model-at-lowest-validation-loss-epoch=*.ckpt", + ) + ) + ) + == 1 + ) + assert os.path.exists( + os.path.join(output_folder, "model", "model-summary.txt") + ) + assert os.path.exists(os.path.join(output_folder, "model", "trainlog.pdf")) + assert ( + len( + glob.glob( + os.path.join( + output_folder, "model", "logs", "events.out.tfevents.*" + ) + ) + ) + == 1 + ) + assert os.path.exists( + os.path.join(output_folder, "evaluation", "plots.pdf") + ) + assert os.path.exists( + os.path.join(output_folder, "evaluation", "summary.rst") + ) + + # This script does not work anymore, either fix or remove the script + this test # def test_evaluatevis(temporary_basedir): # import pandas as pd