Skip to content
Snippets Groups Projects
Commit 4a8aaf38 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[script.evaluate] Fix second-annotator comparisons with a key-resemblance...

[script.evaluate] Fix second-annotator comparisons with a key-resemblance check; Add assertions in engine.evaluator to ensure proper comparisons
parent 3ea648dd
No related branches found
No related tags found
No related merge requests found
......@@ -122,7 +122,8 @@ class SampleListDataset(torch.utils.data.Dataset):
A transform object can be passed that will be applied to the image, ground
truth and mask (if present).
It supports indexing such that dataset[i] can be used to get ith sample.
It supports indexing such that dataset[i] can be used to get the i-th
sample.
Attributes
......@@ -175,6 +176,18 @@ class SampleListDataset(torch.utils.data.Dataset):
return SampleListDataset(self._samples, transforms or self.transforms)
def keys(self):
"""Generator producing all keys for all samples"""
for k in self._samples:
yield k.key
def all_keys_match(self, other):
"""Compares all keys to ``other``, return ``True`` if all match
"""
return len(self) == len(other) and all(
[(ks == ko) for ks, ko in zip(self.keys(), other.keys())]
)
def __len__(self):
"""
......@@ -249,8 +262,20 @@ class SSLDataset(torch.utils.data.Dataset):
"""
def __init__(self, labelled, unlabelled):
self.labelled = labelled
self.unlabelled = unlabelled
self._labelled = labelled
self._unlabelled = unlabelled
def keys(self):
"""Generator producing all keys for all samples"""
for k in self._labelled + self._unlabelled:
yield k.key
def all_keys_match(self, other):
"""Compares all keys to ``other``, return ``True`` if all match
"""
return len(self) == len(other) and all(
[(ks == ko) for ks, ko in zip(self.keys(), other.keys())]
)
def __len__(self):
"""
......@@ -263,7 +288,7 @@ class SSLDataset(torch.utils.data.Dataset):
"""
return len(self.labelled)
return len(self._labelled)
def __getitem__(self, index):
"""
......@@ -281,8 +306,8 @@ class SSLDataset(torch.utils.data.Dataset):
"""
retval = self.labelled[index]
retval = self._labelled[index]
# gets one an unlabelled sample randomly to follow the labelled sample
unlab = self.unlabelled[torch.randint(len(self.unlabelled), ())]
unlab = self._unlabelled[torch.randint(len(self._unlabelled), ())]
# only interested in key and data
return retval + unlab[:2]
......@@ -363,7 +363,8 @@ def compare_annotators(baseline, other, name, output_folder,
other : py:class:`torch.utils.data.Dataset`
a second dataset, with the same samples as ``baseline``, but annotated
by a different annotator than in the first dataset.
by a different annotator than in the first dataset. The key values
must much between ``baseline`` and this dataset.
name : str
the local name of this dataset (e.g. ``train-second-annotator``, or
......@@ -387,6 +388,12 @@ def compare_annotators(baseline, other, name, output_folder,
for baseline_sample, other_sample in tqdm(
list(zip(baseline, other)), desc="samples", leave=False, disable=None,
):
assert baseline_sample[0] == other_sample[0], f"Mismatch between " \
f"datasets for second-annotator analysis " \
f"({baseline_sample[0]} != {other_sample[0]}). This " \
f"typically occurs when the second annotator (`other`) " \
f"comes from a different dataset than the `baseline` dataset"
stem = baseline_sample[0]
image = baseline_sample[1]
gt = baseline_sample[2]
......
......@@ -189,4 +189,9 @@ def evaluate(
steps=steps)
second = second_annotator.get(k)
if second is not None:
compare_annotators(v, second, k, output_folder, overlayed)
if not second.all_keys_match(v):
logger.warn(f"Key mismatch between `dataset[{k}]` and " \
f"`second_annotator[{k}]` - skipping " \
f"second-annotator comparisons for {k} subset")
else:
compare_annotators(v, second, k, output_folder, overlayed)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment