Skip to content
Snippets Groups Projects
test_cli.py 14.4 KiB
Newer Older
André Anjos's avatar
André Anjos committed
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Tests for our CLI applications."""

import contextlib
import re

André Anjos's avatar
André Anjos committed
from click.testing import CliRunner


@contextlib.contextmanager
def stdout_logging():
    # copy logging messages to std out

    import io
    import logging

    buf = io.StringIO()
    ch = logging.StreamHandler(buf)
    ch.setFormatter(logging.Formatter("%(message)s"))
    ch.setLevel(logging.INFO)
    logger = logging.getLogger("mednet")
André Anjos's avatar
André Anjos committed
    logger.addHandler(ch)
    yield buf
    logger.removeHandler(ch)


def _assert_exit_0(result):
    assert (
        result.exit_code == 0
    ), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}"


def _check_help(entry_point):
    runner = CliRunner()
    result = runner.invoke(entry_point, ["--help"])
    _assert_exit_0(result)
    assert result.output.startswith("Usage:")


def test_config_help():
    from mednet.scripts.config import config
André Anjos's avatar
André Anjos committed

    _check_help(config)


def test_config_list_help():
    from mednet.scripts.config import list
André Anjos's avatar
André Anjos committed

    _check_help(list)


def test_config_list():
    from mednet.scripts.config import list
André Anjos's avatar
André Anjos committed

    runner = CliRunner()
    result = runner.invoke(list)
    _assert_exit_0(result)
    assert "module: mednet.config.data" in result.output
    assert "module: mednet.config.models" in result.output
André Anjos's avatar
André Anjos committed


def test_config_list_v():
    from mednet.scripts.config import list
André Anjos's avatar
André Anjos committed

    result = CliRunner().invoke(list, ["--verbose"])
    _assert_exit_0(result)
    assert "module: mednet.config.data" in result.output
    assert "module: mednet.config.models" in result.output
André Anjos's avatar
André Anjos committed


def test_config_describe_help():
    from mednet.scripts.config import describe
André Anjos's avatar
André Anjos committed

    _check_help(describe)


@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
André Anjos's avatar
André Anjos committed
def test_config_describe_montgomery():
    from mednet.scripts.config import describe
André Anjos's avatar
André Anjos committed

    runner = CliRunner()
    result = runner.invoke(describe, ["montgomery"])
    _assert_exit_0(result)
    assert "Montgomery DataModule for TB detection." in result.output
André Anjos's avatar
André Anjos committed


    from mednet.scripts.database import database
André Anjos's avatar
André Anjos committed

André Anjos's avatar
André Anjos committed


def test_datamodule_list_help():
    from mednet.scripts.database import list
André Anjos's avatar
André Anjos committed

    _check_help(list)


def test_datamodule_list():
    from mednet.scripts.database import list
André Anjos's avatar
André Anjos committed

    runner = CliRunner()
    result = runner.invoke(list)
    _assert_exit_0(result)
    assert result.output.startswith("Available databases:")
def test_datamodule_check_help():
    from mednet.scripts.database import check
André Anjos's avatar
André Anjos committed

    _check_help(check)


@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
    from mednet.scripts.database import check
André Anjos's avatar
André Anjos committed

    runner = CliRunner()
    result = runner.invoke(check, ["--verbose", "--limit=1", "montgomery"])
André Anjos's avatar
André Anjos committed
    _assert_exit_0(result)


def test_main_help():
    from mednet.scripts.cli import cli
André Anjos's avatar
André Anjos committed

    _check_help(cli)


def test_train_help():
    from mednet.scripts.train import train
André Anjos's avatar
André Anjos committed

    _check_help(train)


def _str_counter(substr, s):
    return sum(1 for _ in re.finditer(substr, s, re.MULTILINE))


def test_predict_help():
    from mednet.scripts.predict import predict
André Anjos's avatar
André Anjos committed

    _check_help(predict)


def test_evaluate_help():
    from mednet.scripts.evaluate import evaluate
ogueler@idiap.ch's avatar
ogueler@idiap.ch committed

    _check_help(evaluate)
def test_saliency_generate_help():
    from mednet.scripts.saliency.generate import generate
ogueler@idiap.ch's avatar
ogueler@idiap.ch committed

    _check_help(generate)
def test_saliency_completeness_help():
    from mednet.scripts.saliency.completeness import completeness

    _check_help(completeness)


def test_saliency_view_help():
    from mednet.scripts.saliency.view import view
ogueler@idiap.ch's avatar
ogueler@idiap.ch committed

    _check_help(view)
def test_saliency_evaluate_help():
    from mednet.scripts.saliency.evaluate import evaluate
André Anjos's avatar
André Anjos committed

    _check_help(evaluate)


@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
def test_train_pasa_montgomery(temporary_basedir):
    from mednet.scripts.train import train
    from mednet.utils.checkpointer import (
        CHECKPOINT_EXTENSION,
        _get_checkpoint_from_alias,
    )
André Anjos's avatar
André Anjos committed

André Anjos's avatar
André Anjos committed

        output_folder = temporary_basedir / "results"
        result = runner.invoke(
            train,
            [
                "pasa",
                "montgomery",
                "-vv",
                "--epochs=1",
                "--batch-size=1",
                f"--output-folder={str(output_folder)}",
André Anjos's avatar
André Anjos committed

        # asserts checkpoints are there, or raises FileNotFoundError
        last = _get_checkpoint_from_alias(output_folder, "periodic")
        assert last.name.endswith("epoch=0" + CHECKPOINT_EXTENSION)
        best = _get_checkpoint_from_alias(output_folder, "best")
        assert best.name.endswith("epoch=0" + CHECKPOINT_EXTENSION)

        assert (
            len(list((output_folder / "logs").glob("events.out.tfevents.*")))
Daniel CARRON's avatar
Daniel CARRON committed
        )
        assert (output_folder / "meta.json").exists()
            r"^Loading dataset:`train` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1,
            r"^Loading dataset:`validation` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1,
Daniel CARRON's avatar
Daniel CARRON committed
            r"^Applying DataModule train sampler balancing...$": 1,
            r"^Balancing samples from dataset using metadata targets `label`$": 1,
            r"^Training for at most 1 epochs.$": 1,
            r"^Uninitialised pasa model - computing z-norm factors from train dataloader.$": 1,
            r"^Writing run metadata at.*$": 1,
            r"^Dataset `train` is already setup. Not re-instantiating it.$": 1,
            r"^Dataset `validation` is already setup. Not re-instantiating it.$": 1,
        }
        buf.seek(0)
        logging_output = buf.read()

        for k, v in keywords.items():
            assert _str_counter(k, logging_output) == v, (
                f"Count for string '{k}' appeared "
                f"({_str_counter(k, logging_output)}) "
                f"instead of the expected {v}:\nOutput:\n{logging_output}"
André Anjos's avatar
André Anjos committed
            )

@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
def test_train_pasa_montgomery_from_checkpoint(temporary_basedir):
    from mednet.scripts.train import train
    from mednet.utils.checkpointer import (
        CHECKPOINT_EXTENSION,
        _get_checkpoint_from_alias,
    )
    output_folder = temporary_basedir / "results" / "pasa_checkpoint"
    result0 = runner.invoke(
        train,
        [
            "pasa",
            "montgomery",
            "-vv",
            "--epochs=1",
            "--batch-size=1",
            f"--output-folder={str(output_folder)}",
        ],
    )
    _assert_exit_0(result0)

    # asserts checkpoints are there, or raises FileNotFoundError
    last = _get_checkpoint_from_alias(output_folder, "periodic")
    assert last.name.endswith("epoch=0" + CHECKPOINT_EXTENSION)
    best = _get_checkpoint_from_alias(output_folder, "best")
    assert best.name.endswith("epoch=0" + CHECKPOINT_EXTENSION)

    assert (output_folder / "meta.json").exists()
    assert (
        len(list((output_folder / "logs").glob("events.out.tfevents.*"))) == 1
Daniel CARRON's avatar
Daniel CARRON committed
    )
    with stdout_logging() as buf:
        result = runner.invoke(
            train,
            [
                "pasa",
                "montgomery",
                "-vv",
                "--epochs=2",
                "--batch-size=1",
                f"--output-folder={output_folder}",
            ],
        )
        _assert_exit_0(result)

        # asserts checkpoints are there, or raises FileNotFoundError
        last = _get_checkpoint_from_alias(output_folder, "periodic")
        assert last.name.endswith("epoch=1" + CHECKPOINT_EXTENSION)
        best = _get_checkpoint_from_alias(output_folder, "best")

        assert (output_folder / "meta.json").exists()
        assert (
            len(list((output_folder / "logs").glob("events.out.tfevents.*")))
Daniel CARRON's avatar
Daniel CARRON committed
        )
            r"^Loading dataset:`train` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1,
            r"^Loading dataset:`validation` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1,
Daniel CARRON's avatar
Daniel CARRON committed
            r"^Applying DataModule train sampler balancing...$": 1,
            r"^Balancing samples from dataset using metadata targets `label`$": 1,
            r"^Training for at most 2 epochs.$": 1,
            r"^Resuming from epoch 0 \(checkpoint file: .*$": 1,
            r"^Writing run metadata at.*$": 1,
            r"^Dataset `train` is already setup. Not re-instantiating it.$": 1,
            r"^Dataset `validation` is already setup. Not re-instantiating it.$": 1,
            r"^Restoring normalizer from checkpoint.$": 1,
        }
        buf.seek(0)
        logging_output = buf.read()

        for k, v in keywords.items():
            assert _str_counter(k, logging_output) == v, (
                f"Count for string '{k}' appeared "
                f"({_str_counter(k, logging_output)}) "
                f"instead of the expected {v}:\nOutput:\n{logging_output}"
            )


@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
def test_predict_pasa_montgomery(temporary_basedir):
    from mednet.scripts.predict import predict
    from mednet.utils.checkpointer import (
        CHECKPOINT_EXTENSION,
        _get_checkpoint_from_alias,
    )
        output = temporary_basedir / "predictions.json"
        last = _get_checkpoint_from_alias(
            temporary_basedir / "results", "periodic"
        )
        assert last.name.endswith("epoch=0" + CHECKPOINT_EXTENSION)
        result = runner.invoke(
            predict,
            [
                "pasa",
                "montgomery",
                "-vv",
                "--batch-size=1",
            r"^Loading dataset: * without caching. Trade-off: CPU RAM usage: less | Disk I/O: more$": 3,
            r"^Restoring normalizer from checkpoint.$": 1,
            r"^Running prediction on `train` split...$": 1,
            r"^Running prediction on `validation` split...$": 1,
            r"^Running prediction on `test` split...$": 1,
            r"^Predictions saved to .*$": 1,
        buf.seek(0)
        logging_output = buf.read()

        for k, v in keywords.items():
            assert _str_counter(k, logging_output) == v, (
                f"Count for string '{k}' appeared "
                f"({_str_counter(k, logging_output)}) "
                f"instead of the expected {v}:\nOutput:\n{logging_output}"
André Anjos's avatar
André Anjos committed
            )
@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
def test_evaluate_pasa_montgomery(temporary_basedir):
    from mednet.scripts.evaluate import evaluate
        prediction_path = temporary_basedir / "predictions.json"
        output_path = temporary_basedir / "evaluation.json"
                f"--predictions={str(prediction_path)}",
                f"--output={str(output_path)}",
                "--threshold=test",
        assert output_path.with_suffix(".meta.json").exists()
        assert output_path.with_suffix(".rst").exists()
        assert output_path.with_suffix(".pdf").exists()
André Anjos's avatar
André Anjos committed

        keywords = {
            r"^Setting --threshold=.*$": 1,
            r"^Analyzing split `train`...$": 1,
            r"^Analyzing split `validation`...$": 1,
            r"^Analyzing split `test`...$": 1,
            r"^Saving evaluation results .*$": 2,
            r"^Saving evaluation figures at .*$": 1,
André Anjos's avatar
André Anjos committed
        }
        buf.seek(0)
        logging_output = buf.read()

        for k, v in keywords.items():
            assert _str_counter(k, logging_output) == v, (
                f"Count for string '{k}' appeared "
                f"({_str_counter(k, logging_output)}) "
                f"instead of the expected {v}:\nOutput:\n{logging_output}"
            )
@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
def test_experiment(temporary_basedir):
    from mednet.scripts.experiment import experiment

    runner = CliRunner()

    output_folder = temporary_basedir / "experiment"
    num_epochs = 2
    result = runner.invoke(
        experiment,
        [
            "-vv",
            "pasa",
            "montgomery",
            f"--epochs={num_epochs}",
            f"--output-folder={str(output_folder)}",
        ],
    )
    _assert_exit_0(result)

    assert (output_folder / "model" / "meta.json").exists()
    assert (
        output_folder / "model" / f"model-at-epoch={num_epochs-1}.ckpt"
    ).exists()
    assert (output_folder / "predictions.json").exists()
    assert (output_folder / "predictions.meta.json").exists()

    # Need to glob because we cannot be sure of the checkpoint with lowest validation loss
    assert (
        len(
            list(
                (output_folder / "model").glob(
                    "model-at-lowest-validation-loss-epoch=*.ckpt"
    assert (output_folder / "model" / "trainlog.pdf").exists()
            list(
                (output_folder / "model" / "logs").glob("events.out.tfevents.*")
    assert (output_folder / "evaluation.json").exists()
    assert (output_folder / "evaluation.meta.json").exists()
    assert (output_folder / "evaluation.rst").exists()
    assert (output_folder / "evaluation.pdf").exists()
    assert (output_folder / "gradcam" / "saliencies").exists()
            list(
                (output_folder / "gradcam" / "saliencies" / "CXR_png").glob(
                    "MCUCXR_*.npy"
    assert (output_folder / "gradcam" / "visualizations").exists()
            list(
                (output_folder / "gradcam" / "visualizations" / "CXR_png").glob(
                    "MCUCXR_*.png"