Commit e90a4429 authored by Tiago Pereira's avatar Tiago Pereira
Browse files

Added normalized in the Embedding

parent 7b03d324
Pipeline #11217 passed with stages
in 12 minutes and 53 seconds
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
import tensorflow as tf import tensorflow as tf
from bob.learn.tensorflow.utils.session import Session from bob.learn.tensorflow.utils.session import Session
from bob.learn.tensorflow.datashuffler import Linear
class Embedding(object): class Embedding(object):
...@@ -19,12 +20,18 @@ class Embedding(object): ...@@ -19,12 +20,18 @@ class Embedding(object):
graph: Embedding graph graph: Embedding graph
""" """
def __init__(self, input, graph): def __init__(self, input, graph, normalizer=Linear()):
self.input = input self.input = input
self.graph = graph self.graph = graph
self.normalizer = normalizer
def __call__(self, data): def __call__(self, data):
session = Session.instance().session session = Session.instance().session
if self.normalizer is not None:
for i in range(data.shape[0]):
data[i] = self.normalizer(data[i])
feed_dict = {self.input: data} feed_dict = {self.input: data}
return session.run([self.graph], feed_dict=feed_dict)[0] return session.run([self.graph], feed_dict=feed_dict)[0]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment