From 795e4494b6d1c3d03dac9f887296655430b4571b Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Fri, 21 Jul 2023 20:35:33 +0200
Subject: [PATCH] [data.split] Make variables private

---
 src/ptbench/data/datamodule.py |  6 ++----
 src/ptbench/data/split.py      | 30 +++++++++++++++---------------
 2 files changed, 17 insertions(+), 19 deletions(-)

diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py
index af0d513a..440c1ea4 100644
--- a/src/ptbench/data/datamodule.py
+++ b/src/ptbench/data/datamodule.py
@@ -686,9 +686,7 @@ class CachingDataModule(lightning.LightningDataModule):
     def _val_dataset_keys(self) -> list[str]:
         """Returns list of validation dataset names."""
         return ["validation"] + [
-            k
-            for k in self.database_split.subsets.keys()
-            if k.startswith("monitor-")
+            k for k in self.database_split.keys() if k.startswith("monitor-")
         ]
 
     def setup(self, stage: str) -> None:
@@ -729,7 +727,7 @@ class CachingDataModule(lightning.LightningDataModule):
             self._setup_dataset("test")
 
         elif stage == "predict":
-            for k in self.database_split.subsets.keys():
+            for k in self.database_split.keys():
                 self._setup_dataset(k)
 
     def teardown(self, stage: str) -> None:
diff --git a/src/ptbench/data/split.py b/src/ptbench/data/split.py
index 606e40cb..4b055be2 100644
--- a/src/ptbench/data/split.py
+++ b/src/ptbench/data/split.py
@@ -68,8 +68,8 @@ class JSONDatabaseSplit(DatabaseSplit):
     def __init__(self, path: pathlib.Path | str | importlib.abc.Traversable):
         if isinstance(path, str):
             path = pathlib.Path(path)
-        self.path = path
-        self.subsets = self._load_split_from_disk()
+        self._path = path
+        self._subsets = self._load_split_from_disk()
 
     def _load_split_from_disk(self) -> DatabaseSplit:
         """Loads all subsets in a split from its file system representation.
@@ -86,25 +86,25 @@ class JSONDatabaseSplit(DatabaseSplit):
             A dictionary mapping subset names to lists of JSON objects
         """
 
-        if str(self.path).endswith(".bz2"):
-            logger.debug(f"Loading database split from {str(self.path)}...")
-            with __import__("bz2").open(self.path) as f:
+        if str(self._path).endswith(".bz2"):
+            logger.debug(f"Loading database split from {str(self._path)}...")
+            with __import__("bz2").open(self._path) as f:
                 return json.load(f)
         else:
-            with self.path.open() as f:
+            with self._path.open() as f:
                 return json.load(f)
 
     def __getitem__(self, key: str) -> typing.Sequence[typing.Any]:
         """Accesses subset ``key`` from this split."""
-        return self.subsets[key]
+        return self._subsets[key]
 
     def __iter__(self):
         """Iterates over the subsets."""
-        return iter(self.subsets)
+        return iter(self._subsets)
 
     def __len__(self) -> int:
         """How many subsets we currently have."""
-        return len(self.subsets)
+        return len(self._subsets)
 
 
 class CSVDatabaseSplit(DatabaseSplit):
@@ -149,8 +149,8 @@ class CSVDatabaseSplit(DatabaseSplit):
         assert (
             directory.is_dir()
         ), f"`{str(directory)}` is not a valid directory"
-        self.directory = directory
-        self.subsets = self._load_split_from_disk()
+        self._directory = directory
+        self._subsets = self._load_split_from_disk()
 
     def _load_split_from_disk(self) -> DatabaseSplit:
         """Loads all subsets in a split from its file system representation.
@@ -168,7 +168,7 @@ class CSVDatabaseSplit(DatabaseSplit):
         """
 
         retval: DatabaseSplit = {}
-        for subset in self.directory.iterdir():
+        for subset in self._directory.iterdir():
             if str(subset).endswith(".csv.bz2"):
                 logger.debug(f"Loading database split from {subset}...")
                 with __import__("bz2").open(subset) as f:
@@ -188,15 +188,15 @@ class CSVDatabaseSplit(DatabaseSplit):
 
     def __getitem__(self, key: str) -> typing.Sequence[typing.Any]:
         """Accesses subset ``key`` from this split."""
-        return self.subsets[key]
+        return self._subsets[key]
 
     def __iter__(self):
         """Iterates over the subsets."""
-        return iter(self.subsets)
+        return iter(self._subsets)
 
     def __len__(self) -> int:
         """How many subsets we currently have."""
-        return len(self.subsets)
+        return len(self._subsets)
 
 
 def check_database_split_loading(
-- 
GitLab