diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py index b954a0009cf3fbaac39620f1fa7d7c75275aaa7d..82353be3fb2080b9e761dba9d2b52899316b9dd9 100644 --- a/src/ptbench/data/datamodule.py +++ b/src/ptbench/data/datamodule.py @@ -592,7 +592,8 @@ class ConcatDataModule(lightning.LightningDataModule): ] = multiprocessing.get_context("spawn") # keep workers hanging around if we have multiple - self._dataloader_multiproc["persistent_workers"] = True + if value >= 0: + self._dataloader_multiproc["persistent_workers"] = True @property def model_transforms(self) -> list[Transform] | None: diff --git a/src/ptbench/engine/saliency/generator.py b/src/ptbench/engine/saliency/generator.py index 5d87167aa6234b7f72387923c95d826d100a126b..c8fa1d4614541e351af67ab341cf71bc63731b5b 100644 --- a/src/ptbench/engine/saliency/generator.py +++ b/src/ptbench/engine/saliency/generator.py @@ -101,7 +101,7 @@ def run( model: lightning.pytorch.LightningModule, datamodule: lightning.pytorch.LightningDataModule, device_manager: DeviceManager, - saliency_map_algorithms: typing.Sequence[SaliencyMapAlgorithm], + saliency_map_algorithm: SaliencyMapAlgorithm, target_class: typing.Literal["highest", "all"], positive_only: bool, output_folder: pathlib.Path, @@ -119,8 +119,8 @@ def run( An internal device representation, to be used for training and validation. This representation can be converted into a pytorch device or a torch lightning accelerator setup. - saliency_map_algorithms - The algorithms for saliency map estimation to use. + saliency_map_algorithm + The algorithm to use for saliency map estimation. target_class (Use only with multi-label models) Which class to target for CAM calculation. Can be either set to "all" or "highest". "highest" is @@ -140,7 +140,7 @@ def run( from ...models.pasa import Pasa if isinstance(model, Pasa): - if "fullgrad" in saliency_map_algorithms: + if saliency_map_algorithm == "fullgrad": raise ValueError( "Fullgrad saliency map algorithm is not supported for the " "Pasa model." @@ -160,71 +160,63 @@ def run( model = model.to(device) model.eval() - for algo_type in saliency_map_algorithms: - saliency_map_callable = _create_saliency_map_callable( - algo_type, - model, - target_layers, # type: ignore - use_cuda, + saliency_map_callable = _create_saliency_map_callable( + saliency_map_algorithm, + model, + target_layers, # type: ignore + use_cuda, + ) + + for k, v in datamodule.predict_dataloader().items(): + logger.info( + f"Generating saliency maps for dataset `{k}` via `{saliency_map_algorithm}`..." ) - for k, v in datamodule.predict_dataloader().items(): - logger.info( - f"Generating saliency maps for dataset `{k}` via `{algo_type}`..." + for sample in tqdm.tqdm(v, desc="samples", leave=False, disable=None): + name = sample[1]["name"][0] + label = sample[1]["label"].item() + image = sample[0].to( + device=device, non_blocking=torch.cuda.is_available() ) - for sample in tqdm.tqdm( - v, desc="samples", leave=False, disable=None - ): - name = sample[1]["name"][0] - label = sample[1]["label"].item() - image = sample[0].to( - device=device, non_blocking=torch.cuda.is_available() - ) - - # in binary classification systems, negative labels may be skipped - if positive_only and (model.num_classes == 1) and (label == 0): - continue - - # chooses target outputs to generate saliency maps for - if model.num_classes > 1: - if target_class == "all": - # just blindly generate saliency maps for all outputs - # - make one directory for every target output and lay - # images there like in the original dataset. - for output_num in range(model.num_classes): - use_folder = ( - output_folder / algo_type / str(output_num) - ) - saliency_map = saliency_map_callable( - input_tensor=image, - targets=[ClassifierOutputTarget(output_num)], # type: ignore - ) - _save_saliency_map(use_folder, name, saliency_map) # type: ignore - - else: - # pytorch-grad-cam will figure out the output with the - # highest value and produce a saliency map for it - we - # will save it to disk. - use_folder = ( - output_folder / algo_type / "highest-output" - ) + # in binary classification systems, negative labels may be skipped + if positive_only and (model.num_classes == 1) and (label == 0): + continue + + # chooses target outputs to generate saliency maps for + if model.num_classes > 1: + if target_class == "all": + # just blindly generate saliency maps for all outputs + # - make one directory for every target output and lay + # images there like in the original dataset. + for output_num in range(model.num_classes): + use_folder = output_folder / str(output_num) saliency_map = saliency_map_callable( input_tensor=image, - # setting `targets=None` will set target to the - # maximum output index using - # ClassifierOutputTarget(max_output_index) - targets=None, # type: ignore + targets=[ClassifierOutputTarget(output_num)], # type: ignore ) _save_saliency_map(use_folder, name, saliency_map) # type: ignore + else: - # binary classification model with a single output - just - # lay all cams uniformily like the original dataset - use_folder = output_folder / algo_type + # pytorch-grad-cam will figure out the output with the + # highest value and produce a saliency map for it - we + # will save it to disk. + use_folder = output_folder / "highest-output" saliency_map = saliency_map_callable( input_tensor=image, - targets=[ - ClassifierOutputTarget(0), # type: ignore - ], + # setting `targets=None` will set target to the + # maximum output index using + # ClassifierOutputTarget(max_output_index) + targets=None, # type: ignore ) _save_saliency_map(use_folder, name, saliency_map) # type: ignore + else: + # binary classification model with a single output - just + # lay all cams uniformily like the original dataset + saliency_map = saliency_map_callable( + input_tensor=image, + targets=[ + ClassifierOutputTarget(0), # type: ignore + ], + ) + _save_saliency_map(output_folder, name, saliency_map) # type: ignore diff --git a/src/ptbench/engine/saliency/interpretability.py b/src/ptbench/engine/saliency/interpretability.py index 62c7512be56a0b1a56f01195ddf5bc8c8e81aeea..25ebd01fad4ed4d123d87159b921eaaa9ed15e5a 100644 --- a/src/ptbench/engine/saliency/interpretability.py +++ b/src/ptbench/engine/saliency/interpretability.py @@ -269,7 +269,7 @@ def _compute_proportional_energy( def _process_sample( gt_bboxes: BoundingBoxes, saliency_map: numpy.typing.NDArray[numpy.double], -) -> tuple[float, float, float, float, BoundingBox]: +) -> tuple[float, float, float, float, tuple[int, int, int, int]]: """Calculates the metrics for a single sample. Parameters @@ -310,7 +310,12 @@ def _process_sample( ioda, _compute_proportional_energy(saliency_map, binary_mask), _compute_avg_saliency_focus(saliency_map, binary_mask), - detected_box, + ( + detected_box.xmin, + detected_box.ymin, + detected_box.width, + detected_box.height, + ), ) @@ -396,7 +401,7 @@ def run( name, label, *_process_sample( - bboxes, + bboxes[0], numpy.load( input_folder / pathlib.Path(name).with_suffix(".npy") diff --git a/src/ptbench/scripts/experiment.py b/src/ptbench/scripts/experiment.py index 1d520fe3e963e7bd438ed856b4b42ae0abd7079a..f75569513efaada033140c1f78c58cf79a19a72a 100644 --- a/src/ptbench/scripts/experiment.py +++ b/src/ptbench/scripts/experiment.py @@ -111,11 +111,6 @@ def experiment( logger.info("Started predicting") - from ..utils.checkpointer import get_checkpoint_to_run_inference - - model_file = get_checkpoint_to_run_inference(train_output_folder) - logger.info(f"Found `{str(model_file)}`. Continuing...") - from .predict import predict predictions_output = output_folder / "predictions.json" @@ -126,7 +121,7 @@ def experiment( model=model, datamodule=datamodule, device=device, - weight=model_file, + weight=train_output_folder, batch_size=batch_size, parallel=parallel, ) diff --git a/src/ptbench/scripts/predict.py b/src/ptbench/scripts/predict.py index 91b66d725c83ed9802d7ff9f03c2e796d670e546..6b77f3d3d9d580df59348a45ca9841b19b9ba0a6 100644 --- a/src/ptbench/scripts/predict.py +++ b/src/ptbench/scripts/predict.py @@ -23,13 +23,13 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") .. code:: sh - ptbench predict -vv pasa montgomery --weight=path/to/model.ckpt --output=path/to/predictions.json + ptbench predict -vv pasa montgomery --weight=path/to/model-at-lowest-validation-loss.ckpt --output=path/to/predictions.json 2. Enables multi-processing data loading with 6 processes: .. code:: sh - ptbench predict -vv pasa montgomery --parallel=6 --weight=path/to/model.ckpt --output=path/to/predictions.json + ptbench predict -vv pasa montgomery --parallel=6 --weight=path/to/model-at-lowest-validation-loss.ckpt --output=path/to/predictions.json """, ) @@ -88,10 +88,18 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") "--weight", "-w", help="""Path or URL to pretrained model file (`.ckpt` extension), - corresponding to the architecture set with `--model`.""", + corresponding to the architecture set with `--model`. Optionally, you may + also pass a directory containing the result of a training session, in which + case either the best (lowest validation) or latest model will be loaded.""", required=True, cls=ResourceOption, - type=click.Path(exists=True, file_okay=True, dir_okay=False, readable=True), + type=click.Path( + exists=True, + file_okay=True, + dir_okay=True, + readable=True, + path_type=pathlib.Path, + ), ) @click.option( "--parallel", @@ -125,6 +133,7 @@ def predict( from ..engine.device import DeviceManager from ..engine.predictor import run + from ..utils.checkpointer import get_checkpoint_to_run_inference datamodule.set_chunk_size(batch_size, 1) datamodule.parallel = parallel @@ -133,6 +142,9 @@ def predict( datamodule.prepare_data() datamodule.setup(stage="predict") + if weight.is_dir(): + weight = get_checkpoint_to_run_inference(weight) + logger.info(f"Loading checkpoint from `{weight}`...") model = type(model).load_from_checkpoint(weight, strict=False) diff --git a/src/ptbench/scripts/saliency/completeness.py b/src/ptbench/scripts/saliency/completeness.py index 29c7599ea5660bd1bb4018d3b9aa4c08298d7a28..eb305b6847aed6618f88ebdaba8b032b8e6f133c 100644 --- a/src/ptbench/scripts/saliency/completeness.py +++ b/src/ptbench/scripts/saliency/completeness.py @@ -25,7 +25,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") .. code:: sh - ptbench saliency completeness -vv pasa tbx11k-v1-healthy-vs-atb --device="cuda" --weight=path/to/model-at-lowest-validation-loss.ckpt --output-folder=path/to/completeness-scores/ + ptbench saliency completeness -vv pasa tbx11k-v1-healthy-vs-atb --device="cuda" --weight=path/to/model-at-lowest-validation-loss.ckpt --output-json=path/to/completeness-scores.json """, ) @@ -49,18 +49,17 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") cls=ResourceOption, ) @click.option( - "--output-folder", + "--output-json", "-o", - help="Path where to store saliency maps (created if does not exist)", + help="""Path where to store the output JSON file containing all + measures.""", required=True, type=click.Path( - exists=False, - file_okay=False, - dir_okay=True, - writable=True, + file_okay=True, + dir_okay=False, path_type=pathlib.Path, ), - default="saliency-maps", + default="saliency-interpretability.json", cls=ResourceOption, ) @click.option( @@ -85,10 +84,18 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") "--weight", "-w", help="""Path or URL to pretrained model file (`.ckpt` extension), - corresponding to the architecture set with `--model`.""", + corresponding to the architecture set with `--model`. Optionally, you may + also pass a directory containing the result of a training session, in which + case either the best (lowest validation) or latest model will be loaded.""", required=True, cls=ResourceOption, - type=click.Path(exists=True, file_okay=True, dir_okay=False, readable=True), + type=click.Path( + exists=True, + file_okay=True, + dir_okay=True, + readable=True, + path_type=pathlib.Path, + ), ) @click.option( "--parallel", @@ -108,13 +115,11 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @click.option( "--saliency-map-algorithm", "-s", - help="""Saliency map algorithm(s) to be used. Can be called multiple times - with different techniques.""", + help="""Saliency map algorithm to be used.""", type=click.Choice( typing.get_args(SaliencyMapAlgorithm), case_sensitive=False ), - multiple=True, - default=["gradcam"], + default="gradcam", show_default=True, cls=ResourceOption, ) @@ -157,7 +162,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") def completeness( model, datamodule, - output_folder, + output_json, device, cache_samples, weight, @@ -171,7 +176,7 @@ def completeness( """Evaluates saliency map algorithm completeness using RemOve And Debias (ROAD). - For each selected saliency map algorithm, evaluates the completeness of + For the selected saliency map algorithm, evaluates the completeness of explanations using the RemOve And Debias (ROAD) algorithm. The ROAD algorithm was first described at [ROAD-2022]_. It estimates explainability (in the completeness sense) of saliency mapping algorithms by substituting @@ -185,12 +190,9 @@ def completeness( This program outputs a JSON file containing the ROAD evaluations (using most-relevant-first, or MoRF, and least-relevant-first, or LeRF for each - sample in the datamodule, and per saliency-mapping algorithm. Each - saliency-mapping algorithm yields a single JSON file with the target - algorithm name on the ``output-folder``. Values for MoRF and LeRF represent - averages by removing 20, 40, 60 and 80% of most or least relevant pixels - respectively from the image, and averaging results for all these - percentiles. + sample in the datamodule. Values for MoRF and LeRF represent averages by + removing 20, 40, 60 and 80% of most or least relevant pixels respectively + from the image, and averaging results for all these percentiles. .. note:: @@ -201,9 +203,7 @@ def completeness( from ...engine.device import DeviceManager from ...engine.saliency.completeness import run - - logger.info(f"Output folder: {output_folder}") - output_folder.mkdir(parents=True, exist_ok=True) + from ...utils.checkpointer import get_checkpoint_to_run_inference if device in ("cuda", "mps") and (parallel == 0 or parallel > 1): raise RuntimeError( @@ -226,27 +226,29 @@ def completeness( datamodule.prepare_data() datamodule.setup(stage="predict") - logger.info(f"Loading checkpoint from `{weight}`...") - model = model.load_from_checkpoint(weight, strict=False) - - for algo in saliency_map_algorithm: - logger.info( - f"Evaluating RemOve And Debias (ROAD) average scores for " - f"algorithm `{algo}` with percentiles " - f"`{', '.join([str(k) for k in percentile])}`..." - ) - results = run( - model=model, - datamodule=datamodule, - device_manager=device_manager, - saliency_map_algorithm=algo, - target_class=target_class, - positive_only=positive_only, - percentiles=percentile, - parallel=parallel, - ) + if weight.is_dir(): + weight = get_checkpoint_to_run_inference(weight) - output_json = output_folder / (algo + ".json") - with output_json.open("w") as f: - logger.info(f"Saving output file to `{str(output_json)}`...") - json.dump(results, f, indent=2) + logger.info(f"Loading checkpoint from `{weight}`...") + model = type(model).load_from_checkpoint(weight, strict=False) + + logger.info( + f"Evaluating RemOve And Debias (ROAD) average scores for " + f"algorithm `{saliency_map_algorithm}` with percentiles " + f"`{', '.join([str(k) for k in percentile])}`..." + ) + results = run( + model=model, + datamodule=datamodule, + device_manager=device_manager, + saliency_map_algorithm=saliency_map_algorithm, + target_class=target_class, + positive_only=positive_only, + percentiles=percentile, + parallel=parallel, + ) + + output_json.parent.mkdir(parents=True, exist_ok=True) + with output_json.open("w") as f: + logger.info(f"Saving output file to `{str(output_json)}`...") + json.dump(results, f, indent=2) diff --git a/src/ptbench/scripts/saliency/generate.py b/src/ptbench/scripts/saliency/generate.py index 6ccb9f5e4aeb15f25df761dbe744d5222ff74fb3..b0327728c53aec36cfa08770a2bded3779e8b8ce 100644 --- a/src/ptbench/scripts/saliency/generate.py +++ b/src/ptbench/scripts/saliency/generate.py @@ -87,10 +87,18 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") "--weight", "-w", help="""Path or URL to pretrained model file (`.ckpt` extension), - corresponding to the architecture set with `--model`.""", + corresponding to the architecture set with `--model`. Optionally, you may + also pass a directory containing the result of a training session, in which + case either the best (lowest validation) or latest model will be loaded.""", required=True, cls=ResourceOption, - type=click.Path(exists=True, file_okay=True, dir_okay=False, readable=True), + type=click.Path( + exists=True, + file_okay=True, + dir_okay=True, + readable=True, + path_type=pathlib.Path, + ), ) @click.option( "--parallel", @@ -108,13 +116,11 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @click.option( "--saliency-map-algorithm", "-s", - help="""Saliency map algorithm(s) to be used. Can be called multiple times - with different techniques.""", + help="""Saliency map algorithm to be used.""", type=click.Choice( typing.get_args(SaliencyMapAlgorithm), case_sensitive=False ), - multiple=True, - default=["gradcam"], + default="gradcam", show_default=True, cls=ResourceOption, ) @@ -165,6 +171,7 @@ def generate( from ...engine.device import DeviceManager from ...engine.saliency.generator import run + from ...utils.checkpointer import get_checkpoint_to_run_inference logger.info(f"Output folder: {output_folder}") output_folder.mkdir(parents=True, exist_ok=True) @@ -181,14 +188,17 @@ def generate( datamodule.prepare_data() datamodule.setup(stage="predict") + if weight.is_dir(): + weight = get_checkpoint_to_run_inference(weight) + logger.info(f"Loading checkpoint from `{weight}`...") - model = model.load_from_checkpoint(weight, strict=False) + model = type(model).load_from_checkpoint(weight, strict=False) run( model=model, datamodule=datamodule, device_manager=device_manager, - saliency_map_algorithms=saliency_map_algorithm, + saliency_map_algorithm=saliency_map_algorithm, target_class=target_class, positive_only=positive_only, output_folder=output_folder, diff --git a/src/ptbench/scripts/saliency/interpretability.py b/src/ptbench/scripts/saliency/interpretability.py index 155f81db4559fcba6bc420a7f57d21efe61f37d6..978bda7a841d0b25f6d552dd004aa367219f1512 100644 --- a/src/ptbench/scripts/saliency/interpretability.py +++ b/src/ptbench/scripts/saliency/interpretability.py @@ -23,7 +23,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") .. code:: sh - ptbench saliency interpretability -vv tbx11k-v1-healthy-vs-atb --input-folder=parent_folder/gradcam/ --output-json=parent_folder/gradcam/tbx11k-v1-interp.json + ptbench saliency interpretability -vv tbx11k-v1-healthy-vs-atb --input-folder=parent-folder/saliencies/ --output-json=path/to/interpretability-scores.json """, ) diff --git a/src/ptbench/utils/checkpointer.py b/src/ptbench/utils/checkpointer.py index 988e611e4b48420ef02ba327533b5c1829452216..543f7c7008e2aef6403babdc55cd4f391d36770b 100644 --- a/src/ptbench/utils/checkpointer.py +++ b/src/ptbench/utils/checkpointer.py @@ -139,7 +139,7 @@ def get_checkpoint_to_run_inference( """ try: - _get_checkpoint_from_alias(path, "best") + return _get_checkpoint_from_alias(path, "best") except FileNotFoundError: logger.error( "Did not find lowest-validation-loss model to run inference " diff --git a/tests/test_cam_utils.py b/tests/test_cam_utils.py index 2d3273d55c08aa9e5d7ba5da6293abae959553d6..54f81a9ef440ee27e16f4f0ab2ecdf2272f10b46 100644 --- a/tests/test_cam_utils.py +++ b/tests/test_cam_utils.py @@ -3,19 +3,17 @@ # SPDX-License-Identifier: GPL-3.0-or-later """Tests for the cam_utils script.""" -import cv2 import numpy as np -import pandas as pd import pytest -from ptbench.utils.cam_utils import ( - _calculate_stats_over_dataset, - calculate_metrics_avg_for_every_class, - draw_boxes_on_image, - draw_largest_component_bbox_on_image, - show_cam_on_image, - visualize_road_scores, -) +# from ptbench.utils.cam_utils import ( +# _calculate_stats_over_dataset, +# calculate_metrics_avg_for_every_class, +# draw_boxes_on_image, +# draw_largest_component_bbox_on_image, +# show_cam_on_image, +# visualize_road_scores, +# ) def test_calculate_stats_over_dataset(datadir):