Skip to content
Snippets Groups Projects
Commit 8b2a2902 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[data.split] Make variables private

parent 03ae3f4a
No related branches found
No related tags found
No related merge requests found
......@@ -686,9 +686,7 @@ class CachingDataModule(lightning.LightningDataModule):
def _val_dataset_keys(self) -> list[str]:
"""Returns list of validation dataset names."""
return ["validation"] + [
k
for k in self.database_split.subsets.keys()
if k.startswith("monitor-")
k for k in self.database_split.keys() if k.startswith("monitor-")
]
def setup(self, stage: str) -> None:
......@@ -729,7 +727,7 @@ class CachingDataModule(lightning.LightningDataModule):
self._setup_dataset("test")
elif stage == "predict":
for k in self.database_split.subsets.keys():
for k in self.database_split.keys():
self._setup_dataset(k)
def teardown(self, stage: str) -> None:
......
......@@ -68,8 +68,8 @@ class JSONDatabaseSplit(DatabaseSplit):
def __init__(self, path: pathlib.Path | str | importlib.abc.Traversable):
if isinstance(path, str):
path = pathlib.Path(path)
self.path = path
self.subsets = self._load_split_from_disk()
self._path = path
self._subsets = self._load_split_from_disk()
def _load_split_from_disk(self) -> DatabaseSplit:
"""Loads all subsets in a split from its file system representation.
......@@ -86,25 +86,25 @@ class JSONDatabaseSplit(DatabaseSplit):
A dictionary mapping subset names to lists of JSON objects
"""
if str(self.path).endswith(".bz2"):
logger.debug(f"Loading database split from {str(self.path)}...")
with __import__("bz2").open(self.path) as f:
if str(self._path).endswith(".bz2"):
logger.debug(f"Loading database split from {str(self._path)}...")
with __import__("bz2").open(self._path) as f:
return json.load(f)
else:
with self.path.open() as f:
with self._path.open() as f:
return json.load(f)
def __getitem__(self, key: str) -> typing.Sequence[typing.Any]:
"""Accesses subset ``key`` from this split."""
return self.subsets[key]
return self._subsets[key]
def __iter__(self):
"""Iterates over the subsets."""
return iter(self.subsets)
return iter(self._subsets)
def __len__(self) -> int:
"""How many subsets we currently have."""
return len(self.subsets)
return len(self._subsets)
class CSVDatabaseSplit(DatabaseSplit):
......@@ -149,8 +149,8 @@ class CSVDatabaseSplit(DatabaseSplit):
assert (
directory.is_dir()
), f"`{str(directory)}` is not a valid directory"
self.directory = directory
self.subsets = self._load_split_from_disk()
self._directory = directory
self._subsets = self._load_split_from_disk()
def _load_split_from_disk(self) -> DatabaseSplit:
"""Loads all subsets in a split from its file system representation.
......@@ -168,7 +168,7 @@ class CSVDatabaseSplit(DatabaseSplit):
"""
retval: DatabaseSplit = {}
for subset in self.directory.iterdir():
for subset in self._directory.iterdir():
if str(subset).endswith(".csv.bz2"):
logger.debug(f"Loading database split from {subset}...")
with __import__("bz2").open(subset) as f:
......@@ -188,15 +188,15 @@ class CSVDatabaseSplit(DatabaseSplit):
def __getitem__(self, key: str) -> typing.Sequence[typing.Any]:
"""Accesses subset ``key`` from this split."""
return self.subsets[key]
return self._subsets[key]
def __iter__(self):
"""Iterates over the subsets."""
return iter(self.subsets)
return iter(self._subsets)
def __len__(self) -> int:
"""How many subsets we currently have."""
return len(self.subsets)
return len(self._subsets)
def check_database_split_loading(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment