From dc71e94f3cb41bb7099a64c98bc80459cff4df2b Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Mon, 20 Apr 2020 11:21:49 +0200
Subject: [PATCH] [data.utils] SampleList2TorchDataset: support for slicing

---
 bob/ip/binseg/data/utils.py | 23 +++++++++++++----------
 1 file changed, 13 insertions(+), 10 deletions(-)

diff --git a/bob/ip/binseg/data/utils.py b/bob/ip/binseg/data/utils.py
index 9a5f40a1..8ca05524 100644
--- a/bob/ip/binseg/data/utils.py
+++ b/bob/ip/binseg/data/utils.py
@@ -138,13 +138,13 @@ class SampleList2TorchDataset(torch.utils.data.Dataset):
         """
         return len(self._samples)
 
-    def __getitem__(self, index):
+    def __getitem__(self, key):
         """
 
         Parameters
         ----------
 
-        index : int
+        key : int, slice
 
         Returns
         -------
@@ -154,17 +154,20 @@ class SampleList2TorchDataset(torch.utils.data.Dataset):
 
         """
 
-        item = self._samples[index]
-        data = item.data  # triggers data loading
+        if isinstance(key, slice):
+            return [self[k] for k in range(*key.indices(len(self)))]
+        else:  # we try it as an int
+            item = self._samples[key]
+            data = item.data  # triggers data loading
 
-        retval = [data["data"]]
-        if "label" in data: retval.append(data["label"])
-        if "mask" in data: retval.append(data["mask"])
+            retval = [data["data"]]
+            if "label" in data: retval.append(data["label"])
+            if "mask" in data: retval.append(data["mask"])
 
-        if self._transform:
-            retval = self._transform(*retval)
+            if self._transform:
+                retval = self._transform(*retval)
 
-        return [item.key] + retval
+            return [item.key] + retval
 
 
 class SSLDataset(torch.utils.data.Dataset):
-- 
GitLab