Skip to content
Snippets Groups Projects
Commit e90a4429 authored by Tiago Pereira's avatar Tiago Pereira
Browse files

Added normalized in the Embedding

parent 7b03d324
Branches
Tags
No related merge requests found
Pipeline #
......@@ -6,6 +6,7 @@
import tensorflow as tf
from bob.learn.tensorflow.utils.session import Session
from bob.learn.tensorflow.datashuffler import Linear
class Embedding(object):
......@@ -19,12 +20,18 @@ class Embedding(object):
graph: Embedding graph
"""
def __init__(self, input, graph):
def __init__(self, input, graph, normalizer=Linear()):
self.input = input
self.graph = graph
self.normalizer = normalizer
def __call__(self, data):
session = Session.instance().session
if self.normalizer is not None:
for i in range(data.shape[0]):
data[i] = self.normalizer(data[i])
feed_dict = {self.input: data}
return session.run([self.graph], feed_dict=feed_dict)[0]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment