diff --git a/bob/ip/binseg/data/utils.py b/bob/ip/binseg/data/utils.py
index 9a5f40a1a3d1f3f30b589a2eeac0bcb508558ef7..8ca055245c9694dc15ef479f008bba8a0660875c 100644
--- a/bob/ip/binseg/data/utils.py
+++ b/bob/ip/binseg/data/utils.py
@@ -138,13 +138,13 @@ class SampleList2TorchDataset(torch.utils.data.Dataset):
         """
         return len(self._samples)
 
-    def __getitem__(self, index):
+    def __getitem__(self, key):
         """
 
         Parameters
         ----------
 
-        index : int
+        key : int, slice
 
         Returns
         -------
@@ -154,17 +154,20 @@ class SampleList2TorchDataset(torch.utils.data.Dataset):
 
         """
 
-        item = self._samples[index]
-        data = item.data  # triggers data loading
+        if isinstance(key, slice):
+            return [self[k] for k in range(*key.indices(len(self)))]
+        else:  # we try it as an int
+            item = self._samples[key]
+            data = item.data  # triggers data loading
 
-        retval = [data["data"]]
-        if "label" in data: retval.append(data["label"])
-        if "mask" in data: retval.append(data["mask"])
+            retval = [data["data"]]
+            if "label" in data: retval.append(data["label"])
+            if "mask" in data: retval.append(data["mask"])
 
-        if self._transform:
-            retval = self._transform(*retval)
+            if self._transform:
+                retval = self._transform(*retval)
 
-        return [item.key] + retval
+            return [item.key] + retval
 
 
 class SSLDataset(torch.utils.data.Dataset):