diff --git a/bob/bio/face/algorithm/GaborJet.py b/bob/bio/face/algorithm/GaborJet.py index a878819f1f53f78cc60a8ba48bedde34a6a1dd59..54879038e63d6171a80bf53887958191d87f1463 100644 --- a/bob/bio/face/algorithm/GaborJet.py +++ b/bob/bio/face/algorithm/GaborJet.py @@ -11,6 +11,7 @@ import math from bob.bio.base.algorithm import Algorithm from bob.bio.face.extractor import GridGraph + class GaborJet(Algorithm): """Computes a comparison of lists of Gabor jets using a similarity function of :py:class:`bob.ip.gabor.Similarity`. @@ -72,21 +73,20 @@ class GaborJet(Algorithm): multiple_probe_scoring=None, ) - self.gabor_jet_similarity_type=gabor_jet_similarity_type - self.multiple_feature_scoring=multiple_feature_scoring - self.gabor_directions=gabor_directions - self.gabor_scales=gabor_scales - self.gabor_sigma=gabor_sigma - self.gabor_maximum_frequency=gabor_maximum_frequency - self.gabor_frequency_step=gabor_frequency_step - self.gabor_power_of_k=gabor_power_of_k - self.gabor_dc_free=gabor_dc_free + self.gabor_jet_similarity_type = gabor_jet_similarity_type + self.multiple_feature_scoring = multiple_feature_scoring + self.gabor_directions = gabor_directions + self.gabor_scales = gabor_scales + self.gabor_sigma = gabor_sigma + self.gabor_maximum_frequency = gabor_maximum_frequency + self.gabor_frequency_step = gabor_frequency_step + self.gabor_power_of_k = gabor_power_of_k + self.gabor_dc_free = gabor_dc_free self.gabor_jet_similarity_type = gabor_jet_similarity_type self._init_non_pickables() - def _init_non_pickables(self): # the Gabor wavelet transform; used by (some of) the Gabor jet similarities self.gwt = bob.ip.gabor.Transform( @@ -128,9 +128,10 @@ class GaborJet(Algorithm): }[self.multiple_feature_scoring] def _check_feature(self, feature): - # import ipdb; ipdb.set_trace() + assert isinstance(feature, list) or isinstance(feature, numpy.ndarray) assert len(feature) + feature = GridGraph.serialize_jets(feature) assert all(isinstance(f, bob.ip.gabor.Jet) for f in feature) def enroll(self, enroll_features): @@ -215,7 +216,7 @@ class GaborJet(Algorithm): model = [] for g in range(count): name = "Node-" + str(g + 1) - f.cd(name) + f.cd(name) model.append(GridGraph.serialize_jets(bob.ip.gabor.load_jets(f))) f.cd("..") return model @@ -354,4 +355,3 @@ class GaborJet(Algorithm): def __setstate__(self, d): self.__dict__ = d self._init_non_pickables() - diff --git a/bob/bio/face/extractor/GridGraph.py b/bob/bio/face/extractor/GridGraph.py index 89e8e5dc14b56cc9fc55b97ff8bf5b92c9df4118..a3b6e24807f60a9cfed7ed9e847a925f7950be31 100644 --- a/bob/bio/face/extractor/GridGraph.py +++ b/bob/bio/face/extractor/GridGraph.py @@ -140,7 +140,7 @@ class GridGraph(Extractor): "Please specify either 'eyes' or the grid parameters 'node_distance' (and 'first_node')!" ) self._aligned_graph = None - self._last_image_resolution = None + self._last_image_resolution = None if isinstance(self.node_distance, (int, float)): self.node_distance = (int(self.node_distance), int(self.node_distance)) @@ -229,7 +229,7 @@ class GridGraph(Extractor): [j.normalize() for j in jets] # return the extracted face graph - return jets + return self.__class__.serialize_jets(jets) def write_feature(self, feature, feature_file): """Writes the feature extracted by the `__call__` function to the given file. @@ -264,7 +264,9 @@ class GridGraph(Extractor): feature : [:py:class:`bob.ip.gabor.Jet`] The list of Gabor jets read from file. """ - return self.__class__.serialize_jets(bob.ip.gabor.load_jets(bob.io.base.HDF5File(feature_file))) + return self.__class__.serialize_jets( + bob.ip.gabor.load_jets(bob.io.base.HDF5File(feature_file)) + ) # re-define the train function to get it non-documented def train(*args, **kwargs): @@ -277,7 +279,7 @@ class GridGraph(Extractor): def __getstate__(self): d = dict(self.__dict__) - d.pop("gwt") + d.pop("gwt") d.pop("_aligned_graph") if "_graph" in d: d.pop("_graph") @@ -287,12 +289,11 @@ class GridGraph(Extractor): self.__dict__ = d self._init_non_pickables() - @staticmethod def serialize_jets(jets): - serialize_jets = [] + serialize_jets = [] for jet in jets: sj = bob.ip.gabor.Jet(jet.length) sj.jet = jet.jet - serialize_jets.append(sj) + serialize_jets.append(sj) return serialize_jets