Skip to content
Snippets Groups Projects
Commit dc71e94f authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[data.utils] SampleList2TorchDataset: support for slicing

parent 77a8a6ad
No related branches found
No related tags found
1 merge request!12Streamlining
...@@ -138,13 +138,13 @@ class SampleList2TorchDataset(torch.utils.data.Dataset): ...@@ -138,13 +138,13 @@ class SampleList2TorchDataset(torch.utils.data.Dataset):
""" """
return len(self._samples) return len(self._samples)
def __getitem__(self, index): def __getitem__(self, key):
""" """
Parameters Parameters
---------- ----------
index : int key : int, slice
Returns Returns
------- -------
...@@ -154,17 +154,20 @@ class SampleList2TorchDataset(torch.utils.data.Dataset): ...@@ -154,17 +154,20 @@ class SampleList2TorchDataset(torch.utils.data.Dataset):
""" """
item = self._samples[index] if isinstance(key, slice):
data = item.data # triggers data loading return [self[k] for k in range(*key.indices(len(self)))]
else: # we try it as an int
item = self._samples[key]
data = item.data # triggers data loading
retval = [data["data"]] retval = [data["data"]]
if "label" in data: retval.append(data["label"]) if "label" in data: retval.append(data["label"])
if "mask" in data: retval.append(data["mask"]) if "mask" in data: retval.append(data["mask"])
if self._transform: if self._transform:
retval = self._transform(*retval) retval = self._transform(*retval)
return [item.key] + retval return [item.key] + retval
class SSLDataset(torch.utils.data.Dataset): class SSLDataset(torch.utils.data.Dataset):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment