Skip to content
Snippets Groups Projects
Commit 666b29c7 authored by Yannick DAYER's avatar Yannick DAYER
Browse files

Merge branch 'fix-sset-metadata' into 'master'

fix [CSVDatabase]: defining templates metadata.

Closes #191

See merge request !320
parents e6cfce35 9a1777d1
No related branches found
No related tags found
1 merge request!320fix [CSVDatabase]: defining templates metadata.
Pipeline #72156 passed
......@@ -195,6 +195,7 @@ class CSVDatabase(FileListDatabase, Database):
protocol: str,
dataset_protocols_path: Optional[str] = None,
transformer: Optional[sklearn.pipeline.Pipeline] = None,
templates_metadata: Optional[list[str]] = None,
annotation_type: Optional[str] = None,
fixed_positions: Optional[dict[str, tuple[float, float]]] = None,
memory_demanding=False,
......@@ -213,6 +214,10 @@ class CSVDatabase(FileListDatabase, Database):
transformer
An sklearn pipeline or equivalent transformer that handles some light
preprocessing of the samples (This will always run locally).
templates_metadata
Metadata that originate from the samples and must be present in the
templates (SampleSet) e.g. ``["gender", "age"]``. This should be metadata
that is common to all the samples in a template.
annotation_type
A string describing the annotations passed to the annotation loading
function
......@@ -243,6 +248,10 @@ class CSVDatabase(FileListDatabase, Database):
else:
self.score_all_vs_all = False
self.templates_metadata = []
if templates_metadata is not None:
self.templates_metadata = templates_metadata
def list_file(self, group: str, name: str) -> TextIO:
"""Returns a definition file containing one sample per row.
......@@ -307,11 +316,16 @@ class CSVDatabase(FileListDatabase, Database):
# we add that as well.
samples = list(samples_for_template_id)
subject_id = samples[0].subject_id
metadata = {
m: getattr(samples[0], m) for m in self.templates_metadata
}
sample_sets.append(
SampleSet(
samples,
template_id=template_id,
subject_id=subject_id,
key=f"template_{template_id}",
**metadata,
)
)
validate_bio_samples(sample_sets)
......
path,subject_id,template_id,metadata_1,metadata_2
data/model3_session1_sample1,3,3,F,10
data/model3_session1_sample2,3,3,F,10
data/model3_session1_sample3,3,3,F,10
data/model3_session2_sample1,3,3,F,10
data/model4_session1_sample1,4,4,M,30
data/model4_session1_sample2,4,4,M,30
data/model4_session1_sample3,4,4,M,30
data/model4_session2_sample1,4,4,M,30
path,subject_id,template_id,metadata_1,metadata_2,sample_metadata,subject_metadata
data/model3_session1_sample1,3,3,F,10,1,A
data/model3_session1_sample2,3,3,F,10,2,A
data/model3_session1_sample3,3,3,F,10,3,A
data/model3_session2_sample1,3,3,F,10,4,A
data/model4_session1_sample1,4,4,M,30,5,B
data/model4_session1_sample2,4,4,M,30,6,B
data/model4_session1_sample3,4,4,M,30,7,B
data/model4_session2_sample1,4,4,M,30,8,B
path,subject_id,template_id,metadata_1,metadata_2
data/model3_session3_sample1,3,3,F,10
data/model3_session3_sample2,3,3,F,10
data/model3_session3_sample3,3,3,F,10
data/model3_session4_sample1,3,3,F,10
data/model4_session3_sample1,4,4,M,30
data/model4_session3_sample2,4,4,M,30
data/model4_session3_sample1,4,4,M,30
data/model4_session3_sample2,4,4,M,30
data/model4_session3_sample3,4,4,M,30
data/model4_session4_sample1,4,4,M,30
path,subject_id,template_id,metadata_1,metadata_2,sample_metadata,subject_metadata
data/model3_session3_sample1,3,3,F,10,1,A
data/model3_session3_sample2,3,3,F,10,2,A
data/model3_session3_sample3,3,3,F,10,3,A
data/model3_session4_sample1,3,3,F,10,4,A
data/model4_session3_sample1,4,4,M,30,5,B
data/model4_session3_sample2,4,4,M,30,6,B
data/model4_session3_sample1,4,4,M,30,7,B
data/model4_session3_sample2,4,4,M,30,8,B
data/model4_session3_sample3,4,4,M,30,9,B
data/model4_session4_sample1,4,4,M,30,10,B
......@@ -8,10 +8,15 @@
Very simple tests for Implementations
"""
from pathlib import Path
import bob.bio.base
from bob.bio.base.config.dummy.database import database as dummy_database
from bob.pipelines import DelayedSample
from bob.bio.base.database import CSVDatabase
from bob.pipelines import DelayedSample, SampleSet
DATA_DIR = Path(__file__).parent / "data"
def test_all_samples():
......@@ -27,7 +32,41 @@ def test_atnt():
database = bob.bio.base.load_resource(
"atnt", "database", preferred_package="bob.bio.base"
)
assert len(database.background_model_samples()) > 0
assert len(database.references()) > 0
assert len(database.probes()) > 0
assert len(database.all_samples()) > 0
train_set = database.background_model_samples()
assert len(train_set) > 0
assert isinstance(train_set[0], DelayedSample)
references = database.references()
assert len(references) > 0
references_sset = references[0]
assert isinstance(references_sset, SampleSet)
assert hasattr(references_sset, "key")
assert hasattr(references_sset, "subject_id")
assert hasattr(references_sset, "template_id")
references_sample = references_sset.samples[0]
assert isinstance(references_sample, DelayedSample)
assert hasattr(references_sample, "key")
probes = database.probes()
assert len(probes) > 0
assert isinstance(probes[0], SampleSet)
assert isinstance(probes[0].samples[0], DelayedSample)
all_samples = database.all_samples()
assert len(all_samples) > 0
assert isinstance(all_samples[0], DelayedSample)
def test_metadata():
local_protocol_definition_path = DATA_DIR / "example_csv_filelist"
database = CSVDatabase(
name="dummy_metadata",
protocol="protocol_only_dev_metadata",
dataset_protocols_path=local_protocol_definition_path,
templates_metadata=["subject_metadata"],
)
references_sset = database.references()[0]
assert hasattr(references_sset, "subject_metadata")
references_sample = references_sset.samples[0]
assert hasattr(references_sample, "sample_metadata")
probes_sset = database.probes()[0]
assert hasattr(probes_sset, "subject_metadata")
probes_sample = probes_sset.samples[0]
assert hasattr(probes_sample, "sample_metadata")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment