Skip to content
Snippets Groups Projects
test_cli.py 14.9 KiB
Newer Older
André Anjos's avatar
André Anjos committed
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Tests for our CLI applications."""

import contextlib
import re

André Anjos's avatar
André Anjos committed
from click.testing import CliRunner


@contextlib.contextmanager
def stdout_logging():
    # copy logging messages to std out

    import io
    import logging

    buf = io.StringIO()
    ch = logging.StreamHandler(buf)
    ch.setFormatter(logging.Formatter("%(message)s"))
    ch.setLevel(logging.INFO)
    logger = logging.getLogger("mednet")
André Anjos's avatar
André Anjos committed
    logger.addHandler(ch)
    yield buf
    logger.removeHandler(ch)


def _assert_exit_0(result):
    assert (
        result.exit_code == 0
    ), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}"


def _check_help(entry_point):
    runner = CliRunner()
    result = runner.invoke(entry_point, ["--help"])
    _assert_exit_0(result)
    assert result.output.startswith("Usage:")


def test_info_help():
    from mednet.scripts.info import info

    _check_help(info)


def test_info():
    from mednet.scripts.info import info

    runner = CliRunner()
    result = runner.invoke(info)
    _assert_exit_0(result)
    assert "platform:" in result.output
    assert "accelerators:" in result.output
    assert "version:" in result.output
    assert "configured databases:" in result.output
    assert "dependencies:" in result.output
    assert "python:" in result.output


André Anjos's avatar
André Anjos committed
def test_config_help():
    from mednet.scripts.config import config
André Anjos's avatar
André Anjos committed

    _check_help(config)


def test_config_list_help():
André Anjos's avatar
André Anjos committed

André Anjos's avatar
André Anjos committed


def test_config_list():
André Anjos's avatar
André Anjos committed

    runner = CliRunner()
André Anjos's avatar
André Anjos committed
    _assert_exit_0(result)
    assert "module: mednet.config.data" in result.output
    assert "module: mednet.config.models" in result.output
André Anjos's avatar
André Anjos committed


def test_config_list_v():
André Anjos's avatar
André Anjos committed

André Anjos's avatar
André Anjos committed
    _assert_exit_0(result)
    assert "module: mednet.config.data" in result.output
    assert "module: mednet.config.models" in result.output
André Anjos's avatar
André Anjos committed


def test_config_describe_help():
    from mednet.scripts.config import describe
André Anjos's avatar
André Anjos committed

    _check_help(describe)


@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
André Anjos's avatar
André Anjos committed
def test_config_describe_montgomery():
    from mednet.scripts.config import describe
André Anjos's avatar
André Anjos committed

    runner = CliRunner()
    result = runner.invoke(describe, ["montgomery"])
    _assert_exit_0(result)
    assert "Montgomery DataModule for TB detection." in result.output
    from mednet.scripts.database import database
André Anjos's avatar
André Anjos committed

def test_datamodule_list_help():
André Anjos's avatar
André Anjos committed

def test_datamodule_list():
André Anjos's avatar
André Anjos committed

    runner = CliRunner()
André Anjos's avatar
André Anjos committed
    _assert_exit_0(result)
    assert result.output.startswith("Available databases:")
def test_datamodule_check_help():
    from mednet.scripts.database import check
André Anjos's avatar
André Anjos committed

    _check_help(check)


@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
    from mednet.scripts.database import check
André Anjos's avatar
André Anjos committed

    runner = CliRunner()
    result = runner.invoke(check, ["--verbose", "--limit=1", "montgomery"])
André Anjos's avatar
André Anjos committed
    _assert_exit_0(result)


def test_main_help():
    from mednet.scripts.cli import cli
André Anjos's avatar
André Anjos committed

    _check_help(cli)


def test_train_help():
    from mednet.scripts.train import train
André Anjos's avatar
André Anjos committed

    _check_help(train)


def _str_counter(substr, s):
    return sum(1 for _ in re.finditer(substr, s, re.MULTILINE))


def test_predict_help():
    from mednet.scripts.predict import predict
André Anjos's avatar
André Anjos committed

    _check_help(predict)


def test_evaluate_help():
    from mednet.scripts.evaluate import evaluate
ogueler@idiap.ch's avatar
ogueler@idiap.ch committed

    _check_help(evaluate)
def test_saliency_generate_help():
    from mednet.scripts.saliency.generate import generate
ogueler@idiap.ch's avatar
ogueler@idiap.ch committed

    _check_help(generate)
def test_saliency_completeness_help():
    from mednet.scripts.saliency.completeness import completeness

    _check_help(completeness)


def test_saliency_view_help():
    from mednet.scripts.saliency.view import view
ogueler@idiap.ch's avatar
ogueler@idiap.ch committed

    _check_help(view)
def test_saliency_evaluate_help():
    from mednet.scripts.saliency.evaluate import evaluate
André Anjos's avatar
André Anjos committed

    _check_help(evaluate)


def test_upload_help():
    from mednet.scripts.upload import upload

    _check_help(upload)


@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
def test_train_pasa_montgomery(temporary_basedir):
    from mednet.scripts.train import train
    from mednet.utils.checkpointer import (
        CHECKPOINT_EXTENSION,
        _get_checkpoint_from_alias,
    )
André Anjos's avatar
André Anjos committed

André Anjos's avatar
André Anjos committed

        output_folder = temporary_basedir / "results"
        result = runner.invoke(
            train,
            [
                "pasa",
                "montgomery",
                "-vv",
                "--epochs=1",
                "--batch-size=1",
                f"--output-folder={str(output_folder)}",
André Anjos's avatar
André Anjos committed

        # asserts checkpoints are there, or raises FileNotFoundError
        last = _get_checkpoint_from_alias(output_folder, "periodic")
        assert last.name.endswith("epoch=0" + CHECKPOINT_EXTENSION)
        best = _get_checkpoint_from_alias(output_folder, "best")
        assert best.name.endswith("epoch=0" + CHECKPOINT_EXTENSION)

        assert (
            len(list((output_folder / "logs").glob("events.out.tfevents.*")))
Daniel CARRON's avatar
Daniel CARRON committed
        )
        assert (output_folder / "meta.json").exists()
            r"^Loading dataset:`train` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1,
            r"^Loading dataset:`validation` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1,
Daniel CARRON's avatar
Daniel CARRON committed
            r"^Applying train/valid loss balancing...$": 1,
            r"^Training for at most 1 epochs.$": 1,
            r"^Uninitialised pasa model - computing z-norm factors from train dataloader.$": 1,
            r"^Writing run metadata at.*$": 1,
            r"^Dataset `train` is already setup. Not re-instantiating it.$": 1,
            r"^Dataset `validation` is already setup. Not re-instantiating it.$": 1,
        }
        buf.seek(0)
        logging_output = buf.read()

        for k, v in keywords.items():
            assert _str_counter(k, logging_output) == v, (
                f"Count for string '{k}' appeared "
                f"({_str_counter(k, logging_output)}) "
                f"instead of the expected {v}:\nOutput:\n{logging_output}"
André Anjos's avatar
André Anjos committed
            )

@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
def test_train_pasa_montgomery_from_checkpoint(temporary_basedir):
    from mednet.scripts.train import train
    from mednet.utils.checkpointer import (
        CHECKPOINT_EXTENSION,
        _get_checkpoint_from_alias,
    )
    output_folder = temporary_basedir / "results" / "pasa_checkpoint"
    result0 = runner.invoke(
        train,
        [
            "pasa",
            "montgomery",
            "-vv",
            "--epochs=1",
            "--batch-size=1",
            f"--output-folder={str(output_folder)}",
        ],
    )
    _assert_exit_0(result0)

    # asserts checkpoints are there, or raises FileNotFoundError
    last = _get_checkpoint_from_alias(output_folder, "periodic")
    assert last.name.endswith("epoch=0" + CHECKPOINT_EXTENSION)
    best = _get_checkpoint_from_alias(output_folder, "best")
    assert best.name.endswith("epoch=0" + CHECKPOINT_EXTENSION)

    assert (output_folder / "meta.json").exists()
    assert (
        len(list((output_folder / "logs").glob("events.out.tfevents.*"))) == 1
Daniel CARRON's avatar
Daniel CARRON committed
    )
    with stdout_logging() as buf:
        result = runner.invoke(
            train,
            [
                "pasa",
                "montgomery",
                "-vv",
                "--epochs=2",
                "--batch-size=1",
                f"--output-folder={output_folder}",
            ],
        )
        _assert_exit_0(result)

        # asserts checkpoints are there, or raises FileNotFoundError
        last = _get_checkpoint_from_alias(output_folder, "periodic")
        assert last.name.endswith("epoch=1" + CHECKPOINT_EXTENSION)
        best = _get_checkpoint_from_alias(output_folder, "best")

        assert (output_folder / "meta.json").exists()
        assert (
            len(list((output_folder / "logs").glob("events.out.tfevents.*")))
Daniel CARRON's avatar
Daniel CARRON committed
        )
            r"^Loading dataset:`train` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1,
            r"^Loading dataset:`validation` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1,
Daniel CARRON's avatar
Daniel CARRON committed
            r"^Applying train/valid loss balancing...$": 1,
            r"^Training for at most 2 epochs.$": 1,
            r"^Resuming from epoch 0 \(checkpoint file: .*$": 1,
            r"^Writing run metadata at.*$": 1,
            r"^Dataset `train` is already setup. Not re-instantiating it.$": 1,
            r"^Dataset `validation` is already setup. Not re-instantiating it.$": 1,
            r"^Restoring normalizer from checkpoint.$": 1,
        }
        buf.seek(0)
        logging_output = buf.read()

        for k, v in keywords.items():
            assert _str_counter(k, logging_output) == v, (
                f"Count for string '{k}' appeared "
                f"({_str_counter(k, logging_output)}) "
                f"instead of the expected {v}:\nOutput:\n{logging_output}"
            )


@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
def test_predict_pasa_montgomery(temporary_basedir):
    from mednet.scripts.predict import predict
    from mednet.utils.checkpointer import (
        CHECKPOINT_EXTENSION,
        _get_checkpoint_from_alias,
    )
        output = temporary_basedir / "predictions.json"
        last = _get_checkpoint_from_alias(
        )
        assert last.name.endswith("epoch=0" + CHECKPOINT_EXTENSION)
        result = runner.invoke(
            predict,
            [
                "pasa",
                "montgomery",
                "-vv",
                "--batch-size=1",
            r"^Loading dataset: * without caching. Trade-off: CPU RAM usage: less | Disk I/O: more$": 3,
            r"^Restoring normalizer from checkpoint.$": 1,
            r"^Running prediction on `train` split...$": 1,
            r"^Running prediction on `validation` split...$": 1,
            r"^Running prediction on `test` split...$": 1,
            r"^Predictions saved to .*$": 1,
        buf.seek(0)
        logging_output = buf.read()

        for k, v in keywords.items():
            assert _str_counter(k, logging_output) == v, (
                f"Count for string '{k}' appeared "
                f"({_str_counter(k, logging_output)}) "
                f"instead of the expected {v}:\nOutput:\n{logging_output}"
André Anjos's avatar
André Anjos committed
            )
@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
def test_evaluate_pasa_montgomery(temporary_basedir):
    from mednet.scripts.evaluate import evaluate
        prediction_path = temporary_basedir / "predictions.json"
Daniel CARRON's avatar
Daniel CARRON committed
        evaluation_filename = "evaluation.json"
        evaluation_file = temporary_basedir / evaluation_filename
                f"--predictions={str(prediction_path)}",
Daniel CARRON's avatar
Daniel CARRON committed
                f"--output-folder={str(temporary_basedir)}",
                "--threshold=test",
Daniel CARRON's avatar
Daniel CARRON committed
        assert evaluation_file.exists()
        assert evaluation_file.with_suffix(".meta.json").exists()
        assert evaluation_file.with_suffix(".rst").exists()
        assert evaluation_file.with_suffix(".pdf").exists()
André Anjos's avatar
André Anjos committed

        keywords = {
            r"^Setting --threshold=.*$": 1,
            r"^Analyzing split `train`...$": 1,
            r"^Analyzing split `validation`...$": 1,
            r"^Analyzing split `test`...$": 1,
            r"^Saving evaluation results .*$": 2,
            r"^Saving evaluation figures at .*$": 1,
André Anjos's avatar
André Anjos committed
        }
        buf.seek(0)
        logging_output = buf.read()

        for k, v in keywords.items():
            assert _str_counter(k, logging_output) == v, (
                f"Count for string '{k}' appeared "
                f"({_str_counter(k, logging_output)}) "
                f"instead of the expected {v}:\nOutput:\n{logging_output}"
            )
@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
def test_experiment(temporary_basedir):
    from mednet.scripts.experiment import experiment

    runner = CliRunner()

    output_folder = temporary_basedir / "experiment"
    num_epochs = 2
    result = runner.invoke(
        experiment,
        [
            "-vv",
            "pasa",
            "montgomery",
            f"--epochs={num_epochs}",
            f"--output-folder={str(output_folder)}",
        ],
    )
    _assert_exit_0(result)

    assert (output_folder / "model" / "meta.json").exists()
    assert (
        output_folder / "model" / f"model-at-epoch={num_epochs-1}.ckpt"
    ).exists()
    assert (output_folder / "predictions.json").exists()
    assert (output_folder / "predictions.meta.json").exists()

    # Need to glob because we cannot be sure of the checkpoint with lowest validation loss
    assert (
        len(
            list(
                (output_folder / "model").glob(
    assert (output_folder / "trainlog.pdf").exists()
                (output_folder / "model" / "logs").glob(
                    "events.out.tfevents.*",
                ),
            ),
    assert (output_folder / "evaluation.json").exists()
    assert (output_folder / "evaluation.meta.json").exists()
    assert (output_folder / "evaluation.rst").exists()
    assert (output_folder / "evaluation.pdf").exists()
    assert (output_folder / "gradcam" / "saliencies").exists()
            list(
                (output_folder / "gradcam" / "saliencies" / "CXR_png").glob(
    assert (output_folder / "gradcam" / "visualizations").exists()
            list(
                (output_folder / "gradcam" / "visualizations" / "CXR_png").glob(