Commit 64bc0eb0 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Remove fiddling with bob.pipelines.Samples's internals

parent 6a60ffd4
......@@ -44,7 +44,6 @@ class CSVBaseSampleLoader(metaclass=ABCMeta):
self.data_loader = data_loader
self.extension = extension
self.dataset_original_directory = dataset_original_directory
self.excluding_attributes = ["_data", "load", "key"]
@abstractmethod
def __call__(self, filename):
......@@ -104,15 +103,6 @@ class CSVToSampleLoader(CSVBaseSampleLoader):
def convert_samples_to_samplesets(
self, samples, group_by_subject=True, references=None
):
def get_attribute_from_sample(sample):
return dict(
[
[attribute, sample.__dict__[attribute]]
for attribute in list(sample.__dict__.keys())
if attribute not in self.excluding_attributes
]
)
if group_by_subject:
# Grouping sample sets
......@@ -120,7 +110,7 @@ class CSVToSampleLoader(CSVBaseSampleLoader):
for s in samples:
if s.subject not in sample_sets:
sample_sets[s.subject] = SampleSet(
[s], **get_attribute_from_sample(s)
[s], parent=s, references=references
)
else:
sample_sets[s.subject].append(s)
......@@ -128,7 +118,7 @@ class CSVToSampleLoader(CSVBaseSampleLoader):
else:
return [
SampleSet([s], **get_attribute_from_sample(s), references=references)
SampleSet([s], parent=s, references=references)
for s in samples
]
......@@ -174,7 +164,7 @@ class CSVDatasetDevEval:
path_i,subject_j
...
You might want to ship metadata within your Samples (e.g gender, age, annotation, ...)
To do so is simple, just do as below:
......@@ -189,7 +179,7 @@ class CSVDatasetDevEval:
The files `my_dataset/my_protocol/train.csv/eval_enroll.csv` and `my_dataset/my_protocol/train.csv/eval_probe.csv`
are optional and it is used in case a protocol contains data for evaluation.
Finally, the content of the file `my_dataset/my_protocol/train.csv` is used in the case a protocol
contains data for training (`bob.bio.base.pipelines.vanilla_biometrics.Database.background_model_samples`)
......@@ -329,7 +319,7 @@ class CSVDatasetDevEval:
class CSVDatasetCrossValidation:
"""
Generic filelist dataset for :any:`bob.bio.base.pipelines.vanilla_biometrics.VanillaBiometricsPipeline` pipeline that
Generic filelist dataset for :any:`bob.bio.base.pipelines.vanilla_biometrics.VanillaBiometricsPipeline` pipeline that
handles **CROSS VALIDATION**.
Check :any:`vanilla_biometrics_features` for more details about the Vanilla Biometrics Dataset
......
......@@ -3,7 +3,7 @@
from abc import ABCMeta, abstractmethod
from bob.pipelines.sample import Sample, SampleSet, DelayedSample
from bob.pipelines.sample import SAMPLE_DATA_ATTRS, Sample, SampleSet, DelayedSample
import functools
import numpy as np
import os
......@@ -211,13 +211,7 @@ class BioAlgorithm(metaclass=ABCMeta):
scores_biometric_references.append(Sample(score, parent=ref))
# Fetching metadata from the probe
kwargs = dict(
(metadata, sampleset.__dict__[metadata])
for metadata in sampleset.__dict__.keys()
if metadata not in ["samples", "key", "data", "load", "_data"]
)
return SampleSet(scores_biometric_references, parent=sampleset, **kwargs)
return SampleSet(scores_biometric_references, parent=sampleset)
@abstractmethod
def score(self, biometric_reference, data):
......
......@@ -4,6 +4,7 @@
import os
from bob.pipelines import SampleSet, DelayedSample
from bob.pipelines.sample import SAMPLE_DATA_ATTRS
from .abstract_classes import ScoreWriter
import functools
import csv
......@@ -74,7 +75,7 @@ class CSVScoreWriter(ScoreWriter):
self,
path,
n_sample_sets=1000,
exclude_list=["samples", "key", "data", "load", "_data", "references", "annotations"],
exclude_list=tuple(SAMPLE_DATA_ATTRS) + ("key", "references", "annotations"),
):
super().__init__(path)
self.n_sample_sets = n_sample_sets
......@@ -92,13 +93,13 @@ class CSVScoreWriter(ScoreWriter):
probe_dict = dict(
(k, f"probe_{k}")
for k in probe_sampleset.__dict__.keys()
if k not in self.exclude_list
if not (k in self.exclude_list or k.startswith("__"))
)
bioref_dict = dict(
(k, f"bio_ref_{k}")
for k in first_biometric_reference.__dict__.keys()
if k not in self.exclude_list
if not (k in self.exclude_list or k.startswith("__"))
)
header = (
......@@ -130,7 +131,7 @@ class CSVScoreWriter(ScoreWriter):
rows = []
probe_row = [str(probe.key)] + [
str(probe.__dict__[k]) for k in probe_dict.keys()
str(getattr(probe, k)) for k in probe_dict.keys()
]
# If it's delayed, load it
......@@ -139,7 +140,7 @@ class CSVScoreWriter(ScoreWriter):
for biometric_reference in probe:
bio_ref_row = [
str(biometric_reference.__dict__[k])
str(getattr(biometric_reference, k))
for k in list(bioref_dict.keys()) + ["data"]
]
......@@ -182,4 +183,4 @@ class CSVScoreWriter(ScoreWriter):
if isinstance(score_paths, dask.bag.Bag):
all_paths = dask.delayed(list)(score_paths)
return dask.delayed(_post_process)(all_paths, path)
return _post_process(score_paths, path)
return _post_process(score_paths, path)
......@@ -288,7 +288,7 @@ def test_norm_mechanics():
z_normed_scores = _dump_scores_from_samples(
z_normed_score_samples, shape=(n_probes, n_references)
)
assert np.allclose(z_normed_scores, z_normed_scores_ref)
np.testing.assert_allclose(z_normed_scores, z_normed_scores_ref)
############
# TESTING T-NORM
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment