Skip to content
Snippets Groups Projects
Commit 982404c9 authored by André Anjos's avatar André Anjos :speech_balloon: Committed by Daniel CARRON
Browse files

[saliencymap_evaluator] Reimplements interpretability evaluator to simplify...

[saliencymap_evaluator] Reimplements interpretability evaluator to simplify code, be less dependent on opencv, improve documentation, and improve type hinting
parent 2e149443
No related branches found
No related tags found
1 merge request!12Adds grad-cam support on classifiers
...@@ -65,3 +65,8 @@ ...@@ -65,3 +65,8 @@
**Rethinking computer-aided tuberculosis diagnosis.**, **Rethinking computer-aided tuberculosis diagnosis.**,
In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern
Recognition, pages 2646–2655. Recognition, pages 2646–2655.
.. [SCORECAM-2020] *H. Wang et al.*, **Score-CAM: Score-Weighted Visual
Explanations for Convolutional Neural Networks** 2020 IEEE/CVF Conference on
Computer Vision and Pattern Recognition Workshops (CVPRW), Seattle, WA, USA,
2020 pp. 111-119. doi: https://doi.org/10.1109/CVPRW50498.2020.00020
...@@ -3,17 +3,108 @@ ...@@ -3,17 +3,108 @@
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
import logging import logging
import os import pathlib
import typing
import cv2 import cv2
import numpy as np import lightning.pytorch
import numpy
import numpy.typing
import torch
import torchvision.ops
from tqdm import tqdm from tqdm import tqdm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _compute_max_iou_and_ioda(detected_box, gt_boxes): def _ordered_connected_components(
saliency_map: typing.Sequence[typing.Sequence[float]]
| numpy.typing.NDArray[numpy.double],
) -> list[numpy.typing.NDArray[numpy.bool_]]:
"""Calculates the largest connected components available on a saliency map
and return those as individual masks.
This implementation is based on [SCORECAM-2020]_:
1. Thresholding: The pixel values above 20% of max value are kept in the
original saliency map. Everything else is set to zero.
2. The thresholded saliency map is transformed into a boolean array (ones
are attributed to all elements above the threshold.
3. We call :py:func:`skimage.metrics.label` to evaluate all connected
components and label those with distinct integers.
4. We histogram the labels and return one binary mask for each label,
sorted by decreasing size.
Parameters
----------
saliency_map
Input saliciency map whose connected components will be calculated
from.
Returns
-------
A list of boolean masks, one for each connected component, ordered by
decreasing size. This list may be empty if the input ``saliency_map``
is all zeroes.
"""
# thresholds like [SCORECAM-2020]_
thresholded_mask = (saliency_map >= (0.2 * numpy.max(saliency_map))).astype(
numpy.uint8
)
# avoids an all zeroes mask being processed
if not numpy.any(thresholded_mask):
return []
# opencv implementation:
n, labelled = cv2.connectedComponents(thresholded_mask, connectivity=8)
retval = [labelled == k for k in range(1, n)]
# scikit-image implementation
# import skimage.measure
# labelled, n = skimage.measure.label(thresholded_mask, return_num=True)
# retval = [labelled == k for k in range(1, n+1)]
return sorted(retval, key=lambda x: x.sum(), reverse=True)
def _extract_bounding_box(
mask: numpy.typing.NDArray[numpy.bool_],
) -> tuple[int, int, int, int]:
"""Defines a bounding box surrounding a connected component mask.
Parameters
----------
mask
The connected component mask from whom extract the bounding box.
Returns
-------
A tuple of 4 integers representing the bounding box with the following
components:
* top-left horizontal coordinate (``x``) (pixels)
* top-left vertical coordinate (``y``) (pixels)
* width (pixels)
* height (pixels)
"""
# opencv implementation:
# x, y, w, h = cv2.boundingRect(mask.astype(numpy.uint8))
x, y, x2, y2 = torchvision.ops.masks_to_boxes(torch.tensor(mask)[None, :])[
0
]
return (int(x), int(y), int(x2 - x + 1), int(y2 - y + 1))
def _compute_max_iou_and_ioda(
detected_box: tuple[int, int, int, int],
gt_bboxes: typing.Sequence[tuple[int, int, int, int]],
) -> tuple[float, float]:
"""Will calculate how much of detected area lies in ground truth boxes. """Will calculate how much of detected area lies in ground truth boxes.
If there are multiple gt boxes, the detected area will be calculated If there are multiple gt boxes, the detected area will be calculated
...@@ -27,9 +118,8 @@ def _compute_max_iou_and_ioda(detected_box, gt_boxes): ...@@ -27,9 +118,8 @@ def _compute_max_iou_and_ioda(detected_box, gt_boxes):
max_intersection = 0 max_intersection = 0
max_gt_area = 0 max_gt_area = 0
for bbox in gt_boxes: for bbox in gt_bboxes:
xmin, ymin = int(bbox[1].item()), int(bbox[2].item()) xmin, ymin, width, height = bbox
width, height = int(bbox[3].item()), int(bbox[4].item())
gt_area = width * height gt_area = width * height
...@@ -49,23 +139,26 @@ def _compute_max_iou_and_ioda(detected_box, gt_boxes): ...@@ -49,23 +139,26 @@ def _compute_max_iou_and_ioda(detected_box, gt_boxes):
if max_gt_area == 0 and max_intersection == 0: if max_gt_area == 0 and max_intersection == 0:
# This case means no intersection was found, even though there are gt boxes # This case means no intersection was found, even though there are gt boxes
iou, ioda = 0, 0 iou, ioda = 0.0, 0.0
else: else:
iou = max_intersection / ( iou = max_intersection / (
detected_area + max_gt_area - max_intersection detected_area + max_gt_area - max_intersection
) )
ioda = max_intersection / detected_area ioda = max_intersection / detected_area
return iou, ioda return float(iou), float(ioda)
def _compute_simultaneous_iou_and_ioda(detected_box, gt_boxes): def _compute_simultaneous_iou_and_ioda(
detected_box: tuple[int, int, int, int],
gt_bboxes: typing.Sequence[tuple[int, int, int, int]],
) -> tuple[float, float]:
"""Will calculate how much of detected area lies between ground truth """Will calculate how much of detected area lies between ground truth
boxes. boxes.
This means that if there are multiple gt boxes, the detected area This means that if there are multiple gt boxes, the detected area
will be compared to them simultaneously (and not to each gt box will be compared to them simultaneously (and not to each gt box
separately).) separately).
""" """
x_L, y_L, w_L, h_L = detected_box x_L, y_L, w_L, h_L = detected_box
detected_area = w_L * h_L detected_area = w_L * h_L
...@@ -74,9 +167,8 @@ def _compute_simultaneous_iou_and_ioda(detected_box, gt_boxes): ...@@ -74,9 +167,8 @@ def _compute_simultaneous_iou_and_ioda(detected_box, gt_boxes):
intersection = 0 intersection = 0
total_gt_area = 0 total_gt_area = 0
for bbox in gt_boxes: for bbox in gt_bboxes:
xmin, ymin = int(bbox[1].item()), int(bbox[2].item()) xmin, ymin, width, height = bbox
width, height = int(bbox[3].item()), int(bbox[4].item())
gt_area = width * height gt_area = width * height
total_gt_area += gt_area total_gt_area += gt_area
...@@ -93,203 +185,199 @@ def _compute_simultaneous_iou_and_ioda(detected_box, gt_boxes): ...@@ -93,203 +185,199 @@ def _compute_simultaneous_iou_and_ioda(detected_box, gt_boxes):
iou = intersection / (detected_area + total_gt_area - intersection) iou = intersection / (detected_area + total_gt_area - intersection)
ioda = intersection / detected_area ioda = intersection / detected_area
return iou, ioda return float(iou), float(ioda)
def _compute_avg_saliency_focus(gt_boxes, saliency_map): def _compute_avg_saliency_focus(
"""Will calculate how much of the ground truth bounding boxes area is saliency_map: numpy.typing.NDArray[numpy.double],
covered by the activations.""" gt_mask: numpy.typing.NDArray[numpy.bool_],
) -> float:
"""Integrates the saliency map over the ground-truth boxes and normalizes
by total bounding-box area.
binary_mask = np.zeros_like(saliency_map) This function will integrate (sum) the value of the saliency map over the
ground-truth bounding boxes and normalize it by the total area covered by
all ground-truth bounding boxes.
total_gt_bbox_area = 0
# For each gt box, draw a binary mask Parameters
# The binary_mask will be 1 where the gt boxes are located ----------
for bbox in gt_boxes: gt_bboxes
xmin, ymin = int(bbox[1].item()), int(bbox[2].item()) Ground-truth bounding boxes in the format ``(x, y, width,
width, height = int(bbox[3].item()), int(bbox[4].item()) height)``.
gt_mask
Ground-truth mask containing the bounding boxes of the ground-truth
drawn as ``True`` values.
binary_mask[ymin : ymin + height, xmin : xmin + width] = 1
total_gt_bbox_area += width * height Returns
-------
A single floating-point number representing the Average saliency focus.
"""
multiplied_mask = binary_mask * saliency_map area = gt_mask.sum()
numerator = np.sum(multiplied_mask) if area == 0:
return 0.0
if total_gt_bbox_area == 0: return numpy.sum(saliency_map * gt_mask) / area
avg_saliency_focus = 0
else:
avg_saliency_focus = numerator / total_gt_bbox_area
return avg_saliency_focus
def _compute_proportional_energy(
saliency_map: numpy.typing.NDArray[numpy.double],
gt_mask: numpy.typing.NDArray[numpy.bool_],
) -> float:
"""Calculates how much activation lies within the ground truth boxes
compared to the total sum of the activations (integral).
# Own implementation based on Parameters
# "Score-CAM: Score-Weighted Visual Explanations for Convolutional Neural Networks" by Wang et al. (2020), ----------
# https://arxiv.org/abs/1910.01279 saliency_map
def _compute_proportional_energy(gt_boxes, saliency_map): A real-valued saliency-map that conveys regions used for
"""Will calculate how much activation lies within the ground truth boxes classification in the original sample.
compared to the total sum of the activations.""" gt_mask
binary_mask = np.zeros_like(saliency_map) Ground-truth mask containing the bounding boxes of the ground-truth
drawn as ``True`` values.
# For each gt box, draw a binary mask
# The binary_mask will be 1 where the gt boxes are located
for bbox in gt_boxes:
xmin, ymin = int(bbox[1].item()), int(bbox[2].item())
width, height = int(bbox[3].item()), int(bbox[4].item())
binary_mask[ymin : ymin + height, xmin : xmin + width] = 1 Returns
-------
A single floating-point number representing the proportional energy.
"""
multiplied_mask = binary_mask * saliency_map denominator = numpy.sum(saliency_map)
numerator = np.sum(multiplied_mask)
denominator = np.sum(saliency_map)
if denominator == 0: if denominator == 0.0:
proportional_energy = 0 return 0.0
else:
proportional_energy = numerator / denominator
return proportional_energy return float(numpy.sum(saliency_map * gt_mask) / denominator)
def calculate_localization_metrics( def _process_sample(
saliency_map, detected_box, ground_truth_box gt_bboxes: typing.Sequence[tuple[int, int, int, int]],
): saliency_map: numpy.typing.NDArray[numpy.double],
"""Calculates localization metrics for a single input image for a given ) -> tuple[float, float, float, float, float, float, float, float]:
visualization method.""" """Calculates the metrics for a single sample.
iou, ioda = _compute_max_iou_and_ioda(detected_box, ground_truth_box) Parameters
----------
proportional_energy = _compute_proportional_energy( gt_bboxes
ground_truth_box, saliency_map A list of ground-truth bounding boxes following the format:
)
avg_saliency_focus = _compute_avg_saliency_focus( * xmin: horizontal position of bounding box upper-left corner, in
ground_truth_box, saliency_map pixels
) * ymin: vertical position of bounding box upper-left corner, in
pixels
* width: width of the bounding box, in pixels
* height: height of the bounding box, in pixels
"""
return iou, ioda, proportional_energy, avg_saliency_focus masks = _ordered_connected_components(saliency_map)
detected_box = (0, 0, 0, 0)
if masks:
# Helper function to calculate the metrics for a single target class # we get only the largest bounding box as of now
# of a single input image. detected_box = _extract_bounding_box(masks[0])
def process_target_class(
names,
gt_bboxes,
saliency_map_path,
csv_writer,
):
saliency_map = np.load(saliency_map_path)
# Calculate bounding boxes for largest connected component
# The pixel values above 20% of max value are kept in the mask to
# calculate IoU and IoDA.
# This imitates the process done by the original CAM paper:
# "Learning Deep Features for Discriminative Localization" by Zhou et al. (2015),
# https://arxiv.org/abs/1512.04150
thresholded_mask = np.copy(saliency_map)
max_value = np.max(thresholded_mask)
threshold_value = 0.2 * max_value
thresholded_mask[thresholded_mask < threshold_value] = 0
thresholded_mask = (thresholded_mask > 0).astype(np.uint8)
if np.any(thresholded_mask > 0):
_, label_ids, values, _ = cv2.connectedComponentsWithStats(
thresholded_mask, connectivity=8
)
largest_label_id = np.argmax(values[1:, cv2.CC_STAT_AREA]) + 1
largest_label_mask = (label_ids == largest_label_id).astype(np.uint8)
x_L, y_L, w_L, h_L = cv2.boundingRect(largest_label_mask)
x_L, y_L, w_L, h_L = int(x_L), int(y_L), int(w_L), int(h_L)
else:
x_L, y_L, w_L, h_L = 0, 0, 0, 0
# Calculate localization metrics # Calculate localization metrics
iou, ioda, proportional_energy, asf = calculate_localization_metrics( iou, ioda = _compute_max_iou_and_ioda(detected_box, gt_bboxes)
saliency_map=saliency_map,
detected_box=(x_L, y_L, w_L, h_L), # The binary_mask will be ON/True where the gt boxes are located
ground_truth_box=gt_bboxes, binary_mask = numpy.zeros_like(saliency_map, dtype=numpy.bool_)
) for bbox in gt_bboxes:
xmin, ymin, width, height = bbox
# Write metrics to csv file binary_mask[ymin : ymin + height, xmin : xmin + width] = True
csv_writer.writerow(
[ return (
names[0], iou,
iou, ioda,
ioda, _compute_proportional_energy(saliency_map, binary_mask),
proportional_energy, _compute_avg_saliency_focus(saliency_map, binary_mask),
asf, *detected_box,
x_L,
y_L,
w_L,
h_L,
]
) )
def run( def run(
input_folder, input_folder: pathlib.Path,
data_loader, datamodule: lightning.pytorch.LightningDataModule,
dataset_split_name, ) -> dict[str, list[typing.Any]]:
csv_writers,
):
"""Applies visualization techniques on input CXR, outputs images with """Applies visualization techniques on input CXR, outputs images with
overlaid heatmaps and csv files with measurements. overlaid heatmaps and csv files with measurements.
Parameters Parameters
--------- ---------
input_folder
Directory in which the saliency maps are stored for a specific
visualization type.
datamodule
The lightning datamodule to iterate on.
input_folder : str
Directory in which the saliency maps are stored for a specific visualization type.
data_loader : py:class:`torch.torch.utils.data.DataLoader`
The pytorch Dataloader used to iterate over batches.
dataset_split_name : str
Name of the dataset split (e.g. "train", "validation", "test").
csv_writers : dict
Dictionary containing csv writer objects for each target class.
Returns Returns
------- -------
all_predictions : list A dictionary where keys are dataset names in the provide datamodule,
All the predictions associated with filename and ground truth, saved as .csv. and values are lists containing sample information alongside metrics
calculated:
* Sample name
* Sample target class
* IoU
* IoDA
* Proportional energy
* Average saliency focus
* Largest detected bounding box
""" """
for samples in tqdm(data_loader, desc="batches", leave=False, disable=None):
# Check if the sample has a bounding box entry
if "radsign_bboxes" not in samples[1]:
logger.warning(
"The dataset does not contain bounding box information. No localization metrics can be calculated."
)
return
else:
# TB negative labels are skipped (they don't have gt bboxes)
if samples[1]["label"].item() == 0:
continue
names = samples[1]["name"] retval: dict[str, list[typing.Any]] = {}
gt_bboxes = samples[1]["radsign_bboxes"] for dataset_name, dataset_loader in datamodule.predict_dataloader().items():
logger.info(
f"Estimating interpretability metrics for dataset `{dataset_name}`..."
)
retval[dataset_name] = []
# TODO: This loads the images from the dataset, but they are not useful at
# this point...
for sample in tqdm(
dataset_loader, desc="batches", leave=False, disable=None
):
name = str(sample[1]["name"][0])
label = int(sample[1]["label"].item())
# TODO: This is very specific to the TBX11k system for labelling
# regions of interest. We need to abstract from this to support more
# datasets and other ways to annotate.
bboxes = sample[1].get("radsign_bboxes", [])
bboxes = [k[1:] for k in bboxes] # remove bbox label...
if label == 0:
# we add the entry for dataset completeness
retval[dataset_name].append([name, label])
continue
if not gt_bboxes: if not bboxes:
logger.warning( logger.warning(
f'This sample does not have bounding box information. No localization metrics can be calculated. Sample "{names[0]}" is skipped.' f"Sample `{name}` does not contdain bounding-box information. "
) f"No localization metrics can be calculated in this case. "
continue f"Skipping..."
)
for target_class_name, csv_writer in csv_writers.items(): # we add the entry for dataset completeness
saliency_map_path = os.path.join( retval[dataset_name].append([name, label])
input_folder, continue
target_class_name,
dataset_split_name,
names[0].rsplit(".", 1)[0] + ".npy",
)
process_target_class( # we fully process this entry
names, retval[dataset_name].append(
gt_bboxes, [
saliency_map_path, name,
csv_writer, label,
*_process_sample(
bboxes,
numpy.load(
input_folder
/ pathlib.Path(name).with_suffix(".npy")
),
),
]
) )
return retval
...@@ -2,58 +2,16 @@ ...@@ -2,58 +2,16 @@
# #
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
import csv
import os
import pathlib import pathlib
import click import click
from clapper.click import ConfigCommand, ResourceOption, verbosity_option from clapper.click import ResourceOption, verbosity_option
from clapper.logging import setup from clapper.logging import setup
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") from .click import ConfigCommand
def _get_target_classes_from_directory(input_folder): logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
# Gets a list of target classes from a directory
return [
item
for item in os.listdir(input_folder)
if os.path.isdir(os.path.join(input_folder, item))
]
def prepare_csv_writers(input_folder, dataset_split):
# Create a CSV file to store the performance metrics for each image, for each target class
target_classes = _get_target_classes_from_directory(input_folder)
csv_files = {}
csv_writers = {}
for target_class in target_classes:
directory_path = os.path.join(input_folder, target_class)
os.makedirs(directory_path, exist_ok=True)
csv_file_path = os.path.join(
directory_path, f"{dataset_split}_localization_metrics.csv"
)
csv_files[target_class] = open(csv_file_path, "w", newline="")
csv_writers[target_class] = csv.writer(csv_files[target_class])
csv_writers[target_class].writerow(
[
"Image",
"IoU",
"IoDA",
"Proportional Energy",
"Average Saliency Focus",
"detected_bbox_xmin",
"detected_bbox_ymin",
"detected_bbox_width",
"detected_bbox_height",
]
)
return csv_files, csv_writers
@click.command( @click.command(
...@@ -61,84 +19,119 @@ def prepare_csv_writers(input_folder, dataset_split): ...@@ -61,84 +19,119 @@ def prepare_csv_writers(input_folder, dataset_split):
cls=ConfigCommand, cls=ConfigCommand,
epilog="""Examples: epilog="""Examples:
\b 1. Evaluate the generated saliency maps for their localization performance:
1. Evaluate the generated saliency maps for their localization performance:
.. code:: sh .. code:: sh
ptbench evaluate-saliencymaps -vv pasa tbx11k-v1-healthy-vs-atb --input-folder=parent_folder/gradcam/ ptbench evaluate-saliencymaps -vv tbx11k-v1-healthy-vs-atb --input-folder=parent_folder/gradcam/ --output-json=parent_folder/gradcam/tbx11k-v1-interp.json
""", """,
) )
@click.option(
"--model",
"-m",
help="A lightining module instance implementing the network to be used for applying the necessary data transformations.",
required=True,
cls=ResourceOption,
)
@click.option( @click.option(
"--datamodule", "--datamodule",
"-d", "-d",
help="A lighting data module containing the training, validation and test sets.", help="""A lighting data module that will be asked for prediction data
loaders. Typically, this includes all configured splits in a datamodule,
however this is not a requirement. A datamodule that returns a single
dataloader for prediction (wrapped in a dictionary) is acceptable.""",
required=True, required=True,
cls=ResourceOption, cls=ResourceOption,
) )
@click.option( @click.option(
"--input-folder", "--input-folder",
"-i", "-i",
help="Path to the folder containing the saliency maps for a specific visualization type.", help="""Path where to load saliency maps from.""",
required=True, required=True,
type=click.Path( type=click.Path(
exists=True,
file_okay=False, file_okay=False,
dir_okay=True, dir_okay=True,
writable=True,
path_type=pathlib.Path, path_type=pathlib.Path,
), ),
default="visualizations", default="saliency-maps",
cls=ResourceOption,
)
@click.option(
"--output-json",
"-o",
help="""Path where to store the output JSON file containing all
measures.""",
required=True,
type=click.Path(
file_okay=True,
dir_okay=False,
path_type=pathlib.Path,
),
default="saliencymap-interpretability.json",
cls=ResourceOption, cls=ResourceOption,
) )
@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
def evaluate_saliencymaps( def evaluate_saliencymaps(
model,
datamodule, datamodule,
input_folder, input_folder,
output_json,
**_, **_,
) -> None: ) -> None:
"""Creates .csv files with the IoU, IoDA, Proportional Energy, and ASF """Evaluates saliency map agreement with annotations (human
metrics, and additionally saves the detected bounding box coordinates for interpretability).
each image.
The evaluation happens by comparing saliency maps with ground-truth
Calculates them for each target class and split of the dataset. provided by any other means (typically following a manual annotation
procedure).
.. note::
For obvious reasons, this evaluation is limited to databases that
contain built-in annotations which corroborate classification.
As a result of the evaluation, this application creates a single JSON file
that resembles the original datamodule, with added information containing
the following measures, for each sample:
* IoU: The intersection of the (thresholded) saliency maps with
the annotation the most overlaps, over the union of both areas.
* IoDA: The intersection of the (thresholded) saliency maps with
the annotation that most overlaps, over area of (thresholded) saliency
maps.
* Proportional Energy: A measure that compares (UNthresholed) saliency maps
with annotations originated from "Score-CAM: Score-Weighted Visual
Explanations for Convolutional Neural Networks" by Wang et al. (2020),
https://arxiv.org/abs/1910.01279. It estimates how much activation lies
within the ground truth boxes compared to the total sum of the activations.
* Average Saliency Focus: estimates how much of the ground truth bounding
boxes area is covered by the activations.
.. important::
The thresholding algorithm used to evaluate IoU and IoDA measures is
based on the process done by the original CAM paper: "Learning Deep
Features for Discriminative Localization" by Zhou et al. (2015),
https://arxiv.org/abs/1512.04150. It keeps all points from the saliency
map that are above the 20% of its maximum value.
It then calculates a **single** bounding box for largest connected
component. This bounding box represents detected elements on the
original sample that corroborate the classification outcome.
IoU and IoDA are only evaluated for a single ground-truth bounding box
per sample (the one with the highest overlap). Any other bounding box
marked on the sample is ignored in the present implementation.
""" """
import json
from ..engine.saliencymap_evaluator import run from ..engine.saliencymap_evaluator import run
from .utils import save_sh_command from .utils import save_sh_command
save_sh_command(input_folder / "command.sh") save_sh_command(input_folder / "command.sh")
datamodule.set_chunk_size(1, 1) datamodule.model_transforms = []
datamodule.drop_incomplete_batch = False
# datamodule.cache_samples = cache_samples
# datamodule.parallel = parallel
datamodule.model_transforms = model.model_transforms
datamodule.prepare_data() datamodule.prepare_data()
datamodule.setup(stage="predict") datamodule.setup(stage="predict")
dataloaders = datamodule.predict_dataloader() results = run(input_folder, datamodule)
for k, v in dataloaders.items():
csv_files, csv_writers = prepare_csv_writers(input_folder, k)
logger.info(f"Calculating localization metrics for '{k}' set...")
run(
input_folder,
data_loader=v,
dataset_split_name=k,
csv_writers=csv_writers,
)
for csv_file in csv_files.values(): with output_json.open("w") as f:
csv_file.close() logger.info(f"Saving output file to `{str(output_json)}`...")
json.dump(results, f, indent=2)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment