From 795e4494b6d1c3d03dac9f887296655430b4571b Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Fri, 21 Jul 2023 20:35:33 +0200 Subject: [PATCH] [data.split] Make variables private --- src/ptbench/data/datamodule.py | 6 ++---- src/ptbench/data/split.py | 30 +++++++++++++++--------------- 2 files changed, 17 insertions(+), 19 deletions(-) diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py index af0d513a..440c1ea4 100644 --- a/src/ptbench/data/datamodule.py +++ b/src/ptbench/data/datamodule.py @@ -686,9 +686,7 @@ class CachingDataModule(lightning.LightningDataModule): def _val_dataset_keys(self) -> list[str]: """Returns list of validation dataset names.""" return ["validation"] + [ - k - for k in self.database_split.subsets.keys() - if k.startswith("monitor-") + k for k in self.database_split.keys() if k.startswith("monitor-") ] def setup(self, stage: str) -> None: @@ -729,7 +727,7 @@ class CachingDataModule(lightning.LightningDataModule): self._setup_dataset("test") elif stage == "predict": - for k in self.database_split.subsets.keys(): + for k in self.database_split.keys(): self._setup_dataset(k) def teardown(self, stage: str) -> None: diff --git a/src/ptbench/data/split.py b/src/ptbench/data/split.py index 606e40cb..4b055be2 100644 --- a/src/ptbench/data/split.py +++ b/src/ptbench/data/split.py @@ -68,8 +68,8 @@ class JSONDatabaseSplit(DatabaseSplit): def __init__(self, path: pathlib.Path | str | importlib.abc.Traversable): if isinstance(path, str): path = pathlib.Path(path) - self.path = path - self.subsets = self._load_split_from_disk() + self._path = path + self._subsets = self._load_split_from_disk() def _load_split_from_disk(self) -> DatabaseSplit: """Loads all subsets in a split from its file system representation. @@ -86,25 +86,25 @@ class JSONDatabaseSplit(DatabaseSplit): A dictionary mapping subset names to lists of JSON objects """ - if str(self.path).endswith(".bz2"): - logger.debug(f"Loading database split from {str(self.path)}...") - with __import__("bz2").open(self.path) as f: + if str(self._path).endswith(".bz2"): + logger.debug(f"Loading database split from {str(self._path)}...") + with __import__("bz2").open(self._path) as f: return json.load(f) else: - with self.path.open() as f: + with self._path.open() as f: return json.load(f) def __getitem__(self, key: str) -> typing.Sequence[typing.Any]: """Accesses subset ``key`` from this split.""" - return self.subsets[key] + return self._subsets[key] def __iter__(self): """Iterates over the subsets.""" - return iter(self.subsets) + return iter(self._subsets) def __len__(self) -> int: """How many subsets we currently have.""" - return len(self.subsets) + return len(self._subsets) class CSVDatabaseSplit(DatabaseSplit): @@ -149,8 +149,8 @@ class CSVDatabaseSplit(DatabaseSplit): assert ( directory.is_dir() ), f"`{str(directory)}` is not a valid directory" - self.directory = directory - self.subsets = self._load_split_from_disk() + self._directory = directory + self._subsets = self._load_split_from_disk() def _load_split_from_disk(self) -> DatabaseSplit: """Loads all subsets in a split from its file system representation. @@ -168,7 +168,7 @@ class CSVDatabaseSplit(DatabaseSplit): """ retval: DatabaseSplit = {} - for subset in self.directory.iterdir(): + for subset in self._directory.iterdir(): if str(subset).endswith(".csv.bz2"): logger.debug(f"Loading database split from {subset}...") with __import__("bz2").open(subset) as f: @@ -188,15 +188,15 @@ class CSVDatabaseSplit(DatabaseSplit): def __getitem__(self, key: str) -> typing.Sequence[typing.Any]: """Accesses subset ``key`` from this split.""" - return self.subsets[key] + return self._subsets[key] def __iter__(self): """Iterates over the subsets.""" - return iter(self.subsets) + return iter(self._subsets) def __len__(self) -> int: """How many subsets we currently have.""" - return len(self.subsets) + return len(self._subsets) def check_database_split_loading( -- GitLab