Skip to content
Snippets Groups Projects
Commit e6db9362 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[data.split] Make splits to be lazy-loadable (closes #27)

parent eea2f306
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
import csv import csv
import functools
import importlib.abc import importlib.abc
import json import json
import logging import logging
...@@ -26,7 +27,7 @@ class JSONDatabaseSplit(DatabaseSplit): ...@@ -26,7 +27,7 @@ class JSONDatabaseSplit(DatabaseSplit):
.. code-block:: json .. code-block:: json
{ {
"subset1": [ "dataset1": [
[ [
"sample1-data1", "sample1-data1",
"sample1-data2", "sample1-data2",
...@@ -38,7 +39,7 @@ class JSONDatabaseSplit(DatabaseSplit): ...@@ -38,7 +39,7 @@ class JSONDatabaseSplit(DatabaseSplit):
"sample2-data3", "sample2-data3",
] ]
], ],
"subset2": [ "dataset2": [
[ [
"sample42-data1", "sample42-data1",
"sample42-data2", "sample42-data2",
...@@ -47,14 +48,16 @@ class JSONDatabaseSplit(DatabaseSplit): ...@@ -47,14 +48,16 @@ class JSONDatabaseSplit(DatabaseSplit):
] ]
} }
Your database split many contain any number of subsets (dictionary keys). Your database split many contain any number of (raw) datasets (dictionary
For simplicity, we recommend all sample entries are formatted similarly so keys). For simplicity, we recommend all sample entries are formatted
that raw-data-loading is simplified. Use the function similarly so that raw-data-loading is simplified. Use the function
:py:func:`check_database_split_loading` to test raw data loading and fine :py:func:`check_database_split_loading` to test raw data loading and fine
tune the dataset split, or its loading. tune the dataset split, or its loading.
Objects of this class behave like a dictionary in which keys are subset Objects of this class behave like a dictionary in which keys are dataset
names in the split, and values represent samples data and meta-data. names in the split, and values represent samples data and meta-data. The
actual JSON file descriptors are loaded on demand using
a py:func:`functools.cached_property`.
Parameters Parameters
...@@ -69,21 +72,20 @@ class JSONDatabaseSplit(DatabaseSplit): ...@@ -69,21 +72,20 @@ class JSONDatabaseSplit(DatabaseSplit):
if isinstance(path, str): if isinstance(path, str):
path = pathlib.Path(path) path = pathlib.Path(path)
self._path = path self._path = path
self._subsets = self._load_split_from_disk()
def _load_split_from_disk(self) -> DatabaseSplit: @functools.cached_property
"""Loads all subsets in a split from its file system representation. def _datasets(self) -> DatabaseSplit:
"""Datasets in a split.
This method will load JSON information for the current split and return The first call to this (cached) property will trigger full JSON file
all subsets of the given split after converting each entry through the loading from disk. Subsequent calls will be cached.
loader function.
Returns Returns
------- -------
subsets : dict datasets : dict
A dictionary mapping subset names to lists of JSON objects A dictionary mapping dataset names to lists of JSON objects
""" """
if str(self._path).endswith(".bz2"): if str(self._path).endswith(".bz2"):
...@@ -95,16 +97,16 @@ class JSONDatabaseSplit(DatabaseSplit): ...@@ -95,16 +97,16 @@ class JSONDatabaseSplit(DatabaseSplit):
return json.load(f) return json.load(f)
def __getitem__(self, key: str) -> typing.Sequence[typing.Any]: def __getitem__(self, key: str) -> typing.Sequence[typing.Any]:
"""Accesses subset ``key`` from this split.""" """Accesses dataset ``key`` from this split."""
return self._subsets[key] return self._datasets[key]
def __iter__(self): def __iter__(self):
"""Iterates over the subsets.""" """Iterates over the datasets."""
return iter(self._subsets) return iter(self._datasets)
def __len__(self) -> int: def __len__(self) -> int:
"""How many subsets we currently have.""" """How many datasets we currently have."""
return len(self._subsets) return len(self._datasets)
class CSVDatabaseSplit(DatabaseSplit): class CSVDatabaseSplit(DatabaseSplit):
...@@ -112,7 +114,7 @@ class CSVDatabaseSplit(DatabaseSplit): ...@@ -112,7 +114,7 @@ class CSVDatabaseSplit(DatabaseSplit):
CSV format. CSV format.
To create a new database split, you need to provide one or more CSV To create a new database split, you need to provide one or more CSV
formatted files, each representing a subset of this split, containing the formatted files, each representing a dataset of this split, containing the
sample data (one per row). Example: sample data (one per row). Example:
Inside the directory ``my-split/``, one can file files ``train.csv``, Inside the directory ``my-split/``, one can file files ``train.csv``,
...@@ -125,11 +127,11 @@ class CSVDatabaseSplit(DatabaseSplit): ...@@ -125,11 +127,11 @@ class CSVDatabaseSplit(DatabaseSplit):
sample2-value1,sample2-value2,sample2-value3 sample2-value1,sample2-value2,sample2-value3
... ...
Each file in the provided directory defines the subset name on the split. Each file in the provided directory defines the dataset name on the split.
So, the file ``train.csv`` will contain the data from the ``train`` subset, So, the file ``train.csv`` will contain the data from the ``train`` dataset,
and so on. and so on.
Objects of this class behave like a dictionary in which keys are subset Objects of this class behave like a dictionary in which keys are dataset
names in the split, and values represent samples data and meta-data. names in the split, and values represent samples data and meta-data.
...@@ -138,7 +140,7 @@ class CSVDatabaseSplit(DatabaseSplit): ...@@ -138,7 +140,7 @@ class CSVDatabaseSplit(DatabaseSplit):
directory directory
Absolute path to a directory containing the database split layed down Absolute path to a directory containing the database split layed down
as a set of CSV files, one per subset. as a set of CSV files, one per dataset.
""" """
def __init__( def __init__(
...@@ -150,53 +152,52 @@ class CSVDatabaseSplit(DatabaseSplit): ...@@ -150,53 +152,52 @@ class CSVDatabaseSplit(DatabaseSplit):
directory.is_dir() directory.is_dir()
), f"`{str(directory)}` is not a valid directory" ), f"`{str(directory)}` is not a valid directory"
self._directory = directory self._directory = directory
self._subsets = self._load_split_from_disk()
def _load_split_from_disk(self) -> DatabaseSplit: @functools.cached_property
"""Loads all subsets in a split from its file system representation. def _datasets(self) -> DatabaseSplit:
"""Datasets in a split.
This method will load CSV information for the current split and return all The first call to this (cached) property will trigger all CSV file
subsets of the given split after converting each entry through the loading from disk. Subsequent calls will be cached.
loader function.
Returns Returns
------- -------
subsets : dict datasets : dict
A dictionary mapping subset names to lists of JSON objects A dictionary mapping dataset names to lists of JSON objects
""" """
retval: DatabaseSplit = {} retval: dict[str, typing.Sequence[typing.Any]] = {}
for subset in self._directory.iterdir(): for dataset in self._directory.iterdir():
if str(subset).endswith(".csv.bz2"): if str(dataset).endswith(".csv.bz2"):
logger.debug(f"Loading database split from {subset}...") logger.debug(f"Loading database split from {dataset}...")
with __import__("bz2").open(subset) as f: with __import__("bz2").open(dataset) as f:
reader = csv.reader(f) reader = csv.reader(f)
retval[subset.name[: -len(".csv.bz2")]] = [ retval[dataset.name[: -len(".csv.bz2")]] = [
k for k in reader k for k in reader
] ]
elif str(subset).endswith(".csv"): elif str(dataset).endswith(".csv"):
with subset.open() as f: with dataset.open() as f:
reader = csv.reader(f) reader = csv.reader(f)
retval[subset.name[: -len(".csv")]] = [k for k in reader] retval[dataset.name[: -len(".csv")]] = [k for k in reader]
else: else:
logger.debug( logger.debug(
f"Ignoring file {subset} in CSVDatabaseSplit readout" f"Ignoring file {dataset} in CSVDatabaseSplit readout"
) )
return retval return retval
def __getitem__(self, key: str) -> typing.Sequence[typing.Any]: def __getitem__(self, key: str) -> typing.Sequence[typing.Any]:
"""Accesses subset ``key`` from this split.""" """Accesses dataset ``key`` from this split."""
return self._subsets[key] return self._datasets[key]
def __iter__(self): def __iter__(self):
"""Iterates over the subsets.""" """Iterates over the datasets."""
return iter(self._subsets) return iter(self._datasets)
def __len__(self) -> int: def __len__(self) -> int:
"""How many subsets we currently have.""" """How many datasets we currently have."""
return len(self._subsets) return len(self._datasets)
def check_database_split_loading( def check_database_split_loading(
...@@ -204,7 +205,7 @@ def check_database_split_loading( ...@@ -204,7 +205,7 @@ def check_database_split_loading(
loader: RawDataLoader, loader: RawDataLoader,
limit: int = 0, limit: int = 0,
) -> int: ) -> int:
"""For each subset in the split, check if all data can be correctly loaded """For each dataset in the split, check if all data can be correctly loaded
using the provided loader function. using the provided loader function.
This function will return the number of errors loading samples, and will This function will return the number of errors loading samples, and will
...@@ -216,14 +217,14 @@ def check_database_split_loading( ...@@ -216,14 +217,14 @@ def check_database_split_loading(
database_split database_split
A mapping that, contains the database split. Each key represents the A mapping that, contains the database split. Each key represents the
name of a subset in the split. Each value is a (potentially complex) name of a dataset in the split. Each value is a (potentially complex)
object that represents a single sample. object that represents a single sample.
loader loader
A loader object that knows how to handle full-samples or just labels. A loader object that knows how to handle full-samples or just labels.
limit limit
Maximum number of samples to check (in each split/subset Maximum number of samples to check (in each split/dataset
combination) in this dataset. If set to zero, then check combination) in this dataset. If set to zero, then check
everything. everything.
...@@ -235,10 +236,10 @@ def check_database_split_loading( ...@@ -235,10 +236,10 @@ def check_database_split_loading(
Number of errors found Number of errors found
""" """
logger.info( logger.info(
"Checking if can load all samples in all subsets of this split..." "Checking if can load all samples in all datasets of this split..."
) )
errors = 0 errors = 0
for subset, samples in database_split.items(): for dataset, samples in database_split.items():
samples = samples if not limit else samples[:limit] samples = samples if not limit else samples[:limit]
for pos, sample in enumerate(samples): for pos, sample in enumerate(samples):
try: try:
...@@ -246,7 +247,7 @@ def check_database_split_loading( ...@@ -246,7 +247,7 @@ def check_database_split_loading(
assert isinstance(data, torch.Tensor) assert isinstance(data, torch.Tensor)
except Exception as e: except Exception as e:
logger.info( logger.info(
f"Found error loading entry {pos} in subset `{subset}`: {e}" f"Found error loading entry {pos} in dataset `{dataset}`: {e}"
) )
errors += 1 errors += 1
return errors return errors
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment