Skip to content
Snippets Groups Projects
Commit 313d962d authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[data.split] Make splits to be lazy-loadable (closes #27)

parent 1120d9df
No related branches found
No related tags found
No related merge requests found
......@@ -3,6 +3,7 @@
# SPDX-License-Identifier: GPL-3.0-or-later
import csv
import functools
import importlib.abc
import json
import logging
......@@ -26,7 +27,7 @@ class JSONDatabaseSplit(DatabaseSplit):
.. code-block:: json
{
"subset1": [
"dataset1": [
[
"sample1-data1",
"sample1-data2",
......@@ -38,7 +39,7 @@ class JSONDatabaseSplit(DatabaseSplit):
"sample2-data3",
]
],
"subset2": [
"dataset2": [
[
"sample42-data1",
"sample42-data2",
......@@ -47,14 +48,16 @@ class JSONDatabaseSplit(DatabaseSplit):
]
}
Your database split many contain any number of subsets (dictionary keys).
For simplicity, we recommend all sample entries are formatted similarly so
that raw-data-loading is simplified. Use the function
Your database split many contain any number of (raw) datasets (dictionary
keys). For simplicity, we recommend all sample entries are formatted
similarly so that raw-data-loading is simplified. Use the function
:py:func:`check_database_split_loading` to test raw data loading and fine
tune the dataset split, or its loading.
Objects of this class behave like a dictionary in which keys are subset
names in the split, and values represent samples data and meta-data.
Objects of this class behave like a dictionary in which keys are dataset
names in the split, and values represent samples data and meta-data. The
actual JSON file descriptors are loaded on demand using
a py:func:`functools.cached_property`.
Parameters
......@@ -69,21 +72,20 @@ class JSONDatabaseSplit(DatabaseSplit):
if isinstance(path, str):
path = pathlib.Path(path)
self._path = path
self._subsets = self._load_split_from_disk()
def _load_split_from_disk(self) -> DatabaseSplit:
"""Loads all subsets in a split from its file system representation.
@functools.cached_property
def _datasets(self) -> DatabaseSplit:
"""Datasets in a split.
This method will load JSON information for the current split and return
all subsets of the given split after converting each entry through the
loader function.
The first call to this (cached) property will trigger full JSON file
loading from disk. Subsequent calls will be cached.
Returns
-------
subsets : dict
A dictionary mapping subset names to lists of JSON objects
datasets : dict
A dictionary mapping dataset names to lists of JSON objects
"""
if str(self._path).endswith(".bz2"):
......@@ -95,16 +97,16 @@ class JSONDatabaseSplit(DatabaseSplit):
return json.load(f)
def __getitem__(self, key: str) -> typing.Sequence[typing.Any]:
"""Accesses subset ``key`` from this split."""
return self._subsets[key]
"""Accesses dataset ``key`` from this split."""
return self._datasets[key]
def __iter__(self):
"""Iterates over the subsets."""
return iter(self._subsets)
"""Iterates over the datasets."""
return iter(self._datasets)
def __len__(self) -> int:
"""How many subsets we currently have."""
return len(self._subsets)
"""How many datasets we currently have."""
return len(self._datasets)
class CSVDatabaseSplit(DatabaseSplit):
......@@ -112,7 +114,7 @@ class CSVDatabaseSplit(DatabaseSplit):
CSV format.
To create a new database split, you need to provide one or more CSV
formatted files, each representing a subset of this split, containing the
formatted files, each representing a dataset of this split, containing the
sample data (one per row). Example:
Inside the directory ``my-split/``, one can file files ``train.csv``,
......@@ -125,11 +127,11 @@ class CSVDatabaseSplit(DatabaseSplit):
sample2-value1,sample2-value2,sample2-value3
...
Each file in the provided directory defines the subset name on the split.
So, the file ``train.csv`` will contain the data from the ``train`` subset,
Each file in the provided directory defines the dataset name on the split.
So, the file ``train.csv`` will contain the data from the ``train`` dataset,
and so on.
Objects of this class behave like a dictionary in which keys are subset
Objects of this class behave like a dictionary in which keys are dataset
names in the split, and values represent samples data and meta-data.
......@@ -138,7 +140,7 @@ class CSVDatabaseSplit(DatabaseSplit):
directory
Absolute path to a directory containing the database split layed down
as a set of CSV files, one per subset.
as a set of CSV files, one per dataset.
"""
def __init__(
......@@ -150,53 +152,52 @@ class CSVDatabaseSplit(DatabaseSplit):
directory.is_dir()
), f"`{str(directory)}` is not a valid directory"
self._directory = directory
self._subsets = self._load_split_from_disk()
def _load_split_from_disk(self) -> DatabaseSplit:
"""Loads all subsets in a split from its file system representation.
@functools.cached_property
def _datasets(self) -> DatabaseSplit:
"""Datasets in a split.
This method will load CSV information for the current split and return all
subsets of the given split after converting each entry through the
loader function.
The first call to this (cached) property will trigger all CSV file
loading from disk. Subsequent calls will be cached.
Returns
-------
subsets : dict
A dictionary mapping subset names to lists of JSON objects
datasets : dict
A dictionary mapping dataset names to lists of JSON objects
"""
retval: DatabaseSplit = {}
for subset in self._directory.iterdir():
if str(subset).endswith(".csv.bz2"):
logger.debug(f"Loading database split from {subset}...")
with __import__("bz2").open(subset) as f:
retval: dict[str, typing.Sequence[typing.Any]] = {}
for dataset in self._directory.iterdir():
if str(dataset).endswith(".csv.bz2"):
logger.debug(f"Loading database split from {dataset}...")
with __import__("bz2").open(dataset) as f:
reader = csv.reader(f)
retval[subset.name[: -len(".csv.bz2")]] = [
retval[dataset.name[: -len(".csv.bz2")]] = [
k for k in reader
]
elif str(subset).endswith(".csv"):
with subset.open() as f:
elif str(dataset).endswith(".csv"):
with dataset.open() as f:
reader = csv.reader(f)
retval[subset.name[: -len(".csv")]] = [k for k in reader]
retval[dataset.name[: -len(".csv")]] = [k for k in reader]
else:
logger.debug(
f"Ignoring file {subset} in CSVDatabaseSplit readout"
f"Ignoring file {dataset} in CSVDatabaseSplit readout"
)
return retval
def __getitem__(self, key: str) -> typing.Sequence[typing.Any]:
"""Accesses subset ``key`` from this split."""
return self._subsets[key]
"""Accesses dataset ``key`` from this split."""
return self._datasets[key]
def __iter__(self):
"""Iterates over the subsets."""
return iter(self._subsets)
"""Iterates over the datasets."""
return iter(self._datasets)
def __len__(self) -> int:
"""How many subsets we currently have."""
return len(self._subsets)
"""How many datasets we currently have."""
return len(self._datasets)
def check_database_split_loading(
......@@ -204,7 +205,7 @@ def check_database_split_loading(
loader: RawDataLoader,
limit: int = 0,
) -> int:
"""For each subset in the split, check if all data can be correctly loaded
"""For each dataset in the split, check if all data can be correctly loaded
using the provided loader function.
This function will return the number of errors loading samples, and will
......@@ -216,14 +217,14 @@ def check_database_split_loading(
database_split
A mapping that, contains the database split. Each key represents the
name of a subset in the split. Each value is a (potentially complex)
name of a dataset in the split. Each value is a (potentially complex)
object that represents a single sample.
loader
A loader object that knows how to handle full-samples or just labels.
limit
Maximum number of samples to check (in each split/subset
Maximum number of samples to check (in each split/dataset
combination) in this dataset. If set to zero, then check
everything.
......@@ -235,10 +236,10 @@ def check_database_split_loading(
Number of errors found
"""
logger.info(
"Checking if can load all samples in all subsets of this split..."
"Checking if can load all samples in all datasets of this split..."
)
errors = 0
for subset, samples in database_split.items():
for dataset, samples in database_split.items():
samples = samples if not limit else samples[:limit]
for pos, sample in enumerate(samples):
try:
......@@ -246,7 +247,7 @@ def check_database_split_loading(
assert isinstance(data, torch.Tensor)
except Exception as e:
logger.info(
f"Found error loading entry {pos} in subset `{subset}`: {e}"
f"Found error loading entry {pos} in dataset `{dataset}`: {e}"
)
errors += 1
return errors
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment