Commit e90a4429 authored by Tiago Pereira's avatar Tiago Pereira

Added normalized in the Embedding

parent 7b03d324
Pipeline #11217 passed with stages
in 12 minutes and 53 seconds
......@@ -6,6 +6,7 @@
import tensorflow as tf
from bob.learn.tensorflow.utils.session import Session
from bob.learn.tensorflow.datashuffler import Linear
class Embedding(object):
......@@ -19,12 +20,18 @@ class Embedding(object):
graph: Embedding graph
"""
def __init__(self, input, graph):
def __init__(self, input, graph, normalizer=Linear()):
self.input = input
self.graph = graph
self.normalizer = normalizer
def __call__(self, data):
session = Session.instance().session
if self.normalizer is not None:
for i in range(data.shape[0]):
data[i] = self.normalizer(data[i])
feed_dict = {self.input: data}
return session.run([self.graph], feed_dict=feed_dict)[0]
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment