diff --git a/bob/ip/binseg/data/dataset.py b/bob/ip/binseg/data/dataset.py index 687a7d0bab474f774ffe2a6bfca6d703f0acf7da..7ad0e9f2dfd29e4d4f1975d7c03cdab99cdf98b9 100644 --- a/bob/ip/binseg/data/dataset.py +++ b/bob/ip/binseg/data/dataset.py @@ -2,8 +2,10 @@ # coding=utf-8 import os +import csv import copy import json +import pathlib import functools import logging @@ -49,8 +51,9 @@ class JSONDataset: protocols : list, dict Paths to one or more JSON formatted files containing the various protocols to be recognized by this dataset, or a dictionary, mapping - protocol names to paths of JSON files. Internally, we save a - dictionary where keys default to the basename of paths. + protocol names to paths (or opened file objects) of CSV files. + Internally, we save a dictionary where keys default to the basename of + paths (list input). fieldnames : list, tuple An iterable over the field names (strings) to assign to each entry in @@ -74,12 +77,10 @@ class JSONDataset: def __init__(self, protocols, fieldnames, loader, keymaker): if isinstance(protocols, dict): - self.protocols = dict( - (k, os.path.realpath(v)) for k, v in protocols.items() - ) + self.protocols = protocols else: self.protocols = dict( - (os.path.splitext(os.path.basename(k))[0], os.path.realpath(k)) + (os.path.splitext(os.path.basename(k))[0], k) for k in protocols ) self.fieldnames = fieldnames @@ -155,8 +156,13 @@ class JSONDataset: """ - with open(self.protocols[protocol], "r") as f: + fileobj = self.protocols[protocol] + if isinstance(fileobj, (str, bytes, pathlib.Path)): + with open(self.protocols[protocol], "r") as f: + data = json.load(f) + else: data = json.load(f) + fileobj.seek(0) retval = {} for subset, samples in data.items(): @@ -187,10 +193,10 @@ class CSVDataset: ---------- subsets : list, dict - Paths to one or more CSV formatted files containing the various - subsets to be recognized by this dataset, or a dictionary, mapping - subset names to paths of CSV files. Internally, we save a - dictionary where keys default to the basename of paths. + Paths to one or more CSV formatted files containing the various subsets + to be recognized by this dataset, or a dictionary, mapping subset names + to paths (or opened file objects) of CSV files. Internally, we save a + dictionary where keys default to the basename of paths (list input). fieldnames : list, tuple An iterable over the field names (strings) to assign to each column in @@ -213,12 +219,10 @@ class CSVDataset: def __init__(self, subsets, fieldnames, loader, keymaker): if isinstance(subsets, dict): - self.subsets = dict( - (k, os.path.realpath(v)) for k, v in subsets.items() - ) + self.subsets = subsets else: self.subsets = dict( - (os.path.splitext(os.path.basename(k))[0], os.path.realpath(k)) + (os.path.splitext(os.path.basename(k))[0], k) for k in subsets ) self.fieldnames = fieldnames @@ -257,7 +261,7 @@ class CSVDataset: f"entries instead of {len(self.fieldnames)} (expected). Fix " f"file {self.subsets[context['subset']]}" ) - item = dict(zip(self.fieldnames, v)) + item = dict(zip(self.fieldnames, sample)) return DelayedSample( functools.partial(self.loader, context, item), key=self.keymaker(context, item), @@ -291,9 +295,15 @@ class CSVDataset: """ - with open(self.subsets[subset], newline="") as f: - cf = csv.reader(f) + fileobj = self.subsets[subset] + if isinstance(fileobj, (str, bytes, pathlib.Path)): + with open(self.subsets[subset], newline="") as f: + cf = csv.reader(f) + samples = [k for k in cf] + else: + cf = csv.reader(fileobj) samples = [k for k in cf] + fileobj.seek(0) context = dict(subset=subset) return [self._make_delayed(k, v, context) for (k, v) in enumerate(samples)] diff --git a/bob/ip/binseg/test/test_csv.py b/bob/ip/binseg/test/test_csv.py new file mode 100644 index 0000000000000000000000000000000000000000..39fc3c80f423ab5d171e9852d7ef30c009e8deec --- /dev/null +++ b/bob/ip/binseg/test/test_csv.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python +# coding=utf-8 + +"""Unit tests for the CSV dataset""" + +import io + +import nose.tools + +from ..data.dataset import CSVDataset +from ..data import stare + +## special trick for CI builds +from . import mock_dataset, TESTDB_TMPDIR + +json_dataset, rc_variable_set = mock_dataset() + + +## definition of stare subsets for "default" protocol +default = { + "train": io.StringIO( + """\ +stare-images/im0001.ppm,labels-ah/im0001.ah.ppm +stare-images/im0002.ppm,labels-ah/im0002.ah.ppm +stare-images/im0003.ppm,labels-ah/im0003.ah.ppm +stare-images/im0004.ppm,labels-ah/im0004.ah.ppm +stare-images/im0005.ppm,labels-ah/im0005.ah.ppm +stare-images/im0044.ppm,labels-ah/im0044.ah.ppm +stare-images/im0077.ppm,labels-ah/im0077.ah.ppm +stare-images/im0081.ppm,labels-ah/im0081.ah.ppm +stare-images/im0082.ppm,labels-ah/im0082.ah.ppm +stare-images/im0139.ppm,labels-ah/im0139.ah.ppm""" + ), + "test": io.StringIO( + """\ +stare-images/im0162.ppm,labels-ah/im0162.ah.ppm +stare-images/im0163.ppm,labels-ah/im0163.ah.ppm +stare-images/im0235.ppm,labels-ah/im0235.ah.ppm +stare-images/im0236.ppm,labels-ah/im0236.ah.ppm +stare-images/im0239.ppm,labels-ah/im0239.ah.ppm +stare-images/im0240.ppm,labels-ah/im0240.ah.ppm +stare-images/im0255.ppm,labels-ah/im0255.ah.ppm +stare-images/im0291.ppm,labels-ah/im0291.ah.ppm +stare-images/im0319.ppm,labels-ah/im0319.ah.ppm +stare-images/im0324.ppm,labels-ah/im0324.ah.ppm""" + ), +} + + +@rc_variable_set("bob.ip.binseg.stare.datadir") +def test_compare_to_json(): + + if TESTDB_TMPDIR is not None: + stare_dir = TESTDB_TMPDIR.name + else: + import bob.extension + + stare_dir = bob.extension.rc.get("bob.ip.binseg.stare.datadir") + + test_dataset = CSVDataset( + default, + stare._fieldnames, + stare._make_loader(stare_dir), + stare.data_path_keymaker, + ) + + for subset in ("train", "test"): + for t1, t2 in zip( + test_dataset.samples(subset), + json_dataset.subsets("default")[subset], + ): + nose.tools.eq_(t1.key, t2.key) + nose.tools.eq_(t1.data, t2.data)