From 284bc5762d332195dcaf796dcaf3d5ad0dbd2887 Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Mon, 20 Apr 2020 19:16:52 +0200 Subject: [PATCH] [test.test_config] Test contextual data augmentation switch --- bob/ip/binseg/test/test_config.py | 64 +++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/bob/ip/binseg/test/test_config.py b/bob/ip/binseg/test/test_config.py index b563c8f1..9b0463d6 100644 --- a/bob/ip/binseg/test/test_config.py +++ b/bob/ip/binseg/test/test_config.py @@ -14,11 +14,13 @@ from .utils import rc_variable_set # testing for the extra tools wrapping the dataset N = 10 + @rc_variable_set("bob.ip.binseg.drive.datadir") def test_drive_default(): from ..configs.datasets.drive.default import dataset nose.tools.eq_(len(dataset["train"]), 20) + nose.tools.eq_(dataset["train"].augmented, True) for sample in dataset["train"][:N]: nose.tools.eq_(len(sample), 4) assert isinstance(sample[0], str) @@ -30,6 +32,7 @@ def test_drive_default(): nose.tools.eq_(sample[3].dtype, torch.float32) nose.tools.eq_(len(dataset["test"]), 20) + nose.tools.eq_(dataset["test"].augmented, False) for sample in dataset["test"][:N]: nose.tools.eq_(len(sample), 4) assert isinstance(sample[0], str) @@ -41,6 +44,33 @@ def test_drive_default(): nose.tools.eq_(sample[3].dtype, torch.float32) +@stare_variable_set("bob.ip.binseg.stare.datadir") +def test_stare_augmentation_manipulation(): + + # some tests to check our context management for dataset augmentation works + # adequately, with one example dataset + + from ..configs.datasets.stare.ah import dataset + # hack to allow testing on the CI + dataset["train"]._samples = stare_dataset.subsets("ah")["train"] + + nose.tools.eq_(dataset["train"].augmented, True) + nose.tools.eq_(dataset["test"].augmented, False) + nose.tools.eq_(len(dataset["train"]._transforms.transforms), + len(dataset["test"]._transforms.transforms) + 4) + + with dataset["train"].not_augmented() as d: + nose.tools.eq_(len(d._transforms.transforms), 2) + nose.tools.eq_(d.augmented, False) + nose.tools.eq_(dataset["train"].augmented, False) + nose.tools.eq_(dataset["test"].augmented, False) + + nose.tools.eq_(dataset["train"].augmented, True) + nose.tools.eq_(dataset["test"].augmented, False) + nose.tools.eq_(len(dataset["train"]._transforms.transforms), + len(dataset["test"]._transforms.transforms) + 4) + + @stare_variable_set("bob.ip.binseg.stare.datadir") def test_stare_ah(): @@ -49,6 +79,7 @@ def test_stare_ah(): dataset["train"]._samples = stare_dataset.subsets("ah")["train"] nose.tools.eq_(len(dataset["train"]), 10) + nose.tools.eq_(dataset["train"].augmented, True) for sample in dataset["train"][:N]: nose.tools.eq_(len(sample), 3) assert isinstance(sample[0], str) @@ -61,6 +92,7 @@ def test_stare_ah(): dataset["test"]._samples = stare_dataset.subsets("ah")["test"] nose.tools.eq_(len(dataset["test"]), 10) + nose.tools.eq_(dataset["test"].augmented, False) for sample in dataset["test"][:N]: nose.tools.eq_(len(sample), 3) assert isinstance(sample[0], str) @@ -78,6 +110,7 @@ def test_stare_vk(): dataset["train"]._samples = stare_dataset.subsets("vk")["train"] nose.tools.eq_(len(dataset["train"]), 10) + nose.tools.eq_(dataset["train"].augmented, True) for sample in dataset["train"][:N]: nose.tools.eq_(len(sample), 3) assert isinstance(sample[0], str) @@ -90,6 +123,7 @@ def test_stare_vk(): dataset["test"]._samples = stare_dataset.subsets("vk")["test"] nose.tools.eq_(len(dataset["test"]), 10) + nose.tools.eq_(dataset["test"].augmented, False) for sample in dataset["test"][:N]: nose.tools.eq_(len(sample), 3) assert isinstance(sample[0], str) @@ -105,6 +139,7 @@ def test_chasedb1_first_annotator(): from ..configs.datasets.chasedb1.first_annotator import dataset nose.tools.eq_(len(dataset["train"]), 8) + nose.tools.eq_(dataset["train"].augmented, True) for sample in dataset["train"][:N]: nose.tools.eq_(len(sample), 3) assert isinstance(sample[0], str) @@ -114,6 +149,7 @@ def test_chasedb1_first_annotator(): nose.tools.eq_(sample[2].dtype, torch.float32) nose.tools.eq_(len(dataset["test"]), 20) + nose.tools.eq_(dataset["test"].augmented, False) for sample in dataset["test"][:N]: nose.tools.eq_(len(sample), 3) assert isinstance(sample[0], str) @@ -129,6 +165,7 @@ def test_chasedb1_second_annotator(): from ..configs.datasets.chasedb1.second_annotator import dataset nose.tools.eq_(len(dataset["train"]), 8) + nose.tools.eq_(dataset["train"].augmented, True) for sample in dataset["train"][:N]: nose.tools.eq_(len(sample), 3) assert isinstance(sample[0], str) @@ -138,6 +175,7 @@ def test_chasedb1_second_annotator(): nose.tools.eq_(sample[2].dtype, torch.float32) nose.tools.eq_(len(dataset["test"]), 20) + nose.tools.eq_(dataset["test"].augmented, False) for sample in dataset["test"][:N]: nose.tools.eq_(len(sample), 3) assert isinstance(sample[0], str) @@ -153,6 +191,7 @@ def test_hrf_default(): from ..configs.datasets.hrf.default import dataset nose.tools.eq_(len(dataset["train"]), 15) + nose.tools.eq_(dataset["train"].augmented, True) for sample in dataset["train"][:N]: nose.tools.eq_(len(sample), 4) assert isinstance(sample[0], str) @@ -164,6 +203,7 @@ def test_hrf_default(): nose.tools.eq_(sample[3].dtype, torch.float32) nose.tools.eq_(len(dataset["test"]), 30) + nose.tools.eq_(dataset["test"].augmented, False) for sample in dataset["test"][:N]: nose.tools.eq_(len(sample), 4) assert isinstance(sample[0], str) @@ -181,6 +221,7 @@ def test_refuge_disc(): from ..configs.datasets.refuge.disc import dataset nose.tools.eq_(len(dataset["train"]), 400) + nose.tools.eq_(dataset["train"].augmented, True) for sample in dataset["train"][:N]: nose.tools.eq_(len(sample), 3) assert isinstance(sample[0], str) @@ -190,6 +231,7 @@ def test_refuge_disc(): nose.tools.eq_(sample[2].dtype, torch.float32) nose.tools.eq_(len(dataset["validation"]), 400) + nose.tools.eq_(dataset["validation"].augmented, False) for sample in dataset["validation"][:N]: nose.tools.eq_(len(sample), 3) assert isinstance(sample[0], str) @@ -199,6 +241,7 @@ def test_refuge_disc(): nose.tools.eq_(sample[2].dtype, torch.float32) nose.tools.eq_(len(dataset["test"]), 400) + nose.tools.eq_(dataset["test"].augmented, False) for sample in dataset["test"][:N]: nose.tools.eq_(len(sample), 3) assert isinstance(sample[0], str) @@ -214,6 +257,7 @@ def test_refuge_cup(): from ..configs.datasets.refuge.cup import dataset nose.tools.eq_(len(dataset["train"]), 400) + nose.tools.eq_(dataset["train"].augmented, True) for sample in dataset["train"][:N]: nose.tools.eq_(len(sample), 3) assert isinstance(sample[0], str) @@ -223,6 +267,7 @@ def test_refuge_cup(): nose.tools.eq_(sample[2].dtype, torch.float32) nose.tools.eq_(len(dataset["validation"]), 400) + nose.tools.eq_(dataset["validation"].augmented, False) for sample in dataset["validation"][:N]: nose.tools.eq_(len(sample), 3) assert isinstance(sample[0], str) @@ -232,6 +277,7 @@ def test_refuge_cup(): nose.tools.eq_(sample[2].dtype, torch.float32) nose.tools.eq_(len(dataset["test"]), 400) + nose.tools.eq_(dataset["test"].augmented, False) for sample in dataset["test"][:N]: nose.tools.eq_(len(sample), 3) assert isinstance(sample[0], str) @@ -247,6 +293,7 @@ def test_drishtigs1_disc_all(): from ..configs.datasets.drishtigs1.disc_all import dataset nose.tools.eq_(len(dataset["train"]), 50) + nose.tools.eq_(dataset["train"].augmented, True) for sample in dataset["train"][:N]: nose.tools.eq_(len(sample), 3) assert isinstance(sample[0], str) @@ -256,6 +303,7 @@ def test_drishtigs1_disc_all(): nose.tools.eq_(sample[2].dtype, torch.float32) nose.tools.eq_(len(dataset["test"]), 51) + nose.tools.eq_(dataset["test"].augmented, False) for sample in dataset["test"][:N]: nose.tools.eq_(len(sample), 3) assert isinstance(sample[0], str) @@ -271,6 +319,7 @@ def test_drishtigs1_cup_all(): from ..configs.datasets.drishtigs1.cup_all import dataset nose.tools.eq_(len(dataset["train"]), 50) + nose.tools.eq_(dataset["train"].augmented, True) for sample in dataset["train"][:N]: nose.tools.eq_(len(sample), 3) assert isinstance(sample[0], str) @@ -280,6 +329,7 @@ def test_drishtigs1_cup_all(): nose.tools.eq_(sample[2].dtype, torch.float32) nose.tools.eq_(len(dataset["test"]), 51) + nose.tools.eq_(dataset["test"].augmented, False) for sample in dataset["test"][:N]: nose.tools.eq_(len(sample), 3) assert isinstance(sample[0], str) @@ -295,6 +345,7 @@ def test_drionsdb_expert1(): from ..configs.datasets.drionsdb.expert1 import dataset nose.tools.eq_(len(dataset["train"]), 60) + nose.tools.eq_(dataset["train"].augmented, True) for sample in dataset["train"][:N]: nose.tools.eq_(len(sample), 3) assert isinstance(sample[0], str) @@ -304,6 +355,7 @@ def test_drionsdb_expert1(): nose.tools.eq_(sample[2].dtype, torch.float32) nose.tools.eq_(len(dataset["test"]), 50) + nose.tools.eq_(dataset["test"].augmented, False) for sample in dataset["test"][:N]: nose.tools.eq_(len(sample), 3) assert isinstance(sample[0], str) @@ -322,6 +374,8 @@ def test_drive_covd(): from ..configs.datasets.drive.covd import dataset nose.tools.eq_(len(dataset["train"]), 53) + #nose.tools.eq_(dataset["train"].augmented, True) ##ConcatDataset + nose.tools.eq_(dataset["test"].augmented, False) for sample in dataset["train"]: assert 3 <= len(sample) <= 4 assert isinstance(sample[0], str) @@ -344,6 +398,8 @@ def test_drive_ssl(): from ..configs.datasets.drive.ssl import dataset nose.tools.eq_(len(dataset["train"]), 53) + #nose.tools.eq_(dataset["train"].augmented, True) ##ConcatDataset + nose.tools.eq_(dataset["test"].augmented, False) for sample in dataset["train"]: assert 5 <= len(sample) <= 6 assert isinstance(sample[0], str) @@ -372,6 +428,8 @@ def test_stare_covd(): from ..configs.datasets.stare.covd import dataset nose.tools.eq_(len(dataset["train"]), 63) + #nose.tools.eq_(dataset["train"].augmented, True) ##ConcatDataset + nose.tools.eq_(dataset["test"].augmented, False) for sample in dataset["train"]: assert 3 <= len(sample) <= 4 assert isinstance(sample[0], str) @@ -393,6 +451,8 @@ def test_chasedb1_covd(): from ..configs.datasets.chasedb1.covd import dataset nose.tools.eq_(len(dataset["train"]), 65) + #nose.tools.eq_(dataset["train"].augmented, True) ##ConcatDataset + nose.tools.eq_(dataset["test"].augmented, False) for sample in dataset["train"]: assert 3 <= len(sample) <= 4 assert isinstance(sample[0], str) @@ -414,6 +474,8 @@ def test_hrf_covd(): from ..configs.datasets.hrf.covd import dataset nose.tools.eq_(len(dataset["train"]), 58) + #nose.tools.eq_(dataset["train"].augmented, True) ##ConcatDataset + nose.tools.eq_(dataset["test"].augmented, False) for sample in dataset["train"]: assert 3 <= len(sample) <= 4 assert isinstance(sample[0], str) @@ -435,6 +497,8 @@ def test_iostar_covd(): from ..configs.datasets.iostar.covd import dataset nose.tools.eq_(len(dataset["train"]), 53) + #nose.tools.eq_(dataset["train"].augmented, True) ##ConcatDataset + nose.tools.eq_(dataset["test"].augmented, False) for sample in dataset["train"]: assert 3 <= len(sample) <= 4 assert isinstance(sample[0], str) -- GitLab