Skip to content
Snippets Groups Projects
Commit 09466bb7 authored by André Anjos's avatar André Anjos :speech_balloon: Committed by Daniel CARRON
Browse files

[engine.saliency.completeness] Implement multiprocessing

parent fee2f98f
No related branches found
No related tags found
1 merge request!12Adds grad-cam support on classifiers
...@@ -2,7 +2,9 @@ ...@@ -2,7 +2,9 @@
# #
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
import functools
import logging import logging
import multiprocessing
import typing import typing
import lightning.pytorch import lightning.pytorch
...@@ -16,6 +18,7 @@ from pytorch_grad_cam.metrics.road import ( ...@@ -16,6 +18,7 @@ from pytorch_grad_cam.metrics.road import (
) )
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from ...data.typing import Sample
from ...models.typing import SaliencyMapAlgorithm from ...models.typing import SaliencyMapAlgorithm
from ..device import DeviceManager from ..device import DeviceManager
...@@ -110,6 +113,66 @@ def _calculate_road_scores( ...@@ -110,6 +113,66 @@ def _calculate_road_scores(
) )
def _process_sample(
sample: Sample,
model: lightning.pytorch.LightningModule,
device: torch.device,
saliency_map_callable: typing.Callable,
target_class: typing.Literal["highest", "all"],
positive_only: bool,
percentiles: typing.Sequence[int],
) -> list:
"""Helper function to :py:func:`run` to be used in multiprocessing
contexts."""
name: str = sample[1]["name"][0]
label: int = int(sample[1]["label"].item())
image = sample[0].to(device=device, non_blocking=torch.cuda.is_available())
# in binary classification systems, negative labels may be skipped
if positive_only and (model.num_classes == 1) and (label == 0):
return [name, label]
# chooses target outputs to generate saliency maps for
if model.num_classes > 1: # type: ignore
if target_class == "all":
# test all outputs
for output_num in range(model.num_classes): # type: ignore
results = _calculate_road_scores(
model,
image,
output_num,
saliency_map_callable,
percentiles,
)
return [name, label, output_num, *results]
else:
# we will figure out the output with the highest value and
# evaluate the saliency mapping technique over it.
outputs = saliency_map_callable.activations_and_grads(image) # type: ignore
output_nums = np.argmax(outputs.cpu().data.numpy(), axis=-1)
assert len(output_nums) == 1
results = _calculate_road_scores(
model,
image,
output_nums[0],
saliency_map_callable,
percentiles,
)
return [name, label, output_nums[0], *results]
# default route for binary classification
results = _calculate_road_scores(
model,
image,
0,
saliency_map_callable,
percentiles,
)
return [name, label, 0, *results]
def run( def run(
model: lightning.pytorch.LightningModule, model: lightning.pytorch.LightningModule,
datamodule: lightning.pytorch.LightningDataModule, datamodule: lightning.pytorch.LightningDataModule,
...@@ -118,6 +181,7 @@ def run( ...@@ -118,6 +181,7 @@ def run(
target_class: typing.Literal["highest", "all"], target_class: typing.Literal["highest", "all"],
positive_only: bool, positive_only: bool,
percentiles: typing.Sequence[int], percentiles: typing.Sequence[int],
parallel: int,
) -> dict[str, list[typing.Any]]: ) -> dict[str, list[typing.Any]]:
"""Evaluates ROAD scores for all samples in a datamodule. """Evaluates ROAD scores for all samples in a datamodule.
...@@ -162,6 +226,11 @@ def run( ...@@ -162,6 +226,11 @@ def run(
A sequence of percentiles (percent x100) integer values indicating the A sequence of percentiles (percent x100) integer values indicating the
proportion of pixels to perturb in the original image to calculate both proportion of pixels to perturb in the original image to calculate both
MoRF and LeRF scores. MoRF and LeRF scores.
parallel
Use multiprocessing for data processing: if set to -1, disables
multiprocessing. Set to 0 to enable as many data processing instances
as processing cores as available in the system. Set to >= 1 to enable
that many multiprocessing instances for data processing.
Returns Returns
...@@ -201,6 +270,17 @@ def run( ...@@ -201,6 +270,17 @@ def run(
raise TypeError(f"Model of type `{type(model)}` is not yet supported.") raise TypeError(f"Model of type `{type(model)}` is not yet supported.")
use_cuda = device_manager.device_type == "cuda" use_cuda = device_manager.device_type == "cuda"
if device_manager.device_type in ("cuda", "mps") and (
parallel == 0 or parallel > 1
):
raise RuntimeError(
f"The number of multiprocessing instances is set to {parallel} and "
f"you asked to use a GPU (device = `{device_manager.device_type}`"
f"). The currently implementation can only handle a single GPU. "
f"Either disable GPU utilisation or set the number of "
f"multiprocessing instances to one, or disable multiprocessing "
"entirely (ie. set it to -1)."
)
# prepares model for evaluation, cast to target device # prepares model for evaluation, cast to target device
device = device_manager.torch_device() device = device_manager.torch_device()
...@@ -216,59 +296,37 @@ def run( ...@@ -216,59 +296,37 @@ def run(
retval: dict[str, list[typing.Any]] = {} retval: dict[str, list[typing.Any]] = {}
# our worker function
_process = functools.partial(
_process_sample,
model=model,
device=device,
saliency_map_callable=saliency_map_callable,
target_class=target_class,
positive_only=positive_only,
percentiles=percentiles,
)
for k, v in datamodule.predict_dataloader().items(): for k, v in datamodule.predict_dataloader().items():
logger.info(f"Computing ROAD scores for dataset `{k}`...")
retval[k] = [] retval[k] = []
for sample in tqdm.tqdm(v, desc="samples", leave=False, disable=None): if parallel < 0:
name = sample[1]["name"][0] logger.info(
label = int(sample[1]["label"].item()) f"Computing ROAD scores for dataset `{k}` in the current "
image = sample[0].to( f"process context..."
device=device, non_blocking=torch.cuda.is_available()
) )
for sample in tqdm.tqdm(
# in binary classification systems, negative labels may be skipped v, desc="samples", leave=False, disable=None
if positive_only and (model.num_classes == 1) and (label == 0): ):
retval[k].append([name, label]) retval[k].append(_process(sample))
continue
else:
# chooses target outputs to generate saliency maps for instances = parallel or multiprocessing.cpu_count()
if model.num_classes > 1: logger.info(
if target_class == "all": f"Computing ROAD scores for dataset `{k}` using {instances} "
# test all outputs f"processes..."
for output_num in range(model.num_classes): )
results = _calculate_road_scores( with multiprocessing.Pool(instances) as p:
model, retval[k] = list(tqdm.tqdm(p.imap(_process, v), total=len(v)))
image,
output_num,
saliency_map_callable,
percentiles,
)
retval[k].append([name, label, output_num, *results])
else:
# we will figure out the output with the highest value and
# evaluate the saliency mapping technique over it.
outputs = saliency_map_callable.activations_and_grads(image)
output_nums = np.argmax(outputs.cpu().data.numpy(), axis=-1)
assert len(output_nums) == 1
results = _calculate_road_scores(
model,
image,
output_nums[0],
saliency_map_callable,
percentiles,
)
retval[k].append([name, label, output_nums[0], *results])
else:
results = _calculate_road_scores(
model,
image,
0,
saliency_map_callable,
percentiles,
)
retval[k].append([name, label, 0, *results])
return retval return retval
...@@ -93,10 +93,12 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") ...@@ -93,10 +93,12 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
@click.option( @click.option(
"--parallel", "--parallel",
"-P", "-P",
help="""Use multiprocessing for data loading: if set to -1 (default), help="""Use multiprocessing for data loading processing: if set to -1
disables multiprocessing data loading. Set to 0 to enable as many data (default), disables multiprocessing. Set to 0 to enable as many data
loading instances as processing cores as available in the system. Set to processing instances as processing cores available in the system. Set to
>= 1 to enable that many multiprocessing instances for data loading.""", >= 1 to enable that many multiprocessing instances. Note that if you
activate this option, then you must use --device=cpu, as using a GPU
concurrently is not supported.""",
type=click.IntRange(min=-1), type=click.IntRange(min=-1),
show_default=True, show_default=True,
required=True, required=True,
...@@ -203,6 +205,15 @@ def saliency_completeness( ...@@ -203,6 +205,15 @@ def saliency_completeness(
logger.info(f"Output folder: {output_folder}") logger.info(f"Output folder: {output_folder}")
output_folder.mkdir(parents=True, exist_ok=True) output_folder.mkdir(parents=True, exist_ok=True)
if device in ("cuda", "mps") and (parallel == 0 or parallel > 1):
raise RuntimeError(
f"The number of multiprocessing instances is set to {parallel} and "
f"you asked to use a GPU (device = `{device}`). The currently "
f"implementation can only handle a single GPU. Either disable GPU "
f"utilisation or set the number of multiprocessing instances to "
f"one, or disable multiprocessing entirely (ie. set it to -1)."
)
device_manager = DeviceManager(device) device_manager = DeviceManager(device)
# batch_size must be == 1 for now (underlying code is NOT prepared to # batch_size must be == 1 for now (underlying code is NOT prepared to
...@@ -232,6 +243,7 @@ def saliency_completeness( ...@@ -232,6 +243,7 @@ def saliency_completeness(
target_class=target_class, target_class=target_class,
positive_only=positive_only, positive_only=positive_only,
percentiles=percentile, percentiles=percentile,
parallel=parallel,
) )
output_json = output_folder / (algo + ".json") output_json = output_folder / (algo + ".json")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment