Commit 1c49a056 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Harmonized opencv interface

parent f453965e
......@@ -13,7 +13,7 @@ from bob.extension.download import get_file
class OpenCVTransformer(TransformerMixin, BaseEstimator):
"""
Base Transformer using the OpenCV interface.
Base Transformer using the OpenCV DNN interface (https://docs.opencv.org/master/d2/d58/tutorial_table_of_content_dnn.html).
.. note::
......@@ -28,13 +28,24 @@ class OpenCVTransformer(TransformerMixin, BaseEstimator):
config:
Path containing some configuration file (e.g. .json, .prototxt)
preprocessor:
A function that will transform the data right before forward. The default transformation is `X/255`
"""
def __init__(self, checkpoint_path=None, config=None, **kwargs):
def __init__(
self,
checkpoint_path=None,
config=None,
preprocessor=lambda x: x / 255,
**kwargs,
):
super().__init__(**kwargs)
self.checkpoint_path = checkpoint_path
self.config = config
self.model = None
self.preprocessor = preprocessor
def _load_model(self):
import cv2
......@@ -63,14 +74,11 @@ class OpenCVTransformer(TransformerMixin, BaseEstimator):
if self.model is None:
self._load_model()
import ipdb
ipdb.set_trace()
X = check_array(X, allow_nd=True)
img = np.array(X)
img = img / 255
X = self.preprocessor(X)
self.model.setInput(img)
self.model.setInput(X)
return self.model.forward()
......@@ -108,10 +116,30 @@ class VGG16_Oxford(OpenCVTransformer):
config = os.path.join(path, "vgg_face_caffe", "VGG_FACE_deploy.prototxt")
checkpoint_path = os.path.join(path, "vgg_face_caffe", "VGG_FACE.caffemodel")
super(VGG16_Oxford, self).__init__(checkpoint_path, config)
caffe_average_img = [129.1863, 104.7624, 93.5940]
def preprocessor(X):
"""
Normalize using data from caffe
Caffe has the shape `C x H x W` and the chanel is BGR and
"""
# To BGR
X = X[:, ::-1, :, :].astype("float32")
# Subtracting
X[:, :, :, 0] -= caffe_average_img[0]
X[:, :, :, 1] -= caffe_average_img[1]
X[:, :, :, 2] -= caffe_average_img[2]
return X
super(VGG16_Oxford, self).__init__(checkpoint_path, config, preprocessor)
def _load_model(self):
import cv2
net = cv2.dnn.readNet(self.checkpoint_path, self.config)
self.model = net
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment