Skip to content
Snippets Groups Projects
Commit 284bc576 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[test.test_config] Test contextual data augmentation switch

parent a2301952
No related branches found
No related tags found
1 merge request!12Streamlining
...@@ -14,11 +14,13 @@ from .utils import rc_variable_set ...@@ -14,11 +14,13 @@ from .utils import rc_variable_set
# testing for the extra tools wrapping the dataset # testing for the extra tools wrapping the dataset
N = 10 N = 10
@rc_variable_set("bob.ip.binseg.drive.datadir") @rc_variable_set("bob.ip.binseg.drive.datadir")
def test_drive_default(): def test_drive_default():
from ..configs.datasets.drive.default import dataset from ..configs.datasets.drive.default import dataset
nose.tools.eq_(len(dataset["train"]), 20) nose.tools.eq_(len(dataset["train"]), 20)
nose.tools.eq_(dataset["train"].augmented, True)
for sample in dataset["train"][:N]: for sample in dataset["train"][:N]:
nose.tools.eq_(len(sample), 4) nose.tools.eq_(len(sample), 4)
assert isinstance(sample[0], str) assert isinstance(sample[0], str)
...@@ -30,6 +32,7 @@ def test_drive_default(): ...@@ -30,6 +32,7 @@ def test_drive_default():
nose.tools.eq_(sample[3].dtype, torch.float32) nose.tools.eq_(sample[3].dtype, torch.float32)
nose.tools.eq_(len(dataset["test"]), 20) nose.tools.eq_(len(dataset["test"]), 20)
nose.tools.eq_(dataset["test"].augmented, False)
for sample in dataset["test"][:N]: for sample in dataset["test"][:N]:
nose.tools.eq_(len(sample), 4) nose.tools.eq_(len(sample), 4)
assert isinstance(sample[0], str) assert isinstance(sample[0], str)
...@@ -41,6 +44,33 @@ def test_drive_default(): ...@@ -41,6 +44,33 @@ def test_drive_default():
nose.tools.eq_(sample[3].dtype, torch.float32) nose.tools.eq_(sample[3].dtype, torch.float32)
@stare_variable_set("bob.ip.binseg.stare.datadir")
def test_stare_augmentation_manipulation():
# some tests to check our context management for dataset augmentation works
# adequately, with one example dataset
from ..configs.datasets.stare.ah import dataset
# hack to allow testing on the CI
dataset["train"]._samples = stare_dataset.subsets("ah")["train"]
nose.tools.eq_(dataset["train"].augmented, True)
nose.tools.eq_(dataset["test"].augmented, False)
nose.tools.eq_(len(dataset["train"]._transforms.transforms),
len(dataset["test"]._transforms.transforms) + 4)
with dataset["train"].not_augmented() as d:
nose.tools.eq_(len(d._transforms.transforms), 2)
nose.tools.eq_(d.augmented, False)
nose.tools.eq_(dataset["train"].augmented, False)
nose.tools.eq_(dataset["test"].augmented, False)
nose.tools.eq_(dataset["train"].augmented, True)
nose.tools.eq_(dataset["test"].augmented, False)
nose.tools.eq_(len(dataset["train"]._transforms.transforms),
len(dataset["test"]._transforms.transforms) + 4)
@stare_variable_set("bob.ip.binseg.stare.datadir") @stare_variable_set("bob.ip.binseg.stare.datadir")
def test_stare_ah(): def test_stare_ah():
...@@ -49,6 +79,7 @@ def test_stare_ah(): ...@@ -49,6 +79,7 @@ def test_stare_ah():
dataset["train"]._samples = stare_dataset.subsets("ah")["train"] dataset["train"]._samples = stare_dataset.subsets("ah")["train"]
nose.tools.eq_(len(dataset["train"]), 10) nose.tools.eq_(len(dataset["train"]), 10)
nose.tools.eq_(dataset["train"].augmented, True)
for sample in dataset["train"][:N]: for sample in dataset["train"][:N]:
nose.tools.eq_(len(sample), 3) nose.tools.eq_(len(sample), 3)
assert isinstance(sample[0], str) assert isinstance(sample[0], str)
...@@ -61,6 +92,7 @@ def test_stare_ah(): ...@@ -61,6 +92,7 @@ def test_stare_ah():
dataset["test"]._samples = stare_dataset.subsets("ah")["test"] dataset["test"]._samples = stare_dataset.subsets("ah")["test"]
nose.tools.eq_(len(dataset["test"]), 10) nose.tools.eq_(len(dataset["test"]), 10)
nose.tools.eq_(dataset["test"].augmented, False)
for sample in dataset["test"][:N]: for sample in dataset["test"][:N]:
nose.tools.eq_(len(sample), 3) nose.tools.eq_(len(sample), 3)
assert isinstance(sample[0], str) assert isinstance(sample[0], str)
...@@ -78,6 +110,7 @@ def test_stare_vk(): ...@@ -78,6 +110,7 @@ def test_stare_vk():
dataset["train"]._samples = stare_dataset.subsets("vk")["train"] dataset["train"]._samples = stare_dataset.subsets("vk")["train"]
nose.tools.eq_(len(dataset["train"]), 10) nose.tools.eq_(len(dataset["train"]), 10)
nose.tools.eq_(dataset["train"].augmented, True)
for sample in dataset["train"][:N]: for sample in dataset["train"][:N]:
nose.tools.eq_(len(sample), 3) nose.tools.eq_(len(sample), 3)
assert isinstance(sample[0], str) assert isinstance(sample[0], str)
...@@ -90,6 +123,7 @@ def test_stare_vk(): ...@@ -90,6 +123,7 @@ def test_stare_vk():
dataset["test"]._samples = stare_dataset.subsets("vk")["test"] dataset["test"]._samples = stare_dataset.subsets("vk")["test"]
nose.tools.eq_(len(dataset["test"]), 10) nose.tools.eq_(len(dataset["test"]), 10)
nose.tools.eq_(dataset["test"].augmented, False)
for sample in dataset["test"][:N]: for sample in dataset["test"][:N]:
nose.tools.eq_(len(sample), 3) nose.tools.eq_(len(sample), 3)
assert isinstance(sample[0], str) assert isinstance(sample[0], str)
...@@ -105,6 +139,7 @@ def test_chasedb1_first_annotator(): ...@@ -105,6 +139,7 @@ def test_chasedb1_first_annotator():
from ..configs.datasets.chasedb1.first_annotator import dataset from ..configs.datasets.chasedb1.first_annotator import dataset
nose.tools.eq_(len(dataset["train"]), 8) nose.tools.eq_(len(dataset["train"]), 8)
nose.tools.eq_(dataset["train"].augmented, True)
for sample in dataset["train"][:N]: for sample in dataset["train"][:N]:
nose.tools.eq_(len(sample), 3) nose.tools.eq_(len(sample), 3)
assert isinstance(sample[0], str) assert isinstance(sample[0], str)
...@@ -114,6 +149,7 @@ def test_chasedb1_first_annotator(): ...@@ -114,6 +149,7 @@ def test_chasedb1_first_annotator():
nose.tools.eq_(sample[2].dtype, torch.float32) nose.tools.eq_(sample[2].dtype, torch.float32)
nose.tools.eq_(len(dataset["test"]), 20) nose.tools.eq_(len(dataset["test"]), 20)
nose.tools.eq_(dataset["test"].augmented, False)
for sample in dataset["test"][:N]: for sample in dataset["test"][:N]:
nose.tools.eq_(len(sample), 3) nose.tools.eq_(len(sample), 3)
assert isinstance(sample[0], str) assert isinstance(sample[0], str)
...@@ -129,6 +165,7 @@ def test_chasedb1_second_annotator(): ...@@ -129,6 +165,7 @@ def test_chasedb1_second_annotator():
from ..configs.datasets.chasedb1.second_annotator import dataset from ..configs.datasets.chasedb1.second_annotator import dataset
nose.tools.eq_(len(dataset["train"]), 8) nose.tools.eq_(len(dataset["train"]), 8)
nose.tools.eq_(dataset["train"].augmented, True)
for sample in dataset["train"][:N]: for sample in dataset["train"][:N]:
nose.tools.eq_(len(sample), 3) nose.tools.eq_(len(sample), 3)
assert isinstance(sample[0], str) assert isinstance(sample[0], str)
...@@ -138,6 +175,7 @@ def test_chasedb1_second_annotator(): ...@@ -138,6 +175,7 @@ def test_chasedb1_second_annotator():
nose.tools.eq_(sample[2].dtype, torch.float32) nose.tools.eq_(sample[2].dtype, torch.float32)
nose.tools.eq_(len(dataset["test"]), 20) nose.tools.eq_(len(dataset["test"]), 20)
nose.tools.eq_(dataset["test"].augmented, False)
for sample in dataset["test"][:N]: for sample in dataset["test"][:N]:
nose.tools.eq_(len(sample), 3) nose.tools.eq_(len(sample), 3)
assert isinstance(sample[0], str) assert isinstance(sample[0], str)
...@@ -153,6 +191,7 @@ def test_hrf_default(): ...@@ -153,6 +191,7 @@ def test_hrf_default():
from ..configs.datasets.hrf.default import dataset from ..configs.datasets.hrf.default import dataset
nose.tools.eq_(len(dataset["train"]), 15) nose.tools.eq_(len(dataset["train"]), 15)
nose.tools.eq_(dataset["train"].augmented, True)
for sample in dataset["train"][:N]: for sample in dataset["train"][:N]:
nose.tools.eq_(len(sample), 4) nose.tools.eq_(len(sample), 4)
assert isinstance(sample[0], str) assert isinstance(sample[0], str)
...@@ -164,6 +203,7 @@ def test_hrf_default(): ...@@ -164,6 +203,7 @@ def test_hrf_default():
nose.tools.eq_(sample[3].dtype, torch.float32) nose.tools.eq_(sample[3].dtype, torch.float32)
nose.tools.eq_(len(dataset["test"]), 30) nose.tools.eq_(len(dataset["test"]), 30)
nose.tools.eq_(dataset["test"].augmented, False)
for sample in dataset["test"][:N]: for sample in dataset["test"][:N]:
nose.tools.eq_(len(sample), 4) nose.tools.eq_(len(sample), 4)
assert isinstance(sample[0], str) assert isinstance(sample[0], str)
...@@ -181,6 +221,7 @@ def test_refuge_disc(): ...@@ -181,6 +221,7 @@ def test_refuge_disc():
from ..configs.datasets.refuge.disc import dataset from ..configs.datasets.refuge.disc import dataset
nose.tools.eq_(len(dataset["train"]), 400) nose.tools.eq_(len(dataset["train"]), 400)
nose.tools.eq_(dataset["train"].augmented, True)
for sample in dataset["train"][:N]: for sample in dataset["train"][:N]:
nose.tools.eq_(len(sample), 3) nose.tools.eq_(len(sample), 3)
assert isinstance(sample[0], str) assert isinstance(sample[0], str)
...@@ -190,6 +231,7 @@ def test_refuge_disc(): ...@@ -190,6 +231,7 @@ def test_refuge_disc():
nose.tools.eq_(sample[2].dtype, torch.float32) nose.tools.eq_(sample[2].dtype, torch.float32)
nose.tools.eq_(len(dataset["validation"]), 400) nose.tools.eq_(len(dataset["validation"]), 400)
nose.tools.eq_(dataset["validation"].augmented, False)
for sample in dataset["validation"][:N]: for sample in dataset["validation"][:N]:
nose.tools.eq_(len(sample), 3) nose.tools.eq_(len(sample), 3)
assert isinstance(sample[0], str) assert isinstance(sample[0], str)
...@@ -199,6 +241,7 @@ def test_refuge_disc(): ...@@ -199,6 +241,7 @@ def test_refuge_disc():
nose.tools.eq_(sample[2].dtype, torch.float32) nose.tools.eq_(sample[2].dtype, torch.float32)
nose.tools.eq_(len(dataset["test"]), 400) nose.tools.eq_(len(dataset["test"]), 400)
nose.tools.eq_(dataset["test"].augmented, False)
for sample in dataset["test"][:N]: for sample in dataset["test"][:N]:
nose.tools.eq_(len(sample), 3) nose.tools.eq_(len(sample), 3)
assert isinstance(sample[0], str) assert isinstance(sample[0], str)
...@@ -214,6 +257,7 @@ def test_refuge_cup(): ...@@ -214,6 +257,7 @@ def test_refuge_cup():
from ..configs.datasets.refuge.cup import dataset from ..configs.datasets.refuge.cup import dataset
nose.tools.eq_(len(dataset["train"]), 400) nose.tools.eq_(len(dataset["train"]), 400)
nose.tools.eq_(dataset["train"].augmented, True)
for sample in dataset["train"][:N]: for sample in dataset["train"][:N]:
nose.tools.eq_(len(sample), 3) nose.tools.eq_(len(sample), 3)
assert isinstance(sample[0], str) assert isinstance(sample[0], str)
...@@ -223,6 +267,7 @@ def test_refuge_cup(): ...@@ -223,6 +267,7 @@ def test_refuge_cup():
nose.tools.eq_(sample[2].dtype, torch.float32) nose.tools.eq_(sample[2].dtype, torch.float32)
nose.tools.eq_(len(dataset["validation"]), 400) nose.tools.eq_(len(dataset["validation"]), 400)
nose.tools.eq_(dataset["validation"].augmented, False)
for sample in dataset["validation"][:N]: for sample in dataset["validation"][:N]:
nose.tools.eq_(len(sample), 3) nose.tools.eq_(len(sample), 3)
assert isinstance(sample[0], str) assert isinstance(sample[0], str)
...@@ -232,6 +277,7 @@ def test_refuge_cup(): ...@@ -232,6 +277,7 @@ def test_refuge_cup():
nose.tools.eq_(sample[2].dtype, torch.float32) nose.tools.eq_(sample[2].dtype, torch.float32)
nose.tools.eq_(len(dataset["test"]), 400) nose.tools.eq_(len(dataset["test"]), 400)
nose.tools.eq_(dataset["test"].augmented, False)
for sample in dataset["test"][:N]: for sample in dataset["test"][:N]:
nose.tools.eq_(len(sample), 3) nose.tools.eq_(len(sample), 3)
assert isinstance(sample[0], str) assert isinstance(sample[0], str)
...@@ -247,6 +293,7 @@ def test_drishtigs1_disc_all(): ...@@ -247,6 +293,7 @@ def test_drishtigs1_disc_all():
from ..configs.datasets.drishtigs1.disc_all import dataset from ..configs.datasets.drishtigs1.disc_all import dataset
nose.tools.eq_(len(dataset["train"]), 50) nose.tools.eq_(len(dataset["train"]), 50)
nose.tools.eq_(dataset["train"].augmented, True)
for sample in dataset["train"][:N]: for sample in dataset["train"][:N]:
nose.tools.eq_(len(sample), 3) nose.tools.eq_(len(sample), 3)
assert isinstance(sample[0], str) assert isinstance(sample[0], str)
...@@ -256,6 +303,7 @@ def test_drishtigs1_disc_all(): ...@@ -256,6 +303,7 @@ def test_drishtigs1_disc_all():
nose.tools.eq_(sample[2].dtype, torch.float32) nose.tools.eq_(sample[2].dtype, torch.float32)
nose.tools.eq_(len(dataset["test"]), 51) nose.tools.eq_(len(dataset["test"]), 51)
nose.tools.eq_(dataset["test"].augmented, False)
for sample in dataset["test"][:N]: for sample in dataset["test"][:N]:
nose.tools.eq_(len(sample), 3) nose.tools.eq_(len(sample), 3)
assert isinstance(sample[0], str) assert isinstance(sample[0], str)
...@@ -271,6 +319,7 @@ def test_drishtigs1_cup_all(): ...@@ -271,6 +319,7 @@ def test_drishtigs1_cup_all():
from ..configs.datasets.drishtigs1.cup_all import dataset from ..configs.datasets.drishtigs1.cup_all import dataset
nose.tools.eq_(len(dataset["train"]), 50) nose.tools.eq_(len(dataset["train"]), 50)
nose.tools.eq_(dataset["train"].augmented, True)
for sample in dataset["train"][:N]: for sample in dataset["train"][:N]:
nose.tools.eq_(len(sample), 3) nose.tools.eq_(len(sample), 3)
assert isinstance(sample[0], str) assert isinstance(sample[0], str)
...@@ -280,6 +329,7 @@ def test_drishtigs1_cup_all(): ...@@ -280,6 +329,7 @@ def test_drishtigs1_cup_all():
nose.tools.eq_(sample[2].dtype, torch.float32) nose.tools.eq_(sample[2].dtype, torch.float32)
nose.tools.eq_(len(dataset["test"]), 51) nose.tools.eq_(len(dataset["test"]), 51)
nose.tools.eq_(dataset["test"].augmented, False)
for sample in dataset["test"][:N]: for sample in dataset["test"][:N]:
nose.tools.eq_(len(sample), 3) nose.tools.eq_(len(sample), 3)
assert isinstance(sample[0], str) assert isinstance(sample[0], str)
...@@ -295,6 +345,7 @@ def test_drionsdb_expert1(): ...@@ -295,6 +345,7 @@ def test_drionsdb_expert1():
from ..configs.datasets.drionsdb.expert1 import dataset from ..configs.datasets.drionsdb.expert1 import dataset
nose.tools.eq_(len(dataset["train"]), 60) nose.tools.eq_(len(dataset["train"]), 60)
nose.tools.eq_(dataset["train"].augmented, True)
for sample in dataset["train"][:N]: for sample in dataset["train"][:N]:
nose.tools.eq_(len(sample), 3) nose.tools.eq_(len(sample), 3)
assert isinstance(sample[0], str) assert isinstance(sample[0], str)
...@@ -304,6 +355,7 @@ def test_drionsdb_expert1(): ...@@ -304,6 +355,7 @@ def test_drionsdb_expert1():
nose.tools.eq_(sample[2].dtype, torch.float32) nose.tools.eq_(sample[2].dtype, torch.float32)
nose.tools.eq_(len(dataset["test"]), 50) nose.tools.eq_(len(dataset["test"]), 50)
nose.tools.eq_(dataset["test"].augmented, False)
for sample in dataset["test"][:N]: for sample in dataset["test"][:N]:
nose.tools.eq_(len(sample), 3) nose.tools.eq_(len(sample), 3)
assert isinstance(sample[0], str) assert isinstance(sample[0], str)
...@@ -322,6 +374,8 @@ def test_drive_covd(): ...@@ -322,6 +374,8 @@ def test_drive_covd():
from ..configs.datasets.drive.covd import dataset from ..configs.datasets.drive.covd import dataset
nose.tools.eq_(len(dataset["train"]), 53) nose.tools.eq_(len(dataset["train"]), 53)
#nose.tools.eq_(dataset["train"].augmented, True) ##ConcatDataset
nose.tools.eq_(dataset["test"].augmented, False)
for sample in dataset["train"]: for sample in dataset["train"]:
assert 3 <= len(sample) <= 4 assert 3 <= len(sample) <= 4
assert isinstance(sample[0], str) assert isinstance(sample[0], str)
...@@ -344,6 +398,8 @@ def test_drive_ssl(): ...@@ -344,6 +398,8 @@ def test_drive_ssl():
from ..configs.datasets.drive.ssl import dataset from ..configs.datasets.drive.ssl import dataset
nose.tools.eq_(len(dataset["train"]), 53) nose.tools.eq_(len(dataset["train"]), 53)
#nose.tools.eq_(dataset["train"].augmented, True) ##ConcatDataset
nose.tools.eq_(dataset["test"].augmented, False)
for sample in dataset["train"]: for sample in dataset["train"]:
assert 5 <= len(sample) <= 6 assert 5 <= len(sample) <= 6
assert isinstance(sample[0], str) assert isinstance(sample[0], str)
...@@ -372,6 +428,8 @@ def test_stare_covd(): ...@@ -372,6 +428,8 @@ def test_stare_covd():
from ..configs.datasets.stare.covd import dataset from ..configs.datasets.stare.covd import dataset
nose.tools.eq_(len(dataset["train"]), 63) nose.tools.eq_(len(dataset["train"]), 63)
#nose.tools.eq_(dataset["train"].augmented, True) ##ConcatDataset
nose.tools.eq_(dataset["test"].augmented, False)
for sample in dataset["train"]: for sample in dataset["train"]:
assert 3 <= len(sample) <= 4 assert 3 <= len(sample) <= 4
assert isinstance(sample[0], str) assert isinstance(sample[0], str)
...@@ -393,6 +451,8 @@ def test_chasedb1_covd(): ...@@ -393,6 +451,8 @@ def test_chasedb1_covd():
from ..configs.datasets.chasedb1.covd import dataset from ..configs.datasets.chasedb1.covd import dataset
nose.tools.eq_(len(dataset["train"]), 65) nose.tools.eq_(len(dataset["train"]), 65)
#nose.tools.eq_(dataset["train"].augmented, True) ##ConcatDataset
nose.tools.eq_(dataset["test"].augmented, False)
for sample in dataset["train"]: for sample in dataset["train"]:
assert 3 <= len(sample) <= 4 assert 3 <= len(sample) <= 4
assert isinstance(sample[0], str) assert isinstance(sample[0], str)
...@@ -414,6 +474,8 @@ def test_hrf_covd(): ...@@ -414,6 +474,8 @@ def test_hrf_covd():
from ..configs.datasets.hrf.covd import dataset from ..configs.datasets.hrf.covd import dataset
nose.tools.eq_(len(dataset["train"]), 58) nose.tools.eq_(len(dataset["train"]), 58)
#nose.tools.eq_(dataset["train"].augmented, True) ##ConcatDataset
nose.tools.eq_(dataset["test"].augmented, False)
for sample in dataset["train"]: for sample in dataset["train"]:
assert 3 <= len(sample) <= 4 assert 3 <= len(sample) <= 4
assert isinstance(sample[0], str) assert isinstance(sample[0], str)
...@@ -435,6 +497,8 @@ def test_iostar_covd(): ...@@ -435,6 +497,8 @@ def test_iostar_covd():
from ..configs.datasets.iostar.covd import dataset from ..configs.datasets.iostar.covd import dataset
nose.tools.eq_(len(dataset["train"]), 53) nose.tools.eq_(len(dataset["train"]), 53)
#nose.tools.eq_(dataset["train"].augmented, True) ##ConcatDataset
nose.tools.eq_(dataset["test"].augmented, False)
for sample in dataset["train"]: for sample in dataset["train"]:
assert 3 <= len(sample) <= 4 assert 3 <= len(sample) <= 4
assert isinstance(sample[0], str) assert isinstance(sample[0], str)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment