From dc71e94f3cb41bb7099a64c98bc80459cff4df2b Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Mon, 20 Apr 2020 11:21:49 +0200 Subject: [PATCH] [data.utils] SampleList2TorchDataset: support for slicing --- bob/ip/binseg/data/utils.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/bob/ip/binseg/data/utils.py b/bob/ip/binseg/data/utils.py index 9a5f40a1..8ca05524 100644 --- a/bob/ip/binseg/data/utils.py +++ b/bob/ip/binseg/data/utils.py @@ -138,13 +138,13 @@ class SampleList2TorchDataset(torch.utils.data.Dataset): """ return len(self._samples) - def __getitem__(self, index): + def __getitem__(self, key): """ Parameters ---------- - index : int + key : int, slice Returns ------- @@ -154,17 +154,20 @@ class SampleList2TorchDataset(torch.utils.data.Dataset): """ - item = self._samples[index] - data = item.data # triggers data loading + if isinstance(key, slice): + return [self[k] for k in range(*key.indices(len(self)))] + else: # we try it as an int + item = self._samples[key] + data = item.data # triggers data loading - retval = [data["data"]] - if "label" in data: retval.append(data["label"]) - if "mask" in data: retval.append(data["mask"]) + retval = [data["data"]] + if "label" in data: retval.append(data["label"]) + if "mask" in data: retval.append(data["mask"]) - if self._transform: - retval = self._transform(*retval) + if self._transform: + retval = self._transform(*retval) - return [item.key] + retval + return [item.key] + retval class SSLDataset(torch.utils.data.Dataset): -- GitLab