Skip to content
Snippets Groups Projects
Commit c2a48fb5 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[test;data.csv] Add tests and fixes to CSV dataset; Support fileobj as input

parent fb43d827
No related branches found
No related tags found
1 merge request!12Streamlining
......@@ -2,8 +2,10 @@
# coding=utf-8
import os
import csv
import copy
import json
import pathlib
import functools
import logging
......@@ -49,8 +51,9 @@ class JSONDataset:
protocols : list, dict
Paths to one or more JSON formatted files containing the various
protocols to be recognized by this dataset, or a dictionary, mapping
protocol names to paths of JSON files. Internally, we save a
dictionary where keys default to the basename of paths.
protocol names to paths (or opened file objects) of CSV files.
Internally, we save a dictionary where keys default to the basename of
paths (list input).
fieldnames : list, tuple
An iterable over the field names (strings) to assign to each entry in
......@@ -74,12 +77,10 @@ class JSONDataset:
def __init__(self, protocols, fieldnames, loader, keymaker):
if isinstance(protocols, dict):
self.protocols = dict(
(k, os.path.realpath(v)) for k, v in protocols.items()
)
self.protocols = protocols
else:
self.protocols = dict(
(os.path.splitext(os.path.basename(k))[0], os.path.realpath(k))
(os.path.splitext(os.path.basename(k))[0], k)
for k in protocols
)
self.fieldnames = fieldnames
......@@ -155,8 +156,13 @@ class JSONDataset:
"""
with open(self.protocols[protocol], "r") as f:
fileobj = self.protocols[protocol]
if isinstance(fileobj, (str, bytes, pathlib.Path)):
with open(self.protocols[protocol], "r") as f:
data = json.load(f)
else:
data = json.load(f)
fileobj.seek(0)
retval = {}
for subset, samples in data.items():
......@@ -187,10 +193,10 @@ class CSVDataset:
----------
subsets : list, dict
Paths to one or more CSV formatted files containing the various
subsets to be recognized by this dataset, or a dictionary, mapping
subset names to paths of CSV files. Internally, we save a
dictionary where keys default to the basename of paths.
Paths to one or more CSV formatted files containing the various subsets
to be recognized by this dataset, or a dictionary, mapping subset names
to paths (or opened file objects) of CSV files. Internally, we save a
dictionary where keys default to the basename of paths (list input).
fieldnames : list, tuple
An iterable over the field names (strings) to assign to each column in
......@@ -213,12 +219,10 @@ class CSVDataset:
def __init__(self, subsets, fieldnames, loader, keymaker):
if isinstance(subsets, dict):
self.subsets = dict(
(k, os.path.realpath(v)) for k, v in subsets.items()
)
self.subsets = subsets
else:
self.subsets = dict(
(os.path.splitext(os.path.basename(k))[0], os.path.realpath(k))
(os.path.splitext(os.path.basename(k))[0], k)
for k in subsets
)
self.fieldnames = fieldnames
......@@ -257,7 +261,7 @@ class CSVDataset:
f"entries instead of {len(self.fieldnames)} (expected). Fix "
f"file {self.subsets[context['subset']]}"
)
item = dict(zip(self.fieldnames, v))
item = dict(zip(self.fieldnames, sample))
return DelayedSample(
functools.partial(self.loader, context, item),
key=self.keymaker(context, item),
......@@ -291,9 +295,15 @@ class CSVDataset:
"""
with open(self.subsets[subset], newline="") as f:
cf = csv.reader(f)
fileobj = self.subsets[subset]
if isinstance(fileobj, (str, bytes, pathlib.Path)):
with open(self.subsets[subset], newline="") as f:
cf = csv.reader(f)
samples = [k for k in cf]
else:
cf = csv.reader(fileobj)
samples = [k for k in cf]
fileobj.seek(0)
context = dict(subset=subset)
return [self._make_delayed(k, v, context) for (k, v) in enumerate(samples)]
#!/usr/bin/env python
# coding=utf-8
"""Unit tests for the CSV dataset"""
import io
import nose.tools
from ..data.dataset import CSVDataset
from ..data import stare
## special trick for CI builds
from . import mock_dataset, TESTDB_TMPDIR
json_dataset, rc_variable_set = mock_dataset()
## definition of stare subsets for "default" protocol
default = {
"train": io.StringIO(
"""\
stare-images/im0001.ppm,labels-ah/im0001.ah.ppm
stare-images/im0002.ppm,labels-ah/im0002.ah.ppm
stare-images/im0003.ppm,labels-ah/im0003.ah.ppm
stare-images/im0004.ppm,labels-ah/im0004.ah.ppm
stare-images/im0005.ppm,labels-ah/im0005.ah.ppm
stare-images/im0044.ppm,labels-ah/im0044.ah.ppm
stare-images/im0077.ppm,labels-ah/im0077.ah.ppm
stare-images/im0081.ppm,labels-ah/im0081.ah.ppm
stare-images/im0082.ppm,labels-ah/im0082.ah.ppm
stare-images/im0139.ppm,labels-ah/im0139.ah.ppm"""
),
"test": io.StringIO(
"""\
stare-images/im0162.ppm,labels-ah/im0162.ah.ppm
stare-images/im0163.ppm,labels-ah/im0163.ah.ppm
stare-images/im0235.ppm,labels-ah/im0235.ah.ppm
stare-images/im0236.ppm,labels-ah/im0236.ah.ppm
stare-images/im0239.ppm,labels-ah/im0239.ah.ppm
stare-images/im0240.ppm,labels-ah/im0240.ah.ppm
stare-images/im0255.ppm,labels-ah/im0255.ah.ppm
stare-images/im0291.ppm,labels-ah/im0291.ah.ppm
stare-images/im0319.ppm,labels-ah/im0319.ah.ppm
stare-images/im0324.ppm,labels-ah/im0324.ah.ppm"""
),
}
@rc_variable_set("bob.ip.binseg.stare.datadir")
def test_compare_to_json():
if TESTDB_TMPDIR is not None:
stare_dir = TESTDB_TMPDIR.name
else:
import bob.extension
stare_dir = bob.extension.rc.get("bob.ip.binseg.stare.datadir")
test_dataset = CSVDataset(
default,
stare._fieldnames,
stare._make_loader(stare_dir),
stare.data_path_keymaker,
)
for subset in ("train", "test"):
for t1, t2 in zip(
test_dataset.samples(subset),
json_dataset.subsets("default")[subset],
):
nose.tools.eq_(t1.key, t2.key)
nose.tools.eq_(t1.data, t2.data)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment