Skip to content
Snippets Groups Projects
Commit 8a15da7b authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[engine] Use mask for performance evaluation

parent 92d531a2
No related branches found
No related tags found
No related merge requests found
...@@ -24,7 +24,39 @@ logger = logging.getLogger(__name__) ...@@ -24,7 +24,39 @@ logger = logging.getLogger(__name__)
def _posneg(pred, gt, threshold): def _posneg(pred, gt, threshold):
"""Calculates true and false positives and negatives""" """Calculates true and false positives and negatives
Parameters
----------
pred : torch.Tensor
pixel-wise predictions
gt : torch.Tensor
ground-truth (annotations)
threshold : float
a particular threshold in which to calculate the performance
measures
Returns
-------
tp_tensor : torch.Tensor
boolean tensor with true positives, considering all observations
fp_tensor : torch.Tensor
boolean tensor with false positives, considering all observations
tn_tensor : torch.Tensor
boolean tensor with true negatives, considering all observations
fn_tensor : torch.Tensor
boolean tensor with false negatives, considering all observations
"""
gt = gt.byte() # byte tensor gt = gt.byte() # byte tensor
...@@ -39,18 +71,18 @@ def _posneg(pred, gt, threshold): ...@@ -39,18 +71,18 @@ def _posneg(pred, gt, threshold):
tp_tensor = gt * binary_pred tp_tensor = gt * binary_pred
# false positives # false positives
fp_tensor = torch.eq((binary_pred + tp_tensor), 1) fp_tensor = torch.eq((binary_pred + tp_tensor), 1).byte()
# true negatives # true negatives
tn_tensor = equals - tp_tensor tn_tensor = equals - tp_tensor
# false negatives # false negatives
fn_tensor = notequals - fp_tensor.type(torch.uint8) fn_tensor = notequals - fp_tensor
return tp_tensor, fp_tensor, tn_tensor, fn_tensor return tp_tensor, fp_tensor, tn_tensor, fn_tensor
def _sample_measures_for_threshold(pred, gt, threshold): def sample_measures_for_threshold(pred, gt, mask, threshold):
""" """
Calculates measures on one single sample, for a specific threshold Calculates measures on one single sample, for a specific threshold
...@@ -64,6 +96,9 @@ def _sample_measures_for_threshold(pred, gt, threshold): ...@@ -64,6 +96,9 @@ def _sample_measures_for_threshold(pred, gt, threshold):
gt : torch.Tensor gt : torch.Tensor
ground-truth (annotations) ground-truth (annotations)
mask : torch.Tensor
region mask (used only if available). May be set to ``None``.
threshold : float threshold : float
a particular threshold in which to calculate the performance a particular threshold in which to calculate the performance
measures measures
...@@ -88,15 +123,25 @@ def _sample_measures_for_threshold(pred, gt, threshold): ...@@ -88,15 +123,25 @@ def _sample_measures_for_threshold(pred, gt, threshold):
tp_tensor, fp_tensor, tn_tensor, fn_tensor = _posneg(pred, gt, threshold) tp_tensor, fp_tensor, tn_tensor, fn_tensor = _posneg(pred, gt, threshold)
# if a mask is provided, consider only TP/FP/TN/FN **within** the region of
# interest defined by the mask
if mask is not None:
antimask = torch.le(mask, 0.5)
tp_tensor[antimask] = 0
fp_tensor[antimask] = 0
tn_tensor[antimask] = 0
fn_tensor[antimask] = 0
# calc measures from scalars # calc measures from scalars
tp_count = torch.sum(tp_tensor).item() tp_count = torch.sum(tp_tensor).item()
fp_count = torch.sum(fp_tensor).item() fp_count = torch.sum(fp_tensor).item()
tn_count = torch.sum(tn_tensor).item() tn_count = torch.sum(tn_tensor).item()
fn_count = torch.sum(fn_tensor).item() fn_count = torch.sum(fn_tensor).item()
return base_measures(tp_count, fp_count, tn_count, fn_count) return base_measures(tp_count, fp_count, tn_count, fn_count)
def _sample_measures(pred, gt, steps): def _sample_measures(pred, gt, mask, steps):
""" """
Calculates measures on one single sample Calculates measures on one single sample
...@@ -110,6 +155,9 @@ def _sample_measures(pred, gt, steps): ...@@ -110,6 +155,9 @@ def _sample_measures(pred, gt, steps):
gt : torch.Tensor gt : torch.Tensor
ground-truth (annotations) ground-truth (annotations)
mask : torch.Tensor
region mask (used only if available). May be set to ``None``.
steps : int steps : int
number of steps to use for threshold analysis. The step size is number of steps to use for threshold analysis. The step size is
calculated from this by dividing ``1.0/steps`` calculated from this by dividing ``1.0/steps``
...@@ -134,7 +182,8 @@ def _sample_measures(pred, gt, steps): ...@@ -134,7 +182,8 @@ def _sample_measures(pred, gt, steps):
step_size = 1.0 / steps step_size = 1.0 / steps
data = [ data = [
(index, threshold) + _sample_measures_for_threshold(pred, gt, threshold) (index, threshold) + sample_measures_for_threshold(pred, gt, mask,
threshold)
for index, threshold in enumerate(numpy.arange(0.0, 1.0, step_size)) for index, threshold in enumerate(numpy.arange(0.0, 1.0, step_size))
] ]
...@@ -157,6 +206,7 @@ def _sample_analysis( ...@@ -157,6 +206,7 @@ def _sample_analysis(
img, img,
pred, pred,
gt, gt,
mask,
threshold, threshold,
tp_color=(0, 255, 0), # (128,128,128) Gray tp_color=(0, 255, 0), # (128,128,128) Gray
fp_color=(0, 0, 255), # (70, 240, 240) Cyan fp_color=(0, 0, 255), # (70, 240, 240) Cyan
...@@ -178,6 +228,9 @@ def _sample_analysis( ...@@ -178,6 +228,9 @@ def _sample_analysis(
gt : torch.Tensor gt : torch.Tensor
ground-truth (annotations) ground-truth (annotations)
mask : torch.Tensor
region mask (used only if available). May be set to ``None``.
threshold : float threshold : float
The threshold to be used while analyzing this image's probability map The threshold to be used while analyzing this image's probability map
...@@ -207,6 +260,15 @@ def _sample_analysis( ...@@ -207,6 +260,15 @@ def _sample_analysis(
tp_tensor, fp_tensor, tn_tensor, fn_tensor = _posneg(pred, gt, threshold) tp_tensor, fp_tensor, tn_tensor, fn_tensor = _posneg(pred, gt, threshold)
# if a mask is provided, consider only TP/FP/TN/FN **within** the region of
# interest defined by the mask
if mask is not None:
antimask = torch.le(mask, 0.5)
tp_tensor[antimask] = 0
fp_tensor[antimask] = 0
tn_tensor[antimask] = 0
fn_tensor[antimask] = 0
# change to PIL representation # change to PIL representation
tp_pil = VF.to_pil_image(tp_tensor.float()) tp_pil = VF.to_pil_image(tp_tensor.float())
tp_pil_colored = PIL.ImageOps.colorize(tp_pil, (0, 0, 0), tp_color) tp_pil_colored = PIL.ImageOps.colorize(tp_pil, (0, 0, 0), tp_color)
...@@ -295,6 +357,7 @@ def run( ...@@ -295,6 +357,7 @@ def run(
stem = sample[0] stem = sample[0]
image = sample[1] image = sample[1]
gt = sample[2] gt = sample[2]
mask = None if len(sample) <= 3 else sample[3]
pred_fullpath = os.path.join(use_predictions_folder, stem + ".hdf5") pred_fullpath = os.path.join(use_predictions_folder, stem + ".hdf5")
with h5py.File(pred_fullpath, "r") as f: with h5py.File(pred_fullpath, "r") as f:
pred = f["array"][:] pred = f["array"][:]
...@@ -303,7 +366,7 @@ def run( ...@@ -303,7 +366,7 @@ def run(
raise RuntimeError( raise RuntimeError(
f"{stem} entry already exists in data. Cannot overwrite." f"{stem} entry already exists in data. Cannot overwrite."
) )
data[stem] = _sample_measures(pred, gt, steps) data[stem] = _sample_measures(pred, gt, mask, steps)
if output_folder is not None: if output_folder is not None:
fullpath = os.path.join(output_folder, name, f"{stem}.csv") fullpath = os.path.join(output_folder, name, f"{stem}.csv")
...@@ -313,7 +376,7 @@ def run( ...@@ -313,7 +376,7 @@ def run(
if overlayed_folder is not None: if overlayed_folder is not None:
overlay_image = _sample_analysis( overlay_image = _sample_analysis(
image, pred, gt, threshold=threshold, overlay=True image, pred, gt, mask, threshold=threshold, overlay=True
) )
fullpath = os.path.join(overlayed_folder, name, f"{stem}.png") fullpath = os.path.join(overlayed_folder, name, f"{stem}.png")
tqdm.write(f"Saving {fullpath}...") tqdm.write(f"Saving {fullpath}...")
...@@ -432,11 +495,12 @@ def compare_annotators( ...@@ -432,11 +495,12 @@ def compare_annotators(
image = baseline_sample[1] image = baseline_sample[1]
gt = baseline_sample[2] gt = baseline_sample[2]
pred = other_sample[2] # works as a prediction pred = other_sample[2] # works as a prediction
mask = None if len(sample) <= 3 else sample[3]
if stem in data: if stem in data:
raise RuntimeError( raise RuntimeError(
f"{stem} entry already exists in data. " f"Cannot overwrite." f"{stem} entry already exists in data. " f"Cannot overwrite."
) )
data[stem] = _sample_measures(pred, gt, 2) data[stem] = _sample_measures(pred, gt, mask, 2)
if output_folder is not None: if output_folder is not None:
fullpath = os.path.join( fullpath = os.path.join(
...@@ -448,7 +512,7 @@ def compare_annotators( ...@@ -448,7 +512,7 @@ def compare_annotators(
if overlayed_folder is not None: if overlayed_folder is not None:
overlay_image = _sample_analysis( overlay_image = _sample_analysis(
image, pred, gt, threshold=0.5, overlay=True image, pred, gt, mask, threshold=0.5, overlay=True
) )
fullpath = os.path.join( fullpath = os.path.join(
overlayed_folder, "second-annotator", name, f"{stem}.png" overlayed_folder, "second-annotator", name, f"{stem}.png"
......
...@@ -5,9 +5,11 @@ import random ...@@ -5,9 +5,11 @@ import random
import unittest import unittest
import math import math
import torch
import nose.tools import nose.tools
from ..utils.measure import base_measures, auc from ..utils.measure import base_measures, auc
from ..engine.evaluator import sample_measures_for_threshold
class Tester(unittest.TestCase): class Tester(unittest.TestCase):
...@@ -103,3 +105,85 @@ def test_auc_raises_assertion_error(): ...@@ -103,3 +105,85 @@ def test_auc_raises_assertion_error():
# x is **not** the same size as y # x is **not** the same size as y
assert math.isclose(auc([0.0, 0.5, 1.0], [1.0, 1.0]), 1.0) assert math.isclose(auc([0.0, 0.5, 1.0], [1.0, 1.0]), 1.0)
def test_sample_measures_mask_checkerbox():
prediction = torch.ones((4, 4), dtype=float)
ground_truth = torch.ones((4, 4), dtype=float)
ground_truth[2:, :2] = 0.0
ground_truth[:2, 2:] = 0.0
mask = torch.zeros((4, 4), dtype=float)
mask[1:3, 1:3] = 1.0
threshold = 0.5
# with this configuration, this should be the correct count
tp = 2
fp = 2
tn = 0
fn = 0
nose.tools.eq_(
base_measures(tp, fp, tn, fn),
sample_measures_for_threshold(
prediction, ground_truth, mask, threshold
),
)
def test_sample_measures_mask_cross():
prediction = torch.ones((10, 10), dtype=float)
prediction[0,:] = 0.0
prediction[9,:] = 0.0
ground_truth = torch.ones((10, 10), dtype=float)
ground_truth[:5,] = 0.0 #lower part is not to be set
mask = torch.zeros((10, 10), dtype=float)
mask[(0,1,2,3,4,5,6,7,8,9),(0,1,2,3,4,5,6,7,8,9)] = 1.0
mask[(0,1,2,3,4,5,6,7,8,9),(9,8,7,6,5,4,3,2,1,0)] = 1.0
threshold = 0.5
# with this configuration, this should be the correct count
tp = 8
fp = 8
tn = 2
fn = 2
nose.tools.eq_(
base_measures(tp, fp, tn, fn),
sample_measures_for_threshold(
prediction, ground_truth, mask, threshold
),
)
def test_sample_measures_mask_border():
prediction = torch.zeros((10, 10), dtype=float)
prediction[:,4] = 1.0
prediction[:,5] = 1.0
prediction[0,4] = 0.0
prediction[8,4] = 0.0
prediction[1,6] = 1.0
ground_truth = torch.zeros((10, 10), dtype=float)
ground_truth[:,4] = 1.0
ground_truth[:,5] = 1.0
mask = torch.ones((10, 10), dtype=float)
mask[:,0] = 0.0
mask[0,:] = 0.0
mask[:,9] = 0.0
mask[9,:] = 0.0
threshold = 0.5
# with this configuration, this should be the correct count
tp = 15
fp = 1
tn = 47
fn = 1
nose.tools.eq_(
base_measures(tp, fp, tn, fn),
sample_measures_for_threshold(
prediction, ground_truth, mask, threshold
),
)
...@@ -12,7 +12,7 @@ can be downloaded. We include the reference of the data split protocols used ...@@ -12,7 +12,7 @@ can be downloaded. We include the reference of the data split protocols used
to generate iterators for training and testing. to generate iterators for training and testing.
.. list-table:: .. list-table:: Supported Datasets (``*`` provided within this package)
* - Dataset * - Dataset
- Reference - Reference
...@@ -40,7 +40,7 @@ to generate iterators for training and testing. ...@@ -40,7 +40,7 @@ to generate iterators for training and testing.
- [STARE-2000]_ - [STARE-2000]_
- 605 x 700 - 605 x 700
- 20 - 20
- - *
- x - x
- -
- -
...@@ -51,7 +51,7 @@ to generate iterators for training and testing. ...@@ -51,7 +51,7 @@ to generate iterators for training and testing.
- [CHASEDB1-2012]_ - [CHASEDB1-2012]_
- 960 x 999 - 960 x 999
- 28 - 28
- - *
- x - x
- -
- -
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment