Skip to content
Snippets Groups Projects
Commit 795e4494 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[data.split] Make variables private

parent bfc106ab
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
...@@ -686,9 +686,7 @@ class CachingDataModule(lightning.LightningDataModule): ...@@ -686,9 +686,7 @@ class CachingDataModule(lightning.LightningDataModule):
def _val_dataset_keys(self) -> list[str]: def _val_dataset_keys(self) -> list[str]:
"""Returns list of validation dataset names.""" """Returns list of validation dataset names."""
return ["validation"] + [ return ["validation"] + [
k k for k in self.database_split.keys() if k.startswith("monitor-")
for k in self.database_split.subsets.keys()
if k.startswith("monitor-")
] ]
def setup(self, stage: str) -> None: def setup(self, stage: str) -> None:
...@@ -729,7 +727,7 @@ class CachingDataModule(lightning.LightningDataModule): ...@@ -729,7 +727,7 @@ class CachingDataModule(lightning.LightningDataModule):
self._setup_dataset("test") self._setup_dataset("test")
elif stage == "predict": elif stage == "predict":
for k in self.database_split.subsets.keys(): for k in self.database_split.keys():
self._setup_dataset(k) self._setup_dataset(k)
def teardown(self, stage: str) -> None: def teardown(self, stage: str) -> None:
......
...@@ -68,8 +68,8 @@ class JSONDatabaseSplit(DatabaseSplit): ...@@ -68,8 +68,8 @@ class JSONDatabaseSplit(DatabaseSplit):
def __init__(self, path: pathlib.Path | str | importlib.abc.Traversable): def __init__(self, path: pathlib.Path | str | importlib.abc.Traversable):
if isinstance(path, str): if isinstance(path, str):
path = pathlib.Path(path) path = pathlib.Path(path)
self.path = path self._path = path
self.subsets = self._load_split_from_disk() self._subsets = self._load_split_from_disk()
def _load_split_from_disk(self) -> DatabaseSplit: def _load_split_from_disk(self) -> DatabaseSplit:
"""Loads all subsets in a split from its file system representation. """Loads all subsets in a split from its file system representation.
...@@ -86,25 +86,25 @@ class JSONDatabaseSplit(DatabaseSplit): ...@@ -86,25 +86,25 @@ class JSONDatabaseSplit(DatabaseSplit):
A dictionary mapping subset names to lists of JSON objects A dictionary mapping subset names to lists of JSON objects
""" """
if str(self.path).endswith(".bz2"): if str(self._path).endswith(".bz2"):
logger.debug(f"Loading database split from {str(self.path)}...") logger.debug(f"Loading database split from {str(self._path)}...")
with __import__("bz2").open(self.path) as f: with __import__("bz2").open(self._path) as f:
return json.load(f) return json.load(f)
else: else:
with self.path.open() as f: with self._path.open() as f:
return json.load(f) return json.load(f)
def __getitem__(self, key: str) -> typing.Sequence[typing.Any]: def __getitem__(self, key: str) -> typing.Sequence[typing.Any]:
"""Accesses subset ``key`` from this split.""" """Accesses subset ``key`` from this split."""
return self.subsets[key] return self._subsets[key]
def __iter__(self): def __iter__(self):
"""Iterates over the subsets.""" """Iterates over the subsets."""
return iter(self.subsets) return iter(self._subsets)
def __len__(self) -> int: def __len__(self) -> int:
"""How many subsets we currently have.""" """How many subsets we currently have."""
return len(self.subsets) return len(self._subsets)
class CSVDatabaseSplit(DatabaseSplit): class CSVDatabaseSplit(DatabaseSplit):
...@@ -149,8 +149,8 @@ class CSVDatabaseSplit(DatabaseSplit): ...@@ -149,8 +149,8 @@ class CSVDatabaseSplit(DatabaseSplit):
assert ( assert (
directory.is_dir() directory.is_dir()
), f"`{str(directory)}` is not a valid directory" ), f"`{str(directory)}` is not a valid directory"
self.directory = directory self._directory = directory
self.subsets = self._load_split_from_disk() self._subsets = self._load_split_from_disk()
def _load_split_from_disk(self) -> DatabaseSplit: def _load_split_from_disk(self) -> DatabaseSplit:
"""Loads all subsets in a split from its file system representation. """Loads all subsets in a split from its file system representation.
...@@ -168,7 +168,7 @@ class CSVDatabaseSplit(DatabaseSplit): ...@@ -168,7 +168,7 @@ class CSVDatabaseSplit(DatabaseSplit):
""" """
retval: DatabaseSplit = {} retval: DatabaseSplit = {}
for subset in self.directory.iterdir(): for subset in self._directory.iterdir():
if str(subset).endswith(".csv.bz2"): if str(subset).endswith(".csv.bz2"):
logger.debug(f"Loading database split from {subset}...") logger.debug(f"Loading database split from {subset}...")
with __import__("bz2").open(subset) as f: with __import__("bz2").open(subset) as f:
...@@ -188,15 +188,15 @@ class CSVDatabaseSplit(DatabaseSplit): ...@@ -188,15 +188,15 @@ class CSVDatabaseSplit(DatabaseSplit):
def __getitem__(self, key: str) -> typing.Sequence[typing.Any]: def __getitem__(self, key: str) -> typing.Sequence[typing.Any]:
"""Accesses subset ``key`` from this split.""" """Accesses subset ``key`` from this split."""
return self.subsets[key] return self._subsets[key]
def __iter__(self): def __iter__(self):
"""Iterates over the subsets.""" """Iterates over the subsets."""
return iter(self.subsets) return iter(self._subsets)
def __len__(self) -> int: def __len__(self) -> int:
"""How many subsets we currently have.""" """How many subsets we currently have."""
return len(self.subsets) return len(self._subsets)
def check_database_split_loading( def check_database_split_loading(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment