diff --git a/bob/ip/binseg/data/utils.py b/bob/ip/binseg/data/utils.py index 2d1439a63e774d7270e76ca9b1fd3f030d6aec72..6bdd727b7c63bdff0bf99383ca4d394d2ba87876 100644 --- a/bob/ip/binseg/data/utils.py +++ b/bob/ip/binseg/data/utils.py @@ -122,7 +122,8 @@ class SampleListDataset(torch.utils.data.Dataset): A transform object can be passed that will be applied to the image, ground truth and mask (if present). - It supports indexing such that dataset[i] can be used to get ith sample. + It supports indexing such that dataset[i] can be used to get the i-th + sample. Attributes @@ -175,6 +176,18 @@ class SampleListDataset(torch.utils.data.Dataset): return SampleListDataset(self._samples, transforms or self.transforms) + def keys(self): + """Generator producing all keys for all samples""" + for k in self._samples: + yield k.key + + def all_keys_match(self, other): + """Compares all keys to ``other``, return ``True`` if all match + """ + return len(self) == len(other) and all( + [(ks == ko) for ks, ko in zip(self.keys(), other.keys())] + ) + def __len__(self): """ @@ -249,8 +262,20 @@ class SSLDataset(torch.utils.data.Dataset): """ def __init__(self, labelled, unlabelled): - self.labelled = labelled - self.unlabelled = unlabelled + self._labelled = labelled + self._unlabelled = unlabelled + + def keys(self): + """Generator producing all keys for all samples""" + for k in self._labelled + self._unlabelled: + yield k.key + + def all_keys_match(self, other): + """Compares all keys to ``other``, return ``True`` if all match + """ + return len(self) == len(other) and all( + [(ks == ko) for ks, ko in zip(self.keys(), other.keys())] + ) def __len__(self): """ @@ -263,7 +288,7 @@ class SSLDataset(torch.utils.data.Dataset): """ - return len(self.labelled) + return len(self._labelled) def __getitem__(self, index): """ @@ -281,8 +306,8 @@ class SSLDataset(torch.utils.data.Dataset): """ - retval = self.labelled[index] + retval = self._labelled[index] # gets one an unlabelled sample randomly to follow the labelled sample - unlab = self.unlabelled[torch.randint(len(self.unlabelled), ())] + unlab = self._unlabelled[torch.randint(len(self._unlabelled), ())] # only interested in key and data return retval + unlab[:2] diff --git a/bob/ip/binseg/engine/evaluator.py b/bob/ip/binseg/engine/evaluator.py index eab56b66290bf8c5a92ce01cdb10595de9d0e47e..66d5d40bf78a022b9c507e358f0b7fbf11016812 100644 --- a/bob/ip/binseg/engine/evaluator.py +++ b/bob/ip/binseg/engine/evaluator.py @@ -363,7 +363,8 @@ def compare_annotators(baseline, other, name, output_folder, other : py:class:`torch.utils.data.Dataset` a second dataset, with the same samples as ``baseline``, but annotated - by a different annotator than in the first dataset. + by a different annotator than in the first dataset. The key values + must much between ``baseline`` and this dataset. name : str the local name of this dataset (e.g. ``train-second-annotator``, or @@ -387,6 +388,12 @@ def compare_annotators(baseline, other, name, output_folder, for baseline_sample, other_sample in tqdm( list(zip(baseline, other)), desc="samples", leave=False, disable=None, ): + assert baseline_sample[0] == other_sample[0], f"Mismatch between " \ + f"datasets for second-annotator analysis " \ + f"({baseline_sample[0]} != {other_sample[0]}). This " \ + f"typically occurs when the second annotator (`other`) " \ + f"comes from a different dataset than the `baseline` dataset" + stem = baseline_sample[0] image = baseline_sample[1] gt = baseline_sample[2] diff --git a/bob/ip/binseg/script/evaluate.py b/bob/ip/binseg/script/evaluate.py index 2e558671d1d7bd4919f37f3a63d160387c1988f2..be3d8bf2a2edc7be425707b41e6fcdc370fdeea9 100644 --- a/bob/ip/binseg/script/evaluate.py +++ b/bob/ip/binseg/script/evaluate.py @@ -189,4 +189,9 @@ def evaluate( steps=steps) second = second_annotator.get(k) if second is not None: - compare_annotators(v, second, k, output_folder, overlayed) + if not second.all_keys_match(v): + logger.warn(f"Key mismatch between `dataset[{k}]` and " \ + f"`second_annotator[{k}]` - skipping " \ + f"second-annotator comparisons for {k} subset") + else: + compare_annotators(v, second, k, output_folder, overlayed)