From 284bc5762d332195dcaf796dcaf3d5ad0dbd2887 Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Mon, 20 Apr 2020 19:16:52 +0200
Subject: [PATCH] [test.test_config] Test contextual data augmentation switch

---
 bob/ip/binseg/test/test_config.py | 64 +++++++++++++++++++++++++++++++
 1 file changed, 64 insertions(+)

diff --git a/bob/ip/binseg/test/test_config.py b/bob/ip/binseg/test/test_config.py
index b563c8f1..9b0463d6 100644
--- a/bob/ip/binseg/test/test_config.py
+++ b/bob/ip/binseg/test/test_config.py
@@ -14,11 +14,13 @@ from .utils import rc_variable_set
 # testing for the extra tools wrapping the dataset
 N = 10
 
+
 @rc_variable_set("bob.ip.binseg.drive.datadir")
 def test_drive_default():
 
     from ..configs.datasets.drive.default import dataset
     nose.tools.eq_(len(dataset["train"]), 20)
+    nose.tools.eq_(dataset["train"].augmented, True)
     for sample in dataset["train"][:N]:
         nose.tools.eq_(len(sample), 4)
         assert isinstance(sample[0], str)
@@ -30,6 +32,7 @@ def test_drive_default():
         nose.tools.eq_(sample[3].dtype, torch.float32)
 
     nose.tools.eq_(len(dataset["test"]), 20)
+    nose.tools.eq_(dataset["test"].augmented, False)
     for sample in dataset["test"][:N]:
         nose.tools.eq_(len(sample), 4)
         assert isinstance(sample[0], str)
@@ -41,6 +44,33 @@ def test_drive_default():
         nose.tools.eq_(sample[3].dtype, torch.float32)
 
 
+@stare_variable_set("bob.ip.binseg.stare.datadir")
+def test_stare_augmentation_manipulation():
+
+    # some tests to check our context management for dataset augmentation works
+    # adequately, with one example dataset
+
+    from ..configs.datasets.stare.ah import dataset
+    # hack to allow testing on the CI
+    dataset["train"]._samples = stare_dataset.subsets("ah")["train"]
+
+    nose.tools.eq_(dataset["train"].augmented, True)
+    nose.tools.eq_(dataset["test"].augmented, False)
+    nose.tools.eq_(len(dataset["train"]._transforms.transforms),
+            len(dataset["test"]._transforms.transforms) + 4)
+
+    with dataset["train"].not_augmented() as d:
+        nose.tools.eq_(len(d._transforms.transforms), 2)
+        nose.tools.eq_(d.augmented, False)
+        nose.tools.eq_(dataset["train"].augmented, False)
+        nose.tools.eq_(dataset["test"].augmented, False)
+
+    nose.tools.eq_(dataset["train"].augmented, True)
+    nose.tools.eq_(dataset["test"].augmented, False)
+    nose.tools.eq_(len(dataset["train"]._transforms.transforms),
+            len(dataset["test"]._transforms.transforms) + 4)
+
+
 @stare_variable_set("bob.ip.binseg.stare.datadir")
 def test_stare_ah():
 
@@ -49,6 +79,7 @@ def test_stare_ah():
     dataset["train"]._samples = stare_dataset.subsets("ah")["train"]
 
     nose.tools.eq_(len(dataset["train"]), 10)
+    nose.tools.eq_(dataset["train"].augmented, True)
     for sample in dataset["train"][:N]:
         nose.tools.eq_(len(sample), 3)
         assert isinstance(sample[0], str)
@@ -61,6 +92,7 @@ def test_stare_ah():
     dataset["test"]._samples = stare_dataset.subsets("ah")["test"]
 
     nose.tools.eq_(len(dataset["test"]), 10)
+    nose.tools.eq_(dataset["test"].augmented, False)
     for sample in dataset["test"][:N]:
         nose.tools.eq_(len(sample), 3)
         assert isinstance(sample[0], str)
@@ -78,6 +110,7 @@ def test_stare_vk():
     dataset["train"]._samples = stare_dataset.subsets("vk")["train"]
 
     nose.tools.eq_(len(dataset["train"]), 10)
+    nose.tools.eq_(dataset["train"].augmented, True)
     for sample in dataset["train"][:N]:
         nose.tools.eq_(len(sample), 3)
         assert isinstance(sample[0], str)
@@ -90,6 +123,7 @@ def test_stare_vk():
     dataset["test"]._samples = stare_dataset.subsets("vk")["test"]
 
     nose.tools.eq_(len(dataset["test"]), 10)
+    nose.tools.eq_(dataset["test"].augmented, False)
     for sample in dataset["test"][:N]:
         nose.tools.eq_(len(sample), 3)
         assert isinstance(sample[0], str)
@@ -105,6 +139,7 @@ def test_chasedb1_first_annotator():
     from ..configs.datasets.chasedb1.first_annotator import dataset
 
     nose.tools.eq_(len(dataset["train"]), 8)
+    nose.tools.eq_(dataset["train"].augmented, True)
     for sample in dataset["train"][:N]:
         nose.tools.eq_(len(sample), 3)
         assert isinstance(sample[0], str)
@@ -114,6 +149,7 @@ def test_chasedb1_first_annotator():
         nose.tools.eq_(sample[2].dtype, torch.float32)
 
     nose.tools.eq_(len(dataset["test"]), 20)
+    nose.tools.eq_(dataset["test"].augmented, False)
     for sample in dataset["test"][:N]:
         nose.tools.eq_(len(sample), 3)
         assert isinstance(sample[0], str)
@@ -129,6 +165,7 @@ def test_chasedb1_second_annotator():
     from ..configs.datasets.chasedb1.second_annotator import dataset
 
     nose.tools.eq_(len(dataset["train"]), 8)
+    nose.tools.eq_(dataset["train"].augmented, True)
     for sample in dataset["train"][:N]:
         nose.tools.eq_(len(sample), 3)
         assert isinstance(sample[0], str)
@@ -138,6 +175,7 @@ def test_chasedb1_second_annotator():
         nose.tools.eq_(sample[2].dtype, torch.float32)
 
     nose.tools.eq_(len(dataset["test"]), 20)
+    nose.tools.eq_(dataset["test"].augmented, False)
     for sample in dataset["test"][:N]:
         nose.tools.eq_(len(sample), 3)
         assert isinstance(sample[0], str)
@@ -153,6 +191,7 @@ def test_hrf_default():
     from ..configs.datasets.hrf.default import dataset
 
     nose.tools.eq_(len(dataset["train"]), 15)
+    nose.tools.eq_(dataset["train"].augmented, True)
     for sample in dataset["train"][:N]:
         nose.tools.eq_(len(sample), 4)
         assert isinstance(sample[0], str)
@@ -164,6 +203,7 @@ def test_hrf_default():
         nose.tools.eq_(sample[3].dtype, torch.float32)
 
     nose.tools.eq_(len(dataset["test"]), 30)
+    nose.tools.eq_(dataset["test"].augmented, False)
     for sample in dataset["test"][:N]:
         nose.tools.eq_(len(sample), 4)
         assert isinstance(sample[0], str)
@@ -181,6 +221,7 @@ def test_refuge_disc():
     from ..configs.datasets.refuge.disc import dataset
 
     nose.tools.eq_(len(dataset["train"]), 400)
+    nose.tools.eq_(dataset["train"].augmented, True)
     for sample in dataset["train"][:N]:
         nose.tools.eq_(len(sample), 3)
         assert isinstance(sample[0], str)
@@ -190,6 +231,7 @@ def test_refuge_disc():
         nose.tools.eq_(sample[2].dtype, torch.float32)
 
     nose.tools.eq_(len(dataset["validation"]), 400)
+    nose.tools.eq_(dataset["validation"].augmented, False)
     for sample in dataset["validation"][:N]:
         nose.tools.eq_(len(sample), 3)
         assert isinstance(sample[0], str)
@@ -199,6 +241,7 @@ def test_refuge_disc():
         nose.tools.eq_(sample[2].dtype, torch.float32)
 
     nose.tools.eq_(len(dataset["test"]), 400)
+    nose.tools.eq_(dataset["test"].augmented, False)
     for sample in dataset["test"][:N]:
         nose.tools.eq_(len(sample), 3)
         assert isinstance(sample[0], str)
@@ -214,6 +257,7 @@ def test_refuge_cup():
     from ..configs.datasets.refuge.cup import dataset
 
     nose.tools.eq_(len(dataset["train"]), 400)
+    nose.tools.eq_(dataset["train"].augmented, True)
     for sample in dataset["train"][:N]:
         nose.tools.eq_(len(sample), 3)
         assert isinstance(sample[0], str)
@@ -223,6 +267,7 @@ def test_refuge_cup():
         nose.tools.eq_(sample[2].dtype, torch.float32)
 
     nose.tools.eq_(len(dataset["validation"]), 400)
+    nose.tools.eq_(dataset["validation"].augmented, False)
     for sample in dataset["validation"][:N]:
         nose.tools.eq_(len(sample), 3)
         assert isinstance(sample[0], str)
@@ -232,6 +277,7 @@ def test_refuge_cup():
         nose.tools.eq_(sample[2].dtype, torch.float32)
 
     nose.tools.eq_(len(dataset["test"]), 400)
+    nose.tools.eq_(dataset["test"].augmented, False)
     for sample in dataset["test"][:N]:
         nose.tools.eq_(len(sample), 3)
         assert isinstance(sample[0], str)
@@ -247,6 +293,7 @@ def test_drishtigs1_disc_all():
     from ..configs.datasets.drishtigs1.disc_all import dataset
 
     nose.tools.eq_(len(dataset["train"]), 50)
+    nose.tools.eq_(dataset["train"].augmented, True)
     for sample in dataset["train"][:N]:
         nose.tools.eq_(len(sample), 3)
         assert isinstance(sample[0], str)
@@ -256,6 +303,7 @@ def test_drishtigs1_disc_all():
         nose.tools.eq_(sample[2].dtype, torch.float32)
 
     nose.tools.eq_(len(dataset["test"]), 51)
+    nose.tools.eq_(dataset["test"].augmented, False)
     for sample in dataset["test"][:N]:
         nose.tools.eq_(len(sample), 3)
         assert isinstance(sample[0], str)
@@ -271,6 +319,7 @@ def test_drishtigs1_cup_all():
     from ..configs.datasets.drishtigs1.cup_all import dataset
 
     nose.tools.eq_(len(dataset["train"]), 50)
+    nose.tools.eq_(dataset["train"].augmented, True)
     for sample in dataset["train"][:N]:
         nose.tools.eq_(len(sample), 3)
         assert isinstance(sample[0], str)
@@ -280,6 +329,7 @@ def test_drishtigs1_cup_all():
         nose.tools.eq_(sample[2].dtype, torch.float32)
 
     nose.tools.eq_(len(dataset["test"]), 51)
+    nose.tools.eq_(dataset["test"].augmented, False)
     for sample in dataset["test"][:N]:
         nose.tools.eq_(len(sample), 3)
         assert isinstance(sample[0], str)
@@ -295,6 +345,7 @@ def test_drionsdb_expert1():
     from ..configs.datasets.drionsdb.expert1 import dataset
 
     nose.tools.eq_(len(dataset["train"]), 60)
+    nose.tools.eq_(dataset["train"].augmented, True)
     for sample in dataset["train"][:N]:
         nose.tools.eq_(len(sample), 3)
         assert isinstance(sample[0], str)
@@ -304,6 +355,7 @@ def test_drionsdb_expert1():
         nose.tools.eq_(sample[2].dtype, torch.float32)
 
     nose.tools.eq_(len(dataset["test"]), 50)
+    nose.tools.eq_(dataset["test"].augmented, False)
     for sample in dataset["test"][:N]:
         nose.tools.eq_(len(sample), 3)
         assert isinstance(sample[0], str)
@@ -322,6 +374,8 @@ def test_drive_covd():
     from ..configs.datasets.drive.covd import dataset
 
     nose.tools.eq_(len(dataset["train"]), 53)
+    #nose.tools.eq_(dataset["train"].augmented, True)  ##ConcatDataset
+    nose.tools.eq_(dataset["test"].augmented, False)
     for sample in dataset["train"]:
         assert 3 <= len(sample) <= 4
         assert isinstance(sample[0], str)
@@ -344,6 +398,8 @@ def test_drive_ssl():
     from ..configs.datasets.drive.ssl import dataset
 
     nose.tools.eq_(len(dataset["train"]), 53)
+    #nose.tools.eq_(dataset["train"].augmented, True)  ##ConcatDataset
+    nose.tools.eq_(dataset["test"].augmented, False)
     for sample in dataset["train"]:
         assert 5 <= len(sample) <= 6
         assert isinstance(sample[0], str)
@@ -372,6 +428,8 @@ def test_stare_covd():
     from ..configs.datasets.stare.covd import dataset
 
     nose.tools.eq_(len(dataset["train"]), 63)
+    #nose.tools.eq_(dataset["train"].augmented, True)  ##ConcatDataset
+    nose.tools.eq_(dataset["test"].augmented, False)
     for sample in dataset["train"]:
         assert 3 <= len(sample) <= 4
         assert isinstance(sample[0], str)
@@ -393,6 +451,8 @@ def test_chasedb1_covd():
     from ..configs.datasets.chasedb1.covd import dataset
 
     nose.tools.eq_(len(dataset["train"]), 65)
+    #nose.tools.eq_(dataset["train"].augmented, True)  ##ConcatDataset
+    nose.tools.eq_(dataset["test"].augmented, False)
     for sample in dataset["train"]:
         assert 3 <= len(sample) <= 4
         assert isinstance(sample[0], str)
@@ -414,6 +474,8 @@ def test_hrf_covd():
     from ..configs.datasets.hrf.covd import dataset
 
     nose.tools.eq_(len(dataset["train"]), 58)
+    #nose.tools.eq_(dataset["train"].augmented, True)  ##ConcatDataset
+    nose.tools.eq_(dataset["test"].augmented, False)
     for sample in dataset["train"]:
         assert 3 <= len(sample) <= 4
         assert isinstance(sample[0], str)
@@ -435,6 +497,8 @@ def test_iostar_covd():
     from ..configs.datasets.iostar.covd import dataset
 
     nose.tools.eq_(len(dataset["train"]), 53)
+    #nose.tools.eq_(dataset["train"].augmented, True)  ##ConcatDataset
+    nose.tools.eq_(dataset["test"].augmented, False)
     for sample in dataset["train"]:
         assert 3 <= len(sample) <= 4
         assert isinstance(sample[0], str)
-- 
GitLab