Skip to content
Snippets Groups Projects
Commit 2ff36c7b authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

[tests] Add tests for experiment

parent 2b40fbd6
No related branches found
No related tags found
2 merge requests!18Update tests,!16Make square centre-padding a model transform
...@@ -436,6 +436,70 @@ def test_evaluate_pasa_montgomery(temporary_basedir): ...@@ -436,6 +436,70 @@ def test_evaluate_pasa_montgomery(temporary_basedir):
) )
@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
def test_experiment(temporary_basedir):
from mednet.scripts.experiment import experiment
runner = CliRunner()
output_folder = str(temporary_basedir / "experiment")
num_epochs = 2
result = runner.invoke(
experiment,
[
"-vv",
"pasa",
"montgomery",
f"--epochs={num_epochs}",
f"--output-folder={output_folder}",
],
)
_assert_exit_0(result)
assert os.path.exists(os.path.join(output_folder, "command.sh"))
assert os.path.exists(os.path.join(output_folder, "predictions.json"))
assert os.path.exists(os.path.join(output_folder, "model", "command.sh"))
assert os.path.exists(os.path.join(output_folder, "model", "constants.csv"))
assert os.path.exists(
os.path.join(
output_folder, "model", f"model-at-epoch={num_epochs-1}.ckpt"
)
)
# Need to glob because we cannot be sure of the checkpoint with lowest validation loss
assert (
len(
glob.glob(
os.path.join(
output_folder,
"model",
"model-at-lowest-validation-loss-epoch=*.ckpt",
)
)
)
== 1
)
assert os.path.exists(
os.path.join(output_folder, "model", "model-summary.txt")
)
assert os.path.exists(os.path.join(output_folder, "model", "trainlog.pdf"))
assert (
len(
glob.glob(
os.path.join(
output_folder, "model", "logs", "events.out.tfevents.*"
)
)
)
== 1
)
assert os.path.exists(
os.path.join(output_folder, "evaluation", "plots.pdf")
)
assert os.path.exists(
os.path.join(output_folder, "evaluation", "summary.rst")
)
# This script does not work anymore, either fix or remove the script + this test # This script does not work anymore, either fix or remove the script + this test
# def test_evaluatevis(temporary_basedir): # def test_evaluatevis(temporary_basedir):
# import pandas as pd # import pandas as pd
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment