test_cli.py 13.41 KiB
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Tests for our CLI applications."""
import contextlib
import re
import pytest
from click.testing import CliRunner
@contextlib.contextmanager
def stdout_logging():
# copy logging messages to std out
import io
import logging
buf = io.StringIO()
ch = logging.StreamHandler(buf)
ch.setFormatter(logging.Formatter("%(message)s"))
ch.setLevel(logging.INFO)
logger = logging.getLogger("mednet")
logger.addHandler(ch)
yield buf
logger.removeHandler(ch)
def _assert_exit_0(result):
assert (
result.exit_code == 0
), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}"
def _check_help(entry_point):
runner = CliRunner()
result = runner.invoke(entry_point, ["--help"])
_assert_exit_0(result)
assert result.output.startswith("Usage:")
def test_config_help():
from mednet.libs.segmentation.scripts.config import config
_check_help(config)
def test_config_list_help():
from mednet.libs.segmentation.scripts.config import list_
_check_help(list_)
def test_config_list():
from mednet.libs.segmentation.scripts.config import list_
runner = CliRunner()
result = runner.invoke(list_)
_assert_exit_0(result)
assert "module: mednet.libs.segmentation.config.data" in result.output
assert "module: mednet.libs.segmentation.config.models" in result.output
def test_config_list_v():
from mednet.libs.segmentation.scripts.config import list_
result = CliRunner().invoke(list_, ["--verbose"])
_assert_exit_0(result)
assert "module: mednet.libs.segmentation.config.data" in result.output
assert "module: mednet.libs.segmentation.config.models" in result.output
def test_config_describe_help():
from mednet.libs.segmentation.scripts.config import describe
_check_help(describe)
@pytest.mark.skip_if_rc_var_not_set("datadir.drive")
def test_config_describe_drive():
from mednet.libs.segmentation.scripts.config import describe
runner = CliRunner()
result = runner.invoke(describe, ["drive"])
_assert_exit_0(result)
assert "DRIVE dataset for Vessel Segmentation (default protocol)." in result.output
def test_database_help():
from mednet.libs.segmentation.scripts.database import database
_check_help(database)
def test_datamodule_list_help():
from mednet.libs.segmentation.scripts.database import list_
_check_help(list_)
def test_datamodule_list():
from mednet.libs.segmentation.scripts.database import list_
runner = CliRunner()
result = runner.invoke(list_)
_assert_exit_0(result)
assert result.output.startswith("Available databases:")
def test_datamodule_check_help():
from mednet.libs.segmentation.scripts.database import check
_check_help(check)
@pytest.mark.skip_if_rc_var_not_set("datadir.drive")
def test_database_check():
from mednet.libs.segmentation.scripts.database import check
runner = CliRunner()
result = runner.invoke(check, ["--verbose", "--limit=1", "drive"])
_assert_exit_0(result)
def test_main_help():
from mednet.libs.segmentation.scripts.cli import segmentation
_check_help(segmentation)
def test_train_help():
from mednet.libs.segmentation.scripts.train import train
_check_help(train)
def _str_counter(substr, s):
return sum(1 for _ in re.finditer(substr, s, re.MULTILINE))
def test_predict_help():
from mednet.libs.segmentation.scripts.predict import predict
_check_help(predict)
def test_evaluate_help():
from mednet.libs.segmentation.scripts.evaluate import evaluate
_check_help(evaluate)
@pytest.mark.slow
@pytest.mark.skip_if_rc_var_not_set("datadir.drive")
def test_train_lwnet_drive(temporary_basedir):
from mednet.libs.common.utils.checkpointer import (
CHECKPOINT_EXTENSION,
_get_checkpoint_from_alias,
)
from mednet.libs.segmentation.scripts.train import train
runner = CliRunner()
with stdout_logging() as buf:
output_folder = temporary_basedir / "results"
result = runner.invoke(
train,
[
"lwnet",
"drive",
"-vv",
"--epochs=1",
"--batch-size=1",
f"--output-folder={str(output_folder)}",
],
)
_assert_exit_0(result)
# asserts checkpoints are there, or raises FileNotFoundError
last = _get_checkpoint_from_alias(output_folder, "periodic")
assert last.name.endswith("epoch=0" + CHECKPOINT_EXTENSION)
best = _get_checkpoint_from_alias(output_folder, "best")
assert best.name.endswith("epoch=0" + CHECKPOINT_EXTENSION)
assert len(list((output_folder / "logs").glob("events.out.tfevents.*"))) == 1
assert (output_folder / "meta.json").exists()
keywords = {
r"^Loading dataset:`train` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1,
r"^Training for at most 1 epochs.$": 1,
r"^Uninitialised lwnet model - computing z-norm factors from train dataloader.$": 1,
r"^Writing run metadata at.*$": 1,
r"^Dataset `train` is already setup. Not re-instantiating it.$": 3,
}
buf.seek(0)
logging_output = buf.read()
for k, v in keywords.items():
assert _str_counter(k, logging_output) == v, (
f"Count for string '{k}' appeared "
f"({_str_counter(k, logging_output)}) "
f"instead of the expected {v}:\nOutput:\n{logging_output}"
)
@pytest.mark.slow
@pytest.mark.skip_if_rc_var_not_set("datadir.drive")
def test_train_lwnet_drive_from_checkpoint(temporary_basedir):
from mednet.libs.common.utils.checkpointer import (
CHECKPOINT_EXTENSION,
_get_checkpoint_from_alias,
)
from mednet.libs.segmentation.scripts.train import train
runner = CliRunner()
output_folder = temporary_basedir / "results" / "lwnet_checkpoint"
result0 = runner.invoke(
train,
[
"lwnet",
"drive",
"-vv",
"--epochs=1",
"--batch-size=1",
f"--output-folder={str(output_folder)}",
],
)
_assert_exit_0(result0)
# asserts checkpoints are there, or raises FileNotFoundError
last = _get_checkpoint_from_alias(output_folder, "periodic")
assert last.name.endswith("epoch=0" + CHECKPOINT_EXTENSION)
best = _get_checkpoint_from_alias(output_folder, "best")
assert best.name.endswith("epoch=0" + CHECKPOINT_EXTENSION)
assert (output_folder / "meta.json").exists()
assert len(list((output_folder / "logs").glob("events.out.tfevents.*"))) == 1
with stdout_logging() as buf:
result = runner.invoke(
train,
[
"lwnet",
"drive",
"-vv",
"--epochs=2",
"--batch-size=1",
f"--output-folder={output_folder}",
],
)
_assert_exit_0(result)
# asserts checkpoints are there, or raises FileNotFoundError
last = _get_checkpoint_from_alias(output_folder, "periodic")
assert last.name.endswith("epoch=1" + CHECKPOINT_EXTENSION)
best = _get_checkpoint_from_alias(output_folder, "best")
assert (output_folder / "meta.json").exists()
assert len(list((output_folder / "logs").glob("events.out.tfevents.*"))) == 2
keywords = {
r"^Loading dataset:`train` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1,
r"^Training for at most 2 epochs.$": 1,
r"^Resuming from epoch 0 \(checkpoint file: .*$": 1,
r"^Writing run metadata at.*$": 1,
r"^Dataset `train` is already setup. Not re-instantiating it.$": 3,
r"^Restoring normalizer from checkpoint.$": 1,
}
buf.seek(0)
logging_output = buf.read()
for k, v in keywords.items():
assert _str_counter(k, logging_output) == v, (
f"Count for string '{k}' appeared "
f"({_str_counter(k, logging_output)}) "
f"instead of the expected {v}:\nOutput:\n{logging_output}"
)
@pytest.mark.slow
@pytest.mark.skip_if_rc_var_not_set("datadir.drive")
def test_predict_lwnet_drive(temporary_basedir, datadir):
from mednet.libs.common.utils.checkpointer import (
CHECKPOINT_EXTENSION,
_get_checkpoint_from_alias,
)
from mednet.libs.segmentation.scripts.predict import predict
runner = CliRunner()
with stdout_logging() as buf:
output = temporary_basedir / "predictions"
last = _get_checkpoint_from_alias(
temporary_basedir / "results",
"periodic",
)
assert last.name.endswith("epoch=0" + CHECKPOINT_EXTENSION)
result = runner.invoke(
predict,
[
"lwnet",
"drive",
"-vv",
"--batch-size=1",
f"--weight={str(last)}",
f"--output-folder={str(output)}",
],
)
_assert_exit_0(result)
assert output.exists()
keywords = {
r"^Loading dataset: * without caching. Trade-off: CPU RAM usage: less | Disk I/O: more$": 2,
r"^Loading checkpoint from .*$": 1,
r"^Restoring normalizer from checkpoint.$": 1,
r"^Running prediction on `train` split...$": 1,
r"^Running prediction on `test` split...$": 1,
r"^Predictions saved to .*$": 1,
}
buf.seek(0)
logging_output = buf.read()
for k, v in keywords.items():
assert _str_counter(k, logging_output) == v, (
f"Count for string '{k}' appeared "
f"({_str_counter(k, logging_output)}) "
f"instead of the expected {v}:\nOutput:\n{logging_output}"
)
@pytest.mark.slow
@pytest.mark.skip_if_rc_var_not_set("datadir.drive")
def test_evaluate_lwnet_drive(temporary_basedir):
from mednet.libs.segmentation.scripts.evaluate import evaluate
runner = CliRunner()
with stdout_logging() as buf:
prediction_path = temporary_basedir / "predictions"
predictions_file = prediction_path / "predictions.json"
evaluation_path = temporary_basedir / "evaluations"
result = runner.invoke(
evaluate,
[
"-vv",
"drive",
f"--predictions={predictions_file}",
f"--output-folder={evaluation_path}",
"--threshold=test",
],
)
_assert_exit_0(result)
assert (evaluation_path / "evaluation.json").exists()
assert (evaluation_path / "evaluation.meta.json").exists()
assert (evaluation_path / "evaluation.pdf").exists()
assert (evaluation_path / "evaluation.rst").exists()
keywords = {
r"^Writing run metadata at.*$": 1,
r"^Counting true/false positive/negatives at split.*$": 2,
r"^Writing run metadata at.*$": 1,
r"^Evaluating threshold on.*$": 1,
r"^Tabulating performance summary...": 1,
r"^Saving evaluation results at.*$": 1,
r"^Saving table at .*$": 1,
r"^Plotting performance curves...": 1,
r"^Saving figures at .*$": 1,
}
buf.seek(0)
logging_output = buf.read()
for k, v in keywords.items():
assert _str_counter(k, logging_output) == v, (
f"Count for string '{k}' appeared "
f"({_str_counter(k, logging_output)}) "
f"instead of the expected {v}:\nOutput:\n{logging_output}"
)
@pytest.mark.slow
@pytest.mark.skip_if_rc_var_not_set("datadir.drive")
def test_experiment(temporary_basedir):
from mednet.libs.segmentation.scripts.experiment import experiment
runner = CliRunner()
output_folder = temporary_basedir / "experiment"
num_epochs = 2
result = runner.invoke(
experiment,
[
"-vv",
"lwnet",
"drive",
f"--epochs={num_epochs}",
f"--output-folder={str(output_folder)}",
],
)
_assert_exit_0(result)
assert (output_folder / "model" / "meta.json").exists()
assert (output_folder / "model" / f"model-at-epoch={num_epochs-1}.ckpt").exists()
assert (output_folder / "predictions" / "predictions.json").exists()
assert (output_folder / "predictions" / "predictions.meta.json").exists()
# Need to glob because we cannot be sure of the checkpoint with lowest validation loss
assert (
len(
list(
(output_folder / "model").glob(
"model-at-lowest-validation-loss-epoch=*.ckpt",
),
),
)
== 1
)
assert (output_folder / "model" / "trainlog.pdf").exists()
assert (
len(
list(
(output_folder / "model" / "logs").glob(
"events.out.tfevents.*",
),
),
)
== 1
)
assert (output_folder / "evaluation" / "evaluation.json").exists()
assert (output_folder / "evaluation" / "evaluation.meta.json").exists()
assert (output_folder / "evaluation" / "evaluation.pdf").exists()
assert (output_folder / "evaluation" / "evaluation.rst").exists()