diff --git a/bob/ip/binseg/data/utils.py b/bob/ip/binseg/data/utils.py index 58f39f2fce678b78471b1ea013866747b28b0358..a8046e4bc2b0ebf466a3e0cdefdc0d93b2a5810f 100644 --- a/bob/ip/binseg/data/utils.py +++ b/bob/ip/binseg/data/utils.py @@ -228,6 +228,6 @@ class SSLDataset(torch.utils.data.Dataset): retval = self.labelled[index] # gets one an unlabelled sample randomly to follow the labelled sample - unlab = self.unlabelled[torch.randint(len(self.unlabelled))] + unlab = self.unlabelled[torch.randint(len(self.unlabelled), ())] # only interested in key and data return retval + unlab[:2] diff --git a/bob/ip/binseg/test/test_config.py b/bob/ip/binseg/test/test_config.py index c40440ca50d988c2815c7945cb60e226b8f5823d..e61c38dbe3534e345befc6e279ce81950043d4c2 100644 --- a/bob/ip/binseg/test/test_config.py +++ b/bob/ip/binseg/test/test_config.py @@ -329,6 +329,34 @@ def test_covd_drive(): nose.tools.eq_(sample[3].dtype, torch.float32) +@rc_variable_set("bob.ip.binseg.drive.datadir") +@rc_variable_set("bob.ip.binseg.stare.datadir") +@rc_variable_set("bob.ip.binseg.chasedb1.datadir") +@rc_variable_set("bob.ip.binseg.hrf.datadir") +@rc_variable_set("bob.ip.binseg.iostar.datadir") +def test_covd_drive_ssl(): + + from ..configs.datasets.covd_drive_ssl import dataset + nose.tools.eq_(len(dataset), 53) + for sample in dataset: + assert 5 <= len(sample) <= 6 + assert isinstance(sample[0], str) + nose.tools.eq_(sample[1].shape, (3, 544, 544)) #planes, height, width + nose.tools.eq_(sample[1].dtype, torch.float32) + nose.tools.eq_(sample[2].shape, (1, 544, 544)) #planes, height, width + nose.tools.eq_(sample[2].dtype, torch.float32) + if len(sample) == 6: + nose.tools.eq_(sample[3].shape, (1, 544, 544)) #planes, height, width + nose.tools.eq_(sample[3].dtype, torch.float32) + assert isinstance(sample[4], str) + nose.tools.eq_(sample[5].shape, (3, 544, 544)) #planes, height, width + nose.tools.eq_(sample[5].dtype, torch.float32) + else: + assert isinstance(sample[3], str) + nose.tools.eq_(sample[4].shape, (3, 544, 544)) #planes, height, width + nose.tools.eq_(sample[4].dtype, torch.float32) + + @rc_variable_set("bob.ip.binseg.drive.datadir") @rc_variable_set("bob.ip.binseg.chasedb1.datadir") @rc_variable_set("bob.ip.binseg.hrf.datadir")