Skip to content
Snippets Groups Projects
Commit 4a8aaf38 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[script.evaluate] Fix second-annotator comparisons with a key-resemblance...

[script.evaluate] Fix second-annotator comparisons with a key-resemblance check; Add assertions in engine.evaluator to ensure proper comparisons
parent 3ea648dd
No related branches found
No related tags found
No related merge requests found
...@@ -122,7 +122,8 @@ class SampleListDataset(torch.utils.data.Dataset): ...@@ -122,7 +122,8 @@ class SampleListDataset(torch.utils.data.Dataset):
A transform object can be passed that will be applied to the image, ground A transform object can be passed that will be applied to the image, ground
truth and mask (if present). truth and mask (if present).
It supports indexing such that dataset[i] can be used to get ith sample. It supports indexing such that dataset[i] can be used to get the i-th
sample.
Attributes Attributes
...@@ -175,6 +176,18 @@ class SampleListDataset(torch.utils.data.Dataset): ...@@ -175,6 +176,18 @@ class SampleListDataset(torch.utils.data.Dataset):
return SampleListDataset(self._samples, transforms or self.transforms) return SampleListDataset(self._samples, transforms or self.transforms)
def keys(self):
"""Generator producing all keys for all samples"""
for k in self._samples:
yield k.key
def all_keys_match(self, other):
"""Compares all keys to ``other``, return ``True`` if all match
"""
return len(self) == len(other) and all(
[(ks == ko) for ks, ko in zip(self.keys(), other.keys())]
)
def __len__(self): def __len__(self):
""" """
...@@ -249,8 +262,20 @@ class SSLDataset(torch.utils.data.Dataset): ...@@ -249,8 +262,20 @@ class SSLDataset(torch.utils.data.Dataset):
""" """
def __init__(self, labelled, unlabelled): def __init__(self, labelled, unlabelled):
self.labelled = labelled self._labelled = labelled
self.unlabelled = unlabelled self._unlabelled = unlabelled
def keys(self):
"""Generator producing all keys for all samples"""
for k in self._labelled + self._unlabelled:
yield k.key
def all_keys_match(self, other):
"""Compares all keys to ``other``, return ``True`` if all match
"""
return len(self) == len(other) and all(
[(ks == ko) for ks, ko in zip(self.keys(), other.keys())]
)
def __len__(self): def __len__(self):
""" """
...@@ -263,7 +288,7 @@ class SSLDataset(torch.utils.data.Dataset): ...@@ -263,7 +288,7 @@ class SSLDataset(torch.utils.data.Dataset):
""" """
return len(self.labelled) return len(self._labelled)
def __getitem__(self, index): def __getitem__(self, index):
""" """
...@@ -281,8 +306,8 @@ class SSLDataset(torch.utils.data.Dataset): ...@@ -281,8 +306,8 @@ class SSLDataset(torch.utils.data.Dataset):
""" """
retval = self.labelled[index] retval = self._labelled[index]
# gets one an unlabelled sample randomly to follow the labelled sample # gets one an unlabelled sample randomly to follow the labelled sample
unlab = self.unlabelled[torch.randint(len(self.unlabelled), ())] unlab = self._unlabelled[torch.randint(len(self._unlabelled), ())]
# only interested in key and data # only interested in key and data
return retval + unlab[:2] return retval + unlab[:2]
...@@ -363,7 +363,8 @@ def compare_annotators(baseline, other, name, output_folder, ...@@ -363,7 +363,8 @@ def compare_annotators(baseline, other, name, output_folder,
other : py:class:`torch.utils.data.Dataset` other : py:class:`torch.utils.data.Dataset`
a second dataset, with the same samples as ``baseline``, but annotated a second dataset, with the same samples as ``baseline``, but annotated
by a different annotator than in the first dataset. by a different annotator than in the first dataset. The key values
must much between ``baseline`` and this dataset.
name : str name : str
the local name of this dataset (e.g. ``train-second-annotator``, or the local name of this dataset (e.g. ``train-second-annotator``, or
...@@ -387,6 +388,12 @@ def compare_annotators(baseline, other, name, output_folder, ...@@ -387,6 +388,12 @@ def compare_annotators(baseline, other, name, output_folder,
for baseline_sample, other_sample in tqdm( for baseline_sample, other_sample in tqdm(
list(zip(baseline, other)), desc="samples", leave=False, disable=None, list(zip(baseline, other)), desc="samples", leave=False, disable=None,
): ):
assert baseline_sample[0] == other_sample[0], f"Mismatch between " \
f"datasets for second-annotator analysis " \
f"({baseline_sample[0]} != {other_sample[0]}). This " \
f"typically occurs when the second annotator (`other`) " \
f"comes from a different dataset than the `baseline` dataset"
stem = baseline_sample[0] stem = baseline_sample[0]
image = baseline_sample[1] image = baseline_sample[1]
gt = baseline_sample[2] gt = baseline_sample[2]
......
...@@ -189,4 +189,9 @@ def evaluate( ...@@ -189,4 +189,9 @@ def evaluate(
steps=steps) steps=steps)
second = second_annotator.get(k) second = second_annotator.get(k)
if second is not None: if second is not None:
compare_annotators(v, second, k, output_folder, overlayed) if not second.all_keys_match(v):
logger.warn(f"Key mismatch between `dataset[{k}]` and " \
f"`second_annotator[{k}]` - skipping " \
f"second-annotator comparisons for {k} subset")
else:
compare_annotators(v, second, k, output_folder, overlayed)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment