Commit 8cf2dd55 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira

Serializing parts of gaborjet

parent e9e7bbb7
Pipeline #39233 failed with stage
in 4 minutes and 10 seconds
......@@ -11,6 +11,7 @@ import math
from bob.bio.base.algorithm import Algorithm
from bob.bio.face.extractor import GridGraph
class GaborJet(Algorithm):
"""Computes a comparison of lists of Gabor jets using a similarity function of :py:class:`bob.ip.gabor.Similarity`.
......@@ -72,21 +73,20 @@ class GaborJet(Algorithm):
multiple_probe_scoring=None,
)
self.gabor_jet_similarity_type=gabor_jet_similarity_type
self.multiple_feature_scoring=multiple_feature_scoring
self.gabor_directions=gabor_directions
self.gabor_scales=gabor_scales
self.gabor_sigma=gabor_sigma
self.gabor_maximum_frequency=gabor_maximum_frequency
self.gabor_frequency_step=gabor_frequency_step
self.gabor_power_of_k=gabor_power_of_k
self.gabor_dc_free=gabor_dc_free
self.gabor_jet_similarity_type = gabor_jet_similarity_type
self.multiple_feature_scoring = multiple_feature_scoring
self.gabor_directions = gabor_directions
self.gabor_scales = gabor_scales
self.gabor_sigma = gabor_sigma
self.gabor_maximum_frequency = gabor_maximum_frequency
self.gabor_frequency_step = gabor_frequency_step
self.gabor_power_of_k = gabor_power_of_k
self.gabor_dc_free = gabor_dc_free
self.gabor_jet_similarity_type = gabor_jet_similarity_type
self._init_non_pickables()
def _init_non_pickables(self):
# the Gabor wavelet transform; used by (some of) the Gabor jet similarities
self.gwt = bob.ip.gabor.Transform(
......@@ -128,9 +128,10 @@ class GaborJet(Algorithm):
}[self.multiple_feature_scoring]
def _check_feature(self, feature):
# import ipdb; ipdb.set_trace()
assert isinstance(feature, list) or isinstance(feature, numpy.ndarray)
assert len(feature)
feature = GridGraph.serialize_jets(feature)
assert all(isinstance(f, bob.ip.gabor.Jet) for f in feature)
def enroll(self, enroll_features):
......@@ -215,7 +216,7 @@ class GaborJet(Algorithm):
model = []
for g in range(count):
name = "Node-" + str(g + 1)
f.cd(name)
f.cd(name)
model.append(GridGraph.serialize_jets(bob.ip.gabor.load_jets(f)))
f.cd("..")
return model
......@@ -354,4 +355,3 @@ class GaborJet(Algorithm):
def __setstate__(self, d):
self.__dict__ = d
self._init_non_pickables()
......@@ -140,7 +140,7 @@ class GridGraph(Extractor):
"Please specify either 'eyes' or the grid parameters 'node_distance' (and 'first_node')!"
)
self._aligned_graph = None
self._last_image_resolution = None
self._last_image_resolution = None
if isinstance(self.node_distance, (int, float)):
self.node_distance = (int(self.node_distance), int(self.node_distance))
......@@ -229,7 +229,7 @@ class GridGraph(Extractor):
[j.normalize() for j in jets]
# return the extracted face graph
return jets
return self.__class__.serialize_jets(jets)
def write_feature(self, feature, feature_file):
"""Writes the feature extracted by the `__call__` function to the given file.
......@@ -264,7 +264,9 @@ class GridGraph(Extractor):
feature : [:py:class:`bob.ip.gabor.Jet`]
The list of Gabor jets read from file.
"""
return self.__class__.serialize_jets(bob.ip.gabor.load_jets(bob.io.base.HDF5File(feature_file)))
return self.__class__.serialize_jets(
bob.ip.gabor.load_jets(bob.io.base.HDF5File(feature_file))
)
# re-define the train function to get it non-documented
def train(*args, **kwargs):
......@@ -277,7 +279,7 @@ class GridGraph(Extractor):
def __getstate__(self):
d = dict(self.__dict__)
d.pop("gwt")
d.pop("gwt")
d.pop("_aligned_graph")
if "_graph" in d:
d.pop("_graph")
......@@ -287,12 +289,11 @@ class GridGraph(Extractor):
self.__dict__ = d
self._init_non_pickables()
@staticmethod
def serialize_jets(jets):
serialize_jets = []
serialize_jets = []
for jet in jets:
sj = bob.ip.gabor.Jet(jet.length)
sj.jet = jet.jet
serialize_jets.append(sj)
serialize_jets.append(sj)
return serialize_jets
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment