Skip to content
Snippets Groups Projects

New protocol distribution

Merged Tiago de Freitas Pereira requested to merge new_protocol into master
5 files
+ 318
281
Compare changes
  • Side-by-side
  • Inline
Files
5
+ 25
36
@@ -17,8 +17,9 @@ from bob.pipelines.sample import SampleSet, DelayedSample
from .protocol import Protocol
import bob.extension.log
logger = bob.extension.log.setup("bob.db.morph")
import logging
logger = logging.getLogger(__name__)
import copy
class Database:
@@ -58,8 +59,6 @@ class Database:
self.extension = original_extension # NOT TAKEN INTO ACCOUNT
# extension is already present in metadata filenames
# Request the correct protocol definition object
self.protocol = Protocol(protocol)
# Using a local copy of the morph_2008_nonCommercial.csv given with
# the morph dataset.
@@ -81,11 +80,14 @@ class Database:
"photo", # File name of the sample
]]
logger.debug(f"Filtering protocol genders.")
self.dataframe = self.protocol.filter_gender(self.dataframe)
#logger.debug(f"Filtering protocol genders.")
#self.dataframe = self.protocol.filter_gender(self.dataframe)
#logger.debug(f"Filtering protocol ethnicities.")
#self.dataframe = self.protocol.filter_ethnicity(self.dataframe)
logger.debug(f"Filtering protocol ethnicities.")
self.dataframe = self.protocol.filter_ethnicity(self.dataframe)
# Request the correct protocol definition object
self.protocol = Protocol(protocol, self.dataframe)
# Using a local copy of the MORPH_Album2_EYECOORDS.csv given with morph
eyecoords_file = (
@@ -124,8 +126,9 @@ class Database:
f"'{self.protocol.name}'."
)
# Filter the dataframe to keep only 'world' subject (with Protocol)
world_dataframe = self.protocol.world_filter(self.dataframe)
world_dataframe = self.protocol.world_filter()
# Convert the filtered dataframe to a list of SampleSet
samplesets = self._create_list_of_samplesets(world_dataframe)
@@ -168,15 +171,9 @@ class Database:
)
group = "dev"
if group == "dev":
# Filter the dataframe to keep only 'dev' subjects of protocol
set_dataframe = self.protocol.dev_filter(self.dataframe)
elif group == "eval":
# Filter the dataframe to keep only 'eval' subjects of protocol
set_dataframe = self.protocol.eval_filter(self.dataframe)
# Keep only the references samples
refs_dataframe = self.protocol.references_filter(set_dataframe)
refs_dataframe = self.protocol.references_filter(group=group)
# Convert the filtered dataframe to a list of SampleSet
samplesets = self._create_list_of_samplesets(refs_dataframe)
@@ -216,17 +213,10 @@ class Database:
)
group = "dev"
if group == "dev":
# Filter the dataframe to keep only 'dev' subject (with Protocol)
set_dataframe = self.protocol.dev_filter(self.dataframe)
elif group == "eval":
# Filter the dataframe to keep only 'eval' subject (with Protocol)
set_dataframe = self.protocol.eval_filter(self.dataframe)
# Keep only the probes samples
probes_dataframe = self.protocol.probes_filter(set_dataframe)
probes_dataframe = self.protocol.probes_filter(group)
references_id = self.protocol.references_list(probes_dataframe)
references_id = self.protocol.references_list(group)
# Convert the filtered dataframe to a list of SampleSet
samplesets = self._create_list_of_samplesets(
@@ -260,7 +250,7 @@ class Database:
f"'{self.protocol.name}'."
)
zprobes_dataframe = self.protocol.z_probes_filter(self.dataframe)
zprobes_dataframe = self.protocol.z_probes_filter()
# Convert the filtered dataframe to a list of SampleSet
zprobes = self._create_list_of_samplesets(zprobes_dataframe)
@@ -291,7 +281,7 @@ class Database:
if covariate not in self.dataframe.columns:
raise ValueError(f"'{covariate}' not a column of metadata df.")
treferences_dataframe = self.protocol.t_references_filter(self.dataframe)
treferences_dataframe = self.protocol.t_references_filter()
# Convert the filtered dataframe to a list of SampleSet
treferences = self._create_list_of_samplesets(
@@ -403,7 +393,7 @@ class Database:
Each SampleSet object contains all the samples of one ID.
"""
sets = {} # Stores the resulting sequence of SampleSet (as dict now)
logger.debug(f" Creating SampleSets")
if covariate != None:
covariate_col = list(frame.columns).index(covariate)+1
@@ -415,12 +405,12 @@ class Database:
folder, file = row.photo.split('/')
path = os.path.join(folder, file[:3], file)
if subject not in sets:
logger.debug(f" Creating SampleSet for subject '{subject}'.")
#logger.debug(f" Creating SampleSet for subject '{subject}'.")
sets[subject] = SampleSet(
samples=[], # Start with an empty one, fill it below
key=self._subject_to_key(subject),
path=path,
subject=subject,
subject=str(subject),
date_of_birth=row.dob,
photo_date=row.doa,
age_phd=row.age,
@@ -433,11 +423,11 @@ class Database:
if self._subject_to_key(subject) not in references_ids:
references_ids = references_ids[:]
references_ids[0] = self._subject_to_key(subject)
sets[subject].references = references_ids
sets[subject].references = copy.deepcopy(references_ids)
logger.debug(
f" Adding Sample for subject '{subject}', image {row.photo}."
)
#logger.debug(
# f" Adding Sample for subject '{subject}', image {row.photo}."
#)
# Using SampleSet 'insert' method
sets[subject].insert(
index=-1, # Insert at last position
@@ -447,11 +437,10 @@ class Database:
os.path.join(self.directory, path),
),
key=path,
subject=subject,
subject=str(subject),
annotations=self._eyes_annotations(file.split('.')[0]),
)
)
if covariate != None:
sets[subject].cohort = row[covariate_col]
return list(sets.values())
Loading