Skip to content
Snippets Groups Projects
Commit a327a1b6 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Merge branch 'checkpoint' into 'master'

Score checkpoint are more robust

See merge request !245
parents 84dbb997 f775ae4a
Branches
Tags
1 merge request!245Score checkpoint are more robust
Pipeline #49512 passed
......@@ -8,6 +8,7 @@ from . import pipelines
from . import script
from . import test
from . import score
def get_config():
......
from .pipelines import VanillaBiometricsPipeline
import pickle
import gzip
import os
def pickle_compress(path, obj, attempts=5):
"""
Pickle an object, compressed it and save it
Parameters
----------
path: str
Path where to save the object
obj:
Object to be saved
attempts: Serialization attempts
"""
for i in range(attempts):
try:
os.makedirs(os.path.dirname(path), exist_ok=True)
# Trying to get writting right
# This might fail in our file system
with gzip.open(path, "wb") as f:
f.write(pickle.dumps(obj))
# Testing unpression
uncompress_unpickle(path)
break
except:
continue
else:
# If it fails in the 5 attemps
raise EOFError(f"Failed to serialize/desserialize {path}")
def uncompress_unpickle(path):
with gzip.open(path, "rb") as f:
return pickle.loads(f.read())
from .biometric_algorithms import Distance
from .score_writers import FourColumnsScoreWriter, CSVScoreWriter
from .wrappers import (
......
......@@ -14,7 +14,6 @@ from .abstract_classes import BioAlgorithm
import bob.pipelines
import numpy as np
import h5py
import cloudpickle
from .zt_norm import ZTNormPipeline, ZTNormDaskWrapper
from .legacy import BioAlgorithmLegacy
from bob.bio.base.transformers import (
......@@ -24,10 +23,11 @@ from bob.bio.base.transformers import (
)
from bob.pipelines.wrappers import SampleWrapper, CheckpointWrapper
from bob.pipelines.distributed.sge import SGEMultipleQueuesCluster
import joblib
import logging
from bob.pipelines.utils import isinstance_nested
import gc
import time
from . import pickle_compress, uncompress_unpickle
logger = logging.getLogger(__name__)
......@@ -69,7 +69,7 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm):
self.biometric_algorithm = biometric_algorithm
self.force = force
self._biometric_reference_extension = ".hdf5"
self._score_extension = ".joblib"
self._score_extension = ".pickle.gz"
def clear_caches(self):
self.biometric_algorithm.clear_caches()
......@@ -101,18 +101,7 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm):
return bob.io.base.save(sample.data, path, create_directories=True)
def write_scores(self, samples, path):
os.makedirs(os.path.dirname(path), exist_ok=True)
gc.collect()
joblib.dump(samples, path, compress=4)
# cleaning parent
# with open(path, "wb") as f:
# f.write(cloudpickle.dumps(samples))
# f.flush()
# from bob.pipelines.sample import sample_to_hdf5
# with h5py.File(path, "w") as hdf5:
# sample_to_hdf5(samples, hdf5)
pickle_compress(path, samples)
def _enroll_sample_set(self, sampleset):
"""
......@@ -148,14 +137,7 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm):
"""
def _load(path):
gc.collect()
return joblib.load(path)
# return cloudpickle.loads(open(path, "rb").read())
# from bob.pipelines.sample import hdf5_to_sample
# with h5py.File(path) as hdf5:
# return hdf5_to_sample(hdf5)
return uncompress_unpickle(path)
def _make_name(sampleset, biometric_references):
# The score file name is composed by sampleset key and the
......
......@@ -13,6 +13,8 @@ from bob.pipelines import (
DelayedSampleSet,
DelayedSampleSetCached,
)
import numpy as np
import dask
import functools
......@@ -20,9 +22,9 @@ import cloudpickle
import os
from .score_writers import FourColumnsScoreWriter
import copy
import joblib
import logging
from .pipelines import check_valid_pipeline
from . import pickle_compress, uncompress_unpickle
logger = logging.getLogger(__name__)
......@@ -596,16 +598,13 @@ class ZTNormCheckpointWrapper(object):
self.force = force
self.base_dir = base_dir
self._score_extension = ".joblib"
self._score_extension = ".pickle.gz"
def write_scores(self, samples, path):
os.makedirs(os.path.dirname(path), exist_ok=True)
# open(path, "wb").write(cloudpickle.dumps(samples))
joblib.dump(samples, path, compress=4)
pickle_compress(path, samples)
def _load(self, path):
# return cloudpickle.loads(open(path, "rb").read())
return joblib.load(path)
return uncompress_unpickle(path)
def _make_name(self, sampleset, biometric_references, for_zt=False):
# The score file name is composed by sampleset key and the
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment