Skip to content
Snippets Groups Projects
Commit 555b7d61 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Fixing the legacy AlgorithmAsBioAlg

parent aad2b4de
No related branches found
No related tags found
1 merge request!180[dask] Preparing bob.bio.base for dask pipelines
......@@ -8,7 +8,12 @@ import functools
from collections import defaultdict
from bob.bio.base import utils
from .abstract_classes import BioAlgorithm, Database, save_scores_four_columns
from .abstract_classes import (
BioAlgorithm,
Database,
create_score_delayed_sample,
make_four_colums_score,
)
from bob.io.base import HDF5File
from bob.pipelines.mixins import SampleMixin, CheckpointMixin
from bob.pipelines.sample import DelayedSample, SampleSet, Sample
......@@ -236,7 +241,7 @@ class _Extractor(_NonPickableWrapper, TransformerMixin, BaseEstimator):
class Extractor(CheckpointMixin, SampleMixin, _Extractor):
def __init__(self, callable, model_path, **kwargs):
def __init__(self, callable, model_path=None, **kwargs):
instance = callable()
transform_extra_arguments = None
......@@ -390,7 +395,11 @@ class AlgorithmAsBioAlg(BioAlgorithm, _NonPickableWrapper):
for ref in [
r for r in biometric_references if r.key in sampleset.references
]:
subprobe_scores.append(Sample(self.score(ref.data, s.data), parent=ref))
score = self.score(ref.data, s.data)
data = make_four_colums_score(
ref.subject, sampleset.subject, sampleset.path, score
)
subprobe_scores.append(Sample(data, parent=ref))
# Creating one sampleset per probe
subprobe = SampleSet(subprobe_scores, parent=sampleset)
......@@ -400,7 +409,7 @@ class AlgorithmAsBioAlg(BioAlgorithm, _NonPickableWrapper):
path = os.path.join(self.score_dir, str(subprobe.path) + ".txt")
os.makedirs(os.path.dirname(path), exist_ok=True)
delayed_scored_sample = save_scores_four_columns(path, subprobe)
delayed_scored_sample = create_score_delayed_sample(path, subprobe)
subprobe.samples = [delayed_scored_sample]
retval.append(subprobe)
......@@ -418,15 +427,13 @@ class AlgorithmAsBioAlg(BioAlgorithm, _NonPickableWrapper):
self.biometric_reference_dir, str(enroll_features.key) + self.extension
)
if path is None or not os.path.isfile(path):
# Enrolling
data = [s.data for s in enroll_features.samples]
model = self.instance.enroll(data)
# Checkpointing
os.makedirs(os.path.dirname(path), exist_ok=True)
hdf5 = HDF5File(path, "w")
self.instance.write_model(model, hdf5)
self.instance.write_model(model, path)
reader = _get_pickable_method(self.instance.read_model)
return DelayedSample(functools.partial(reader, path), parent=enroll_features)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment