From 4a8aaf386ac71174a24c2f2cf4eb485bd9b13fd5 Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Tue, 19 May 2020 13:25:18 +0200 Subject: [PATCH] [script.evaluate] Fix second-annotator comparisons with a key-resemblance check; Add assertions in engine.evaluator to ensure proper comparisons --- bob/ip/binseg/data/utils.py | 37 ++++++++++++++++++++++++++----- bob/ip/binseg/engine/evaluator.py | 9 +++++++- bob/ip/binseg/script/evaluate.py | 7 +++++- 3 files changed, 45 insertions(+), 8 deletions(-) diff --git a/bob/ip/binseg/data/utils.py b/bob/ip/binseg/data/utils.py index 2d1439a6..6bdd727b 100644 --- a/bob/ip/binseg/data/utils.py +++ b/bob/ip/binseg/data/utils.py @@ -122,7 +122,8 @@ class SampleListDataset(torch.utils.data.Dataset): A transform object can be passed that will be applied to the image, ground truth and mask (if present). - It supports indexing such that dataset[i] can be used to get ith sample. + It supports indexing such that dataset[i] can be used to get the i-th + sample. Attributes @@ -175,6 +176,18 @@ class SampleListDataset(torch.utils.data.Dataset): return SampleListDataset(self._samples, transforms or self.transforms) + def keys(self): + """Generator producing all keys for all samples""" + for k in self._samples: + yield k.key + + def all_keys_match(self, other): + """Compares all keys to ``other``, return ``True`` if all match + """ + return len(self) == len(other) and all( + [(ks == ko) for ks, ko in zip(self.keys(), other.keys())] + ) + def __len__(self): """ @@ -249,8 +262,20 @@ class SSLDataset(torch.utils.data.Dataset): """ def __init__(self, labelled, unlabelled): - self.labelled = labelled - self.unlabelled = unlabelled + self._labelled = labelled + self._unlabelled = unlabelled + + def keys(self): + """Generator producing all keys for all samples""" + for k in self._labelled + self._unlabelled: + yield k.key + + def all_keys_match(self, other): + """Compares all keys to ``other``, return ``True`` if all match + """ + return len(self) == len(other) and all( + [(ks == ko) for ks, ko in zip(self.keys(), other.keys())] + ) def __len__(self): """ @@ -263,7 +288,7 @@ class SSLDataset(torch.utils.data.Dataset): """ - return len(self.labelled) + return len(self._labelled) def __getitem__(self, index): """ @@ -281,8 +306,8 @@ class SSLDataset(torch.utils.data.Dataset): """ - retval = self.labelled[index] + retval = self._labelled[index] # gets one an unlabelled sample randomly to follow the labelled sample - unlab = self.unlabelled[torch.randint(len(self.unlabelled), ())] + unlab = self._unlabelled[torch.randint(len(self._unlabelled), ())] # only interested in key and data return retval + unlab[:2] diff --git a/bob/ip/binseg/engine/evaluator.py b/bob/ip/binseg/engine/evaluator.py index eab56b66..66d5d40b 100644 --- a/bob/ip/binseg/engine/evaluator.py +++ b/bob/ip/binseg/engine/evaluator.py @@ -363,7 +363,8 @@ def compare_annotators(baseline, other, name, output_folder, other : py:class:`torch.utils.data.Dataset` a second dataset, with the same samples as ``baseline``, but annotated - by a different annotator than in the first dataset. + by a different annotator than in the first dataset. The key values + must much between ``baseline`` and this dataset. name : str the local name of this dataset (e.g. ``train-second-annotator``, or @@ -387,6 +388,12 @@ def compare_annotators(baseline, other, name, output_folder, for baseline_sample, other_sample in tqdm( list(zip(baseline, other)), desc="samples", leave=False, disable=None, ): + assert baseline_sample[0] == other_sample[0], f"Mismatch between " \ + f"datasets for second-annotator analysis " \ + f"({baseline_sample[0]} != {other_sample[0]}). This " \ + f"typically occurs when the second annotator (`other`) " \ + f"comes from a different dataset than the `baseline` dataset" + stem = baseline_sample[0] image = baseline_sample[1] gt = baseline_sample[2] diff --git a/bob/ip/binseg/script/evaluate.py b/bob/ip/binseg/script/evaluate.py index 2e558671..be3d8bf2 100644 --- a/bob/ip/binseg/script/evaluate.py +++ b/bob/ip/binseg/script/evaluate.py @@ -189,4 +189,9 @@ def evaluate( steps=steps) second = second_annotator.get(k) if second is not None: - compare_annotators(v, second, k, output_folder, overlayed) + if not second.all_keys_match(v): + logger.warn(f"Key mismatch between `dataset[{k}]` and " \ + f"`second_annotator[{k}]` - skipping " \ + f"second-annotator comparisons for {k} subset") + else: + compare_annotators(v, second, k, output_folder, overlayed) -- GitLab