SoftmaxAnalizer.py 1.51 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Tue 09 Aug 2016 15:33 CEST

"""
Neural net work error rates analizer
"""
import numpy
from tensorflow.core.framework import summary_pb2


13
class SoftmaxAnalizer(object):
14 15 16 17
    """
    Analizer.
    """

18
    def __init__(self):
19
        """
20
        Softmax analizer
21 22 23 24 25 26 27 28 29 30 31 32

        ** Parameters **

          data_shuffler:
          graph:
          session:
          convergence_threshold:
          convergence_reference: References to analize the convergence. Possible values are `eer`, `far10`, `far10`


        """

33 34 35
        self.data_shuffler = None
        self.machine = None
        self.session = None
36

37
    def __call__(self, data_shuffler, machine, session):
38

39 40 41 42
        if self.data_shuffler is None:
            self.data_shuffler = data_shuffler
            self.machine = machine
            self.session = session
43

44 45
        # Creating the graph
        feature_batch, label_batch = self.data_shuffler.get_placeholders(name="validation_accuracy")
46
        data, labels = self.data_shuffler.get_batch()
47
        graph = self.machine.compute_graph(feature_batch)
48

49 50
        predictions = numpy.argmax(self.session.run(graph, feed_dict={feature_batch: data[:]}), 1)
        accuracy = 100. * numpy.sum(predictions == labels) / predictions.shape[0]
51 52

        summaries = []
53 54
        summaries.append(summary_pb2.Summary.Value(tag="accuracy_validation", simple_value=float(accuracy)))
        return summary_pb2.Summary(value=summaries)