From 6b1409f9273a5c3b11507b96c667892f2b06c194 Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Wed, 15 Apr 2020 21:25:55 +0200
Subject: [PATCH] [data.utils] Fix SSLDataset and test it (limited)

---
 bob/ip/binseg/data/utils.py       |  2 +-
 bob/ip/binseg/test/test_config.py | 28 ++++++++++++++++++++++++++++
 2 files changed, 29 insertions(+), 1 deletion(-)

diff --git a/bob/ip/binseg/data/utils.py b/bob/ip/binseg/data/utils.py
index 58f39f2f..a8046e4b 100644
--- a/bob/ip/binseg/data/utils.py
+++ b/bob/ip/binseg/data/utils.py
@@ -228,6 +228,6 @@ class SSLDataset(torch.utils.data.Dataset):
 
         retval = self.labelled[index]
         # gets one an unlabelled sample randomly to follow the labelled sample
-        unlab = self.unlabelled[torch.randint(len(self.unlabelled))]
+        unlab = self.unlabelled[torch.randint(len(self.unlabelled), ())]
         # only interested in key and data
         return retval + unlab[:2]
diff --git a/bob/ip/binseg/test/test_config.py b/bob/ip/binseg/test/test_config.py
index c40440ca..e61c38db 100644
--- a/bob/ip/binseg/test/test_config.py
+++ b/bob/ip/binseg/test/test_config.py
@@ -329,6 +329,34 @@ def test_covd_drive():
             nose.tools.eq_(sample[3].dtype, torch.float32)
 
 
+@rc_variable_set("bob.ip.binseg.drive.datadir")
+@rc_variable_set("bob.ip.binseg.stare.datadir")
+@rc_variable_set("bob.ip.binseg.chasedb1.datadir")
+@rc_variable_set("bob.ip.binseg.hrf.datadir")
+@rc_variable_set("bob.ip.binseg.iostar.datadir")
+def test_covd_drive_ssl():
+
+    from ..configs.datasets.covd_drive_ssl import dataset
+    nose.tools.eq_(len(dataset), 53)
+    for sample in dataset:
+        assert 5 <= len(sample) <= 6
+        assert isinstance(sample[0], str)
+        nose.tools.eq_(sample[1].shape, (3, 544, 544)) #planes, height, width
+        nose.tools.eq_(sample[1].dtype, torch.float32)
+        nose.tools.eq_(sample[2].shape, (1, 544, 544)) #planes, height, width
+        nose.tools.eq_(sample[2].dtype, torch.float32)
+        if len(sample) == 6:
+            nose.tools.eq_(sample[3].shape, (1, 544, 544)) #planes, height, width
+            nose.tools.eq_(sample[3].dtype, torch.float32)
+            assert isinstance(sample[4], str)
+            nose.tools.eq_(sample[5].shape, (3, 544, 544)) #planes, height, width
+            nose.tools.eq_(sample[5].dtype, torch.float32)
+        else:
+            assert isinstance(sample[3], str)
+            nose.tools.eq_(sample[4].shape, (3, 544, 544)) #planes, height, width
+            nose.tools.eq_(sample[4].dtype, torch.float32)
+
+
 @rc_variable_set("bob.ip.binseg.drive.datadir")
 @rc_variable_set("bob.ip.binseg.chasedb1.datadir")
 @rc_variable_set("bob.ip.binseg.hrf.datadir")
-- 
GitLab