Skip to content
Snippets Groups Projects
Commit 7c3b7704 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[engine.significance] Fix mask use

parent f69d8dd7
No related branches found
No related tags found
No related merge requests found
Pipeline #42457 failed
......@@ -150,7 +150,7 @@ def _winperf_measures(pred, gt, mask, threshold, size, stride):
ground-truth (annotations)
mask : torch.Tensor
mask for the region of interest, optional (only used if specified)
mask for the region of interest
threshold : float
threshold to use for evaluating individual sliding window performances
......@@ -191,9 +191,15 @@ def _winperf_measures(pred, gt, mask, threshold, size, stride):
if rem != 0:
padding += (0, (stride[0] - rem))
pred_padded = torch.nn.functional.pad(pred, padding)
gt_padded = torch.nn.functional.pad(gt.squeeze(0), padding)
mask_padded = torch.nn.functional.pad(mask.squeeze(0), padding)
pred_padded = torch.nn.functional.pad(
pred, padding, mode="constant", value=0.0
)
gt_padded = torch.nn.functional.pad(
gt.squeeze(0), padding, mode="constant", value=0.0
)
mask_padded = torch.nn.functional.pad(
mask.squeeze(0), padding, mode="constant", value=1.0
)
# this will create as many views as required
pred_windows = pred_padded.unfold(0, size[0], stride[0]).unfold(
......@@ -364,7 +370,8 @@ def _winperf_for_sample(
sample = dataset[k]
with h5py.File(os.path.join(basedir, sample[0] + ".hdf5"), "r") as f:
pred = torch.from_numpy(f["array"][:])
winperf = _winperf_measures(pred, sample[2], threshold, size, stride)
mask = None if len(sample) < 4 else sample[3]
winperf = _winperf_measures(pred, sample[2], mask, threshold, size, stride)
n, avg, std = _performance_summary(
sample[1].shape[1:], winperf, size, stride, figure
)
......
......@@ -17,11 +17,13 @@ from ..engine.significance import (
from ..utils.measure import base_measures
def _check_window_measures(pred, gt, threshold, size, stride, expected):
def _check_window_measures(pred, gt, mask, threshold, size, stride, expected):
pred = torch.tensor(pred)
gt = torch.tensor(gt)
actual = _winperf_measures(pred, gt, threshold, size, stride)
if mask is None:
mask = torch.ones_like(gt)
actual = _winperf_measures(pred, gt, mask, threshold, size, stride)
# transforms tp,tn,fp,fn through base_measures()
expected_shape = numpy.array(expected).shape[:2]
......@@ -37,6 +39,7 @@ def test_winperf_measures_alltrue():
pred = numpy.ones((4, 4), dtype=float)
gt = numpy.ones((4, 4), dtype=bool)
mask = None
threshold = 0.5
size = (2, 2)
stride = (1, 1)
......@@ -47,13 +50,14 @@ def test_winperf_measures_alltrue():
[(4, 0, 0, 0), (4, 0, 0, 0), (4, 0, 0, 0)],
[(4, 0, 0, 0), (4, 0, 0, 0), (4, 0, 0, 0)],
]
_check_window_measures(pred, gt, threshold, size, stride, expected)
_check_window_measures(pred, gt, mask, threshold, size, stride, expected)
def test_winperf_measures_alltrue_with_padding():
pred = numpy.ones((3, 3), dtype=float)
gt = numpy.ones((3, 3), dtype=bool)
mask = None
threshold = 0.5
size = (2, 2)
stride = (2, 2)
......@@ -63,7 +67,7 @@ def test_winperf_measures_alltrue_with_padding():
[(4, 0, 0, 0), (2, 0, 2, 0)],
[(2, 0, 2, 0), (1, 0, 3, 0)],
]
_check_window_measures(pred, gt, threshold, size, stride, expected)
_check_window_measures(pred, gt, mask, threshold, size, stride, expected)
def test_winperf_measures_dot_with_padding():
......@@ -71,6 +75,7 @@ def test_winperf_measures_dot_with_padding():
pred = numpy.ones((3, 3), dtype=float)
gt = numpy.zeros((3, 3), dtype=bool)
gt[1, 1] = 1.0 # white dot pattern
mask = None
threshold = 0.5
size = (2, 2)
stride = (2, 2)
......@@ -80,7 +85,7 @@ def test_winperf_measures_dot_with_padding():
[(1, 3, 0, 0), (0, 2, 2, 0)],
[(0, 2, 2, 0), (0, 1, 3, 0)],
]
_check_window_measures(pred, gt, threshold, size, stride, expected)
_check_window_measures(pred, gt, mask, threshold, size, stride, expected)
def test_winperf_measures_cross():
......@@ -92,6 +97,7 @@ def test_winperf_measures_cross():
gt = numpy.zeros((5, 5), dtype=bool)
gt[2, :] = 1.0
gt[:, 2] = 1.0 # white cross pattern
mask = None
threshold = 0.5
size = (3, 3)
stride = (1, 1)
......@@ -102,7 +108,7 @@ def test_winperf_measures_cross():
[(4, 0, 4, 1), (4, 0, 4, 1), (4, 0, 4, 1)],
[(4, 0, 4, 1), (4, 0, 4, 1), (4, 0, 4, 1)],
]
_check_window_measures(pred, gt, threshold, size, stride, expected)
_check_window_measures(pred, gt, mask, threshold, size, stride, expected)
def test_winperf_measures_cross_with_padding():
......@@ -111,6 +117,7 @@ def test_winperf_measures_cross_with_padding():
gt = numpy.zeros((5, 5), dtype=bool)
gt[2, :] = 1.0
gt[:, 2] = 1.0 # white cross pattern
mask = None
threshold = 0.5
size = (4, 4)
stride = (2, 2)
......@@ -120,7 +127,7 @@ def test_winperf_measures_cross_with_padding():
[(0, 0, 9, 7), (0, 0, 10, 6)],
[(0, 0, 10, 6), (0, 0, 11, 5)],
]
_check_window_measures(pred, gt, threshold, size, stride, expected)
_check_window_measures(pred, gt, mask, threshold, size, stride, expected)
def test_winperf_measures_cross_with_padding_2():
......@@ -132,6 +139,7 @@ def test_winperf_measures_cross_with_padding_2():
gt = numpy.zeros((5, 5), dtype=bool)
gt[2, :] = 1.0
gt[:, 2] = 1.0 # white cross pattern
mask = None
threshold = 0.5
size = (4, 4)
stride = (2, 2)
......@@ -141,17 +149,19 @@ def test_winperf_measures_cross_with_padding_2():
[(6, 0, 9, 1), (5, 0, 10, 1)],
[(5, 0, 10, 1), (4, 0, 11, 1)],
]
_check_window_measures(pred, gt, threshold, size, stride, expected)
_check_window_measures(pred, gt, mask, threshold, size, stride, expected)
def _check_performance_summary(pred, gt, threshold, size, stride, s, figure):
def _check_performance_summary(pred, gt, mask, threshold, size, stride, s, figure):
figsize = pred.shape
pred = torch.tensor(pred)
gt = torch.tensor(gt)
if mask is None:
mask = torch.ones_like(gt)
# notice _winperf_measures() was previously tested (above)
measures = _winperf_measures(pred, gt, threshold, size, stride)
measures = _winperf_measures(pred, gt, mask, threshold, size, stride)
n_actual, avg_actual, std_actual = _performance_summary(
figsize, measures, size, stride, figure
......@@ -191,6 +201,7 @@ def test_performance_summary_alltrue_accuracy():
pred = numpy.ones((4, 4), dtype=float)
gt = numpy.ones((4, 4), dtype=bool)
mask = None
threshold = 0.5
size = (2, 2)
stride = (1, 1)
......@@ -221,7 +232,7 @@ def test_performance_summary_alltrue_accuracy():
for fig in PERFORMANCE_FIGURES:
_check_performance_summary(
pred, gt, threshold, size, stride, stats, fig,
pred, gt, mask, threshold, size, stride, stats, fig,
)
......@@ -234,6 +245,7 @@ def test_performance_summary_cross():
gt = numpy.zeros((5, 5), dtype=bool)
gt[2, :] = 1.0
gt[:, 2] = 1.0 # white cross pattern
mask = None
threshold = 0.5
size = (3, 3)
stride = (1, 1)
......@@ -257,7 +269,7 @@ def test_performance_summary_cross():
for fig in PERFORMANCE_FIGURES:
_check_performance_summary(
pred, gt, threshold, size, stride, stats, fig,
pred, gt, mask, threshold, size, stride, stats, fig,
)
......@@ -267,6 +279,7 @@ def test_performance_summary_cross_with_padding():
gt = numpy.zeros((5, 5), dtype=bool)
gt[2, :] = 1.0
gt[:, 2] = 1.0 # white cross pattern
mask = None
threshold = 0.5
size = (4, 4)
stride = (2, 2)
......@@ -292,7 +305,7 @@ def test_performance_summary_cross_with_padding():
for fig in PERFORMANCE_FIGURES:
_check_performance_summary(
pred, gt, threshold, size, stride, stats, fig,
pred, gt, mask, threshold, size, stride, stats, fig,
)
......@@ -305,6 +318,7 @@ def test_performance_summary_cross_with_padding_2():
gt = numpy.zeros((5, 5), dtype=bool)
gt[2, :] = 1.0
gt[:, 2] = 1.0 # white cross pattern
mask = None
threshold = 0.5
size = (4, 4)
stride = (2, 2)
......@@ -330,5 +344,5 @@ def test_performance_summary_cross_with_padding_2():
for fig in PERFORMANCE_FIGURES:
_check_performance_summary(
pred, gt, threshold, size, stride, stats, fig,
pred, gt, mask, threshold, size, stride, stats, fig,
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment