diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py index af0d513ac1bd166179f64fb759bf9e422da8bf11..440c1ea4505efe58d8bf7b17b6b48779bafa9c38 100644 --- a/src/ptbench/data/datamodule.py +++ b/src/ptbench/data/datamodule.py @@ -686,9 +686,7 @@ class CachingDataModule(lightning.LightningDataModule): def _val_dataset_keys(self) -> list[str]: """Returns list of validation dataset names.""" return ["validation"] + [ - k - for k in self.database_split.subsets.keys() - if k.startswith("monitor-") + k for k in self.database_split.keys() if k.startswith("monitor-") ] def setup(self, stage: str) -> None: @@ -729,7 +727,7 @@ class CachingDataModule(lightning.LightningDataModule): self._setup_dataset("test") elif stage == "predict": - for k in self.database_split.subsets.keys(): + for k in self.database_split.keys(): self._setup_dataset(k) def teardown(self, stage: str) -> None: diff --git a/src/ptbench/data/split.py b/src/ptbench/data/split.py index 606e40cbf5ca8e9fb16b8c7ae3c894a440042b0a..4b055be23bf42ca1fec67b146bc9c1b758175ca3 100644 --- a/src/ptbench/data/split.py +++ b/src/ptbench/data/split.py @@ -68,8 +68,8 @@ class JSONDatabaseSplit(DatabaseSplit): def __init__(self, path: pathlib.Path | str | importlib.abc.Traversable): if isinstance(path, str): path = pathlib.Path(path) - self.path = path - self.subsets = self._load_split_from_disk() + self._path = path + self._subsets = self._load_split_from_disk() def _load_split_from_disk(self) -> DatabaseSplit: """Loads all subsets in a split from its file system representation. @@ -86,25 +86,25 @@ class JSONDatabaseSplit(DatabaseSplit): A dictionary mapping subset names to lists of JSON objects """ - if str(self.path).endswith(".bz2"): - logger.debug(f"Loading database split from {str(self.path)}...") - with __import__("bz2").open(self.path) as f: + if str(self._path).endswith(".bz2"): + logger.debug(f"Loading database split from {str(self._path)}...") + with __import__("bz2").open(self._path) as f: return json.load(f) else: - with self.path.open() as f: + with self._path.open() as f: return json.load(f) def __getitem__(self, key: str) -> typing.Sequence[typing.Any]: """Accesses subset ``key`` from this split.""" - return self.subsets[key] + return self._subsets[key] def __iter__(self): """Iterates over the subsets.""" - return iter(self.subsets) + return iter(self._subsets) def __len__(self) -> int: """How many subsets we currently have.""" - return len(self.subsets) + return len(self._subsets) class CSVDatabaseSplit(DatabaseSplit): @@ -149,8 +149,8 @@ class CSVDatabaseSplit(DatabaseSplit): assert ( directory.is_dir() ), f"`{str(directory)}` is not a valid directory" - self.directory = directory - self.subsets = self._load_split_from_disk() + self._directory = directory + self._subsets = self._load_split_from_disk() def _load_split_from_disk(self) -> DatabaseSplit: """Loads all subsets in a split from its file system representation. @@ -168,7 +168,7 @@ class CSVDatabaseSplit(DatabaseSplit): """ retval: DatabaseSplit = {} - for subset in self.directory.iterdir(): + for subset in self._directory.iterdir(): if str(subset).endswith(".csv.bz2"): logger.debug(f"Loading database split from {subset}...") with __import__("bz2").open(subset) as f: @@ -188,15 +188,15 @@ class CSVDatabaseSplit(DatabaseSplit): def __getitem__(self, key: str) -> typing.Sequence[typing.Any]: """Accesses subset ``key`` from this split.""" - return self.subsets[key] + return self._subsets[key] def __iter__(self): """Iterates over the subsets.""" - return iter(self.subsets) + return iter(self._subsets) def __len__(self) -> int: """How many subsets we currently have.""" - return len(self.subsets) + return len(self._subsets) def check_database_split_loading(