From 640950dede0dd9ac6d4988da222166dd82ac4fc8 Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Thu, 27 Jul 2023 20:54:19 +0200
Subject: [PATCH] [data.shenzhen] Use right split name; separate split creation
 so it is reusable

---
 src/ptbench/data/shenzhen/datamodule.py | 15 ++++++++++-----
 src/ptbench/data/shenzhen/default.py    |  2 +-
 src/ptbench/data/shenzhen/fold_0.py     |  2 +-
 src/ptbench/data/shenzhen/fold_1.py     |  2 +-
 src/ptbench/data/shenzhen/fold_2.py     |  2 +-
 src/ptbench/data/shenzhen/fold_3.py     |  2 +-
 src/ptbench/data/shenzhen/fold_4.py     |  2 +-
 src/ptbench/data/shenzhen/fold_5.py     |  2 +-
 src/ptbench/data/shenzhen/fold_6.py     |  2 +-
 src/ptbench/data/shenzhen/fold_7.py     |  2 +-
 src/ptbench/data/shenzhen/fold_8.py     |  2 +-
 src/ptbench/data/shenzhen/fold_9.py     |  2 +-
 12 files changed, 21 insertions(+), 16 deletions(-)

diff --git a/src/ptbench/data/shenzhen/datamodule.py b/src/ptbench/data/shenzhen/datamodule.py
index 45ce8762..8307396d 100644
--- a/src/ptbench/data/shenzhen/datamodule.py
+++ b/src/ptbench/data/shenzhen/datamodule.py
@@ -13,6 +13,7 @@ from ...utils.rc import load_rc
 from ..datamodule import CachingDataModule
 from ..image_utils import remove_black_borders
 from ..split import JSONDatabaseSplit
+from ..typing import DatabaseSplit
 from ..typing import RawDataLoader as _BaseRawDataLoader
 from ..typing import Sample
 
@@ -93,6 +94,14 @@ class RawDataLoader(_BaseRawDataLoader):
         return sample[1]
 
 
+def make_split(basename: str) -> DatabaseSplit:
+    """Returns a database split for the Shenzhen database."""
+
+    return JSONDatabaseSplit(
+        importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(basename)
+    )
+
+
 class DataModule(CachingDataModule):
     """Shenzhen datamodule for computer-aided diagnosis.
 
@@ -128,10 +137,6 @@ class DataModule(CachingDataModule):
 
     def __init__(self, split_filename: str):
         super().__init__(
-            database_split=JSONDatabaseSplit(
-                importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(
-                    split_filename
-                )
-            ),
+            database_split=make_split(split_filename),
             raw_data_loader=RawDataLoader(),
         )
diff --git a/src/ptbench/data/shenzhen/default.py b/src/ptbench/data/shenzhen/default.py
index 0e29c385..93517c03 100644
--- a/src/ptbench/data/shenzhen/default.py
+++ b/src/ptbench/data/shenzhen/default.py
@@ -4,7 +4,7 @@
 
 from .datamodule import DataModule
 
-datamodule = DataModule("default.json.bz2")
+datamodule = DataModule("default.json")
 """Default Shenzen TB database split.
 
 * Training samples: 64% of TB and healthy CXR (including labels)
diff --git a/src/ptbench/data/shenzhen/fold_0.py b/src/ptbench/data/shenzhen/fold_0.py
index c810e85c..3d114d07 100644
--- a/src/ptbench/data/shenzhen/fold_0.py
+++ b/src/ptbench/data/shenzhen/fold_0.py
@@ -4,4 +4,4 @@
 
 from .datamodule import DataModule
 
-datamodule = DataModule("fold_0.json.bz2")
+datamodule = DataModule("fold_0.json")
diff --git a/src/ptbench/data/shenzhen/fold_1.py b/src/ptbench/data/shenzhen/fold_1.py
index 736a778d..cd3a8cb6 100644
--- a/src/ptbench/data/shenzhen/fold_1.py
+++ b/src/ptbench/data/shenzhen/fold_1.py
@@ -4,4 +4,4 @@
 
 from .datamodule import DataModule
 
-datamodule = DataModule("fold_1.json.bz2")
+datamodule = DataModule("fold_1.json")
diff --git a/src/ptbench/data/shenzhen/fold_2.py b/src/ptbench/data/shenzhen/fold_2.py
index 48df1bfe..44eeda80 100644
--- a/src/ptbench/data/shenzhen/fold_2.py
+++ b/src/ptbench/data/shenzhen/fold_2.py
@@ -4,4 +4,4 @@
 
 from .datamodule import DataModule
 
-datamodule = DataModule("fold_2.json.bz2")
+datamodule = DataModule("fold_2.json")
diff --git a/src/ptbench/data/shenzhen/fold_3.py b/src/ptbench/data/shenzhen/fold_3.py
index 9967e4ea..f24fb314 100644
--- a/src/ptbench/data/shenzhen/fold_3.py
+++ b/src/ptbench/data/shenzhen/fold_3.py
@@ -4,4 +4,4 @@
 
 from .datamodule import DataModule
 
-datamodule = DataModule("fold_3.json.bz2")
+datamodule = DataModule("fold_3.json")
diff --git a/src/ptbench/data/shenzhen/fold_4.py b/src/ptbench/data/shenzhen/fold_4.py
index 8630ee09..58456d38 100644
--- a/src/ptbench/data/shenzhen/fold_4.py
+++ b/src/ptbench/data/shenzhen/fold_4.py
@@ -4,4 +4,4 @@
 
 from .datamodule import DataModule
 
-datamodule = DataModule("fold_4.json.bz2")
+datamodule = DataModule("fold_4.json")
diff --git a/src/ptbench/data/shenzhen/fold_5.py b/src/ptbench/data/shenzhen/fold_5.py
index 0c7504c5..92796746 100644
--- a/src/ptbench/data/shenzhen/fold_5.py
+++ b/src/ptbench/data/shenzhen/fold_5.py
@@ -4,4 +4,4 @@
 
 from .datamodule import DataModule
 
-datamodule = DataModule("fold_5.json.bz2")
+datamodule = DataModule("fold_5.json")
diff --git a/src/ptbench/data/shenzhen/fold_6.py b/src/ptbench/data/shenzhen/fold_6.py
index 2f8e8e32..9566b7cf 100644
--- a/src/ptbench/data/shenzhen/fold_6.py
+++ b/src/ptbench/data/shenzhen/fold_6.py
@@ -4,4 +4,4 @@
 
 from .datamodule import DataModule
 
-datamodule = DataModule("fold_6.json.bz2")
+datamodule = DataModule("fold_6.json")
diff --git a/src/ptbench/data/shenzhen/fold_7.py b/src/ptbench/data/shenzhen/fold_7.py
index eb5d6f00..8c7ed885 100644
--- a/src/ptbench/data/shenzhen/fold_7.py
+++ b/src/ptbench/data/shenzhen/fold_7.py
@@ -4,4 +4,4 @@
 
 from .datamodule import DataModule
 
-datamodule = DataModule("fold_7.json.bz2")
+datamodule = DataModule("fold_7.json")
diff --git a/src/ptbench/data/shenzhen/fold_8.py b/src/ptbench/data/shenzhen/fold_8.py
index a9480359..fb5332ce 100644
--- a/src/ptbench/data/shenzhen/fold_8.py
+++ b/src/ptbench/data/shenzhen/fold_8.py
@@ -4,4 +4,4 @@
 
 from .datamodule import DataModule
 
-datamodule = DataModule("fold_8.json.bz2")
+datamodule = DataModule("fold_8.json")
diff --git a/src/ptbench/data/shenzhen/fold_9.py b/src/ptbench/data/shenzhen/fold_9.py
index daa85e03..d1626586 100644
--- a/src/ptbench/data/shenzhen/fold_9.py
+++ b/src/ptbench/data/shenzhen/fold_9.py
@@ -4,4 +4,4 @@
 
 from .datamodule import DataModule
 
-datamodule = DataModule("fold_9.json.bz2")
+datamodule = DataModule("fold_9.json")
-- 
GitLab