Commit 33068f84 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

[py] Added sparce scoring

parent f99174fd
Pipeline #46236 passed with stage
in 10 minutes and 28 seconds
......@@ -3,7 +3,7 @@ from .csv_dataset import (
CSVToSampleLoader,
CSVDatasetCrossValidation,
CSVBaseSampleLoader,
IdiapAnnotationsLoader,
AnnotationsLoader,
LSTToSampleLoader,
)
from .file import BioFile
......
......@@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
#####
# ANNOTATIONS LOADERS
####
class IdiapAnnotationsLoader:
class AnnotationsLoader:
"""
Load annotations in the Idiap format
"""
......@@ -30,7 +30,7 @@ class IdiapAnnotationsLoader:
def __init__(
self,
annotation_directory=None,
annotation_extension=".pos",
annotation_extension=".json",
annotation_type="eyecenter",
):
self.annotation_directory = annotation_directory
......@@ -120,15 +120,21 @@ class CSVBaseSampleLoader(metaclass=ABCMeta):
sample_sets = dict()
for s in samples:
if s.reference_id not in sample_sets:
sample_sets[s.reference_id] = SampleSet(
[s], parent=s, references=references
sample_sets[s.reference_id] = (
SampleSet([s], parent=s)
if references is None
else SampleSet([s], parent=s, references=references)
)
else:
sample_sets[s.reference_id].append(s)
return list(sample_sets.values())
else:
return [SampleSet([s], parent=s, references=references) for s in samples]
return (
[SampleSet([s], parent=s) for s in samples]
if references is None
else [SampleSet([s], parent=s, references=references) for s in samples]
)
class CSVToSampleLoader(CSVBaseSampleLoader):
......@@ -163,7 +169,7 @@ class CSVToSampleLoader(CSVBaseSampleLoader):
path = row[0]
reference_id = row[1]
kwargs = dict([[h, r] for h, r in zip(header[2:], row[2:])])
kwargs = dict([[str(h).lower(), r] for h, r in zip(header[2:], row[2:])])
if self.metadata_loader is not None:
metadata = self.metadata_loader(row)
......@@ -190,16 +196,28 @@ class LSTToSampleLoader(CSVBaseSampleLoader):
with open(filename) as cf:
reader = csv.reader(cf, delimiter=" ")
return [self.convert_row_to_sample(row) for row in reader]
samples = []
for row in reader:
if row[0][0] == "#":
continue
samples.append(self.convert_row_to_sample(row))
return samples
def convert_row_to_sample(self, row, header=None):
path = row[0]
reference_id = str(row[1])
kwargs = dict()
if len(row) == 3:
subject = row[2]
kwargs = {"subject": str(subject)}
if len(row) == 4:
path = row[0]
compare_reference_id = row[1]
reference_id = str(row[3])
kwargs = {"compare_reference_id": str(compare_reference_id)}
else:
path = row[0]
reference_id = str(row[1])
kwargs = dict()
if len(row) == 3:
subject = row[2]
kwargs = {"subject": str(subject)}
if self.metadata_loader is not None:
metadata = self.metadata_loader(row)
......@@ -232,12 +250,11 @@ class CSVDatasetDevEval(Database):
.. code-block:: text
my_dataset/
my_dataset/my_protocol/
my_dataset/my_protocol/train.csv
my_dataset/my_protocol/train.csv/dev_enroll.csv
my_dataset/my_protocol/train.csv/dev_probe.csv
my_dataset/my_protocol/train.csv/eval_enroll.csv
my_dataset/my_protocol/train.csv/eval_probe.csv
my_dataset/my_protocol/norm/train_world.csv
my_dataset/my_protocol/dev/for_models.csv
my_dataset/my_protocol/dev/for_probes.csv
my_dataset/my_protocol/eval/for_models.csv
my_dataset/my_protocol/eval/for_probes.csv
...
......@@ -245,8 +262,8 @@ class CSVDatasetDevEval(Database):
evaluation protocols this dataset might have.
Inside of the `my_protocol` directory should contain at least two csv files:
- dev_enroll.csv
- dev_probe.csv
- for_models.csv
- for_probes.csv
Those csv files should contain in each row i-) the path to raw data and ii-) the reference_id label
......@@ -306,8 +323,10 @@ class CSVDatasetDevEval(Database):
dataset_original_directory="",
extension="",
),
is_sparse=False,
):
self.dataset_protocol_path = dataset_protocol_path
self.is_sparse = is_sparse
def get_paths():
......@@ -333,8 +352,9 @@ class CSVDatasetDevEval(Database):
os.path.join(protocol_path, "dev", "for_models.csv"),
)
legacy_probe = "for_scores.lst" if self.is_sparse else "for_probes.lst"
dev_probe_csv = path_discovery(
os.path.join(protocol_path, "dev", "for_probes.lst"),
os.path.join(protocol_path, "dev", legacy_probe),
os.path.join(protocol_path, "dev", "for_probes.csv"),
)
......@@ -344,7 +364,7 @@ class CSVDatasetDevEval(Database):
)
eval_probe_csv = path_discovery(
os.path.join(protocol_path, "eval", "for_probes.lst"),
os.path.join(protocol_path, "eval", legacy_probe),
os.path.join(protocol_path, "eval", "for_probes.csv"),
)
......@@ -421,16 +441,35 @@ class CSVDatasetDevEval(Database):
if self.cache[cache_label] is not None:
return self.cache[cache_label]
# Getting samples from CSV
samples = self.csv_to_sample_loader(self.__dict__[cache_label])
references = None
if fetching_probes:
references = list(
set([s.reference_id for s in self.references(group=group)])
)
if fetching_probes and self.is_sparse:
samples = self.csv_to_sample_loader(self.__dict__[cache_label])
# Checking if `is_sparse` was set properly
if len(samples) > 0 and not hasattr(samples[0], "compare_reference_id"):
ValueError(
f"Attribute `compare_reference_id` not found in `{samples[0]}`."
"Make sure this attribute exists in your dataset if `is_sparse=True`"
)
sparse_samples = dict()
for s in samples:
if s.key in sparse_samples:
sparse_samples[s.key].references.append(s.compare_reference_id)
else:
s.references = [s.compare_reference_id]
sparse_samples[s.key] = s
samples = sparse_samples.values()
else:
if fetching_probes:
references = list(
set([s.reference_id for s in self.references(group=group)])
)
sample_sets = self.csv_to_sample_loader.convert_samples_to_samplesets(
samples, group_by_reference_id=group_by_reference_id, references=references
samples, group_by_reference_id=group_by_reference_id, references=references,
)
self.cache[cache_label] = sample_sets
......
PATH,REFERENCE_ID
data/model3_session1_sample1,3
data/model3_session1_sample2,3
data/model3_session1_sample3,3
data/model3_session2_sample1,3
data/model4_session1_sample1,4
data/model4_session1_sample2,4
data/model4_session1_sample3,4
data/model4_session2_sample1,4
\ No newline at end of file
PATH,REFERENCE_ID,COMPARE_REFERENCE_ID
data/model3_session3_sample1,3,3
data/model3_session3_sample2,3,3
data/model3_session3_sample3,3,3
data/model3_session4_sample1,3,3
data/model4_session3_sample1,4,3
data/model4_session3_sample2,4,3
data/model4_session3_sample3,4,3
data/model4_session4_sample1,4,3
\ No newline at end of file
PATH,REFERENCE_ID
data/model3_session1_sample1,5
data/model3_session1_sample2,5
data/model3_session1_sample3,6
data/model3_session2_sample1,6
data/model4_session1_sample1,7
data/model4_session1_sample2,7
data/model4_session1_sample3,7
data/model4_session2_sample1,8
data/model4_session2_sample1,9
data/model4_session2_sample1,10
\ No newline at end of file
PATH,REFERENCE_ID,COMPARE_REFERENCE_ID
data/model3_session1_sample1,5,5
data/model3_session1_sample2,5,5
data/model3_session1_sample3,6,5
data/model3_session2_sample1,6,5
data/model4_session1_sample1,7,5
data/model4_session1_sample2,7,5
data/model4_session1_sample3,7,5
data/model4_session2_sample1,8,5
data/model4_session2_sample2,9,5
data/model4_session2_sample3,10,5
data/model4_session2_sample4,12,5
data/model4_session2_sample5,13,5
data/model4_session2_sample6,10,5
data/model3_session1_sample1,5,6
\ No newline at end of file
PATH,REFERENCE_ID
data/model11_session1_sample1,1
data/model11_session1_sample2,1
data/model11_session1_sample3,1
data/model11_session2_sample1,1
data/model12_session1_sample1,2
data/model12_session1_sample2,2
data/model12_session1_sample3,2
data/model12_session2_sample1,2
\ No newline at end of file
......@@ -5,4 +5,4 @@ data/model5_session4_sample1 5 5 5
data/model6_session3_sample1 6 6 6
data/model6_session3_sample2 6 6 6
data/model6_session3_sample3 6 6 6
data/model6_session4_sample1 6 6 6
data/model6_session4_sample1 6 6 6
\ No newline at end of file
......@@ -11,7 +11,7 @@ from bob.bio.base.database import (
CSVDatasetDevEval,
CSVToSampleLoader,
CSVDatasetCrossValidation,
IdiapAnnotationsLoader,
AnnotationsLoader,
LSTToSampleLoader,
)
import nose.tools
......@@ -32,6 +32,11 @@ legacy_example_dir = os.path.realpath(
bob.io.base.test_utils.datafile(".", __name__, "data/example_filelist")
)
legacy2_example_dir = os.path.realpath(
bob.io.base.test_utils.datafile(".", __name__, "data/example_filelist2")
)
example_dir = os.path.realpath(
bob.io.base.test_utils.datafile(".", __name__, "data/example_csv_filelist")
)
......@@ -71,23 +76,24 @@ def test_csv_file_list_dev_only_metadata():
dataset = CSVDatasetDevEval(example_dir, "protocol_only_dev_metadata")
assert len(dataset.background_model_samples()) == 8
assert check_all_true(dataset.background_model_samples(), DelayedSample)
assert np.alltrue(
["METADATA_1" in s.__dict__ for s in dataset.background_model_samples()]
["metadata_1" in s.__dict__ for s in dataset.background_model_samples()]
)
assert np.alltrue(
["METADATA_2" in s.__dict__ for s in dataset.background_model_samples()]
["metadata_2" in s.__dict__ for s in dataset.background_model_samples()]
)
assert len(dataset.references()) == 2
assert check_all_true(dataset.references(), SampleSet)
assert np.alltrue(["METADATA_1" in s.__dict__ for s in dataset.references()])
assert np.alltrue(["METADATA_2" in s.__dict__ for s in dataset.references()])
assert np.alltrue(["metadata_1" in s.__dict__ for s in dataset.references()])
assert np.alltrue(["metadata_2" in s.__dict__ for s in dataset.references()])
assert len(dataset.probes()) == 10
assert check_all_true(dataset.probes(), SampleSet)
assert np.alltrue(["METADATA_1" in s.__dict__ for s in dataset.probes()])
assert np.alltrue(["METADATA_2" in s.__dict__ for s in dataset.probes()])
assert np.alltrue(["metadata_1" in s.__dict__ for s in dataset.probes()])
assert np.alltrue(["metadata_2" in s.__dict__ for s in dataset.probes()])
assert np.alltrue(["references" in s.__dict__ for s in dataset.probes()])
......@@ -104,8 +110,8 @@ def test_csv_file_list_dev_eval():
"protocol_dev_eval",
csv_to_sample_loader=CSVToSampleLoader(
data_loader=bob.io.base.load,
metadata_loader=IdiapAnnotationsLoader(
annotation_directory=annotation_directory
metadata_loader=AnnotationsLoader(
annotation_directory=annotation_directory, annotation_extension=".pos"
),
dataset_original_directory="",
extension="",
......@@ -139,6 +145,68 @@ def test_csv_file_list_dev_eval():
assert len(dataset.groups()) == 3
def test_csv_file_list_dev_eval_sparse():
annotation_directory = os.path.realpath(
bob.io.base.test_utils.datafile(
".", __name__, "data/example_csv_filelist/annotations"
)
)
dataset = CSVDatasetDevEval(
example_dir,
"protocol_dev_eval_sparse",
csv_to_sample_loader=CSVToSampleLoader(
data_loader=bob.io.base.load,
metadata_loader=AnnotationsLoader(
annotation_directory=annotation_directory, annotation_extension=".pos"
),
dataset_original_directory="",
extension="",
),
is_sparse=True,
)
assert len(dataset.background_model_samples()) == 8
assert check_all_true(dataset.background_model_samples(), DelayedSample)
assert len(dataset.references()) == 2
assert check_all_true(dataset.references(), SampleSet)
probes = dataset.probes()
assert len(probes) == 8
# here, 1 comparisons comparison per probe
for p in probes:
assert len(p.references) == 1
assert check_all_true(dataset.references(), SampleSet)
assert len(dataset.references(group="eval")) == 6
assert check_all_true(dataset.references(group="eval"), SampleSet)
probes = dataset.probes(group="eval")
assert len(probes) == 13
assert check_all_true(probes, SampleSet)
# Here, 1 comparison per probe, EXPECT THE FIRST ONE
for i, p in enumerate(probes):
if i == 0:
assert len(p.references) == 2
else:
assert len(p.references) == 1
assert len(dataset.all_samples(groups=None)) == 48
assert check_all_true(dataset.all_samples(groups=None), DelayedSample)
# Check the annotations
for s in dataset.all_samples(groups=None):
assert isinstance(s.annotations, dict)
assert len(dataset.reference_ids(group="dev")) == 2
assert len(dataset.reference_ids(group="eval")) == 6
assert len(dataset.groups()) == 3
def test_lst_file_list_dev_eval():
dataset = CSVDatasetDevEval(
......@@ -174,6 +242,60 @@ def test_lst_file_list_dev_eval():
assert len(dataset.groups()) == 3
def test_lst_file_list_dev_eval_sparse():
dataset = CSVDatasetDevEval(
legacy_example_dir,
"",
csv_to_sample_loader=LSTToSampleLoader(
data_loader=bob.io.base.load, dataset_original_directory="", extension="",
),
is_sparse=True,
)
assert len(dataset.background_model_samples()) == 8
assert check_all_true(dataset.background_model_samples(), DelayedSample)
assert len(dataset.references()) == 2
assert check_all_true(dataset.references(), SampleSet)
assert len(dataset.probes()) == 8
assert check_all_true(dataset.references(), SampleSet)
assert len(dataset.references(group="eval")) == 2
assert check_all_true(dataset.references(group="eval"), SampleSet)
assert len(dataset.probes(group="eval")) == 8
assert check_all_true(dataset.probes(group="eval"), SampleSet)
assert len(dataset.all_samples(groups=None)) == 44
assert check_all_true(dataset.all_samples(groups=None), DelayedSample)
assert len(dataset.reference_ids(group="dev")) == 2
assert len(dataset.reference_ids(group="eval")) == 2
assert len(dataset.groups()) == 3
def test_lst_file_list_dev_sparse_filelist2():
dataset = CSVDatasetDevEval(
legacy2_example_dir,
"",
csv_to_sample_loader=LSTToSampleLoader(
data_loader=bob.io.base.load, dataset_original_directory="", extension="",
),
is_sparse=True,
)
assert len(dataset.references()) == 3
assert check_all_true(dataset.references(), SampleSet)
assert len(dataset.probes()) == 9
assert check_all_true(dataset.references(), SampleSet)
def test_csv_file_list_atnt():
dataset = CSVDatasetDevEval(atnt_protocol_path, "idiap_protocol")
......
......@@ -3,6 +3,7 @@
# Laurent El Shafey <Laurent.El-Shafey@idiap.ch>
# Roy Wallace <roy.wallace@idiap.ch>
from .resources import *
from .io import *
import six
......
......@@ -5,61 +5,65 @@ import collections # this is needed for the sphinx documentation
import functools # this is needed for the sphinx documentation
import numpy
import logging
logger = logging.getLogger("bob.bio.base")
from .. import database
from bob.bio.base import database
import bob.io.base
def filter_missing_files(file_names, split_by_client=False, allow_missing_files=True):
"""This function filters out files that do not exist, but only if ``allow_missing_files`` is set to ``True``, otherwise the list of ``file_names`` is returned unaltered."""
if not allow_missing_files:
return file_names
if split_by_client:
# filter out missing files and empty clients
existing_files = [
[f for f in client_files if os.path.exists(f)] for client_files in file_names]
existing_files = [
client_files for client_files in existing_files if client_files]
else:
# filter out missing files
existing_files = [f for f in file_names if os.path.exists(f)]
return existing_files
"""This function filters out files that do not exist, but only if ``allow_missing_files`` is set to ``True``, otherwise the list of ``file_names`` is returned unaltered."""
if not allow_missing_files:
return file_names
if split_by_client:
# filter out missing files and empty clients
existing_files = [
[f for f in client_files if os.path.exists(f)]
for client_files in file_names
]
existing_files = [
client_files for client_files in existing_files if client_files
]
else:
# filter out missing files
existing_files = [f for f in file_names if os.path.exists(f)]
return existing_files
def filter_none(data, split_by_client=False):
"""This function filters out ``None`` values from the given list (or list of lists, when ``split_by_client`` is enabled)."""
if split_by_client:
# filter out missing files and empty clients
existing_data = [[d for d in client_data if d is not None]
for client_data in data]
existing_data = [
client_data for client_data in existing_data if client_data]
else:
# filter out missing files
existing_data = [d for d in data if d is not None]
return existing_data
"""This function filters out ``None`` values from the given list (or list of lists, when ``split_by_client`` is enabled)."""
if split_by_client:
# filter out missing files and empty clients
existing_data = [
[d for d in client_data if d is not None] for client_data in data
]
existing_data = [client_data for client_data in existing_data if client_data]
else:
# filter out missing files
existing_data = [d for d in data if d is not None]
return existing_data
def check_file(filename, force, expected_file_size=1):
"""Checks if the file with the given ``filename`` exists and has size greater or equal to ``expected_file_size``.
"""Checks if the file with the given ``filename`` exists and has size greater or equal to ``expected_file_size``.
If the file is to small, **or** if the ``force`` option is set to ``True``, the file is removed.
This function returns ``True`` is the file exists (and has not been removed), otherwise ``False``"""
if os.path.exists(filename):
if force or os.path.getsize(filename) < expected_file_size:
logger.debug(" .. Removing old file '%s'.", filename)
os.remove(filename)
return False
else:
return True
return False
if os.path.exists(filename):
if force or os.path.getsize(filename) < expected_file_size:
logger.debug(" .. Removing old file '%s'.", filename)
os.remove(filename)
return False
else:
return True
return False
def read_original_data(biofile, directory, extension):
"""This function reads the original data using the given ``biofile`` instance.
"""This function reads the original data using the given ``biofile`` instance.
It simply calls ``load(directory, extension)`` from :py:class:`bob.bio.base.database.BioFile` or one of its derivatives.
Parameters
......@@ -77,95 +81,98 @@ def read_original_data(biofile, directory, extension):
object
Whatver ``biofile.load`` returns; usually a :py:class:`numpy.ndarray`
"""
assert isinstance(biofile, database.BioFile)
return biofile.load(directory, extension)
assert isinstance(biofile, database.BioFile)
return biofile.load(directory, extension)
def load(file):
"""Loads data from file. The given file might be an HDF5 file open for reading or a string."""
if isinstance(file, bob.io.base.HDF5File):
return file.read("array")
else:
return bob.io.base.load(file)
"""Loads data from file. The given file might be an HDF5 file open for reading or a string."""
if isinstance(file, bob.io.base.HDF5File):
return file.read("array")
else:
return bob.io.base.load(file)
def save(data, file, compression=0):
"""Saves the data to file using HDF5. The given file might be an HDF5 file open for writing, or a string.
"""Saves the data to file using HDF5. The given file might be an HDF5 file open for writing, or a string.
If the given data contains a ``save`` method, this method is called with the given HDF5 file.
Otherwise the data is written to the HDF5 file using the given compression."""
f = file if isinstance(file, bob.io.base.HDF5File) else bob.io.base.HDF5File(file, 'w')
if hasattr(data, 'save'):
data.save(f)
else:
f.set("array", data, compression=compression)
f = (
file
if isinstance(file, bob.io.base.HDF5File