Skip to content
Snippets Groups Projects
Commit 3a5d5b89 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Merge branch '56-it-is-not-possible-to-use-plda-in-the-current-execution-chain' into 'master'

Correctly reorganized the input for the PLDA

Closes #56 

I'll give 24h for objections before merge this one :-P

See merge request !55
parents 655b88cf 824b0b6c
No related branches found
No related tags found
1 merge request!55Correctly reorganized the input for the PLDA
Pipeline #
......@@ -73,7 +73,7 @@ class PLDA (Algorithm):
def _train_pca(self, training_set):
"""Trains and returns a LinearMachine that is trained using PCA"""
data = numpy.vstack([feature for client in training_set for feature in client])
data = numpy.vstack([feature for feature in training_set])
logger.info(" -> Training LinearMachine using PCA ")
trainer = bob.learn.linear.PCATrainer()
......@@ -92,20 +92,30 @@ class PLDA (Algorithm):
machine.resize(machine.shape[0], self.subspace_dimension_pca)
return machine
def _perform_pca_client(self, client):
"""Perform PCA on an array"""
return numpy.vstack([self.pca_machine(feature) for feature in client])
def _perform_pca(self, training_set):
"""Perform PCA on data"""
return [self._perform_pca_client(client) for client in training_set]
return [self.pca_machine(client) for client in training_set]
def _arrange_data(self, training_files):
"""Arranges the data to train the PLDA """
data = []
for client_files in training_files:
# at least two files per client are required!
if len(client_files) < 2:
logger.warn("Skipping one client since the number of client files is only %d", len(client_files))
continue
data.append(numpy.vstack([feature.flatten() for feature in client_files]))
# Returns the list of lists of arrays
return data
def train_enroller(self, training_features, projector_file):
"""Generates the PLDA base model from a list of arrays (one per identity),
and a set of training parameters. If PCA is requested, it is trained on the same data.
Both the trained PLDABase and the PCA machine are written."""
# arrange PLDA training data
training_features = self._arrange_data(training_features)
# train PCA and perform PCA on training data
if self.subspace_dimension_pca is not None:
......@@ -113,6 +123,7 @@ class PLDA (Algorithm):
training_features = self._perform_pca(training_features)
input_dimension = training_features[0].shape[1]
logger.info(" -> Training PLDA base machine")
# train machine
......@@ -146,7 +157,7 @@ class PLDA (Algorithm):
plda_machine = bob.learn.em.PLDAMachine(self.plda_base)
# project features, if enabled
if self.pca_machine is not None:
enroll_features = self._perform_pca_client(enroll_features)
enroll_features = self._perform_pca(enroll_features)
# enroll
self.plda_trainer.enroll(plda_machine, enroll_features)
return plda_machine
......
File added
File added
......@@ -358,3 +358,42 @@ def test_plda():
reference_score = 0.
assert abs(plda1.score(model, feature) - reference_score) < 1e-5, "The scores differ: %3.8f, %3.8f" % (plda1.score(model, feature), reference_score)
assert abs(plda1.score_for_multiple_probes(model, [feature, feature]) - reference_score) < 1e-5
def test_plda_nopca():
temp_file = bob.io.base.test_utils.temporary_filename()
plda_ref = bob.bio.base.load_resource("plda", "algorithm", preferred_package = 'bob.bio.base')
reference_file = pkg_resources.resource_filename('bob.bio.base.test', 'data/plda_nopca_enroller.hdf5')
plda_ref.load_enroller(reference_file)
# generate a smaller PCA subspcae
plda = bob.bio.base.algorithm.PLDA(subspace_dimension_of_f = 2, subspace_dimension_of_g = 2, plda_training_iterations = 1, INIT_SEED = seed_value)
# create random training set
train_set = utils.random_training_set_by_id(200, count=20, minimum=0., maximum=255.)
# train the projector
try:
# train projector
plda.train_enroller(train_set, temp_file)
assert os.path.exists(temp_file)
if regenerate_refs: shutil.copy(temp_file, reference_file)
# check projection matrix
assert plda.plda_base.is_similar_to(plda_ref.plda_base)
finally:
if os.path.exists(temp_file): os.remove(temp_file)
# generate and project random feature
feature = utils.random_array(200, 0., 255., seed=84)
# enroll model from random features
reference = pkg_resources.resource_filename('bob.bio.base.test', 'data/plda_nopca_model.hdf5')
model = plda.enroll([feature])
# execute the preprocessor
if regenerate_refs:
plda.write_model(model, reference)
reference = plda.read_model(reference)
assert model.is_similar_to(reference)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment