Skip to content
Snippets Groups Projects
Commit 1c49a056 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Harmonized opencv interface

parent f453965e
No related branches found
No related tags found
1 merge request!112Feature extractors
...@@ -13,7 +13,7 @@ from bob.extension.download import get_file ...@@ -13,7 +13,7 @@ from bob.extension.download import get_file
class OpenCVTransformer(TransformerMixin, BaseEstimator): class OpenCVTransformer(TransformerMixin, BaseEstimator):
""" """
Base Transformer using the OpenCV interface. Base Transformer using the OpenCV DNN interface (https://docs.opencv.org/master/d2/d58/tutorial_table_of_content_dnn.html).
.. note:: .. note::
...@@ -28,13 +28,24 @@ class OpenCVTransformer(TransformerMixin, BaseEstimator): ...@@ -28,13 +28,24 @@ class OpenCVTransformer(TransformerMixin, BaseEstimator):
config: config:
Path containing some configuration file (e.g. .json, .prototxt) Path containing some configuration file (e.g. .json, .prototxt)
preprocessor:
A function that will transform the data right before forward. The default transformation is `X/255`
""" """
def __init__(self, checkpoint_path=None, config=None, **kwargs): def __init__(
self,
checkpoint_path=None,
config=None,
preprocessor=lambda x: x / 255,
**kwargs,
):
super().__init__(**kwargs) super().__init__(**kwargs)
self.checkpoint_path = checkpoint_path self.checkpoint_path = checkpoint_path
self.config = config self.config = config
self.model = None self.model = None
self.preprocessor = preprocessor
def _load_model(self): def _load_model(self):
import cv2 import cv2
...@@ -63,14 +74,11 @@ class OpenCVTransformer(TransformerMixin, BaseEstimator): ...@@ -63,14 +74,11 @@ class OpenCVTransformer(TransformerMixin, BaseEstimator):
if self.model is None: if self.model is None:
self._load_model() self._load_model()
import ipdb X = check_array(X, allow_nd=True)
ipdb.set_trace()
img = np.array(X) X = self.preprocessor(X)
img = img / 255
self.model.setInput(img) self.model.setInput(X)
return self.model.forward() return self.model.forward()
...@@ -108,10 +116,30 @@ class VGG16_Oxford(OpenCVTransformer): ...@@ -108,10 +116,30 @@ class VGG16_Oxford(OpenCVTransformer):
config = os.path.join(path, "vgg_face_caffe", "VGG_FACE_deploy.prototxt") config = os.path.join(path, "vgg_face_caffe", "VGG_FACE_deploy.prototxt")
checkpoint_path = os.path.join(path, "vgg_face_caffe", "VGG_FACE.caffemodel") checkpoint_path = os.path.join(path, "vgg_face_caffe", "VGG_FACE.caffemodel")
super(VGG16_Oxford, self).__init__(checkpoint_path, config) caffe_average_img = [129.1863, 104.7624, 93.5940]
def preprocessor(X):
"""
Normalize using data from caffe
Caffe has the shape `C x H x W` and the chanel is BGR and
"""
# To BGR
X = X[:, ::-1, :, :].astype("float32")
# Subtracting
X[:, :, :, 0] -= caffe_average_img[0]
X[:, :, :, 1] -= caffe_average_img[1]
X[:, :, :, 2] -= caffe_average_img[2]
return X
super(VGG16_Oxford, self).__init__(checkpoint_path, config, preprocessor)
def _load_model(self): def _load_model(self):
import cv2 import cv2
net = cv2.dnn.readNet(self.checkpoint_path, self.config) net = cv2.dnn.readNet(self.checkpoint_path, self.config)
self.model = net self.model = net
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment