Commit 16717689 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Fixing batching issue

parent 77e0ba63
......@@ -9,6 +9,7 @@ from sklearn.base import TransformerMixin, BaseEstimator
import numpy as np
import logging
from sklearn.utils import check_array
from bob.pipelines.sample import SampleBatch
logger = logging.getLogger(__name__)
......@@ -38,13 +39,13 @@ class TensorflowCompatV1(TransformerMixin, BaseEstimator):
self.architecture_fn = architecture_fn
self.loaded = False
def transform(self, data):
def transform(self, X):
"""
Forward the data with the loaded neural network
Parameters
----------
image : numpy.ndarray
X : numpy.ndarray
Input Data
Returns
......@@ -54,30 +55,36 @@ class TensorflowCompatV1(TransformerMixin, BaseEstimator):
"""
data = check_array(data, allow_nd=True)
def _transform(data):
data = check_array(data, allow_nd=True)
# THE INPUT SHAPE FOR THESE MODELS
# ARE `N x C x H x W`
# THE INPUT SHAPE FOR THESE MODELS
# ARE `N x C x H x W`
# If ndim==3 we add another axis
if data.ndim == 3:
data = data[None, ...]
# If ndim==3 we add another axis
if data.ndim == 3:
data = data[None, ...]
# Making sure it's channels last and has three channels
if data.ndim == 4:
# Just swiping the second dimension if bob format NxCxHxH
if data.shape[1] == 3:
data = np.moveaxis(data, 1, -1)
# Making sure it's channels last and has three channels
if data.ndim == 4:
# Just swiping the second dimension if bob format NxCxHxH
if data.shape[1] == 3:
data = np.moveaxis(data, 1, -1)
if data.shape != self.input_shape:
raise ValueError(
f"Image shape {data.shape} not supported. Expected {self.input_shape}"
)
if data.shape != self.input_shape:
raise ValueError(
f"Image shape {data.shape} not supported. Expected {self.input_shape}"
)
if not self.loaded:
self.load_model()
if not self.loaded:
self.load_model()
return self.session.run(self.embedding, feed_dict={self.input_tensor: data},)
return self.session.run(self.embedding, feed_dict={self.input_tensor: data},)
if isinstance(X, SampleBatch):
return [_transform(x) for x in X]
else:
return _transform(X)
def load_model(self):
import tensorflow as tf
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment