From 1f1a5b2bbed2dafc70506408b78ba357d6129120 Mon Sep 17 00:00:00 2001 From: "ogueler@idiap.ch" <ogueler@vws110.idiap.ch> Date: Sat, 15 Apr 2023 22:42:49 +0200 Subject: [PATCH] fixed wrong val/test set length number in tests --- tests/test_11k_v2.py | 69 ++++++++++++++++++++++++++++++++++++++--- tests/test_11k_v2_RS.py | 33 ++++++++++++++++++-- tests/test_11k_v3.py | 69 ++++++++++++++++++++++++++++++++++++++--- tests/test_11k_v3_RS.py | 33 ++++++++++++++++++-- 4 files changed, 192 insertions(+), 12 deletions(-) diff --git a/tests/test_11k_v2.py b/tests/test_11k_v2.py index 38d129a7..dc566e6d 100644 --- a/tests/test_11k_v2.py +++ b/tests/test_11k_v2.py @@ -39,8 +39,8 @@ def test_protocol_consistency(): for s in subset["test"]: assert s.label in [0.0, 1.0] - # Cross-validation fold 0-9 - for f in range(10): + # Cross-validation fold 0-8 + for f in range(9): subset = dataset.subsets("fold_" + str(f)) assert len(subset) == 3 @@ -69,6 +69,34 @@ def test_protocol_consistency(): for s in subset["test"]: assert s.label in [0.0, 1.0] + # Cross-validation fold 9 + subset = dataset.subsets("fold_9") + assert len(subset) == 3 + + assert "train" in subset + assert len(subset["train"]) == 6003 + for s in subset["train"]: + assert s.key.startswith("images/") + + assert "validation" in subset + assert len(subset["validation"]) == 1530 + for s in subset["validation"]: + assert s.key.startswith("images/") + + assert "test" in subset + assert len(subset["test"]) == 836 + for s in subset["test"]: + assert s.key.startswith("images/") + + # Check labels + for s in subset["train"]: + assert s.label in [0.0, 1.0] + + for s in subset["validation"]: + assert s.label in [0.0, 1.0] + + for s in subset["test"]: + assert s.label in [0.0, 1.0] def test_protocol_consistency_bbox(): from ptbench.data.tbx11k_simplified_v2 import dataset_with_bboxes @@ -106,8 +134,8 @@ def test_protocol_consistency_bbox(): for s in subset["train"]: assert s.bboxes == "none" or s.bboxes[0].startswith("{'xmin':") - # Cross-validation fold 0-9 - for f in range(10): + # Cross-validation fold 0-8 + for f in range(9): subset = dataset_with_bboxes.subsets("fold_" + str(f)) assert len(subset) == 3 @@ -140,6 +168,39 @@ def test_protocol_consistency_bbox(): for s in subset["train"]: assert s.bboxes == "none" or s.bboxes[0].startswith("{'xmin':") + # Cross-validation fold 9 + subset = dataset_with_bboxes.subsets("fold_9") + assert len(subset) == 3 + + assert "train" in subset + assert len(subset["train"]) == 6003 + for s in subset["train"]: + assert s.key.startswith("images/") + + assert "validation" in subset + assert len(subset["validation"]) == 1530 + for s in subset["validation"]: + assert s.key.startswith("images/") + + assert "test" in subset + assert len(subset["test"]) == 836 + for s in subset["test"]: + assert s.key.startswith("images/") + + # Check labels + for s in subset["train"]: + assert s.label in [0.0, 1.0] + + for s in subset["validation"]: + assert s.label in [0.0, 1.0] + + for s in subset["test"]: + assert s.label in [0.0, 1.0] + + # Check bounding boxes + for s in subset["train"]: + assert s.bboxes == "none" or s.bboxes[0].startswith("{'xmin':") + @pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k_simplified_v2") def test_loading(): diff --git a/tests/test_11k_v2_RS.py b/tests/test_11k_v2_RS.py index 9eb1bd1f..fe6dd8ee 100644 --- a/tests/test_11k_v2_RS.py +++ b/tests/test_11k_v2_RS.py @@ -35,8 +35,8 @@ # for s in subset["test"]: # assert s.label in [0.0, 1.0] -# # Cross-validation fold 0-9 -# for f in range(10): +# # Cross-validation fold 0-8 +# for f in range(9): # subset = dataset.subsets("fold_" + str(f)) # assert len(subset) == 3 @@ -65,6 +65,35 @@ # for s in subset["test"]: # assert s.label in [0.0, 1.0] +# # Cross-validation fold 9 +# subset = dataset.subsets("fold_9") +# assert len(subset) == 3 + +# assert "train" in subset +# assert len(subset["train"]) == 6003 +# for s in subset["train"]: +# assert s.key.startswith("images/") + +# assert "validation" in subset +# assert len(subset["validation"]) == 1530 +# for s in subset["validation"]: +# assert s.key.startswith("images/") + +# assert "test" in subset +# assert len(subset["test"]) == 836 +# for s in subset["test"]: +# assert s.key.startswith("images/") + +# # Check labels +# for s in subset["train"]: +# assert s.label in [0.0, 1.0] + +# for s in subset["validation"]: +# assert s.label in [0.0, 1.0] + +# for s in subset["test"]: +# assert s.label in [0.0, 1.0] + # @pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k_simplified") # def test_loading(): diff --git a/tests/test_11k_v3.py b/tests/test_11k_v3.py index b15c38ca..f2f66c90 100644 --- a/tests/test_11k_v3.py +++ b/tests/test_11k_v3.py @@ -39,8 +39,8 @@ def test_protocol_consistency(): for s in subset["test"]: assert s.label in [0.0, 1.0, 2.0, 3.0] - # Cross-validation fold 0-9 - for f in range(10): + # Cross-validation fold 0-8 + for f in range(9): subset = dataset.subsets("fold_" + str(f)) assert len(subset) == 3 @@ -69,6 +69,34 @@ def test_protocol_consistency(): for s in subset["test"]: assert s.label in [0.0, 1.0, 2.0, 3.0] + # Cross-validation fold 9 + subset = dataset.subsets("fold_9") + assert len(subset) == 3 + + assert "train" in subset + assert len(subset["train"]) == 6003 + for s in subset["train"]: + assert s.key.startswith("images/") + + assert "validation" in subset + assert len(subset["validation"]) == 1530 + for s in subset["validation"]: + assert s.key.startswith("images/") + + assert "test" in subset + assert len(subset["test"]) == 836 + for s in subset["test"]: + assert s.key.startswith("images/") + + # Check labels + for s in subset["train"]: + assert s.label in [0.0, 1.0, 2.0, 3.0] + + for s in subset["validation"]: + assert s.label in [0.0, 1.0, 2.0, 3.0] + + for s in subset["test"]: + assert s.label in [0.0, 1.0, 2.0, 3.0] def test_protocol_consistency_bbox(): from ptbench.data.tbx11k_simplified_v3 import dataset_with_bboxes @@ -106,8 +134,8 @@ def test_protocol_consistency_bbox(): for s in subset["train"]: assert s.bboxes == "none" or s.bboxes[0].startswith("{'xmin':") - # Cross-validation fold 0-9 - for f in range(10): + # Cross-validation fold 0-8 + for f in range(9): subset = dataset_with_bboxes.subsets("fold_" + str(f)) assert len(subset) == 3 @@ -140,6 +168,39 @@ def test_protocol_consistency_bbox(): for s in subset["train"]: assert s.bboxes == "none" or s.bboxes[0].startswith("{'xmin':") + # Cross-validation fold 9 + subset = dataset_with_bboxes.subsets("fold_9") + assert len(subset) == 3 + + assert "train" in subset + assert len(subset["train"]) == 6003 + for s in subset["train"]: + assert s.key.startswith("images/") + + assert "validation" in subset + assert len(subset["validation"]) == 1530 + for s in subset["validation"]: + assert s.key.startswith("images/") + + assert "test" in subset + assert len(subset["test"]) == 836 + for s in subset["test"]: + assert s.key.startswith("images/") + + # Check labels + for s in subset["train"]: + assert s.label in [0.0, 1.0, 2.0, 3.0] + + for s in subset["validation"]: + assert s.label in [0.0, 1.0, 2.0, 3.0] + + for s in subset["test"]: + assert s.label in [0.0, 1.0, 2.0, 3.0] + + # Check bounding boxes + for s in subset["train"]: + assert s.bboxes == "none" or s.bboxes[0].startswith("{'xmin':") + @pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k_simplified_v3") def test_loading(): diff --git a/tests/test_11k_v3_RS.py b/tests/test_11k_v3_RS.py index 302fbce9..4c3c28a0 100644 --- a/tests/test_11k_v3_RS.py +++ b/tests/test_11k_v3_RS.py @@ -35,8 +35,8 @@ # for s in subset["test"]: # assert s.label in [0.0, 1.0, 2.0, 3.0] -# # Cross-validation fold 0-9 -# for f in range(10): +# # Cross-validation fold 0-8 +# for f in range(9): # subset = dataset.subsets("fold_" + str(f)) # assert len(subset) == 3 @@ -65,6 +65,35 @@ # for s in subset["test"]: # assert s.label in [0.0, 1.0, 2.0, 3.0] +# # Cross-validation fold 9 +# subset = dataset.subsets("fold_9") +# assert len(subset) == 3 + +# assert "train" in subset +# assert len(subset["train"]) == 6003 +# for s in subset["train"]: +# assert s.key.startswith("images/") + +# assert "validation" in subset +# assert len(subset["validation"]) == 1530 +# for s in subset["validation"]: +# assert s.key.startswith("images/") + +# assert "test" in subset +# assert len(subset["test"]) == 836 +# for s in subset["test"]: +# assert s.key.startswith("images/") + +# # Check labels +# for s in subset["train"]: +# assert s.label in [0.0, 1.0, 2.0, 3.0] + +# for s in subset["validation"]: +# assert s.label in [0.0, 1.0, 2.0, 3.0] + +# for s in subset["test"]: +# assert s.label in [0.0, 1.0, 2.0, 3.0] + # @pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k_simplified") # def test_loading(): -- GitLab