From 4a219b48a52aa321699407ed48f527a6264dfda7 Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Wed, 22 Jul 2020 15:11:15 +0200
Subject: [PATCH] [test] Fix tests that diverged

---
 bob/ip/binseg/data/refuge/__init__.py |  7 ++++---
 bob/ip/binseg/test/test_config.py     | 22 +++++++++++++++++++---
 2 files changed, 23 insertions(+), 6 deletions(-)

diff --git a/bob/ip/binseg/data/refuge/__init__.py b/bob/ip/binseg/data/refuge/__init__.py
index addcfca9..e86e9361 100644
--- a/bob/ip/binseg/data/refuge/__init__.py
+++ b/bob/ip/binseg/data/refuge/__init__.py
@@ -54,8 +54,9 @@ def _disc_loader(sample):
     retval = dict(
         data=load_pil_rgb(os.path.join(_root_path, sample["data"])),
         label=load_pil_rgb(os.path.join(_root_path, sample["label"])),
-        glaucoma=sample["glaucoma"],
     )
+    if "glaucoma" in sample:
+        retval["glaucoma"] = sample["glaucoma"]
     retval["label"] = retval["label"].convert("L")
     retval["label"] = retval["label"].point(lambda p: p <= 150, mode="1")
     return retval
@@ -65,8 +66,9 @@ def _cup_loader(sample):
     retval = dict(
         data=load_pil_rgb(os.path.join(_root_path, sample["data"])),
         label=load_pil_rgb(os.path.join(_root_path, sample["label"])),
-        glaucoma=sample["glaucoma"],
     )
+    if "glaucoma" in sample:
+        retval["glaucoma"] = sample["glaucoma"]
     retval["label"] = retval["label"].convert("L")
     retval["label"] = retval["label"].point(lambda p: p <= 100, mode="1")
     return retval
@@ -74,7 +76,6 @@ def _cup_loader(sample):
 
 def _loader(context, sample):
 
-    sample["glaucoma"] = False
     if context["subset"] == "train":
         # adds binary metadata for glaucoma/non-glaucoma patients
         sample["glaucoma"] = os.path.basename(sample["label"]).startswith("g")
diff --git a/bob/ip/binseg/test/test_config.py b/bob/ip/binseg/test/test_config.py
index d918118c..c5d0f376 100644
--- a/bob/ip/binseg/test/test_config.py
+++ b/bob/ip/binseg/test/test_config.py
@@ -351,11 +351,27 @@ def test_hrf():
             assert s[1].max() <= 1.0
             assert s[1].min() >= 0.0
 
+    def _check_subset_fullres(samples, size):
+        nose.tools.eq_(len(samples), size)
+        for s in samples:
+            nose.tools.eq_(len(s), 4)
+            assert isinstance(s[0], str)
+            nose.tools.eq_(s[1].shape, (3, 2336, 3296)) #planes, height, width
+            nose.tools.eq_(s[1].dtype, torch.float32)
+            nose.tools.eq_(s[2].shape, (1, 2336, 3296)) #planes, height, width
+            nose.tools.eq_(s[2].dtype, torch.float32)
+            nose.tools.eq_(s[3].shape, (1, 2336, 3296)) #planes, height, width
+            nose.tools.eq_(s[3].dtype, torch.float32)
+            assert s[1].max() <= 1.0
+            assert s[1].min() >= 0.0
+
     from ..configs.datasets.hrf.default import dataset
-    nose.tools.eq_(len(dataset), 4)
+    nose.tools.eq_(len(dataset), 6)
     _check_subset(dataset["__train__"], 15)
     _check_subset(dataset["train"], 15)
     _check_subset(dataset["test"], 30)
+    _check_subset_fullres(dataset["train (full resolution)"], 15)
+    _check_subset_fullres(dataset["test (full resolution)"], 30)
 
 
 @rc_variable_set("bob.ip.binseg.drive.datadir")
@@ -366,7 +382,7 @@ def test_hrf():
 def test_hrf_mtest():
 
     from ..configs.datasets.hrf.mtest import dataset
-    nose.tools.eq_(len(dataset), 10)
+    nose.tools.eq_(len(dataset), 12)
 
     from ..configs.datasets.hrf.default import dataset as baseline
     nose.tools.eq_(dataset["train"], baseline["train"])
@@ -395,7 +411,7 @@ def test_hrf_mtest():
 def test_hrf_covd():
 
     from ..configs.datasets.hrf.covd import dataset
-    nose.tools.eq_(len(dataset), 4)
+    nose.tools.eq_(len(dataset), 6)
 
     from ..configs.datasets.hrf.default import dataset as baseline
     nose.tools.eq_(dataset["train"], dataset["__valid__"])
-- 
GitLab