Commit 74b600e9 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Cleaning up Extractors

parent 8f62ad63
......@@ -7,12 +7,10 @@ import bob.io.base
import numpy
import math
from sklearn.base import TransformerMixin, BaseEstimator
from sklearn.utils import check_array
from bob.pipelines.sample import SampleBatch
from bob.bio.base.extractor import Extractor
class GridGraph(TransformerMixin, BaseEstimator):
class GridGraph(Extractor):
"""Extracts Gabor jets in a grid structure [GHW12]_ using functionalities from :ref:`bob.ip.gabor <bob.ip.gabor>`.
The grid can be either aligned to the eye locations (in which case the grid might be rotated), or a fixed grid graph can be extracted.
......@@ -73,6 +71,26 @@ class GridGraph(TransformerMixin, BaseEstimator):
first_node=None, # one or two integral values, or None -> automatically determined
):
# call base class constructor
Extractor.__init__(
self,
gabor_directions=gabor_directions,
gabor_scales=gabor_scales,
gabor_sigma=gabor_sigma,
gabor_maximum_frequency=gabor_maximum_frequency,
gabor_frequency_step=gabor_frequency_step,
gabor_power_of_k=gabor_power_of_k,
gabor_dc_free=gabor_dc_free,
normalize_gabor_jets=normalize_gabor_jets,
eyes=eyes,
nodes_between_eyes=nodes_between_eyes,
nodes_along_eyes=nodes_along_eyes,
nodes_above_eyes=nodes_above_eyes,
nodes_below_eyes=nodes_below_eyes,
node_distance=node_distance,
first_node=first_node,
)
self.gabor_directions = gabor_directions
self.gabor_scales = gabor_scales
self.gabor_sigma = gabor_sigma
......@@ -121,18 +139,11 @@ class GridGraph(TransformerMixin, BaseEstimator):
raise ValueError(
"Please specify either 'eyes' or the grid parameters 'node_distance' (and 'first_node')!"
)
if not hasattr(self, "_last_image_resolution"):
self._last_image_resolution = None
if not hasattr(self, "_aligned_graph"):
self._aligned_graph = None
self._aligned_graph = None
self._last_image_resolution = None
if isinstance(self.node_distance, (int, float)):
self.node_distance = (int(self.node_distance), int(self.node_distance))
self._graph = None
def _extractor(self, image):
"""Creates an extractor based on the given image.
If an aligned graph was specified in the constructor, it is simply returned.
......@@ -151,7 +162,7 @@ class GridGraph(TransformerMixin, BaseEstimator):
return self._aligned_graph
# check if a new extractor needs to be created
if self._graph is None or self._last_image_resolution != image.shape:
if self._last_image_resolution != image.shape:
self._last_image_resolution = image.shape
if self.first_node is None:
# automatically compute the first node
......@@ -184,7 +195,7 @@ class GridGraph(TransformerMixin, BaseEstimator):
return self._graph
def transform(self, X):
def __call__(self, image):
"""__call__(image) -> feature
Returns a list of Gabor jets extracted from the given image.
......@@ -201,32 +212,24 @@ class GridGraph(TransformerMixin, BaseEstimator):
The 2D location of the jet's nodes is not returned.
"""
def _extract(image):
import ipdb; ipdb.set_trace()
assert image.ndim == 2
assert isinstance(image, numpy.ndarray)
image = image.astype(numpy.float64)
assert image.dtype == numpy.float64
assert image.ndim == 2
assert isinstance(image, numpy.ndarray)
image = image.astype(numpy.float64)
assert image.dtype == numpy.float64
extractor = self._extractor(image)
extractor = self._extractor(image)
# perform Gabor wavelet transform
self.gwt.transform(image, self.trafo_image)
# extract face graph
jets = extractor.extract(self.trafo_image)
# perform Gabor wavelet transform
self.gwt.transform(image, self.trafo_image)
# extract face graph
jets = extractor.extract(self.trafo_image)
# normalize the Gabor jets of the graph only
if self.normalize_jets:
[j.normalize() for j in jets]
# normalize the Gabor jets of the graph only
if self.normalize_jets:
[j.normalize() for j in jets]
# return the extracted face graph
return self.__class__.serialize_jets(jets)
if isinstance(X, SampleBatch):
return [_extract(x) for x in X]
else:
return _extract(X)
# return the extracted face graph
return self.__class__.serialize_jets(jets)
def write_feature(self, feature, feature_file):
"""Writes the feature extracted by the `__call__` function to the given file.
......@@ -265,6 +268,15 @@ class GridGraph(TransformerMixin, BaseEstimator):
bob.ip.gabor.load_jets(bob.io.base.HDF5File(feature_file))
)
# re-define the train function to get it non-documented
def train(*args, **kwargs):
raise NotImplementedError(
"This function is not implemented and should not be called."
)
def load(*args, **kwargs):
pass
def __getstate__(self):
d = dict(self.__dict__)
d.pop("gwt")
......@@ -285,9 +297,3 @@ class GridGraph(TransformerMixin, BaseEstimator):
sj.jet = jet.jet
serialize_jets.append(sj)
return serialize_jets
def _more_tags(self):
return {"stateless": True, "requires_fit": False}
def fit(self, X, y=None):
return self
......@@ -278,7 +278,7 @@ class LGBPHS(TransformerMixin, BaseEstimator):
if isinstance(X, SampleBatch):
return [_extract(x) for x in X]
else:
return _extract(X)
return _extract(X)
def __getstate__(self):
d = dict(self.__dict__)
......
......@@ -33,6 +33,14 @@ import pkg_resources
regenerate_refs = False
# Cropping
CROPPED_IMAGE_HEIGHT = 80
CROPPED_IMAGE_WIDTH = CROPPED_IMAGE_HEIGHT * 4 // 5
# eye positions for frontal images
RIGHT_EYE_POS = (CROPPED_IMAGE_HEIGHT // 5, CROPPED_IMAGE_WIDTH // 4 - 1)
LEFT_EYE_POS = (CROPPED_IMAGE_HEIGHT // 5, CROPPED_IMAGE_WIDTH // 4 * 3)
def _compare(
data,
......@@ -60,9 +68,10 @@ def _data():
def test_dct_blocks():
# read input
data = _data()
dct = bob.bio.base.load_resource(
"dct-blocks", "extractor", preferred_package="bob.bio.face"
dct = bob.bio.face.extractor.DCTBlocks(
block_size=12, block_overlap=11, number_of_dct_coefficients=45
)
assert isinstance(dct, bob.bio.face.extractor.DCTBlocks)
# generate smaller extractor, using mixed tuple and int input for the block size and overlap
......@@ -81,8 +90,13 @@ def test_dct_blocks():
def test_graphs():
data = _data()
graph = bob.bio.base.load_resource(
"grid-graph", "extractor", preferred_package="bob.bio.face"
graph = bob.bio.face.extractor.GridGraph(
# Gabor parameters
gabor_sigma=math.sqrt(2.0) * math.pi,
# what kind of information to extract
normalize_gabor_jets=True,
# setup of the fixed grid
node_distance=(8, 8),
)
assert isinstance(graph, bob.bio.face.extractor.GridGraph)
......@@ -90,7 +104,7 @@ def test_graphs():
graph = bob.bio.face.extractor.GridGraph(node_distance=24)
# extract features
feature = graph.transform(data)
feature = graph(data)
reference = pkg_resources.resource_filename(
"bob.bio.face.test", "data/graph_regular.hdf5"
......@@ -106,8 +120,9 @@ def test_graphs():
assert all(numpy.allclose(r.jet, f.jet) for r, f in zip(reference, feature))
# get reference face graph extractor
cropper = bob.bio.base.load_resource(
"face-crop-eyes", "preprocessor", preferred_package="bob.bio.face"
cropper = bob.bio.face.preprocessor.FaceCrop(
cropped_image_size=(CROPPED_IMAGE_HEIGHT, CROPPED_IMAGE_WIDTH),
cropped_positions={"leye": LEFT_EYE_POS, "reye": RIGHT_EYE_POS},
)
eyes = cropper.cropped_positions
# generate aligned graph extractor
......@@ -131,10 +146,6 @@ def test_graphs():
def test_lgbphs():
data = _data()
lgbphs = bob.bio.base.load_resource(
"lgbphs", "extractor", preferred_package="bob.bio.face"
)
assert isinstance(lgbphs, bob.bio.face.extractor.LGBPHS)
# in this test, we use a smaller setup of the LGBPHS features
lgbphs = bob.bio.face.extractor.LGBPHS(
......@@ -171,22 +182,3 @@ def test_lgbphs():
"bob.bio.face.test", "data/lgbphs_with_phase.hdf5"
)
_compare(feature, reference)
"""
def test05_sift_key_points(self):
# check if VLSIFT is available
import bob.ip.base
if not hasattr(bob.ip.base, "VLSIFT"):
raise SkipTest("VLSIFT is not part of bob.ip.base; maybe SIFT headers aren't installed in your system?")
# we need the preprocessor tool to actually read the data
preprocessor = facereclib.preprocessing.Keypoints()
data = preprocessor.read_data(self.input_dir('key_points.hdf5'))
# now, we extract features from it
extractor = self.config('sift')
feature = self.execute(extractor, data, 'sift.hdf5', epsilon=1e-4)
self.assertEqual(len(feature.shape), 1)
"""
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment