Skip to content
Snippets Groups Projects
Commit aac49966 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[test.test_config] Add tests for COVD datasets

parent d1eef74f
No related branches found
No related tags found
1 merge request!12Streamlining
......@@ -307,3 +307,103 @@ def test_drionsdb_default_test():
nose.tools.eq_(sample[1].dtype, torch.float32)
nose.tools.eq_(sample[2].shape, (1, 416, 608)) #planes, height, width
nose.tools.eq_(sample[2].dtype, torch.float32)
@rc_variable_set("bob.ip.binseg.stare.datadir")
@rc_variable_set("bob.ip.binseg.chasedb1.datadir")
@rc_variable_set("bob.ip.binseg.hrf.datadir")
@rc_variable_set("bob.ip.binseg.iostar.datadir")
def test_covd_drive():
from ..configs.datasets.covd_drive import dataset
nose.tools.eq_(len(dataset), 53)
for sample in dataset:
assert 3 <= len(sample) <= 4
assert isinstance(sample[0], str)
nose.tools.eq_(sample[1].shape, (3, 544, 544)) #planes, height, width
nose.tools.eq_(sample[1].dtype, torch.float32)
nose.tools.eq_(sample[2].shape, (1, 544, 544)) #planes, height, width
nose.tools.eq_(sample[2].dtype, torch.float32)
if len(sample) == 4:
nose.tools.eq_(sample[3].shape, (1, 544, 544)) #planes, height, width
nose.tools.eq_(sample[3].dtype, torch.float32)
@rc_variable_set("bob.ip.binseg.drive.datadir")
@rc_variable_set("bob.ip.binseg.chasedb1.datadir")
@rc_variable_set("bob.ip.binseg.hrf.datadir")
@rc_variable_set("bob.ip.binseg.iostar.datadir")
def test_covd_stare():
from ..configs.datasets.covd_stare import dataset
nose.tools.eq_(len(dataset), 63)
for sample in dataset:
assert 3 <= len(sample) <= 4
assert isinstance(sample[0], str)
nose.tools.eq_(sample[1].shape, (3, 608, 704)) #planes, height, width
nose.tools.eq_(sample[1].dtype, torch.float32)
nose.tools.eq_(sample[2].shape, (1, 608, 704)) #planes, height, width
nose.tools.eq_(sample[2].dtype, torch.float32)
if len(sample) == 4:
nose.tools.eq_(sample[3].shape, (1, 608, 704)) #planes, height, width
nose.tools.eq_(sample[3].dtype, torch.float32)
@rc_variable_set("bob.ip.binseg.drive.datadir")
@rc_variable_set("bob.ip.binseg.stare.datadir")
@rc_variable_set("bob.ip.binseg.hrf.datadir")
@rc_variable_set("bob.ip.binseg.iostar.datadir")
def test_covd_chasedb1():
from ..configs.datasets.covd_chasedb1 import dataset
nose.tools.eq_(len(dataset), 65)
for sample in dataset:
assert 3 <= len(sample) <= 4
assert isinstance(sample[0], str)
nose.tools.eq_(sample[1].shape, (3, 960, 960)) #planes, height, width
nose.tools.eq_(sample[1].dtype, torch.float32)
nose.tools.eq_(sample[2].shape, (1, 960, 960)) #planes, height, width
nose.tools.eq_(sample[2].dtype, torch.float32)
if len(sample) == 4:
nose.tools.eq_(sample[3].shape, (1, 960, 960)) #planes, height, width
nose.tools.eq_(sample[3].dtype, torch.float32)
@rc_variable_set("bob.ip.binseg.drive.datadir")
@rc_variable_set("bob.ip.binseg.stare.datadir")
@rc_variable_set("bob.ip.binseg.chasedb1.datadir")
@rc_variable_set("bob.ip.binseg.iostar.datadir")
def test_covd_hrf():
from ..configs.datasets.covd_hrf import dataset
nose.tools.eq_(len(dataset), 58)
for sample in dataset:
assert 3 <= len(sample) <= 4
assert isinstance(sample[0], str)
nose.tools.eq_(sample[1].shape, (3, 1168, 1648)) #planes, height, width
nose.tools.eq_(sample[1].dtype, torch.float32)
nose.tools.eq_(sample[2].shape, (1, 1168, 1648)) #planes, height, width
nose.tools.eq_(sample[2].dtype, torch.float32)
if len(sample) == 4:
nose.tools.eq_(sample[3].shape, (1, 1168, 1648))
nose.tools.eq_(sample[3].dtype, torch.float32)
@rc_variable_set("bob.ip.binseg.drive.datadir")
@rc_variable_set("bob.ip.binseg.stare.datadir")
@rc_variable_set("bob.ip.binseg.chasedb1.datadir")
@rc_variable_set("bob.ip.binseg.hrf.datadir")
def test_covd_iostar():
from ..configs.datasets.covd_iostar_vessel import dataset
nose.tools.eq_(len(dataset), 53)
for sample in dataset:
assert 3 <= len(sample) <= 4
assert isinstance(sample[0], str)
nose.tools.eq_(sample[1].shape, (3, 1024, 1024)) #planes, height, width
nose.tools.eq_(sample[1].dtype, torch.float32)
nose.tools.eq_(sample[2].shape, (1, 1024, 1024)) #planes, height, width
nose.tools.eq_(sample[2].dtype, torch.float32)
if len(sample) == 4:
nose.tools.eq_(sample[3].shape, (1, 1024, 1024))
nose.tools.eq_(sample[3].dtype, torch.float32)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment