diff --git a/tests/test_11k_v2.py b/tests/test_11k_v2.py index 38d129a7ddb5f302998e42e0487e92738f419bcc..dc566e6d13541eaa62fcd6ac75f5ad503bb25099 100644 --- a/tests/test_11k_v2.py +++ b/tests/test_11k_v2.py @@ -39,8 +39,8 @@ def test_protocol_consistency(): for s in subset["test"]: assert s.label in [0.0, 1.0] - # Cross-validation fold 0-9 - for f in range(10): + # Cross-validation fold 0-8 + for f in range(9): subset = dataset.subsets("fold_" + str(f)) assert len(subset) == 3 @@ -69,6 +69,34 @@ def test_protocol_consistency(): for s in subset["test"]: assert s.label in [0.0, 1.0] + # Cross-validation fold 9 + subset = dataset.subsets("fold_9") + assert len(subset) == 3 + + assert "train" in subset + assert len(subset["train"]) == 6003 + for s in subset["train"]: + assert s.key.startswith("images/") + + assert "validation" in subset + assert len(subset["validation"]) == 1530 + for s in subset["validation"]: + assert s.key.startswith("images/") + + assert "test" in subset + assert len(subset["test"]) == 836 + for s in subset["test"]: + assert s.key.startswith("images/") + + # Check labels + for s in subset["train"]: + assert s.label in [0.0, 1.0] + + for s in subset["validation"]: + assert s.label in [0.0, 1.0] + + for s in subset["test"]: + assert s.label in [0.0, 1.0] def test_protocol_consistency_bbox(): from ptbench.data.tbx11k_simplified_v2 import dataset_with_bboxes @@ -106,8 +134,8 @@ def test_protocol_consistency_bbox(): for s in subset["train"]: assert s.bboxes == "none" or s.bboxes[0].startswith("{'xmin':") - # Cross-validation fold 0-9 - for f in range(10): + # Cross-validation fold 0-8 + for f in range(9): subset = dataset_with_bboxes.subsets("fold_" + str(f)) assert len(subset) == 3 @@ -140,6 +168,39 @@ def test_protocol_consistency_bbox(): for s in subset["train"]: assert s.bboxes == "none" or s.bboxes[0].startswith("{'xmin':") + # Cross-validation fold 9 + subset = dataset_with_bboxes.subsets("fold_9") + assert len(subset) == 3 + + assert "train" in subset + assert len(subset["train"]) == 6003 + for s in subset["train"]: + assert s.key.startswith("images/") + + assert "validation" in subset + assert len(subset["validation"]) == 1530 + for s in subset["validation"]: + assert s.key.startswith("images/") + + assert "test" in subset + assert len(subset["test"]) == 836 + for s in subset["test"]: + assert s.key.startswith("images/") + + # Check labels + for s in subset["train"]: + assert s.label in [0.0, 1.0] + + for s in subset["validation"]: + assert s.label in [0.0, 1.0] + + for s in subset["test"]: + assert s.label in [0.0, 1.0] + + # Check bounding boxes + for s in subset["train"]: + assert s.bboxes == "none" or s.bboxes[0].startswith("{'xmin':") + @pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k_simplified_v2") def test_loading(): diff --git a/tests/test_11k_v2_RS.py b/tests/test_11k_v2_RS.py index 9eb1bd1f0171a1f49688133c542ab59f4817ed12..fe6dd8ee6096d53a7bdc327e6fe5722f29165c62 100644 --- a/tests/test_11k_v2_RS.py +++ b/tests/test_11k_v2_RS.py @@ -35,8 +35,8 @@ # for s in subset["test"]: # assert s.label in [0.0, 1.0] -# # Cross-validation fold 0-9 -# for f in range(10): +# # Cross-validation fold 0-8 +# for f in range(9): # subset = dataset.subsets("fold_" + str(f)) # assert len(subset) == 3 @@ -65,6 +65,35 @@ # for s in subset["test"]: # assert s.label in [0.0, 1.0] +# # Cross-validation fold 9 +# subset = dataset.subsets("fold_9") +# assert len(subset) == 3 + +# assert "train" in subset +# assert len(subset["train"]) == 6003 +# for s in subset["train"]: +# assert s.key.startswith("images/") + +# assert "validation" in subset +# assert len(subset["validation"]) == 1530 +# for s in subset["validation"]: +# assert s.key.startswith("images/") + +# assert "test" in subset +# assert len(subset["test"]) == 836 +# for s in subset["test"]: +# assert s.key.startswith("images/") + +# # Check labels +# for s in subset["train"]: +# assert s.label in [0.0, 1.0] + +# for s in subset["validation"]: +# assert s.label in [0.0, 1.0] + +# for s in subset["test"]: +# assert s.label in [0.0, 1.0] + # @pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k_simplified") # def test_loading(): diff --git a/tests/test_11k_v3.py b/tests/test_11k_v3.py index b15c38cab75d24d21d4af8763f96675a1069ae85..f2f66c9016fd2f5344d9e2bcceb63925f3284ba0 100644 --- a/tests/test_11k_v3.py +++ b/tests/test_11k_v3.py @@ -39,8 +39,8 @@ def test_protocol_consistency(): for s in subset["test"]: assert s.label in [0.0, 1.0, 2.0, 3.0] - # Cross-validation fold 0-9 - for f in range(10): + # Cross-validation fold 0-8 + for f in range(9): subset = dataset.subsets("fold_" + str(f)) assert len(subset) == 3 @@ -69,6 +69,34 @@ def test_protocol_consistency(): for s in subset["test"]: assert s.label in [0.0, 1.0, 2.0, 3.0] + # Cross-validation fold 9 + subset = dataset.subsets("fold_9") + assert len(subset) == 3 + + assert "train" in subset + assert len(subset["train"]) == 6003 + for s in subset["train"]: + assert s.key.startswith("images/") + + assert "validation" in subset + assert len(subset["validation"]) == 1530 + for s in subset["validation"]: + assert s.key.startswith("images/") + + assert "test" in subset + assert len(subset["test"]) == 836 + for s in subset["test"]: + assert s.key.startswith("images/") + + # Check labels + for s in subset["train"]: + assert s.label in [0.0, 1.0, 2.0, 3.0] + + for s in subset["validation"]: + assert s.label in [0.0, 1.0, 2.0, 3.0] + + for s in subset["test"]: + assert s.label in [0.0, 1.0, 2.0, 3.0] def test_protocol_consistency_bbox(): from ptbench.data.tbx11k_simplified_v3 import dataset_with_bboxes @@ -106,8 +134,8 @@ def test_protocol_consistency_bbox(): for s in subset["train"]: assert s.bboxes == "none" or s.bboxes[0].startswith("{'xmin':") - # Cross-validation fold 0-9 - for f in range(10): + # Cross-validation fold 0-8 + for f in range(9): subset = dataset_with_bboxes.subsets("fold_" + str(f)) assert len(subset) == 3 @@ -140,6 +168,39 @@ def test_protocol_consistency_bbox(): for s in subset["train"]: assert s.bboxes == "none" or s.bboxes[0].startswith("{'xmin':") + # Cross-validation fold 9 + subset = dataset_with_bboxes.subsets("fold_9") + assert len(subset) == 3 + + assert "train" in subset + assert len(subset["train"]) == 6003 + for s in subset["train"]: + assert s.key.startswith("images/") + + assert "validation" in subset + assert len(subset["validation"]) == 1530 + for s in subset["validation"]: + assert s.key.startswith("images/") + + assert "test" in subset + assert len(subset["test"]) == 836 + for s in subset["test"]: + assert s.key.startswith("images/") + + # Check labels + for s in subset["train"]: + assert s.label in [0.0, 1.0, 2.0, 3.0] + + for s in subset["validation"]: + assert s.label in [0.0, 1.0, 2.0, 3.0] + + for s in subset["test"]: + assert s.label in [0.0, 1.0, 2.0, 3.0] + + # Check bounding boxes + for s in subset["train"]: + assert s.bboxes == "none" or s.bboxes[0].startswith("{'xmin':") + @pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k_simplified_v3") def test_loading(): diff --git a/tests/test_11k_v3_RS.py b/tests/test_11k_v3_RS.py index 302fbce962438546d41c3823e7f9c07f70711a88..4c3c28a0ee47114f4d64cdfb98ed36b90e4aff42 100644 --- a/tests/test_11k_v3_RS.py +++ b/tests/test_11k_v3_RS.py @@ -35,8 +35,8 @@ # for s in subset["test"]: # assert s.label in [0.0, 1.0, 2.0, 3.0] -# # Cross-validation fold 0-9 -# for f in range(10): +# # Cross-validation fold 0-8 +# for f in range(9): # subset = dataset.subsets("fold_" + str(f)) # assert len(subset) == 3 @@ -65,6 +65,35 @@ # for s in subset["test"]: # assert s.label in [0.0, 1.0, 2.0, 3.0] +# # Cross-validation fold 9 +# subset = dataset.subsets("fold_9") +# assert len(subset) == 3 + +# assert "train" in subset +# assert len(subset["train"]) == 6003 +# for s in subset["train"]: +# assert s.key.startswith("images/") + +# assert "validation" in subset +# assert len(subset["validation"]) == 1530 +# for s in subset["validation"]: +# assert s.key.startswith("images/") + +# assert "test" in subset +# assert len(subset["test"]) == 836 +# for s in subset["test"]: +# assert s.key.startswith("images/") + +# # Check labels +# for s in subset["train"]: +# assert s.label in [0.0, 1.0, 2.0, 3.0] + +# for s in subset["validation"]: +# assert s.label in [0.0, 1.0, 2.0, 3.0] + +# for s in subset["test"]: +# assert s.label in [0.0, 1.0, 2.0, 3.0] + # @pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k_simplified") # def test_loading():