Skip to content
Snippets Groups Projects
Commit 16717689 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Fixing batching issue

parent 77e0ba63
No related branches found
Tags v3.0.0
1 merge request!64Dask pipelines
......@@ -9,6 +9,7 @@ from sklearn.base import TransformerMixin, BaseEstimator
import numpy as np
import logging
from sklearn.utils import check_array
from bob.pipelines.sample import SampleBatch
logger = logging.getLogger(__name__)
......@@ -38,13 +39,13 @@ class TensorflowCompatV1(TransformerMixin, BaseEstimator):
self.architecture_fn = architecture_fn
self.loaded = False
def transform(self, data):
def transform(self, X):
"""
Forward the data with the loaded neural network
Parameters
----------
image : numpy.ndarray
X : numpy.ndarray
Input Data
Returns
......@@ -54,30 +55,36 @@ class TensorflowCompatV1(TransformerMixin, BaseEstimator):
"""
data = check_array(data, allow_nd=True)
def _transform(data):
data = check_array(data, allow_nd=True)
# THE INPUT SHAPE FOR THESE MODELS
# ARE `N x C x H x W`
# THE INPUT SHAPE FOR THESE MODELS
# ARE `N x C x H x W`
# If ndim==3 we add another axis
if data.ndim == 3:
data = data[None, ...]
# If ndim==3 we add another axis
if data.ndim == 3:
data = data[None, ...]
# Making sure it's channels last and has three channels
if data.ndim == 4:
# Just swiping the second dimension if bob format NxCxHxH
if data.shape[1] == 3:
data = np.moveaxis(data, 1, -1)
# Making sure it's channels last and has three channels
if data.ndim == 4:
# Just swiping the second dimension if bob format NxCxHxH
if data.shape[1] == 3:
data = np.moveaxis(data, 1, -1)
if data.shape != self.input_shape:
raise ValueError(
f"Image shape {data.shape} not supported. Expected {self.input_shape}"
)
if data.shape != self.input_shape:
raise ValueError(
f"Image shape {data.shape} not supported. Expected {self.input_shape}"
)
if not self.loaded:
self.load_model()
if not self.loaded:
self.load_model()
return self.session.run(self.embedding, feed_dict={self.input_tensor: data},)
return self.session.run(self.embedding, feed_dict={self.input_tensor: data},)
if isinstance(X, SampleBatch):
return [_transform(x) for x in X]
else:
return _transform(X)
def load_model(self):
import tensorflow as tf
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment