Skip to content
Snippets Groups Projects
Commit 6b1409f9 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[data.utils] Fix SSLDataset and test it (limited)

parent aac49966
No related branches found
No related tags found
1 merge request!12Streamlining
Pipeline #38965 passed
...@@ -228,6 +228,6 @@ class SSLDataset(torch.utils.data.Dataset): ...@@ -228,6 +228,6 @@ class SSLDataset(torch.utils.data.Dataset):
retval = self.labelled[index] retval = self.labelled[index]
# gets one an unlabelled sample randomly to follow the labelled sample # gets one an unlabelled sample randomly to follow the labelled sample
unlab = self.unlabelled[torch.randint(len(self.unlabelled))] unlab = self.unlabelled[torch.randint(len(self.unlabelled), ())]
# only interested in key and data # only interested in key and data
return retval + unlab[:2] return retval + unlab[:2]
...@@ -329,6 +329,34 @@ def test_covd_drive(): ...@@ -329,6 +329,34 @@ def test_covd_drive():
nose.tools.eq_(sample[3].dtype, torch.float32) nose.tools.eq_(sample[3].dtype, torch.float32)
@rc_variable_set("bob.ip.binseg.drive.datadir")
@rc_variable_set("bob.ip.binseg.stare.datadir")
@rc_variable_set("bob.ip.binseg.chasedb1.datadir")
@rc_variable_set("bob.ip.binseg.hrf.datadir")
@rc_variable_set("bob.ip.binseg.iostar.datadir")
def test_covd_drive_ssl():
from ..configs.datasets.covd_drive_ssl import dataset
nose.tools.eq_(len(dataset), 53)
for sample in dataset:
assert 5 <= len(sample) <= 6
assert isinstance(sample[0], str)
nose.tools.eq_(sample[1].shape, (3, 544, 544)) #planes, height, width
nose.tools.eq_(sample[1].dtype, torch.float32)
nose.tools.eq_(sample[2].shape, (1, 544, 544)) #planes, height, width
nose.tools.eq_(sample[2].dtype, torch.float32)
if len(sample) == 6:
nose.tools.eq_(sample[3].shape, (1, 544, 544)) #planes, height, width
nose.tools.eq_(sample[3].dtype, torch.float32)
assert isinstance(sample[4], str)
nose.tools.eq_(sample[5].shape, (3, 544, 544)) #planes, height, width
nose.tools.eq_(sample[5].dtype, torch.float32)
else:
assert isinstance(sample[3], str)
nose.tools.eq_(sample[4].shape, (3, 544, 544)) #planes, height, width
nose.tools.eq_(sample[4].dtype, torch.float32)
@rc_variable_set("bob.ip.binseg.drive.datadir") @rc_variable_set("bob.ip.binseg.drive.datadir")
@rc_variable_set("bob.ip.binseg.chasedb1.datadir") @rc_variable_set("bob.ip.binseg.chasedb1.datadir")
@rc_variable_set("bob.ip.binseg.hrf.datadir") @rc_variable_set("bob.ip.binseg.hrf.datadir")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment