From 946e86f598605762af174bca0e4d57930e478370 Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Thu, 27 Jun 2024 17:31:04 +0200 Subject: [PATCH] [tests.classification] Streamline testing --- .../classification/engine/saliency/viewer.py | 3 +- .../libs/classification/scripts/evaluate.py | 1 - .../scripts/saliency/generate.py | 25 +++++ .../classification/scripts/saliency/view.py | 31 ++++-- tests/classification/test_cli.py | 99 ++++++++++++++++++- 5 files changed, 148 insertions(+), 11 deletions(-) diff --git a/src/mednet/libs/classification/engine/saliency/viewer.py b/src/mednet/libs/classification/engine/saliency/viewer.py index 112820a0..26d60613 100644 --- a/src/mednet/libs/classification/engine/saliency/viewer.py +++ b/src/mednet/libs/classification/engine/saliency/viewer.py @@ -221,7 +221,8 @@ def run( for dataset_name, dataset_loader in datamodule.predict_dataloader().items(): logger.info( - f"Generating visualisations for samples at dataset `{dataset_name}`...", + f"Generating visualizations for samples (target = {target_label}) " + f"at dataset `{dataset_name}`..." ) for sample in tqdm( diff --git a/src/mednet/libs/classification/scripts/evaluate.py b/src/mednet/libs/classification/scripts/evaluate.py index d0553bad..4e7932e8 100644 --- a/src/mednet/libs/classification/scripts/evaluate.py +++ b/src/mednet/libs/classification/scripts/evaluate.py @@ -148,7 +148,6 @@ def evaluate( ) json_data = {k.replace("_", "-"): v for k, v in json_data.items()} evaluation_meta = evaluation_file.with_suffix(".meta.json") - logger.info(f"Saving evaluation metadata at `{str(evaluation_meta)}`...") save_json_with_backup(evaluation_meta, json_data) if threshold in predict_data: diff --git a/src/mednet/libs/classification/scripts/saliency/generate.py b/src/mednet/libs/classification/scripts/saliency/generate.py index 825c7b3e..2e4c7569 100644 --- a/src/mednet/libs/classification/scripts/saliency/generate.py +++ b/src/mednet/libs/classification/scripts/saliency/generate.py @@ -171,12 +171,37 @@ def generate( from mednet.libs.common.engine.device import DeviceManager from mednet.libs.common.scripts.predict import setup_datamodule + from mednet.libs.common.scripts.utils import ( + execution_metadata, + save_json_with_backup, + ) from mednet.libs.common.utils.checkpointer import ( get_checkpoint_to_run_inference, ) from ...engine.saliency.generator import run + # register metadata + json_data: dict[str, typing.Any] = execution_metadata() + json_data.update( + dict( + database_name=datamodule.database_name, + database_split=datamodule.split_name, + model_name=model.name, + output_folder=str(output_folder), + device=device, + cache_samples=cache_samples, + weight=str(weight), + parallel=parallel, + saliency_map_algorithm=saliency_map_algorithm, + target_class=target_class, + positive_only=positive_only, + ), + ) + json_data = {k.replace("_", "-"): v for k, v in json_data.items()} + saliency_meta = output_folder / "saliency-generation.meta.json" + save_json_with_backup(saliency_meta, json_data) + logger.info(f"Output folder: {output_folder}") output_folder.mkdir(parents=True, exist_ok=True) diff --git a/src/mednet/libs/classification/scripts/saliency/view.py b/src/mednet/libs/classification/scripts/saliency/view.py index 3ad74497..aed9a81d 100644 --- a/src/mednet/libs/classification/scripts/saliency/view.py +++ b/src/mednet/libs/classification/scripts/saliency/view.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later import pathlib +import typing import click from clapper.click import ConfigCommand, ResourceOption, verbosity_option @@ -97,19 +98,33 @@ def view( **_, ) -> None: # numpydoc ignore=PR01 """Generate heatmaps for input CXRs based on existing saliency maps.""" - from ...engine.saliency.viewer import run - - assert ( - input_folder != output_folder - ), "Output folder must not be the same as the input folder." + from mednet.libs.common.scripts.utils import ( + execution_metadata, + save_json_with_backup, + ) - assert not str(output_folder).startswith( - str(input_folder), - ), "Output folder must not be a subdirectory of the input folder." + from ...engine.saliency.viewer import run logger.info(f"Output folder: {output_folder}") output_folder.mkdir(parents=True, exist_ok=True) + # register metadata + json_data: dict[str, typing.Any] = execution_metadata() + json_data.update( + dict( + database_name=datamodule.database_name, + database_split=datamodule.split_name, + model_name=model.name, + input_folder=str(input_folder), + output_folder=str(output_folder), + show_groundtruth=show_groundtruth, + threshold=threshold, + ), + ) + json_data = {k.replace("_", "-"): v for k, v in json_data.items()} + saliency_meta = output_folder / "saliency-view.meta.json" + save_json_with_backup(saliency_meta, json_data) + datamodule.drop_incomplete_batch = False # datamodule.cache_samples = cache_samples # datamodule.parallel = parallel diff --git a/tests/classification/test_cli.py b/tests/classification/test_cli.py index 8f408319..54d09503 100644 --- a/tests/classification/test_cli.py +++ b/tests/classification/test_cli.py @@ -219,6 +219,7 @@ def test_train_pasa_montgomery(session_tmp_path): assert (output_folder / "meta.json").exists() keywords = { + r"^Writing run metadata at .*$": 1, r"^Loading dataset:`train` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1, r"^Loading dataset:`validation` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1, r"^Applying train/valid loss balancing...$": 1, @@ -271,6 +272,7 @@ def test_predict_pasa_montgomery(session_tmp_path): assert (output_folder / "predictions.json").exists() keywords = { + r"^Writing run metadata at .*$": 1, r"^Loading dataset: * without caching. Trade-off: CPU RAM usage: less | Disk I/O: more$": 3, r"^Loading checkpoint from .*$": 1, r"^Restoring normalizer from checkpoint.$": 1, @@ -317,7 +319,7 @@ def test_evaluate_pasa_montgomery(session_tmp_path): assert (output_folder / "evaluation.pdf").exists() keywords = { - r"^Saving evaluation metadata at .*$": 1, + r"^Writing run metadata at .*$": 1, r"^Setting --threshold=.*$": 1, r"^Computing performance on split .*...$": 3, r"^Saving evaluation results at .*$": 1, @@ -335,6 +337,101 @@ def test_evaluate_pasa_montgomery(session_tmp_path): ) +@pytest.mark.slow +@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") +def test_saliency_generation_pasa_montgomery(session_tmp_path): + from mednet.libs.classification.scripts.saliency.generate import generate + from mednet.libs.common.utils.checkpointer import ( + CHECKPOINT_EXTENSION, + _get_checkpoint_from_alias, + ) + + runner = CliRunner() + + with stdout_logging() as buf: + output_folder = session_tmp_path / "classification-standalone" + last = _get_checkpoint_from_alias(output_folder, "periodic") + assert last.name.endswith("epoch=0" + CHECKPOINT_EXTENSION) + result = runner.invoke( + generate, + [ + "-vv", + "pasa", + "montgomery", + f"--weight={str(last)}", + f"--output-folder={str(output_folder)}", + ], + ) + _assert_exit_0(result) + + assert (output_folder / "saliency-generation.meta.json").exists() + + keywords = { + r"^Writing run metadata at .*$": 1, + r"^Loading dataset:.*$": 3, + r"^Generating saliency maps for dataset .*$": 3, + } + buf.seek(0) + logging_output = buf.read() + + for k, v in keywords.items(): + assert _str_counter(k, logging_output) == v, ( + f"Count for string '{k}' appeared " + f"({_str_counter(k, logging_output)}) " + f"instead of the expected {v}:\nOutput:\n{logging_output}" + ) + + assert len(list(output_folder.rglob("**/*.npy"))) == 138 + + +@pytest.mark.slow +@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") +def test_saliency_view_pasa_montgomery(session_tmp_path): + from mednet.libs.classification.scripts.saliency.view import view + from mednet.libs.common.utils.checkpointer import ( + CHECKPOINT_EXTENSION, + _get_checkpoint_from_alias, + ) + + runner = CliRunner() + + with stdout_logging() as buf: + output_folder = session_tmp_path / "classification-standalone" + last = _get_checkpoint_from_alias(output_folder, "periodic") + assert last.name.endswith("epoch=0" + CHECKPOINT_EXTENSION) + result = runner.invoke( + view, + [ + "-vv", + "pasa", + "montgomery", + f"--input-folder={str(output_folder)}", + f"--output-folder={str(output_folder)}", + ], + ) + _assert_exit_0(result) + + assert (output_folder / "saliency-view.meta.json").exists() + + keywords = { + r"^Writing run metadata at .*$": 1, + r"^Loading dataset:.*$": 3, + r"^Generating visualizations for samples \(target = 1\) at dataset .*$": 3, + } + buf.seek(0) + logging_output = buf.read() + + for k, v in keywords.items(): + assert _str_counter(k, logging_output) == v, ( + f"Count for string '{k}' appeared " + f"({_str_counter(k, logging_output)}) " + f"instead of the expected {v}:\nOutput:\n{logging_output}" + ) + + # there are only 58 samples with target = 1 + assert len(list(output_folder.rglob("**/*.png"))) == 58 + + @pytest.mark.slow @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") def test_train_pasa_montgomery_from_checkpoint(tmp_path): -- GitLab