# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later """Tests for our CLI applications.""" import contextlib import re import pytest from click.testing import CliRunner @contextlib.contextmanager def stdout_logging(): # copy logging messages to std out import io import logging buf = io.StringIO() ch = logging.StreamHandler(buf) ch.setFormatter(logging.Formatter("%(message)s")) ch.setLevel(logging.INFO) logger = logging.getLogger("mednet") logger.addHandler(ch) yield buf logger.removeHandler(ch) def _assert_exit_0(result): assert ( result.exit_code == 0 ), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}" def _check_help(entry_point): runner = CliRunner() result = runner.invoke(entry_point, ["--help"]) _assert_exit_0(result) assert result.output.startswith("Usage:") def test_info_help(): from mednet.scripts.info import info _check_help(info) def test_info(): from mednet.scripts.info import info runner = CliRunner() result = runner.invoke(info) _assert_exit_0(result) assert "platform:" in result.output assert "accelerators:" in result.output assert "version:" in result.output assert "configured databases:" in result.output assert "dependencies:" in result.output assert "python:" in result.output def test_config_help(): from mednet.scripts.config import config _check_help(config) def test_config_list_help(): from mednet.scripts.config import list_ _check_help(list_) def test_config_list(): from mednet.scripts.config import list_ runner = CliRunner() result = runner.invoke(list_) _assert_exit_0(result) assert "module: mednet.config.data" in result.output assert "module: mednet.config.models" in result.output def test_config_list_v(): from mednet.scripts.config import list_ result = CliRunner().invoke(list_, ["--verbose"]) _assert_exit_0(result) assert "module: mednet.config.data" in result.output assert "module: mednet.config.models" in result.output def test_config_describe_help(): from mednet.scripts.config import describe _check_help(describe) @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") def test_config_describe_montgomery(): from mednet.scripts.config import describe runner = CliRunner() result = runner.invoke(describe, ["montgomery"]) _assert_exit_0(result) assert "Montgomery DataModule for TB detection." in result.output def test_database_help(): from mednet.scripts.database import database _check_help(database) def test_datamodule_list_help(): from mednet.scripts.database import list_ _check_help(list_) def test_datamodule_list(): from mednet.scripts.database import list_ runner = CliRunner() result = runner.invoke(list_) _assert_exit_0(result) assert result.output.startswith("Available databases:") def test_datamodule_check_help(): from mednet.scripts.database import check _check_help(check) @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") def test_database_check(): from mednet.scripts.database import check runner = CliRunner() result = runner.invoke(check, ["--verbose", "--limit=1", "montgomery"]) _assert_exit_0(result) def test_main_help(): from mednet.scripts.cli import cli _check_help(cli) def test_train_help(): from mednet.scripts.train import train _check_help(train) def _str_counter(substr, s): return sum(1 for _ in re.finditer(substr, s, re.MULTILINE)) def test_predict_help(): from mednet.scripts.predict import predict _check_help(predict) def test_evaluate_help(): from mednet.scripts.evaluate import evaluate _check_help(evaluate) def test_saliency_generate_help(): from mednet.scripts.saliency.generate import generate _check_help(generate) def test_saliency_completeness_help(): from mednet.scripts.saliency.completeness import completeness _check_help(completeness) def test_saliency_view_help(): from mednet.scripts.saliency.view import view _check_help(view) def test_saliency_evaluate_help(): from mednet.scripts.saliency.evaluate import evaluate _check_help(evaluate) def test_upload_help(): from mednet.scripts.upload import upload _check_help(upload) @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") def test_train_pasa_montgomery(temporary_basedir): from mednet.scripts.train import train from mednet.utils.checkpointer import ( CHECKPOINT_EXTENSION, _get_checkpoint_from_alias, ) runner = CliRunner() with stdout_logging() as buf: output_folder = temporary_basedir / "results" result = runner.invoke( train, [ "pasa", "montgomery", "-vv", "--epochs=1", "--batch-size=1", f"--output-folder={str(output_folder)}", ], ) _assert_exit_0(result) # asserts checkpoints are there, or raises FileNotFoundError last = _get_checkpoint_from_alias(output_folder, "periodic") assert last.name.endswith("epoch=0" + CHECKPOINT_EXTENSION) best = _get_checkpoint_from_alias(output_folder, "best") assert best.name.endswith("epoch=0" + CHECKPOINT_EXTENSION) assert ( len(list((output_folder / "logs").glob("events.out.tfevents.*"))) == 1 ) assert (output_folder / "meta.json").exists() keywords = { r"^Loading dataset:`train` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1, r"^Loading dataset:`validation` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1, r"^Applying train/valid loss balancing...$": 1, r"^Training for at most 1 epochs.$": 1, r"^Uninitialised pasa model - computing z-norm factors from train dataloader.$": 1, r"^Writing run metadata at.*$": 1, r"^Dataset `train` is already setup. Not re-instantiating it.$": 1, r"^Dataset `validation` is already setup. Not re-instantiating it.$": 1, } buf.seek(0) logging_output = buf.read() for k, v in keywords.items(): assert _str_counter(k, logging_output) == v, ( f"Count for string '{k}' appeared " f"({_str_counter(k, logging_output)}) " f"instead of the expected {v}:\nOutput:\n{logging_output}" ) @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") def test_train_pasa_montgomery_from_checkpoint(temporary_basedir): from mednet.scripts.train import train from mednet.utils.checkpointer import ( CHECKPOINT_EXTENSION, _get_checkpoint_from_alias, ) runner = CliRunner() output_folder = temporary_basedir / "results" / "pasa_checkpoint" result0 = runner.invoke( train, [ "pasa", "montgomery", "-vv", "--epochs=1", "--batch-size=1", f"--output-folder={str(output_folder)}", ], ) _assert_exit_0(result0) # asserts checkpoints are there, or raises FileNotFoundError last = _get_checkpoint_from_alias(output_folder, "periodic") assert last.name.endswith("epoch=0" + CHECKPOINT_EXTENSION) best = _get_checkpoint_from_alias(output_folder, "best") assert best.name.endswith("epoch=0" + CHECKPOINT_EXTENSION) assert (output_folder / "meta.json").exists() assert ( len(list((output_folder / "logs").glob("events.out.tfevents.*"))) == 1 ) with stdout_logging() as buf: result = runner.invoke( train, [ "pasa", "montgomery", "-vv", "--epochs=2", "--batch-size=1", f"--output-folder={output_folder}", ], ) _assert_exit_0(result) # asserts checkpoints are there, or raises FileNotFoundError last = _get_checkpoint_from_alias(output_folder, "periodic") assert last.name.endswith("epoch=1" + CHECKPOINT_EXTENSION) best = _get_checkpoint_from_alias(output_folder, "best") assert (output_folder / "meta.json").exists() assert ( len(list((output_folder / "logs").glob("events.out.tfevents.*"))) == 2 ) keywords = { r"^Loading dataset:`train` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1, r"^Loading dataset:`validation` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1, r"^Applying train/valid loss balancing...$": 1, r"^Training for at most 2 epochs.$": 1, r"^Resuming from epoch 0 \(checkpoint file: .*$": 1, r"^Writing run metadata at.*$": 1, r"^Dataset `train` is already setup. Not re-instantiating it.$": 1, r"^Dataset `validation` is already setup. Not re-instantiating it.$": 1, r"^Restoring normalizer from checkpoint.$": 1, } buf.seek(0) logging_output = buf.read() for k, v in keywords.items(): assert _str_counter(k, logging_output) == v, ( f"Count for string '{k}' appeared " f"({_str_counter(k, logging_output)}) " f"instead of the expected {v}:\nOutput:\n{logging_output}" ) @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") def test_predict_pasa_montgomery(temporary_basedir): from mednet.scripts.predict import predict from mednet.utils.checkpointer import ( CHECKPOINT_EXTENSION, _get_checkpoint_from_alias, ) runner = CliRunner() with stdout_logging() as buf: output = temporary_basedir / "predictions.json" last = _get_checkpoint_from_alias( temporary_basedir / "results", "periodic", ) assert last.name.endswith("epoch=0" + CHECKPOINT_EXTENSION) result = runner.invoke( predict, [ "pasa", "montgomery", "-vv", "--batch-size=1", f"--weight={str(last)}", f"--output={str(output)}", ], ) _assert_exit_0(result) assert output.exists() keywords = { r"^Loading dataset: * without caching. Trade-off: CPU RAM usage: less | Disk I/O: more$": 3, r"^Loading checkpoint from .*$": 1, r"^Restoring normalizer from checkpoint.$": 1, r"^Running prediction on `train` split...$": 1, r"^Running prediction on `validation` split...$": 1, r"^Running prediction on `test` split...$": 1, r"^Predictions saved to .*$": 1, } buf.seek(0) logging_output = buf.read() for k, v in keywords.items(): assert _str_counter(k, logging_output) == v, ( f"Count for string '{k}' appeared " f"({_str_counter(k, logging_output)}) " f"instead of the expected {v}:\nOutput:\n{logging_output}" ) @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") def test_evaluate_pasa_montgomery(temporary_basedir): from mednet.scripts.evaluate import evaluate runner = CliRunner() with stdout_logging() as buf: prediction_path = temporary_basedir / "predictions.json" evaluation_filename = "evaluation.json" evaluation_file = temporary_basedir / evaluation_filename result = runner.invoke( evaluate, [ "-vv", "montgomery", f"--predictions={str(prediction_path)}", f"--output-folder={str(temporary_basedir)}", "--threshold=test", ], ) _assert_exit_0(result) assert evaluation_file.exists() assert evaluation_file.with_suffix(".meta.json").exists() assert evaluation_file.with_suffix(".rst").exists() assert evaluation_file.with_suffix(".pdf").exists() keywords = { r"^Setting --threshold=.*$": 1, r"^Analyzing split `train`...$": 1, r"^Analyzing split `validation`...$": 1, r"^Analyzing split `test`...$": 1, r"^Saving evaluation results .*$": 2, r"^Saving evaluation figures at .*$": 1, } buf.seek(0) logging_output = buf.read() for k, v in keywords.items(): assert _str_counter(k, logging_output) == v, ( f"Count for string '{k}' appeared " f"({_str_counter(k, logging_output)}) " f"instead of the expected {v}:\nOutput:\n{logging_output}" ) @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") def test_experiment(temporary_basedir): from mednet.scripts.experiment import experiment runner = CliRunner() output_folder = temporary_basedir / "experiment" num_epochs = 2 result = runner.invoke( experiment, [ "-vv", "pasa", "montgomery", f"--epochs={num_epochs}", f"--output-folder={str(output_folder)}", ], ) _assert_exit_0(result) assert (output_folder / "model" / "meta.json").exists() assert ( output_folder / "model" / f"model-at-epoch={num_epochs-1}.ckpt" ).exists() assert (output_folder / "predictions.json").exists() assert (output_folder / "predictions.meta.json").exists() # Need to glob because we cannot be sure of the checkpoint with lowest validation loss assert ( len( list( (output_folder / "model").glob( "model-at-lowest-validation-loss-epoch=*.ckpt", ), ), ) == 1 ) assert (output_folder / "trainlog.pdf").exists() assert ( len( list( (output_folder / "model" / "logs").glob( "events.out.tfevents.*", ), ), ) == 1 ) assert (output_folder / "evaluation.json").exists() assert (output_folder / "evaluation.meta.json").exists() assert (output_folder / "evaluation.rst").exists() assert (output_folder / "evaluation.pdf").exists() assert (output_folder / "gradcam" / "saliencies").exists() assert ( len( list( (output_folder / "gradcam" / "saliencies" / "CXR_png").glob( "MCUCXR_*.npy", ), ), ) == 138 ) assert (output_folder / "gradcam" / "visualizations").exists() assert ( len( list( (output_folder / "gradcam" / "visualizations" / "CXR_png").glob( "MCUCXR_*.png", ), ), ) == 58 )