diff --git a/src/ptbench/engine/saliency/completeness.py b/src/ptbench/engine/saliency/completeness.py index a75d57b2953c6cb283f96116e4e7d9e563fdce08..f26e577f355216e18513258860aa5691caff0773 100644 --- a/src/ptbench/engine/saliency/completeness.py +++ b/src/ptbench/engine/saliency/completeness.py @@ -2,7 +2,9 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +import functools import logging +import multiprocessing import typing import lightning.pytorch @@ -16,6 +18,7 @@ from pytorch_grad_cam.metrics.road import ( ) from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget +from ...data.typing import Sample from ...models.typing import SaliencyMapAlgorithm from ..device import DeviceManager @@ -110,6 +113,66 @@ def _calculate_road_scores( ) +def _process_sample( + sample: Sample, + model: lightning.pytorch.LightningModule, + device: torch.device, + saliency_map_callable: typing.Callable, + target_class: typing.Literal["highest", "all"], + positive_only: bool, + percentiles: typing.Sequence[int], +) -> list: + """Helper function to :py:func:`run` to be used in multiprocessing + contexts.""" + + name: str = sample[1]["name"][0] + label: int = int(sample[1]["label"].item()) + image = sample[0].to(device=device, non_blocking=torch.cuda.is_available()) + + # in binary classification systems, negative labels may be skipped + if positive_only and (model.num_classes == 1) and (label == 0): + return [name, label] + + # chooses target outputs to generate saliency maps for + if model.num_classes > 1: # type: ignore + if target_class == "all": + # test all outputs + for output_num in range(model.num_classes): # type: ignore + results = _calculate_road_scores( + model, + image, + output_num, + saliency_map_callable, + percentiles, + ) + return [name, label, output_num, *results] + + else: + # we will figure out the output with the highest value and + # evaluate the saliency mapping technique over it. + outputs = saliency_map_callable.activations_and_grads(image) # type: ignore + output_nums = np.argmax(outputs.cpu().data.numpy(), axis=-1) + assert len(output_nums) == 1 + results = _calculate_road_scores( + model, + image, + output_nums[0], + saliency_map_callable, + percentiles, + ) + return [name, label, output_nums[0], *results] + + # default route for binary classification + results = _calculate_road_scores( + model, + image, + 0, + saliency_map_callable, + percentiles, + ) + return [name, label, 0, *results] + + def run( model: lightning.pytorch.LightningModule, datamodule: lightning.pytorch.LightningDataModule, @@ -118,6 +181,7 @@ def run( target_class: typing.Literal["highest", "all"], positive_only: bool, percentiles: typing.Sequence[int], + parallel: int, ) -> dict[str, list[typing.Any]]: """Evaluates ROAD scores for all samples in a datamodule. @@ -162,6 +226,11 @@ def run( A sequence of percentiles (percent x100) integer values indicating the proportion of pixels to perturb in the original image to calculate both MoRF and LeRF scores. + parallel + Use multiprocessing for data processing: if set to -1, disables + multiprocessing. Set to 0 to enable as many data processing instances + as processing cores as available in the system. Set to >= 1 to enable + that many multiprocessing instances for data processing. Returns @@ -201,6 +270,17 @@ def run( raise TypeError(f"Model of type `{type(model)}` is not yet supported.") use_cuda = device_manager.device_type == "cuda" + if device_manager.device_type in ("cuda", "mps") and ( + parallel == 0 or parallel > 1 + ): + raise RuntimeError( + f"The number of multiprocessing instances is set to {parallel} and " + f"you asked to use a GPU (device = `{device_manager.device_type}`" + f"). The currently implementation can only handle a single GPU. " + f"Either disable GPU utilisation or set the number of " + f"multiprocessing instances to one, or disable multiprocessing " + "entirely (ie. set it to -1)." + ) # prepares model for evaluation, cast to target device device = device_manager.torch_device() @@ -216,59 +296,37 @@ def run( retval: dict[str, list[typing.Any]] = {} + # our worker function + _process = functools.partial( + _process_sample, + model=model, + device=device, + saliency_map_callable=saliency_map_callable, + target_class=target_class, + positive_only=positive_only, + percentiles=percentiles, + ) + for k, v in datamodule.predict_dataloader().items(): - logger.info(f"Computing ROAD scores for dataset `{k}`...") retval[k] = [] - for sample in tqdm.tqdm(v, desc="samples", leave=False, disable=None): - name = sample[1]["name"][0] - label = int(sample[1]["label"].item()) - image = sample[0].to( - device=device, non_blocking=torch.cuda.is_available() + if parallel < 0: + logger.info( + f"Computing ROAD scores for dataset `{k}` in the current " + f"process context..." ) - - # in binary classification systems, negative labels may be skipped - if positive_only and (model.num_classes == 1) and (label == 0): - retval[k].append([name, label]) - continue - - # chooses target outputs to generate saliency maps for - if model.num_classes > 1: - if target_class == "all": - # test all outputs - for output_num in range(model.num_classes): - results = _calculate_road_scores( - model, - image, - output_num, - saliency_map_callable, - percentiles, - ) - retval[k].append([name, label, output_num, *results]) - - else: - # we will figure out the output with the highest value and - # evaluate the saliency mapping technique over it. - outputs = saliency_map_callable.activations_and_grads(image) - output_nums = np.argmax(outputs.cpu().data.numpy(), axis=-1) - assert len(output_nums) == 1 - results = _calculate_road_scores( - model, - image, - output_nums[0], - saliency_map_callable, - percentiles, - ) - retval[k].append([name, label, output_nums[0], *results]) - - else: - results = _calculate_road_scores( - model, - image, - 0, - saliency_map_callable, - percentiles, - ) - retval[k].append([name, label, 0, *results]) + for sample in tqdm.tqdm( + v, desc="samples", leave=False, disable=None + ): + retval[k].append(_process(sample)) + + else: + instances = parallel or multiprocessing.cpu_count() + logger.info( + f"Computing ROAD scores for dataset `{k}` using {instances} " + f"processes..." + ) + with multiprocessing.Pool(instances) as p: + retval[k] = list(tqdm.tqdm(p.imap(_process, v), total=len(v))) return retval diff --git a/src/ptbench/scripts/saliency_completeness.py b/src/ptbench/scripts/saliency_completeness.py index 49a3d3a91bb9cf6dec635e75d906470676627d0c..f17ff648559ff4340fe2aa3179692ac2d6360fd7 100644 --- a/src/ptbench/scripts/saliency_completeness.py +++ b/src/ptbench/scripts/saliency_completeness.py @@ -93,10 +93,12 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @click.option( "--parallel", "-P", - help="""Use multiprocessing for data loading: if set to -1 (default), - disables multiprocessing data loading. Set to 0 to enable as many data - loading instances as processing cores as available in the system. Set to - >= 1 to enable that many multiprocessing instances for data loading.""", + help="""Use multiprocessing for data loading processing: if set to -1 + (default), disables multiprocessing. Set to 0 to enable as many data + processing instances as processing cores available in the system. Set to + >= 1 to enable that many multiprocessing instances. Note that if you + activate this option, then you must use --device=cpu, as using a GPU + concurrently is not supported.""", type=click.IntRange(min=-1), show_default=True, required=True, @@ -203,6 +205,15 @@ def saliency_completeness( logger.info(f"Output folder: {output_folder}") output_folder.mkdir(parents=True, exist_ok=True) + if device in ("cuda", "mps") and (parallel == 0 or parallel > 1): + raise RuntimeError( + f"The number of multiprocessing instances is set to {parallel} and " + f"you asked to use a GPU (device = `{device}`). The currently " + f"implementation can only handle a single GPU. Either disable GPU " + f"utilisation or set the number of multiprocessing instances to " + f"one, or disable multiprocessing entirely (ie. set it to -1)." + ) + device_manager = DeviceManager(device) # batch_size must be == 1 for now (underlying code is NOT prepared to @@ -232,6 +243,7 @@ def saliency_completeness( target_class=target_class, positive_only=positive_only, percentiles=percentile, + parallel=parallel, ) output_json = output_folder / (algo + ".json")