From 4a219b48a52aa321699407ed48f527a6264dfda7 Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Wed, 22 Jul 2020 15:11:15 +0200 Subject: [PATCH] [test] Fix tests that diverged --- bob/ip/binseg/data/refuge/__init__.py | 7 ++++--- bob/ip/binseg/test/test_config.py | 22 +++++++++++++++++++--- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/bob/ip/binseg/data/refuge/__init__.py b/bob/ip/binseg/data/refuge/__init__.py index addcfca9..e86e9361 100644 --- a/bob/ip/binseg/data/refuge/__init__.py +++ b/bob/ip/binseg/data/refuge/__init__.py @@ -54,8 +54,9 @@ def _disc_loader(sample): retval = dict( data=load_pil_rgb(os.path.join(_root_path, sample["data"])), label=load_pil_rgb(os.path.join(_root_path, sample["label"])), - glaucoma=sample["glaucoma"], ) + if "glaucoma" in sample: + retval["glaucoma"] = sample["glaucoma"] retval["label"] = retval["label"].convert("L") retval["label"] = retval["label"].point(lambda p: p <= 150, mode="1") return retval @@ -65,8 +66,9 @@ def _cup_loader(sample): retval = dict( data=load_pil_rgb(os.path.join(_root_path, sample["data"])), label=load_pil_rgb(os.path.join(_root_path, sample["label"])), - glaucoma=sample["glaucoma"], ) + if "glaucoma" in sample: + retval["glaucoma"] = sample["glaucoma"] retval["label"] = retval["label"].convert("L") retval["label"] = retval["label"].point(lambda p: p <= 100, mode="1") return retval @@ -74,7 +76,6 @@ def _cup_loader(sample): def _loader(context, sample): - sample["glaucoma"] = False if context["subset"] == "train": # adds binary metadata for glaucoma/non-glaucoma patients sample["glaucoma"] = os.path.basename(sample["label"]).startswith("g") diff --git a/bob/ip/binseg/test/test_config.py b/bob/ip/binseg/test/test_config.py index d918118c..c5d0f376 100644 --- a/bob/ip/binseg/test/test_config.py +++ b/bob/ip/binseg/test/test_config.py @@ -351,11 +351,27 @@ def test_hrf(): assert s[1].max() <= 1.0 assert s[1].min() >= 0.0 + def _check_subset_fullres(samples, size): + nose.tools.eq_(len(samples), size) + for s in samples: + nose.tools.eq_(len(s), 4) + assert isinstance(s[0], str) + nose.tools.eq_(s[1].shape, (3, 2336, 3296)) #planes, height, width + nose.tools.eq_(s[1].dtype, torch.float32) + nose.tools.eq_(s[2].shape, (1, 2336, 3296)) #planes, height, width + nose.tools.eq_(s[2].dtype, torch.float32) + nose.tools.eq_(s[3].shape, (1, 2336, 3296)) #planes, height, width + nose.tools.eq_(s[3].dtype, torch.float32) + assert s[1].max() <= 1.0 + assert s[1].min() >= 0.0 + from ..configs.datasets.hrf.default import dataset - nose.tools.eq_(len(dataset), 4) + nose.tools.eq_(len(dataset), 6) _check_subset(dataset["__train__"], 15) _check_subset(dataset["train"], 15) _check_subset(dataset["test"], 30) + _check_subset_fullres(dataset["train (full resolution)"], 15) + _check_subset_fullres(dataset["test (full resolution)"], 30) @rc_variable_set("bob.ip.binseg.drive.datadir") @@ -366,7 +382,7 @@ def test_hrf(): def test_hrf_mtest(): from ..configs.datasets.hrf.mtest import dataset - nose.tools.eq_(len(dataset), 10) + nose.tools.eq_(len(dataset), 12) from ..configs.datasets.hrf.default import dataset as baseline nose.tools.eq_(dataset["train"], baseline["train"]) @@ -395,7 +411,7 @@ def test_hrf_mtest(): def test_hrf_covd(): from ..configs.datasets.hrf.covd import dataset - nose.tools.eq_(len(dataset), 4) + nose.tools.eq_(len(dataset), 6) from ..configs.datasets.hrf.default import dataset as baseline nose.tools.eq_(dataset["train"], dataset["__valid__"]) -- GitLab