diff --git a/bob/bio/base/pipelines/vanilla_biometrics/legacy.py b/bob/bio/base/pipelines/vanilla_biometrics/legacy.py index 2979f706fe46434a2cc95f34d831f2639bf29516..39385c14b82686792996c3d67b6fee7c0fa52ca2 100644 --- a/bob/bio/base/pipelines/vanilla_biometrics/legacy.py +++ b/bob/bio/base/pipelines/vanilla_biometrics/legacy.py @@ -13,7 +13,7 @@ from .abstract_classes import ( Database, ) from bob.io.base import HDF5File -from bob.pipelines import DelayedSample, SampleSet, Sample +from bob.pipelines import DelayedSample, SampleSet, Sample, DelayedSampleSet import logging import copy import joblib @@ -328,10 +328,10 @@ class BioAlgorithmLegacy(BioAlgorithm): ) self.write_scores(scored_sample_set.samples, path) - scored_sample_set = SampleSet( - DelayedSample(functools.partial(_load, path), parent=sampleset), - parent=sampleset, - ) + + scored_sample_set = DelayedSampleSet(functools.partial(_load, path), + parent=scored_sample_set) + else: scored_sample_set = SampleSet(_load(path), parent=sampleset) diff --git a/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py b/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py index 812156274f835e985a0048e52f4428b2a7093318..0a8f2a844f2f1fd952b22175f5d33c8f55f05594 100644 --- a/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py +++ b/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py @@ -1,4 +1,4 @@ -from bob.pipelines import DelayedSample, SampleSet, Sample +from bob.pipelines import DelayedSample, SampleSet, Sample, DelayedSampleSet import bob.io.base import os import dask @@ -20,8 +20,12 @@ from bob.pipelines.wrappers import SampleWrapper from bob.pipelines.distributed.sge import SGEMultipleQueuesCluster import joblib import logging +from bob.pipelines.utils import isinstance_nested +import gc + logger = logging.getLogger(__name__) + class BioAlgorithmCheckpointWrapper(BioAlgorithm): """Wrapper used to checkpoint enrolled and Scoring samples. @@ -89,16 +93,16 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm): def write_scores(self, samples, path): os.makedirs(os.path.dirname(path), exist_ok=True) - + gc.collect() joblib.dump(samples, path, compress=4) # cleaning parent - #with open(path, "wb") as f: + # with open(path, "wb") as f: # f.write(cloudpickle.dumps(samples)) # f.flush() - #from bob.pipelines.sample import sample_to_hdf5 - #with h5py.File(path, "w") as hdf5: + # from bob.pipelines.sample import sample_to_hdf5 + # with h5py.File(path, "w") as hdf5: # sample_to_hdf5(samples, hdf5) def _enroll_sample_set(self, sampleset): @@ -135,13 +139,13 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm): """ def _load(path): - + gc.collect() return joblib.load(path) - #return cloudpickle.loads(open(path, "rb").read()) - - #from bob.pipelines.sample import hdf5_to_sample - #with h5py.File(path) as hdf5: + # return cloudpickle.loads(open(path, "rb").read()) + + # from bob.pipelines.sample import hdf5_to_sample + # with h5py.File(path) as hdf5: # return hdf5_to_sample(hdf5) def _make_name(sampleset, biometric_references): @@ -167,10 +171,10 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm): ) self.write_scores(scored_sample_set.samples, path) - scored_sample_set = SampleSet( - DelayedSample(functools.partial(_load, path), parent=sampleset), - parent=sampleset, + scored_sample_set = DelayedSampleSet( + functools.partial(_load, path), parent=scored_sample_set ) + else: samples = _load(path) scored_sample_set = SampleSet(samples, parent=sampleset) @@ -191,6 +195,7 @@ class BioAlgorithmDaskWrapper(BioAlgorithm): biometric_references = biometric_reference_features.map_partitions( self.biometric_algorithm.enroll_samples ) + return biometric_references def score_samples( @@ -207,7 +212,7 @@ class BioAlgorithmDaskWrapper(BioAlgorithm): # the disk, directly. A second option would be to generate named delays # for each model and then associate them here. - all_references = dask.delayed(list)(biometric_references) + all_references = dask.delayed(list)(biometric_references) scores = probe_features.map_partitions( self.biometric_algorithm.score_samples, all_references, @@ -251,7 +256,9 @@ def dask_vanilla_biometrics(pipeline, npartitions=None, partition_size=None): if isinstance(pipeline, ZTNormPipeline): # Dasking the first part of the pipelines pipeline.vanilla_biometrics_pipeline = dask_vanilla_biometrics( - pipeline.vanilla_biometrics_pipeline, npartitions=npartitions, partition_size=partition_size + pipeline.vanilla_biometrics_pipeline, + npartitions=npartitions, + partition_size=partition_size, ) pipeline.biometric_algorithm = ( pipeline.vanilla_biometrics_pipeline.biometric_algorithm @@ -303,7 +310,7 @@ def dask_get_partition_size(cluster, n_objects): return None max_jobs = cluster.sge_job_spec["default"]["max_jobs"] - return n_objects // (max_jobs*2) if n_objects > max_jobs else 1 + return n_objects // (max_jobs * 2) if n_objects > max_jobs else 1 def checkpoint_vanilla_biometrics(pipeline, base_dir, biometric_algorithm_dir=None): @@ -375,3 +382,21 @@ def checkpoint_vanilla_biometrics(pipeline, base_dir, biometric_algorithm_dir=No ) return pipeline + + +def is_checkpointed(pipeline): + """ + Check if :any:`VanillaBiometrics` is checkpointed + + + Parameters + ---------- + + pipeline: :any:`VanillaBiometrics` + Vanilla Biometrics based pipeline to be checkpointed + + """ + + return isinstance_nested( + pipeline, "biometric_algorithm", BioAlgorithmCheckpointWrapper + ) or isinstance_nested(pipeline, "biometric_algorithm", BioAlgorithmLegacy)