SoftmaxAnalizer.py 2.94 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Tue 09 Aug 2016 15:33 CEST

"""
Neural net work error rates analizer
"""
import numpy
from tensorflow.core.framework import summary_pb2


13
class SoftmaxAnalizer(object):
14 15 16 17
    """
    Analizer.
    """

18
    def __init__(self):
19
        """
20
        Softmax analizer
21 22 23 24 25 26 27 28 29 30 31 32

        ** Parameters **

          data_shuffler:
          graph:
          session:
          convergence_threshold:
          convergence_reference: References to analize the convergence. Possible values are `eer`, `far10`, `far10`


        """

33
        self.data_shuffler = None
34
        self.trainer = None
35
        self.session = None
36

37
    def __call__(self, data_shuffler, trainer, session):
38

39 40
        if self.data_shuffler is None:
            self.data_shuffler = data_shuffler
41
            self.trainer = trainer
42
            self.session = session
43

44 45
        # Creating the graph
        feature_batch, label_batch = self.data_shuffler.get_placeholders(name="validation_accuracy")
46
        data, labels = self.data_shuffler.get_batch()
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
        graph = self.trainer.architecture.compute_graph(feature_batch)

        predictions = numpy.argmax(self.session.run(graph, feed_dict={feature_batch: data[:]}), 1)
        accuracy = 100. * numpy.sum(predictions == labels) / predictions.shape[0]

        summaries = []
        summaries.append(summary_pb2.Summary.Value(tag="accuracy_validation", simple_value=float(accuracy)))
        return summary_pb2.Summary(value=summaries)


class SoftmaxSiameseAnalizer(object):
    """
    Analizer.
    """

    def __init__(self):
        """
        Softmax analizer

        ** Parameters **

          data_shuffler:
          graph:
          session:
          convergence_threshold:
          convergence_reference: References to analize the convergence. Possible values are `eer`, `far10`, `far10`


        """

        self.data_shuffler = None
        self.trainer = None
        self.session = None

    def __call__(self, data_shuffler, machine, session):

        if self.data_shuffler is None:
            self.data_shuffler = data_shuffler
            self.trainer = trainer
            self.session = session

        # Creating the graph
        #feature_batch, label_batch = self.data_shuffler.get_placeholders(name="validation_accuracy")
        feature_left_batch, feature_right_batch label_batch = self.data_shuffler.get_placeholders_pair(name="validation_accuracy")

        batch_left, batch_right, labels = self.data_shuffler.get_batch()
        left = self.machine.compute_graph(feature_batch)
94

95 96
        predictions = numpy.argmax(self.session.run(graph, feed_dict={feature_batch: data[:]}), 1)
        accuracy = 100. * numpy.sum(predictions == labels) / predictions.shape[0]
97 98

        summaries = []
99
        summaries.append(summary_pb2.Summary.Value(tag="accuracy_validation", simple_value=float(accuracy)))
100
        return summary_pb2.Summary(value=summaries)