From 1f1a5b2bbed2dafc70506408b78ba357d6129120 Mon Sep 17 00:00:00 2001
From: "ogueler@idiap.ch" <ogueler@vws110.idiap.ch>
Date: Sat, 15 Apr 2023 22:42:49 +0200
Subject: [PATCH] fixed wrong val/test set length number in tests

---
 tests/test_11k_v2.py    | 69 ++++++++++++++++++++++++++++++++++++++---
 tests/test_11k_v2_RS.py | 33 ++++++++++++++++++--
 tests/test_11k_v3.py    | 69 ++++++++++++++++++++++++++++++++++++++---
 tests/test_11k_v3_RS.py | 33 ++++++++++++++++++--
 4 files changed, 192 insertions(+), 12 deletions(-)

diff --git a/tests/test_11k_v2.py b/tests/test_11k_v2.py
index 38d129a7..dc566e6d 100644
--- a/tests/test_11k_v2.py
+++ b/tests/test_11k_v2.py
@@ -39,8 +39,8 @@ def test_protocol_consistency():
     for s in subset["test"]:
         assert s.label in [0.0, 1.0]
 
-    # Cross-validation fold 0-9
-    for f in range(10):
+    # Cross-validation fold 0-8
+    for f in range(9):
         subset = dataset.subsets("fold_" + str(f))
         assert len(subset) == 3
 
@@ -69,6 +69,34 @@ def test_protocol_consistency():
         for s in subset["test"]:
             assert s.label in [0.0, 1.0]
 
+    # Cross-validation fold 9
+    subset = dataset.subsets("fold_9")
+    assert len(subset) == 3
+
+    assert "train" in subset
+    assert len(subset["train"]) == 6003
+    for s in subset["train"]:
+        assert s.key.startswith("images/")
+
+    assert "validation" in subset
+    assert len(subset["validation"]) == 1530
+    for s in subset["validation"]:
+        assert s.key.startswith("images/")
+
+    assert "test" in subset
+    assert len(subset["test"]) == 836
+    for s in subset["test"]:
+        assert s.key.startswith("images/")
+
+    # Check labels
+    for s in subset["train"]:
+        assert s.label in [0.0, 1.0]
+
+    for s in subset["validation"]:
+        assert s.label in [0.0, 1.0]
+
+    for s in subset["test"]:
+        assert s.label in [0.0, 1.0]
 
 def test_protocol_consistency_bbox():
     from ptbench.data.tbx11k_simplified_v2 import dataset_with_bboxes
@@ -106,8 +134,8 @@ def test_protocol_consistency_bbox():
     for s in subset["train"]:
         assert s.bboxes == "none" or s.bboxes[0].startswith("{'xmin':")
 
-    # Cross-validation fold 0-9
-    for f in range(10):
+    # Cross-validation fold 0-8
+    for f in range(9):
         subset = dataset_with_bboxes.subsets("fold_" + str(f))
         assert len(subset) == 3
 
@@ -140,6 +168,39 @@ def test_protocol_consistency_bbox():
         for s in subset["train"]:
             assert s.bboxes == "none" or s.bboxes[0].startswith("{'xmin':")
 
+    # Cross-validation fold 9
+    subset = dataset_with_bboxes.subsets("fold_9")
+    assert len(subset) == 3
+
+    assert "train" in subset
+    assert len(subset["train"]) == 6003
+    for s in subset["train"]:
+        assert s.key.startswith("images/")
+
+    assert "validation" in subset
+    assert len(subset["validation"]) == 1530
+    for s in subset["validation"]:
+        assert s.key.startswith("images/")
+
+    assert "test" in subset
+    assert len(subset["test"]) == 836
+    for s in subset["test"]:
+        assert s.key.startswith("images/")
+
+    # Check labels
+    for s in subset["train"]:
+        assert s.label in [0.0, 1.0]
+
+    for s in subset["validation"]:
+        assert s.label in [0.0, 1.0]
+
+    for s in subset["test"]:
+        assert s.label in [0.0, 1.0]
+
+    # Check bounding boxes
+    for s in subset["train"]:
+        assert s.bboxes == "none" or s.bboxes[0].startswith("{'xmin':")
+
 
 @pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k_simplified_v2")
 def test_loading():
diff --git a/tests/test_11k_v2_RS.py b/tests/test_11k_v2_RS.py
index 9eb1bd1f..fe6dd8ee 100644
--- a/tests/test_11k_v2_RS.py
+++ b/tests/test_11k_v2_RS.py
@@ -35,8 +35,8 @@
 #     for s in subset["test"]:
 #         assert s.label in [0.0, 1.0]
 
-#     # Cross-validation fold 0-9
-#     for f in range(10):
+#     # Cross-validation fold 0-8
+#     for f in range(9):
 #         subset = dataset.subsets("fold_" + str(f))
 #         assert len(subset) == 3
 
@@ -65,6 +65,35 @@
 #         for s in subset["test"]:
 #             assert s.label in [0.0, 1.0]
 
+#     # Cross-validation fold 9
+#     subset = dataset.subsets("fold_9")
+#     assert len(subset) == 3
+
+#     assert "train" in subset
+#     assert len(subset["train"]) == 6003
+#     for s in subset["train"]:
+#         assert s.key.startswith("images/")
+
+#     assert "validation" in subset
+#     assert len(subset["validation"]) == 1530
+#     for s in subset["validation"]:
+#         assert s.key.startswith("images/")
+
+#     assert "test" in subset
+#     assert len(subset["test"]) == 836
+#     for s in subset["test"]:
+#         assert s.key.startswith("images/")
+
+#     # Check labels
+#     for s in subset["train"]:
+#         assert s.label in [0.0, 1.0]
+
+#     for s in subset["validation"]:
+#         assert s.label in [0.0, 1.0]
+
+#     for s in subset["test"]:
+#         assert s.label in [0.0, 1.0]
+
 
 # @pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k_simplified")
 # def test_loading():
diff --git a/tests/test_11k_v3.py b/tests/test_11k_v3.py
index b15c38ca..f2f66c90 100644
--- a/tests/test_11k_v3.py
+++ b/tests/test_11k_v3.py
@@ -39,8 +39,8 @@ def test_protocol_consistency():
     for s in subset["test"]:
         assert s.label in [0.0, 1.0, 2.0, 3.0]
 
-    # Cross-validation fold 0-9
-    for f in range(10):
+    # Cross-validation fold 0-8
+    for f in range(9):
         subset = dataset.subsets("fold_" + str(f))
         assert len(subset) == 3
 
@@ -69,6 +69,34 @@ def test_protocol_consistency():
         for s in subset["test"]:
             assert s.label in [0.0, 1.0, 2.0, 3.0]
 
+    # Cross-validation fold 9
+    subset = dataset.subsets("fold_9")
+    assert len(subset) == 3
+
+    assert "train" in subset
+    assert len(subset["train"]) == 6003
+    for s in subset["train"]:
+        assert s.key.startswith("images/")
+
+    assert "validation" in subset
+    assert len(subset["validation"]) == 1530
+    for s in subset["validation"]:
+        assert s.key.startswith("images/")
+
+    assert "test" in subset
+    assert len(subset["test"]) == 836
+    for s in subset["test"]:
+        assert s.key.startswith("images/")
+
+    # Check labels
+    for s in subset["train"]:
+        assert s.label in [0.0, 1.0, 2.0, 3.0]
+
+    for s in subset["validation"]:
+        assert s.label in [0.0, 1.0, 2.0, 3.0]
+
+    for s in subset["test"]:
+        assert s.label in [0.0, 1.0, 2.0, 3.0]
 
 def test_protocol_consistency_bbox():
     from ptbench.data.tbx11k_simplified_v3 import dataset_with_bboxes
@@ -106,8 +134,8 @@ def test_protocol_consistency_bbox():
     for s in subset["train"]:
         assert s.bboxes == "none" or s.bboxes[0].startswith("{'xmin':")
 
-    # Cross-validation fold 0-9
-    for f in range(10):
+    # Cross-validation fold 0-8
+    for f in range(9):
         subset = dataset_with_bboxes.subsets("fold_" + str(f))
         assert len(subset) == 3
 
@@ -140,6 +168,39 @@ def test_protocol_consistency_bbox():
         for s in subset["train"]:
             assert s.bboxes == "none" or s.bboxes[0].startswith("{'xmin':")
 
+    # Cross-validation fold 9
+    subset = dataset_with_bboxes.subsets("fold_9")
+    assert len(subset) == 3
+
+    assert "train" in subset
+    assert len(subset["train"]) == 6003
+    for s in subset["train"]:
+        assert s.key.startswith("images/")
+
+    assert "validation" in subset
+    assert len(subset["validation"]) == 1530
+    for s in subset["validation"]:
+        assert s.key.startswith("images/")
+
+    assert "test" in subset
+    assert len(subset["test"]) == 836
+    for s in subset["test"]:
+        assert s.key.startswith("images/")
+
+    # Check labels
+    for s in subset["train"]:
+        assert s.label in [0.0, 1.0, 2.0, 3.0]
+
+    for s in subset["validation"]:
+        assert s.label in [0.0, 1.0, 2.0, 3.0]
+
+    for s in subset["test"]:
+        assert s.label in [0.0, 1.0, 2.0, 3.0]
+
+    # Check bounding boxes
+    for s in subset["train"]:
+        assert s.bboxes == "none" or s.bboxes[0].startswith("{'xmin':")
+
 
 @pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k_simplified_v3")
 def test_loading():
diff --git a/tests/test_11k_v3_RS.py b/tests/test_11k_v3_RS.py
index 302fbce9..4c3c28a0 100644
--- a/tests/test_11k_v3_RS.py
+++ b/tests/test_11k_v3_RS.py
@@ -35,8 +35,8 @@
 #     for s in subset["test"]:
 #         assert s.label in [0.0, 1.0, 2.0, 3.0]
 
-#     # Cross-validation fold 0-9
-#     for f in range(10):
+#     # Cross-validation fold 0-8
+#     for f in range(9):
 #         subset = dataset.subsets("fold_" + str(f))
 #         assert len(subset) == 3
 
@@ -65,6 +65,35 @@
 #         for s in subset["test"]:
 #             assert s.label in [0.0, 1.0, 2.0, 3.0]
 
+#     # Cross-validation fold 9
+#     subset = dataset.subsets("fold_9")
+#     assert len(subset) == 3
+
+#     assert "train" in subset
+#     assert len(subset["train"]) == 6003
+#     for s in subset["train"]:
+#         assert s.key.startswith("images/")
+
+#     assert "validation" in subset
+#     assert len(subset["validation"]) == 1530
+#     for s in subset["validation"]:
+#         assert s.key.startswith("images/")
+
+#     assert "test" in subset
+#     assert len(subset["test"]) == 836
+#     for s in subset["test"]:
+#         assert s.key.startswith("images/")
+
+#     # Check labels
+#     for s in subset["train"]:
+#         assert s.label in [0.0, 1.0, 2.0, 3.0]
+
+#     for s in subset["validation"]:
+#         assert s.label in [0.0, 1.0, 2.0, 3.0]
+
+#     for s in subset["test"]:
+#         assert s.label in [0.0, 1.0, 2.0, 3.0]
+
 
 # @pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k_simplified")
 # def test_loading():
-- 
GitLab