Embedding.py 873 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Wed 11 May 2016 09:39:36 CEST 


import tensorflow as tf
from bob.learn.tensorflow.utils.session import Session


class Embedding(object):
    """
Tiago Pereira's avatar
Tiago Pereira committed
13
14
15
16
17
18
19
20
    Embedding abstraction
    
    **Parameters**
    
      input: Input placeholder
      
      graph: Embedding graph
    
21
    """
22
23
    def __init__(self, inputs, graph, normalizer=None):
        self.inputs = inputs
24
        self.graph = graph
25
        self.normalizer = normalizer
26
27
28

    def __call__(self, data):
        session = Session.instance().session
29
30
31
32
33

        if self.normalizer is not None:
            for i in range(data.shape[0]):
                data[i] = self.normalizer(data[i])

34
        feed_dict = {self.inputs: data}
35
36

        return session.run([self.graph], feed_dict=feed_dict)[0]