From 946e86f598605762af174bca0e4d57930e478370 Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Thu, 27 Jun 2024 17:31:04 +0200
Subject: [PATCH] [tests.classification] Streamline testing

---
 .../classification/engine/saliency/viewer.py  |  3 +-
 .../libs/classification/scripts/evaluate.py   |  1 -
 .../scripts/saliency/generate.py              | 25 +++++
 .../classification/scripts/saliency/view.py   | 31 ++++--
 tests/classification/test_cli.py              | 99 ++++++++++++++++++-
 5 files changed, 148 insertions(+), 11 deletions(-)

diff --git a/src/mednet/libs/classification/engine/saliency/viewer.py b/src/mednet/libs/classification/engine/saliency/viewer.py
index 112820a0..26d60613 100644
--- a/src/mednet/libs/classification/engine/saliency/viewer.py
+++ b/src/mednet/libs/classification/engine/saliency/viewer.py
@@ -221,7 +221,8 @@ def run(
 
     for dataset_name, dataset_loader in datamodule.predict_dataloader().items():
         logger.info(
-            f"Generating visualisations for samples at dataset `{dataset_name}`...",
+            f"Generating visualizations for samples (target = {target_label}) "
+            f"at dataset `{dataset_name}`..."
         )
 
         for sample in tqdm(
diff --git a/src/mednet/libs/classification/scripts/evaluate.py b/src/mednet/libs/classification/scripts/evaluate.py
index d0553bad..4e7932e8 100644
--- a/src/mednet/libs/classification/scripts/evaluate.py
+++ b/src/mednet/libs/classification/scripts/evaluate.py
@@ -148,7 +148,6 @@ def evaluate(
     )
     json_data = {k.replace("_", "-"): v for k, v in json_data.items()}
     evaluation_meta = evaluation_file.with_suffix(".meta.json")
-    logger.info(f"Saving evaluation metadata at `{str(evaluation_meta)}`...")
     save_json_with_backup(evaluation_meta, json_data)
 
     if threshold in predict_data:
diff --git a/src/mednet/libs/classification/scripts/saliency/generate.py b/src/mednet/libs/classification/scripts/saliency/generate.py
index 825c7b3e..2e4c7569 100644
--- a/src/mednet/libs/classification/scripts/saliency/generate.py
+++ b/src/mednet/libs/classification/scripts/saliency/generate.py
@@ -171,12 +171,37 @@ def generate(
 
     from mednet.libs.common.engine.device import DeviceManager
     from mednet.libs.common.scripts.predict import setup_datamodule
+    from mednet.libs.common.scripts.utils import (
+        execution_metadata,
+        save_json_with_backup,
+    )
     from mednet.libs.common.utils.checkpointer import (
         get_checkpoint_to_run_inference,
     )
 
     from ...engine.saliency.generator import run
 
+    # register metadata
+    json_data: dict[str, typing.Any] = execution_metadata()
+    json_data.update(
+        dict(
+            database_name=datamodule.database_name,
+            database_split=datamodule.split_name,
+            model_name=model.name,
+            output_folder=str(output_folder),
+            device=device,
+            cache_samples=cache_samples,
+            weight=str(weight),
+            parallel=parallel,
+            saliency_map_algorithm=saliency_map_algorithm,
+            target_class=target_class,
+            positive_only=positive_only,
+        ),
+    )
+    json_data = {k.replace("_", "-"): v for k, v in json_data.items()}
+    saliency_meta = output_folder / "saliency-generation.meta.json"
+    save_json_with_backup(saliency_meta, json_data)
+
     logger.info(f"Output folder: {output_folder}")
     output_folder.mkdir(parents=True, exist_ok=True)
 
diff --git a/src/mednet/libs/classification/scripts/saliency/view.py b/src/mednet/libs/classification/scripts/saliency/view.py
index 3ad74497..aed9a81d 100644
--- a/src/mednet/libs/classification/scripts/saliency/view.py
+++ b/src/mednet/libs/classification/scripts/saliency/view.py
@@ -3,6 +3,7 @@
 # SPDX-License-Identifier: GPL-3.0-or-later
 
 import pathlib
+import typing
 
 import click
 from clapper.click import ConfigCommand, ResourceOption, verbosity_option
@@ -97,19 +98,33 @@ def view(
     **_,
 ) -> None:  # numpydoc ignore=PR01
     """Generate heatmaps for input CXRs based on existing saliency maps."""
-    from ...engine.saliency.viewer import run
-
-    assert (
-        input_folder != output_folder
-    ), "Output folder must not be the same as the input folder."
+    from mednet.libs.common.scripts.utils import (
+        execution_metadata,
+        save_json_with_backup,
+    )
 
-    assert not str(output_folder).startswith(
-        str(input_folder),
-    ), "Output folder must not be a subdirectory of the input folder."
+    from ...engine.saliency.viewer import run
 
     logger.info(f"Output folder: {output_folder}")
     output_folder.mkdir(parents=True, exist_ok=True)
 
+    # register metadata
+    json_data: dict[str, typing.Any] = execution_metadata()
+    json_data.update(
+        dict(
+            database_name=datamodule.database_name,
+            database_split=datamodule.split_name,
+            model_name=model.name,
+            input_folder=str(input_folder),
+            output_folder=str(output_folder),
+            show_groundtruth=show_groundtruth,
+            threshold=threshold,
+        ),
+    )
+    json_data = {k.replace("_", "-"): v for k, v in json_data.items()}
+    saliency_meta = output_folder / "saliency-view.meta.json"
+    save_json_with_backup(saliency_meta, json_data)
+
     datamodule.drop_incomplete_batch = False
     # datamodule.cache_samples = cache_samples
     # datamodule.parallel = parallel
diff --git a/tests/classification/test_cli.py b/tests/classification/test_cli.py
index 8f408319..54d09503 100644
--- a/tests/classification/test_cli.py
+++ b/tests/classification/test_cli.py
@@ -219,6 +219,7 @@ def test_train_pasa_montgomery(session_tmp_path):
         assert (output_folder / "meta.json").exists()
 
         keywords = {
+            r"^Writing run metadata at .*$": 1,
             r"^Loading dataset:`train` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1,
             r"^Loading dataset:`validation` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1,
             r"^Applying train/valid loss balancing...$": 1,
@@ -271,6 +272,7 @@ def test_predict_pasa_montgomery(session_tmp_path):
         assert (output_folder / "predictions.json").exists()
 
         keywords = {
+            r"^Writing run metadata at .*$": 1,
             r"^Loading dataset: * without caching. Trade-off: CPU RAM usage: less | Disk I/O: more$": 3,
             r"^Loading checkpoint from .*$": 1,
             r"^Restoring normalizer from checkpoint.$": 1,
@@ -317,7 +319,7 @@ def test_evaluate_pasa_montgomery(session_tmp_path):
         assert (output_folder / "evaluation.pdf").exists()
 
         keywords = {
-            r"^Saving evaluation metadata at .*$": 1,
+            r"^Writing run metadata at .*$": 1,
             r"^Setting --threshold=.*$": 1,
             r"^Computing performance on split .*...$": 3,
             r"^Saving evaluation results at .*$": 1,
@@ -335,6 +337,101 @@ def test_evaluate_pasa_montgomery(session_tmp_path):
             )
 
 
+@pytest.mark.slow
+@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
+def test_saliency_generation_pasa_montgomery(session_tmp_path):
+    from mednet.libs.classification.scripts.saliency.generate import generate
+    from mednet.libs.common.utils.checkpointer import (
+        CHECKPOINT_EXTENSION,
+        _get_checkpoint_from_alias,
+    )
+
+    runner = CliRunner()
+
+    with stdout_logging() as buf:
+        output_folder = session_tmp_path / "classification-standalone"
+        last = _get_checkpoint_from_alias(output_folder, "periodic")
+        assert last.name.endswith("epoch=0" + CHECKPOINT_EXTENSION)
+        result = runner.invoke(
+            generate,
+            [
+                "-vv",
+                "pasa",
+                "montgomery",
+                f"--weight={str(last)}",
+                f"--output-folder={str(output_folder)}",
+            ],
+        )
+        _assert_exit_0(result)
+
+        assert (output_folder / "saliency-generation.meta.json").exists()
+
+        keywords = {
+            r"^Writing run metadata at .*$": 1,
+            r"^Loading dataset:.*$": 3,
+            r"^Generating saliency maps for dataset .*$": 3,
+        }
+        buf.seek(0)
+        logging_output = buf.read()
+
+        for k, v in keywords.items():
+            assert _str_counter(k, logging_output) == v, (
+                f"Count for string '{k}' appeared "
+                f"({_str_counter(k, logging_output)}) "
+                f"instead of the expected {v}:\nOutput:\n{logging_output}"
+            )
+
+        assert len(list(output_folder.rglob("**/*.npy"))) == 138
+
+
+@pytest.mark.slow
+@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
+def test_saliency_view_pasa_montgomery(session_tmp_path):
+    from mednet.libs.classification.scripts.saliency.view import view
+    from mednet.libs.common.utils.checkpointer import (
+        CHECKPOINT_EXTENSION,
+        _get_checkpoint_from_alias,
+    )
+
+    runner = CliRunner()
+
+    with stdout_logging() as buf:
+        output_folder = session_tmp_path / "classification-standalone"
+        last = _get_checkpoint_from_alias(output_folder, "periodic")
+        assert last.name.endswith("epoch=0" + CHECKPOINT_EXTENSION)
+        result = runner.invoke(
+            view,
+            [
+                "-vv",
+                "pasa",
+                "montgomery",
+                f"--input-folder={str(output_folder)}",
+                f"--output-folder={str(output_folder)}",
+            ],
+        )
+        _assert_exit_0(result)
+
+        assert (output_folder / "saliency-view.meta.json").exists()
+
+        keywords = {
+            r"^Writing run metadata at .*$": 1,
+            r"^Loading dataset:.*$": 3,
+            r"^Generating visualizations for samples \(target = 1\) at dataset .*$": 3,
+        }
+        buf.seek(0)
+        logging_output = buf.read()
+
+        for k, v in keywords.items():
+            assert _str_counter(k, logging_output) == v, (
+                f"Count for string '{k}' appeared "
+                f"({_str_counter(k, logging_output)}) "
+                f"instead of the expected {v}:\nOutput:\n{logging_output}"
+            )
+
+        # there are only 58 samples with target = 1
+        assert len(list(output_folder.rglob("**/*.png"))) == 58
+
+
 @pytest.mark.slow
 @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
 def test_train_pasa_montgomery_from_checkpoint(tmp_path):
-- 
GitLab