Skip to content
Snippets Groups Projects
Commit 1c49a056 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Harmonized opencv interface

parent f453965e
Branches
Tags
1 merge request!112Feature extractors
...@@ -13,7 +13,7 @@ from bob.extension.download import get_file ...@@ -13,7 +13,7 @@ from bob.extension.download import get_file
class OpenCVTransformer(TransformerMixin, BaseEstimator): class OpenCVTransformer(TransformerMixin, BaseEstimator):
""" """
Base Transformer using the OpenCV interface. Base Transformer using the OpenCV DNN interface (https://docs.opencv.org/master/d2/d58/tutorial_table_of_content_dnn.html).
.. note:: .. note::
...@@ -28,13 +28,24 @@ class OpenCVTransformer(TransformerMixin, BaseEstimator): ...@@ -28,13 +28,24 @@ class OpenCVTransformer(TransformerMixin, BaseEstimator):
config: config:
Path containing some configuration file (e.g. .json, .prototxt) Path containing some configuration file (e.g. .json, .prototxt)
preprocessor:
A function that will transform the data right before forward. The default transformation is `X/255`
""" """
def __init__(self, checkpoint_path=None, config=None, **kwargs): def __init__(
self,
checkpoint_path=None,
config=None,
preprocessor=lambda x: x / 255,
**kwargs,
):
super().__init__(**kwargs) super().__init__(**kwargs)
self.checkpoint_path = checkpoint_path self.checkpoint_path = checkpoint_path
self.config = config self.config = config
self.model = None self.model = None
self.preprocessor = preprocessor
def _load_model(self): def _load_model(self):
import cv2 import cv2
...@@ -63,14 +74,11 @@ class OpenCVTransformer(TransformerMixin, BaseEstimator): ...@@ -63,14 +74,11 @@ class OpenCVTransformer(TransformerMixin, BaseEstimator):
if self.model is None: if self.model is None:
self._load_model() self._load_model()
import ipdb X = check_array(X, allow_nd=True)
ipdb.set_trace()
img = np.array(X) X = self.preprocessor(X)
img = img / 255
self.model.setInput(img) self.model.setInput(X)
return self.model.forward() return self.model.forward()
...@@ -108,10 +116,30 @@ class VGG16_Oxford(OpenCVTransformer): ...@@ -108,10 +116,30 @@ class VGG16_Oxford(OpenCVTransformer):
config = os.path.join(path, "vgg_face_caffe", "VGG_FACE_deploy.prototxt") config = os.path.join(path, "vgg_face_caffe", "VGG_FACE_deploy.prototxt")
checkpoint_path = os.path.join(path, "vgg_face_caffe", "VGG_FACE.caffemodel") checkpoint_path = os.path.join(path, "vgg_face_caffe", "VGG_FACE.caffemodel")
super(VGG16_Oxford, self).__init__(checkpoint_path, config) caffe_average_img = [129.1863, 104.7624, 93.5940]
def preprocessor(X):
"""
Normalize using data from caffe
Caffe has the shape `C x H x W` and the chanel is BGR and
"""
# To BGR
X = X[:, ::-1, :, :].astype("float32")
# Subtracting
X[:, :, :, 0] -= caffe_average_img[0]
X[:, :, :, 1] -= caffe_average_img[1]
X[:, :, :, 2] -= caffe_average_img[2]
return X
super(VGG16_Oxford, self).__init__(checkpoint_path, config, preprocessor)
def _load_model(self): def _load_model(self):
import cv2 import cv2
net = cv2.dnn.readNet(self.checkpoint_path, self.config) net = cv2.dnn.readNet(self.checkpoint_path, self.config)
self.model = net self.model = net
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment