embedding_validation.py 1.58 KB
Newer Older
1
import tensorflow as tf
2

3
4
5
6
7
8
from bob.learn.tensorflow.metrics.embedding_accuracy import accuracy_from_embeddings


class EmbeddingValidation(tf.keras.Model):
    """
    Use this model if the validation step should validate the accuracy with respect to embeddings.
9

10
11
12
13
    In this model, the `test_step` runs the function `bob.learn.tensorflow.metrics.embedding_accuracy.accuracy_from_embeddings`
    """

    def compile(
14
15
        self,
        **kwargs,
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
    ):
        """
        Compile
        """
        super().compile(**kwargs)
        self.train_loss = tf.keras.metrics.Mean(name="accuracy")
        self.validation_acc = tf.keras.metrics.Mean(name="accuracy")

    def train_step(self, data):
        """
        Train Step
        """

        X, y = data
        with tf.GradientTape() as tape:
            logits, _ = self(X, training=True)
            loss = self.loss(y, logits)

        self.optimizer.minimize(loss, self.trainable_variables, tape=tape)

        self.compiled_metrics.update_state(y, logits, sample_weight=None)
        self.train_loss(loss)
        return {m.name: m.result() for m in self.metrics + [self.train_loss]}

        # self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        # self.train_loss(loss)
        # return {m.name: m.result() for m in [self.train_loss]}

    def test_step(self, data):
        """
        Test Step
        """

        images, labels = data
        logits, prelogits = self(images, training=False)
        self.validation_acc(accuracy_from_embeddings(labels, prelogits))
        return {m.name: m.result() for m in [self.validation_acc]}