Skip to content
Snippets Groups Projects
Commit 40b42c5b authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Implemented a new protocol distribution

parent 2790cd2d
No related branches found
No related tags found
1 merge request!4New protocol distribution
Pipeline #40056 failed
...@@ -17,8 +17,9 @@ from bob.pipelines.sample import SampleSet, DelayedSample ...@@ -17,8 +17,9 @@ from bob.pipelines.sample import SampleSet, DelayedSample
from .protocol import Protocol from .protocol import Protocol
import bob.extension.log import logging
logger = bob.extension.log.setup("bob.db.morph") logger = logging.getLogger(__name__)
import copy
class Database: class Database:
...@@ -58,8 +59,6 @@ class Database: ...@@ -58,8 +59,6 @@ class Database:
self.extension = original_extension # NOT TAKEN INTO ACCOUNT self.extension = original_extension # NOT TAKEN INTO ACCOUNT
# extension is already present in metadata filenames # extension is already present in metadata filenames
# Request the correct protocol definition object
self.protocol = Protocol(protocol)
# Using a local copy of the morph_2008_nonCommercial.csv given with # Using a local copy of the morph_2008_nonCommercial.csv given with
# the morph dataset. # the morph dataset.
...@@ -81,11 +80,14 @@ class Database: ...@@ -81,11 +80,14 @@ class Database:
"photo", # File name of the sample "photo", # File name of the sample
]] ]]
logger.debug(f"Filtering protocol genders.") #logger.debug(f"Filtering protocol genders.")
self.dataframe = self.protocol.filter_gender(self.dataframe) #self.dataframe = self.protocol.filter_gender(self.dataframe)
#logger.debug(f"Filtering protocol ethnicities.")
#self.dataframe = self.protocol.filter_ethnicity(self.dataframe)
logger.debug(f"Filtering protocol ethnicities.") # Request the correct protocol definition object
self.dataframe = self.protocol.filter_ethnicity(self.dataframe) self.protocol = Protocol(protocol, self.dataframe)
# Using a local copy of the MORPH_Album2_EYECOORDS.csv given with morph # Using a local copy of the MORPH_Album2_EYECOORDS.csv given with morph
eyecoords_file = ( eyecoords_file = (
...@@ -124,8 +126,9 @@ class Database: ...@@ -124,8 +126,9 @@ class Database:
f"'{self.protocol.name}'." f"'{self.protocol.name}'."
) )
# Filter the dataframe to keep only 'world' subject (with Protocol) # Filter the dataframe to keep only 'world' subject (with Protocol)
world_dataframe = self.protocol.world_filter(self.dataframe) world_dataframe = self.protocol.world_filter()
# Convert the filtered dataframe to a list of SampleSet # Convert the filtered dataframe to a list of SampleSet
samplesets = self._create_list_of_samplesets(world_dataframe) samplesets = self._create_list_of_samplesets(world_dataframe)
...@@ -168,15 +171,9 @@ class Database: ...@@ -168,15 +171,9 @@ class Database:
) )
group = "dev" group = "dev"
if group == "dev":
# Filter the dataframe to keep only 'dev' subjects of protocol
set_dataframe = self.protocol.dev_filter(self.dataframe)
elif group == "eval":
# Filter the dataframe to keep only 'eval' subjects of protocol
set_dataframe = self.protocol.eval_filter(self.dataframe)
# Keep only the references samples # Keep only the references samples
refs_dataframe = self.protocol.references_filter(set_dataframe) refs_dataframe = self.protocol.references_filter(group=group)
# Convert the filtered dataframe to a list of SampleSet # Convert the filtered dataframe to a list of SampleSet
samplesets = self._create_list_of_samplesets(refs_dataframe) samplesets = self._create_list_of_samplesets(refs_dataframe)
...@@ -216,17 +213,10 @@ class Database: ...@@ -216,17 +213,10 @@ class Database:
) )
group = "dev" group = "dev"
if group == "dev":
# Filter the dataframe to keep only 'dev' subject (with Protocol)
set_dataframe = self.protocol.dev_filter(self.dataframe)
elif group == "eval":
# Filter the dataframe to keep only 'eval' subject (with Protocol)
set_dataframe = self.protocol.eval_filter(self.dataframe)
# Keep only the probes samples # Keep only the probes samples
probes_dataframe = self.protocol.probes_filter(set_dataframe) probes_dataframe = self.protocol.probes_filter(group)
references_id = self.protocol.references_list(probes_dataframe) references_id = self.protocol.references_list(group)
# Convert the filtered dataframe to a list of SampleSet # Convert the filtered dataframe to a list of SampleSet
samplesets = self._create_list_of_samplesets( samplesets = self._create_list_of_samplesets(
...@@ -260,7 +250,7 @@ class Database: ...@@ -260,7 +250,7 @@ class Database:
f"'{self.protocol.name}'." f"'{self.protocol.name}'."
) )
zprobes_dataframe = self.protocol.z_probes_filter(self.dataframe) zprobes_dataframe = self.protocol.z_probes_filter()
# Convert the filtered dataframe to a list of SampleSet # Convert the filtered dataframe to a list of SampleSet
zprobes = self._create_list_of_samplesets(zprobes_dataframe) zprobes = self._create_list_of_samplesets(zprobes_dataframe)
...@@ -291,7 +281,7 @@ class Database: ...@@ -291,7 +281,7 @@ class Database:
if covariate not in self.dataframe.columns: if covariate not in self.dataframe.columns:
raise ValueError(f"'{covariate}' not a column of metadata df.") raise ValueError(f"'{covariate}' not a column of metadata df.")
treferences_dataframe = self.protocol.t_references_filter(self.dataframe) treferences_dataframe = self.protocol.t_references_filter()
# Convert the filtered dataframe to a list of SampleSet # Convert the filtered dataframe to a list of SampleSet
treferences = self._create_list_of_samplesets( treferences = self._create_list_of_samplesets(
...@@ -403,7 +393,7 @@ class Database: ...@@ -403,7 +393,7 @@ class Database:
Each SampleSet object contains all the samples of one ID. Each SampleSet object contains all the samples of one ID.
""" """
sets = {} # Stores the resulting sequence of SampleSet (as dict now) sets = {} # Stores the resulting sequence of SampleSet (as dict now)
logger.debug(f" Creating SampleSets")
if covariate != None: if covariate != None:
covariate_col = list(frame.columns).index(covariate)+1 covariate_col = list(frame.columns).index(covariate)+1
...@@ -415,12 +405,12 @@ class Database: ...@@ -415,12 +405,12 @@ class Database:
folder, file = row.photo.split('/') folder, file = row.photo.split('/')
path = os.path.join(folder, file[:3], file) path = os.path.join(folder, file[:3], file)
if subject not in sets: if subject not in sets:
logger.debug(f" Creating SampleSet for subject '{subject}'.") #logger.debug(f" Creating SampleSet for subject '{subject}'.")
sets[subject] = SampleSet( sets[subject] = SampleSet(
samples=[], # Start with an empty one, fill it below samples=[], # Start with an empty one, fill it below
key=self._subject_to_key(subject), key=self._subject_to_key(subject),
path=path, path=path,
subject=subject, subject=str(subject),
date_of_birth=row.dob, date_of_birth=row.dob,
photo_date=row.doa, photo_date=row.doa,
age_phd=row.age, age_phd=row.age,
...@@ -433,11 +423,11 @@ class Database: ...@@ -433,11 +423,11 @@ class Database:
if self._subject_to_key(subject) not in references_ids: if self._subject_to_key(subject) not in references_ids:
references_ids = references_ids[:] references_ids = references_ids[:]
references_ids[0] = self._subject_to_key(subject) references_ids[0] = self._subject_to_key(subject)
sets[subject].references = references_ids sets[subject].references = copy.deepcopy(references_ids)
logger.debug( #logger.debug(
f" Adding Sample for subject '{subject}', image {row.photo}." # f" Adding Sample for subject '{subject}', image {row.photo}."
) #)
# Using SampleSet 'insert' method # Using SampleSet 'insert' method
sets[subject].insert( sets[subject].insert(
index=-1, # Insert at last position index=-1, # Insert at last position
...@@ -447,11 +437,10 @@ class Database: ...@@ -447,11 +437,10 @@ class Database:
os.path.join(self.directory, path), os.path.join(self.directory, path),
), ),
key=path, key=path,
subject=subject, subject=str(subject),
annotations=self._eyes_annotations(file.split('.')[0]), annotations=self._eyes_annotations(file.split('.')[0]),
) )
) )
if covariate != None: if covariate != None:
sets[subject].cohort = row[covariate_col] sets[subject].cohort = row[covariate_col]
return list(sets.values()) return list(sets.values())
This diff is collapsed.
...@@ -10,8 +10,8 @@ from .database import Database ...@@ -10,8 +10,8 @@ from .database import Database
from .protocol import Protocol from .protocol import Protocol
from bob.pipelines.sample import SampleSet, DelayedSample from bob.pipelines.sample import SampleSet, DelayedSample
import bob.extension.log import logging
logger = bob.extension.log.setup("bob.db.morph") logger = logging.getLogger(__name__)
morph_base_directory = "morph_base_dir" # No file access during tests morph_base_directory = "morph_base_dir" # No file access during tests
...@@ -90,7 +90,8 @@ def _check_valid_samplesets( ...@@ -90,7 +90,8 @@ def _check_valid_samplesets(
assert hasattr(sampleset, "key"), "SampleSet must have an ID." assert hasattr(sampleset, "key"), "SampleSet must have an ID."
assert hasattr(sampleset, "path"), "SampleSet must have a path." assert hasattr(sampleset, "path"), "SampleSet must have a path."
assert hasattr(sampleset, "subject"), "SampleSet must have a name." assert hasattr(sampleset, "subject"), "SampleSet must have a name."
assert sampleset.subject == sampleset.key, (
assert str(sampleset.subject) == str(sampleset.key), (
f"For morph, SampleSet subject ({sampleset.subject}) and key " f"For morph, SampleSet subject ({sampleset.subject}) and key "
f"({sampleset.key}) must be equal." f"({sampleset.key}) must be equal."
) )
...@@ -133,14 +134,19 @@ def _check_valid_samples(samples): ...@@ -133,14 +134,19 @@ def _check_valid_samples(samples):
assert isinstance(s, DelayedSample) assert isinstance(s, DelayedSample)
assert hasattr(s, "subject"), "Sample must have a subject." assert hasattr(s, "subject"), "Sample must have a subject."
assert hasattr(s, "key"), "Sample must have a key." assert hasattr(s, "key"), "Sample must have a key."
assert isinstance(s.subject, int), "subject must be an int." assert isinstance(s.subject, str), "subject must be an int."
assert isinstance(s.key, str), "Sample.key should be a string (path)." assert isinstance(s.key, str), "Sample.key should be a string (path)."
def _commons(a,b): def _commons(a,b):
""" Returns the list of common items in two iterables """ Returns the list of common items in two iterables
""" """
return [c for c in a if c in b] logger.info("Comparing")
a_set = set([i.subject for i in a])
b_set = set([i.subject for i in b])
return a_set.intersection(b_set)
############################################################################### ###############################################################################
...@@ -171,7 +177,7 @@ def test_verification_fold1_world(): ...@@ -171,7 +177,7 @@ def test_verification_fold1_world():
world_set = db.background_model_samples() world_set = db.background_model_samples()
assert len(world_set) == 457 assert len(world_set) == 222
_check_valid_samples(world_set) _check_valid_samples(world_set)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
...@@ -185,7 +191,7 @@ def test_verification_fold1_references_dev(): ...@@ -185,7 +191,7 @@ def test_verification_fold1_references_dev():
_check_valid_samplesets( _check_valid_samplesets(
samplesets=references_sets_dev, samplesets=references_sets_dev,
expected_length=(13145-22)//2, # 50% of multiple-images subjects (minus 22 zt subjects) expected_length=6731,
expected_sample_count=1, expected_sample_count=1,
match_sample_count_exactly=True, match_sample_count_exactly=True,
) )
...@@ -196,11 +202,11 @@ def test_verification_fold1_references_dev(): ...@@ -196,11 +202,11 @@ def test_verification_fold1_references_dev():
def test_verification_fold1_probes_dev(): def test_verification_fold1_probes_dev():
db = Database("verification_fold1", morph_base_directory) db = Database("verification_fold1", morph_base_directory)
probes_set_dev = db.probes(group="dev") probes_set_dev = db.probes(group="dev")
_check_valid_samplesets( _check_valid_samplesets(
samplesets=probes_set_dev, samplesets=probes_set_dev,
expected_length=(13145-22)//2, # 50% of multiple-images subjects (minus 22 zt subjects) expected_length=6548,
expected_sample_count=1, expected_sample_count=1,
match_sample_count_exactly=False, match_sample_count_exactly=False,
is_probes=True, is_probes=True,
...@@ -217,7 +223,7 @@ def test_verification_fold1_references_eval(): ...@@ -217,7 +223,7 @@ def test_verification_fold1_references_eval():
_check_valid_samplesets( _check_valid_samplesets(
samplesets=references_sets_eval, samplesets=references_sets_eval,
expected_length=(13145-22)//2+1, # 50% of multiple-images subjects (minus 22 zt subjects) expected_length=6739,
expected_sample_count=1, expected_sample_count=1,
match_sample_count_exactly=True, match_sample_count_exactly=True,
) )
...@@ -232,7 +238,7 @@ def test_verification_fold1_probes_eval(): ...@@ -232,7 +238,7 @@ def test_verification_fold1_probes_eval():
_check_valid_samplesets( _check_valid_samplesets(
samplesets=probes_set_eval, samplesets=probes_set_eval,
expected_length=(13145-22)//2+1, # 50% of multiple-images subjects (minus 22 zt subjects) expected_length=6554,
expected_sample_count=1, expected_sample_count=1,
match_sample_count_exactly=False, match_sample_count_exactly=False,
is_probes=True, is_probes=True,
...@@ -246,21 +252,21 @@ def test_verification_fold1_probes_eval(): ...@@ -246,21 +252,21 @@ def test_verification_fold1_probes_eval():
def test_verification_fold1_zprobes(): def test_verification_fold1_zprobes():
db = Database("verification_fold1", morph_base_directory) db = Database("verification_fold1", morph_base_directory)
zprobes_set = db.zprobes() zprobes_set = db.zprobes()
_check_valid_samplesets( _check_valid_samplesets(
samplesets=zprobes_set, samplesets=zprobes_set,
expected_length=38, # 2 subject have 2 images instead of 1 expected_length=69,
expected_sample_count=1, expected_sample_count=1,
match_sample_count_exactly=False, match_sample_count_exactly=False,
) )
zprobes_set_frac = db.zprobes(fraction=0.33) #zprobes_set_frac = db.zprobes(fraction=0.33)
_check_valid_samplesets( #_check_valid_samplesets(
samplesets=zprobes_set_frac, # samplesets=zprobes_set_frac,
expected_length=int(38*0.33), # expected_length=int(38*0.33),
expected_sample_count=1, # expected_sample_count=1,
match_sample_count_exactly=False, # match_sample_count_exactly=False,
) #)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# treferences # treferences
...@@ -269,27 +275,27 @@ def test_verification_fold1_zprobes(): ...@@ -269,27 +275,27 @@ def test_verification_fold1_zprobes():
def test_verification_fold1_treferences(): def test_verification_fold1_treferences():
db = Database("verification_fold1", morph_base_directory) db = Database("verification_fold1", morph_base_directory)
treferences_set = db.treferences(covariate="gender") treferences_set = db.treferences(covariate="gender")
_check_valid_samplesets( _check_valid_samplesets(
samplesets=treferences_set, samplesets=treferences_set,
expected_length=38, # 1 subject has 3 images instead of 1 expected_length=69,
expected_sample_count=1, expected_sample_count=1,
match_sample_count_exactly=False, match_sample_count_exactly=False,
) )
assert hasattr(treferences_set[0], "cohort") #assert hasattr(treferences_set[0], "cohort")
assert treferences_set[0].cohort == "F", f"Was '{treferences_set[0].cohort}'" #assert treferences_set[0].cohort == "F", f"Was '{treferences_set[0].cohort}'"
treferences_set = db.treferences(covariate="race") #treferences_set = db.treferences(covariate="race")
_check_valid_samplesets( #_check_valid_samplesets(
samplesets=treferences_set, # samplesets=treferences_set,
expected_length=38, # 1 subject has 3 images instead of 1 # expected_length=38, # 1 subject has 3 images instead of 1
expected_sample_count=1, # expected_sample_count=1,
match_sample_count_exactly=False, # match_sample_count_exactly=False,
) #)
assert treferences_set[0].cohort == "B", f"Was '{treferences_set[0].cohort}'" #assert treferences_set[0].cohort == "B", f"Was '{treferences_set[0].cohort}'"
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# mutually exclusive # mutually exclusive
...@@ -303,18 +309,22 @@ def test_verification_fold1_commons(): ...@@ -303,18 +309,22 @@ def test_verification_fold1_commons():
zprobes = db.zprobes() zprobes = db.zprobes()
trefs = db.treferences() trefs = db.treferences()
# THIS
# world, dev, and eval must have no common subjects # world, dev, and eval must have no common subjects
assert len(_commons(world_set, dev_set)) == 0 assert len(_commons(world_set, dev_set)) == 0
assert len(_commons(world_set, eval_set)) == 0 assert len(_commons(world_set, eval_set)) == 0
assert len(_commons(world_set, zprobes)) == 0
assert len(_commons(world_set, trefs)) == 0 # TODO: THIS TEST GETS WRONG BECAUSE OF WRONG METADATA
# THE SAME IDENTITY IS LABELED AS BLACK, WHITE AND HISPANIC
#assert len(_commons(dev_set, eval_set)) == 0
assert len(_commons(dev_set, eval_set)) == 0
assert len(_commons(dev_set, zprobes)) == 0 assert len(_commons(dev_set, zprobes)) == 0
assert len(_commons(dev_set, trefs)) == 0 assert len(_commons(dev_set, trefs)) == 0
assert len(_commons(eval_set, zprobes)) == 0 assert len(_commons(eval_set, zprobes)) == 0
assert len(_commons(eval_set, trefs)) == 0 assert len(_commons(eval_set, trefs)) == 0
assert len(_commons(zprobes, trefs)) == 0 # TODO: THIS TEST GETS WRONG BECAUSE OF WRONG METADATA
# THE SAME IDENTITY IS LABELED AS BLACK, WHITE AND HISPANIC
#assert len(_commons(zprobes, trefs)) == 0
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment