Skip to content
Snippets Groups Projects
Commit 4a219b48 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[test] Fix tests that diverged

parent 7802eab3
No related branches found
No related tags found
No related merge requests found
......@@ -54,8 +54,9 @@ def _disc_loader(sample):
retval = dict(
data=load_pil_rgb(os.path.join(_root_path, sample["data"])),
label=load_pil_rgb(os.path.join(_root_path, sample["label"])),
glaucoma=sample["glaucoma"],
)
if "glaucoma" in sample:
retval["glaucoma"] = sample["glaucoma"]
retval["label"] = retval["label"].convert("L")
retval["label"] = retval["label"].point(lambda p: p <= 150, mode="1")
return retval
......@@ -65,8 +66,9 @@ def _cup_loader(sample):
retval = dict(
data=load_pil_rgb(os.path.join(_root_path, sample["data"])),
label=load_pil_rgb(os.path.join(_root_path, sample["label"])),
glaucoma=sample["glaucoma"],
)
if "glaucoma" in sample:
retval["glaucoma"] = sample["glaucoma"]
retval["label"] = retval["label"].convert("L")
retval["label"] = retval["label"].point(lambda p: p <= 100, mode="1")
return retval
......@@ -74,7 +76,6 @@ def _cup_loader(sample):
def _loader(context, sample):
sample["glaucoma"] = False
if context["subset"] == "train":
# adds binary metadata for glaucoma/non-glaucoma patients
sample["glaucoma"] = os.path.basename(sample["label"]).startswith("g")
......
......@@ -351,11 +351,27 @@ def test_hrf():
assert s[1].max() <= 1.0
assert s[1].min() >= 0.0
def _check_subset_fullres(samples, size):
nose.tools.eq_(len(samples), size)
for s in samples:
nose.tools.eq_(len(s), 4)
assert isinstance(s[0], str)
nose.tools.eq_(s[1].shape, (3, 2336, 3296)) #planes, height, width
nose.tools.eq_(s[1].dtype, torch.float32)
nose.tools.eq_(s[2].shape, (1, 2336, 3296)) #planes, height, width
nose.tools.eq_(s[2].dtype, torch.float32)
nose.tools.eq_(s[3].shape, (1, 2336, 3296)) #planes, height, width
nose.tools.eq_(s[3].dtype, torch.float32)
assert s[1].max() <= 1.0
assert s[1].min() >= 0.0
from ..configs.datasets.hrf.default import dataset
nose.tools.eq_(len(dataset), 4)
nose.tools.eq_(len(dataset), 6)
_check_subset(dataset["__train__"], 15)
_check_subset(dataset["train"], 15)
_check_subset(dataset["test"], 30)
_check_subset_fullres(dataset["train (full resolution)"], 15)
_check_subset_fullres(dataset["test (full resolution)"], 30)
@rc_variable_set("bob.ip.binseg.drive.datadir")
......@@ -366,7 +382,7 @@ def test_hrf():
def test_hrf_mtest():
from ..configs.datasets.hrf.mtest import dataset
nose.tools.eq_(len(dataset), 10)
nose.tools.eq_(len(dataset), 12)
from ..configs.datasets.hrf.default import dataset as baseline
nose.tools.eq_(dataset["train"], baseline["train"])
......@@ -395,7 +411,7 @@ def test_hrf_mtest():
def test_hrf_covd():
from ..configs.datasets.hrf.covd import dataset
nose.tools.eq_(len(dataset), 4)
nose.tools.eq_(len(dataset), 6)
from ..configs.datasets.hrf.default import dataset as baseline
nose.tools.eq_(dataset["train"], dataset["__valid__"])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment