Skip to content
Snippets Groups Projects
Commit dc71e94f authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[data.utils] SampleList2TorchDataset: support for slicing

parent 77a8a6ad
No related branches found
No related tags found
1 merge request!12Streamlining
......@@ -138,13 +138,13 @@ class SampleList2TorchDataset(torch.utils.data.Dataset):
"""
return len(self._samples)
def __getitem__(self, index):
def __getitem__(self, key):
"""
Parameters
----------
index : int
key : int, slice
Returns
-------
......@@ -154,17 +154,20 @@ class SampleList2TorchDataset(torch.utils.data.Dataset):
"""
item = self._samples[index]
data = item.data # triggers data loading
if isinstance(key, slice):
return [self[k] for k in range(*key.indices(len(self)))]
else: # we try it as an int
item = self._samples[key]
data = item.data # triggers data loading
retval = [data["data"]]
if "label" in data: retval.append(data["label"])
if "mask" in data: retval.append(data["mask"])
retval = [data["data"]]
if "label" in data: retval.append(data["label"])
if "mask" in data: retval.append(data["mask"])
if self._transform:
retval = self._transform(*retval)
if self._transform:
retval = self._transform(*retval)
return [item.key] + retval
return [item.key] + retval
class SSLDataset(torch.utils.data.Dataset):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment