diff --git a/bob/ip/binseg/data/utils.py b/bob/ip/binseg/data/utils.py index 9a5f40a1a3d1f3f30b589a2eeac0bcb508558ef7..8ca055245c9694dc15ef479f008bba8a0660875c 100644 --- a/bob/ip/binseg/data/utils.py +++ b/bob/ip/binseg/data/utils.py @@ -138,13 +138,13 @@ class SampleList2TorchDataset(torch.utils.data.Dataset): """ return len(self._samples) - def __getitem__(self, index): + def __getitem__(self, key): """ Parameters ---------- - index : int + key : int, slice Returns ------- @@ -154,17 +154,20 @@ class SampleList2TorchDataset(torch.utils.data.Dataset): """ - item = self._samples[index] - data = item.data # triggers data loading + if isinstance(key, slice): + return [self[k] for k in range(*key.indices(len(self)))] + else: # we try it as an int + item = self._samples[key] + data = item.data # triggers data loading - retval = [data["data"]] - if "label" in data: retval.append(data["label"]) - if "mask" in data: retval.append(data["mask"]) + retval = [data["data"]] + if "label" in data: retval.append(data["label"]) + if "mask" in data: retval.append(data["mask"]) - if self._transform: - retval = self._transform(*retval) + if self._transform: + retval = self._transform(*retval) - return [item.key] + retval + return [item.key] + retval class SSLDataset(torch.utils.data.Dataset):