From cc8a142a6da0c8b7c1cd4135b58c87d7a36a6bf6 Mon Sep 17 00:00:00 2001
From: Amir MOHAMMADI <amir.mohammadi@idiap.ch>
Date: Thu, 2 May 2019 14:19:01 +0200
Subject: [PATCH] EPSC estimators and losses

---
 bob/learn/tensorflow/estimators/EPSC.py | 489 ++++++++++++++++++++++++
 bob/learn/tensorflow/loss/epsc.py       | 178 +++++++++
 2 files changed, 667 insertions(+)
 create mode 100644 bob/learn/tensorflow/estimators/EPSC.py
 create mode 100644 bob/learn/tensorflow/loss/epsc.py

diff --git a/bob/learn/tensorflow/estimators/EPSC.py b/bob/learn/tensorflow/estimators/EPSC.py
new file mode 100644
index 00000000..8665830b
--- /dev/null
+++ b/bob/learn/tensorflow/estimators/EPSC.py
@@ -0,0 +1,489 @@
+# vim: set fileencoding=utf-8 :
+# @author: Amir Mohammadi <amir.mohammadi@idiap.ch>
+
+from . import check_features, get_trainable_variables
+from .Logits import moving_average_scaffold
+from ..network.utils import append_logits
+from ..utils import predict_using_tensors
+from ..loss.epsc import epsc_metric, siamese_loss
+from tensorflow.python.estimator import estimator
+import tensorflow as tf
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class EPSCBase:
+    """A base class for EPSC based estimators"""
+
+    def _get_loss(self, bio_logits, pad_logits, bio_labels, pad_labels, mode):
+        main_loss = self.loss_op(
+            bio_logits=bio_logits,
+            pad_logits=pad_logits,
+            bio_labels=bio_labels,
+            pad_labels=pad_labels,
+        )
+        total_loss = main_loss
+
+        if self.add_regularization_losses:
+
+            regularization_losses = tf.get_collection(
+                tf.GraphKeys.REGULARIZATION_LOSSES
+            )
+            regularization_losses = [
+                tf.cast(l, main_loss.dtype) for l in regularization_losses
+            ]
+
+            regularization_losses = tf.add_n(
+                regularization_losses, name="regularization_losses"
+            )
+            tf.summary.scalar("regularization_losses", regularization_losses)
+
+            total_loss = tf.add_n([main_loss, regularization_losses], name="total_loss")
+
+        if self.vat_loss is not None:
+            vat_loss = self.vat_loss(
+                self.end_points["features"],
+                self.end_points["Logits/PAD"],
+                self.pad_architecture,
+                mode,
+            )
+            total_loss = tf.add_n([main_loss, vat_loss], name="total_loss")
+
+        return total_loss
+
+
+class EPSCLogits(EPSCBase, estimator.Estimator):
+    """An logits estimator for epsc problems"""
+
+    def __init__(
+        self,
+        architecture,
+        optimizer,
+        loss_op,
+        n_classes,
+        config=None,
+        embedding_validation=False,
+        model_dir="",
+        validation_batch_size=None,
+        extra_checkpoint=None,
+        apply_moving_averages=True,
+        add_histograms="train",
+        add_regularization_losses=True,
+        vat_loss=None,
+        optimize_loss=tf.contrib.layers.optimize_loss,
+        optimize_loss_learning_rate=None,
+    ):
+
+        self.architecture = architecture
+        self.n_classes = n_classes
+        self.loss_op = loss_op
+        self.loss = None
+        self.embedding_validation = embedding_validation
+        self.extra_checkpoint = extra_checkpoint
+        self.add_regularization_losses = add_regularization_losses
+        self.apply_moving_averages = apply_moving_averages
+        self.vat_loss = vat_loss
+        self.optimize_loss = optimize_loss
+        self.optimize_loss_learning_rate = optimize_loss_learning_rate
+
+        if apply_moving_averages and isinstance(optimizer, tf.train.Optimizer):
+            logger.info(
+                "Encapsulating the optimizer with " "the MovingAverageOptimizer"
+            )
+            optimizer = tf.contrib.opt.MovingAverageOptimizer(optimizer)
+
+        self.optimizer = optimizer
+
+        def _model_fn(features, labels, mode):
+
+            check_features(features)
+            data = features["data"]
+            key = features["key"]
+
+            # Checking if we have some variables/scope that we may want to shut
+            # down
+            trainable_variables = get_trainable_variables(
+                self.extra_checkpoint, mode=mode
+            )
+            prelogits, end_points = self.architecture(
+                data, mode=mode, trainable_variables=trainable_variables
+            )
+
+            name = "Logits/Bio"
+            bio_logits = append_logits(
+                prelogits, n_classes, trainable_variables=trainable_variables, name=name
+            )
+            end_points[name] = bio_logits
+
+            name = "Logits/PAD"
+            pad_logits = append_logits(
+                prelogits, 2, trainable_variables=trainable_variables, name=name
+            )
+            end_points[name] = pad_logits
+
+            self.end_points = end_points
+
+            # for vat_loss
+            self.end_points["features"] = data
+
+            def pad_architecture(features, mode, reuse):
+                prelogits, end_points = self.architecture(
+                    features,
+                    mode=mode,
+                    trainable_variables=trainable_variables,
+                    reuse=reuse,
+                )
+                pad_logits = append_logits(
+                    prelogits,
+                    2,
+                    reuse=reuse,
+                    trainable_variables=trainable_variables,
+                    name="Logits/PAD",
+                )
+                return pad_logits, end_points
+
+            self.pad_architecture = pad_architecture
+
+            if self.embedding_validation and mode != tf.estimator.ModeKeys.TRAIN:
+
+                # Compute the embeddings
+                embeddings = tf.nn.l2_normalize(prelogits, 1)
+                predictions = {"embeddings": embeddings}
+            else:
+                predictions = {
+                    # Generate predictions (for PREDICT and EVAL mode)
+                    "bio_classes": tf.argmax(input=bio_logits, axis=1),
+                    # Add `softmax_tensor` to the graph. It is used for PREDICT
+                    # and by the `logging_hook`.
+                    "bio_probabilities": tf.nn.softmax(
+                        bio_logits, name="bio_softmax_tensor"
+                    ),
+                }
+
+            predictions.update(
+                {
+                    "pad_classes": tf.argmax(input=pad_logits, axis=1),
+                    "pad_probabilities": tf.nn.softmax(
+                        pad_logits, name="pad_softmax_tensor"
+                    ),
+                    "key": key,
+                }
+            )
+
+            # add predictions to end_points
+            self.end_points.update(predictions)
+
+            if mode == tf.estimator.ModeKeys.PREDICT:
+                return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
+
+            bio_labels = labels["bio"]
+            pad_labels = labels["pad"]
+
+            if self.embedding_validation and mode != tf.estimator.ModeKeys.TRAIN:
+                bio_predictions_op = predict_using_tensors(
+                    predictions["embeddings"], bio_labels, num=validation_batch_size
+                )
+            else:
+                bio_predictions_op = predictions["bio_classes"]
+
+            pad_predictions_op = predictions["pad_classes"]
+
+            metrics = {
+                "bio_accuracy": tf.metrics.accuracy(
+                    labels=bio_labels, predictions=bio_predictions_op
+                ),
+                "pad_accuracy": tf.metrics.accuracy(
+                    labels=pad_labels, predictions=pad_predictions_op
+                ),
+            }
+
+            if mode == tf.estimator.ModeKeys.EVAL:
+                self.loss = self._get_loss(
+                    bio_logits, pad_logits, bio_labels, pad_labels, mode=mode
+                )
+                return tf.estimator.EstimatorSpec(
+                    mode=mode,
+                    predictions=predictions,
+                    loss=self.loss,
+                    train_op=None,
+                    eval_metric_ops=metrics,
+                )
+
+            # restore the model from an extra_checkpoint
+            if self.extra_checkpoint is not None:
+                if "Logits/" not in self.extra_checkpoint["scopes"]:
+                    logger.warning(
+                        '"Logits/" (which are automatically added by this '
+                        "Logits class are not in the scopes of "
+                        "extra_checkpoint). Did you mean to restore the "
+                        "Logits variables as well?"
+                    )
+
+                logger.info(
+                    "Restoring model from %s in scopes %s",
+                    self.extra_checkpoint["checkpoint_path"],
+                    self.extra_checkpoint["scopes"],
+                )
+                tf.train.init_from_checkpoint(
+                    ckpt_dir_or_file=self.extra_checkpoint["checkpoint_path"],
+                    assignment_map=self.extra_checkpoint["scopes"],
+                )
+
+            # Calculate Loss
+            self.loss = self._get_loss(
+                bio_logits, pad_logits, bio_labels, pad_labels, mode=mode
+            )
+
+            # Compute the moving average of all individual losses and the total
+            # loss.
+            loss_averages = tf.train.ExponentialMovingAverage(0.9, name="avg")
+            loss_averages_op = loss_averages.apply(
+                tf.get_collection(tf.GraphKeys.LOSSES)
+            )
+            tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, loss_averages_op)
+
+            with tf.name_scope("train"):
+                train_op = self.optimize_loss(
+                    loss=self.loss,
+                    global_step=tf.train.get_or_create_global_step(),
+                    optimizer=self.optimizer,
+                    learning_rate=self.optimize_loss_learning_rate,
+                )
+
+                # Get the moving average saver after optimizer.minimize is called
+                if self.apply_moving_averages:
+                    self.saver, self.scaffold = moving_average_scaffold(
+                        self.optimizer.optimizer
+                        if hasattr(self.optimizer, "optimizer")
+                        else self.optimizer,
+                        config,
+                    )
+                else:
+                    self.saver, self.scaffold = None, None
+
+                # Log accuracy and loss
+                with tf.name_scope("train_metrics"):
+                    tf.summary.scalar("bio_accuracy", metrics["bio_accuracy"][1])
+                    tf.summary.scalar("pad_accuracy", metrics["pad_accuracy"][1])
+                    for l in tf.get_collection(tf.GraphKeys.LOSSES):
+                        tf.summary.scalar(
+                            l.op.name + "_averaged", loss_averages.average(l)
+                        )
+
+            # add histograms summaries
+            if add_histograms == "all":
+                for v in tf.all_variables():
+                    tf.summary.histogram(v.name, v)
+            elif add_histograms == "train":
+                for v in tf.trainable_variables():
+                    tf.summary.histogram(v.name, v)
+
+            return tf.estimator.EstimatorSpec(
+                mode=mode,
+                predictions=predictions,
+                loss=self.loss,
+                train_op=train_op,
+                eval_metric_ops=metrics,
+                scaffold=self.scaffold,
+            )
+
+        super().__init__(model_fn=_model_fn, model_dir=model_dir, config=config)
+
+
+class EPSCSiamese(EPSCBase, estimator.Estimator):
+    """An siamese estimator for epsc problems"""
+
+    def __init__(
+        self,
+        architecture,
+        optimizer,
+        loss_op=siamese_loss,
+        config=None,
+        model_dir="",
+        validation_batch_size=None,
+        extra_checkpoint=None,
+        apply_moving_averages=True,
+        add_histograms="train",
+        add_regularization_losses=True,
+        vat_loss=None,
+        optimize_loss=tf.contrib.layers.optimize_loss,
+        optimize_loss_learning_rate=None,
+    ):
+
+        self.architecture = architecture
+        self.loss_op = loss_op
+        self.loss = None
+        self.extra_checkpoint = extra_checkpoint
+        self.add_regularization_losses = add_regularization_losses
+        self.apply_moving_averages = apply_moving_averages
+        self.vat_loss = vat_loss
+        self.optimize_loss = optimize_loss
+        self.optimize_loss_learning_rate = optimize_loss_learning_rate
+
+        if self.apply_moving_averages and isinstance(optimizer, tf.train.Optimizer):
+            logger.info(
+                "Encapsulating the optimizer with " "the MovingAverageOptimizer"
+            )
+            optimizer = tf.contrib.opt.MovingAverageOptimizer(optimizer)
+
+        self.optimizer = optimizer
+
+        def _model_fn(features, labels, mode):
+
+            if mode != tf.estimator.ModeKeys.TRAIN:
+                check_features(features)
+                data = features["data"]
+                key = features["key"]
+            else:
+                if "left" not in features or "right" not in features:
+                    raise ValueError(
+                        "The input features needs to be a dictionary "
+                        "with the keys `left` and `right`"
+                    )
+                data_right = features["right"]["data"]
+                labels_right = labels["right"]
+                data = features["left"]["data"]
+                labels = labels_left = labels["left"]
+
+            # Checking if we have some variables/scope that we may want to shut
+            # down
+            trainable_variables = get_trainable_variables(
+                self.extra_checkpoint, mode=mode
+            )
+
+            prelogits, end_points = self.architecture(
+                data, mode=mode, trainable_variables=trainable_variables
+            )
+
+            self.end_points = end_points
+
+            predictions = dict(
+                bio_embeddings=tf.nn.l2_normalize(prelogits, 1),
+                pad_probabilities=tf.math.exp(-tf.norm(prelogits, ord=2, axis=-1)),
+            )
+
+            if mode == tf.estimator.ModeKeys.PREDICT:
+                predictions["key"] = key
+
+            # add predictions to end_points
+            self.end_points.update(predictions)
+
+            if mode == tf.estimator.ModeKeys.PREDICT:
+                return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
+
+            metrics = None
+            if mode != tf.estimator.ModeKeys.TRAIN:
+                assert validation_batch_size is not None
+                bio_labels = labels["bio"]
+                pad_labels = labels["pad"]
+
+                metrics = epsc_metric(
+                    predictions["bio_embeddings"],
+                    predictions["pad_probabilities"],
+                    bio_labels,
+                    pad_labels,
+                    validation_batch_size,
+                )
+
+            if mode == tf.estimator.ModeKeys.EVAL:
+                self.loss = tf.reduce_mean(0)
+                return tf.estimator.EstimatorSpec(
+                    mode=mode,
+                    predictions=predictions,
+                    loss=self.loss,
+                    train_op=None,
+                    eval_metric_ops=metrics,
+                )
+
+            # now that we are in TRAIN mode, build the right graph too
+            prelogits_left = prelogits
+            prelogits_right, _ = self.architecture(
+                data_right,
+                mode=mode,
+                reuse=True,
+                trainable_variables=trainable_variables,
+            )
+
+            bio_logits = {"left": prelogits_left, "right": prelogits_right}
+            pad_logits = bio_logits
+
+            bio_labels = {"left": labels_left["bio"], "right": labels_right["bio"]}
+
+            pad_labels = {"left": labels_left["pad"], "right": labels_right["pad"]}
+
+            # restore the model from an extra_checkpoint
+            if self.extra_checkpoint is not None:
+                logger.info(
+                    "Restoring model from %s in scopes %s",
+                    self.extra_checkpoint["checkpoint_path"],
+                    self.extra_checkpoint["scopes"],
+                )
+                tf.train.init_from_checkpoint(
+                    ckpt_dir_or_file=self.extra_checkpoint["checkpoint_path"],
+                    assignment_map=self.extra_checkpoint["scopes"],
+                )
+
+            global_step = tf.train.get_or_create_global_step()
+
+            # Some layer like tf.layers.batch_norm need this:
+            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
+
+            with tf.control_dependencies(update_ops), tf.name_scope("train"):
+
+                # Calculate Loss
+                self.loss = self._get_loss(
+                    bio_logits, pad_logits, bio_labels, pad_labels, mode=mode
+                )
+
+                # Compute the moving average of all individual losses
+                # and the total loss.
+                loss_averages = tf.train.ExponentialMovingAverage(0.9, name="avg")
+                loss_averages_op = loss_averages.apply(
+                    tf.get_collection(tf.GraphKeys.LOSSES)
+                )
+
+                train_op = tf.group(
+                    self.optimize_loss(
+                        loss=self.loss,
+                        global_step=tf.train.get_or_create_global_step(),
+                        optimizer=self.optimizer,
+                        learning_rate=self.optimize_loss_learning_rate,
+                    ),
+                    loss_averages_op,
+                )
+
+                # Get the moving average saver after optimizer.minimize is called
+                if apply_moving_averages:
+                    self.saver, self.scaffold = moving_average_scaffold(
+                        self.optimizer.optimizer
+                        if hasattr(self.optimizer, "optimizer")
+                        else self.optimizer,
+                        config,
+                    )
+                else:
+                    self.saver, self.scaffold = None, None
+
+            # Log moving average of losses
+            with tf.name_scope("train_metrics"):
+                for l in tf.get_collection(tf.GraphKeys.LOSSES):
+                    tf.summary.scalar(l.op.name + "_averaged", loss_averages.average(l))
+
+            # add histograms summaries
+            if add_histograms == "all":
+                for v in tf.all_variables():
+                    tf.summary.histogram(v.name, v)
+            elif add_histograms == "train":
+                for v in tf.trainable_variables():
+                    tf.summary.histogram(v.name, v)
+
+            return tf.estimator.EstimatorSpec(
+                mode=mode,
+                predictions=predictions,
+                loss=self.loss,
+                train_op=train_op,
+                eval_metric_ops=metrics,
+                scaffold=self.scaffold,
+            )
+
+        super().__init__(model_fn=_model_fn, model_dir=model_dir, config=config)
diff --git a/bob/learn/tensorflow/loss/epsc.py b/bob/learn/tensorflow/loss/epsc.py
new file mode 100644
index 00000000..cfadb012
--- /dev/null
+++ b/bob/learn/tensorflow/loss/epsc.py
@@ -0,0 +1,178 @@
+import tensorflow as tf
+import bob.measure
+import numpy
+from tensorflow.python.ops.metrics_impl import metric_variable
+from ..utils import norm, predict_using_tensors
+from .ContrastiveLoss import contrastive_loss
+
+
+def logits_loss(
+    bio_logits, pad_logits, bio_labels, pad_labels, bio_loss, pad_loss, alpha=0.5
+):
+
+    with tf.name_scope("Bio_loss"):
+        bio_loss_ = bio_loss(logits=bio_logits, labels=bio_labels)
+
+    with tf.name_scope("PAD_loss"):
+        pad_loss_ = pad_loss(
+            logits=pad_logits, labels=tf.cast(pad_labels, dtype="int32")
+        )
+
+    with tf.name_scope("EPSC_loss"):
+        total_loss = (1 - alpha) * bio_loss_ + alpha * pad_loss_
+
+    tf.add_to_collection(tf.GraphKeys.LOSSES, bio_loss_)
+    tf.add_to_collection(tf.GraphKeys.LOSSES, pad_loss_)
+    tf.add_to_collection(tf.GraphKeys.LOSSES, total_loss)
+
+    tf.summary.scalar("bio_loss", bio_loss_)
+    tf.summary.scalar("pad_loss", pad_loss_)
+    tf.summary.scalar("epsc_loss", total_loss)
+
+    return total_loss
+
+
+def embedding_norm_loss(prelogits_left, prelogits_right, b, c, margin=10.0):
+    with tf.name_scope("embedding_norm_loss"):
+        prelogits_left = norm(prelogits_left)
+        prelogits_right = norm(prelogits_right)
+
+        loss = tf.add_n(
+            [
+                tf.reduce_mean(b * (tf.maximum(prelogits_left - margin, 0))),
+                tf.reduce_mean((1 - b) * (tf.maximum(2 * margin - prelogits_left, 0))),
+                tf.reduce_mean(c * (tf.maximum(prelogits_right - margin, 0))),
+                tf.reduce_mean((1 - c) * (tf.maximum(2 * margin - prelogits_right, 0))),
+            ],
+            name="embedding_norm_loss",
+        )
+        tf.add_to_collection(tf.GraphKeys.LOSSES, loss)
+        tf.summary.scalar("embedding_norm_loss", loss)
+        # log norm of embeddings for BF and PA separately to see how their norm
+        # evolves over time
+        bf_norm = tf.concat(
+            [
+                tf.gather(prelogits_left, tf.where(b > 0.5)),
+                tf.gather(prelogits_right, tf.where(c > 0.5)),
+            ],
+            axis=0,
+        )
+        pa_norm = tf.concat(
+            [
+                tf.gather(prelogits_left, tf.where(b < 0.5)),
+                tf.gather(prelogits_right, tf.where(c < 0.5)),
+            ],
+            axis=0,
+        )
+        tf.summary.histogram("BF_embeddings_norm", bf_norm)
+        tf.summary.histogram("PA_embeddings_norm", pa_norm)
+    return loss
+
+
+def siamese_loss(bio_logits, pad_logits, bio_labels, pad_labels, alpha=0.1):
+    # prepare a, b, c
+    with tf.name_scope("epsc_labels"):
+        a = tf.to_float(
+            tf.math.equal(bio_labels["left"], bio_labels["right"]), name="a"
+        )
+        b = tf.to_float(tf.math.equal(pad_labels["left"], True), name="b")
+        c = tf.to_float(tf.math.equal(pad_labels["right"], True), name="c")
+        tf.summary.scalar("Mean_a", tf.reduce_mean(a))
+        tf.summary.scalar("Mean_b", tf.reduce_mean(b))
+        tf.summary.scalar("Mean_c", tf.reduce_mean(c))
+
+    prelogits_left = bio_logits["left"]
+    prelogits_right = bio_logits["right"]
+
+    bio_loss = contrastive_loss(prelogits_left, prelogits_right, labels=1 - a)
+
+    pad_loss = alpha * embedding_norm_loss(prelogits_left, prelogits_right, b, c)
+
+    with tf.name_scope("epsc_loss"):
+        epsc_loss = (1 - alpha) * bio_loss + alpha * pad_loss
+        tf.add_to_collection(tf.GraphKeys.LOSSES, epsc_loss)
+
+    tf.summary.scalar("epsc_loss", epsc_loss)
+
+    return epsc_loss
+
+
+def py_eer(negatives, positives):
+    def _eer(neg, pos):
+        if neg.size == 0 or pos.size == 0:
+            return numpy.array(0.0, dtype="float64")
+        return bob.measure.eer(neg, pos)
+
+    negatives = tf.reshape(tf.cast(negatives, "float64"), [-1])
+    positives = tf.reshape(tf.cast(positives, "float64"), [-1])
+
+    eer = tf.py_func(_eer, [negatives, positives], tf.float64, name="py_eer")
+
+    return tf.cast(eer, "float32")
+
+
+def epsc_metric(
+    bio_embeddings,
+    pad_probabilities,
+    bio_labels,
+    pad_labels,
+    batch_size,
+    pad_threshold=numpy.exp(-15),
+):
+    # math.exp(-2.0) = 0.1353352832366127
+    # math.exp(-15.0) = 3.059023205018258e-07
+    with tf.name_scope("epsc_metrics"):
+        bio_predictions_op = predict_using_tensors(
+            bio_embeddings, bio_labels, num=batch_size
+        )
+
+        # find the lowest value of bf and highest value of pa
+        # their mean is the threshold
+        # bf_probabilities = tf.gather(pad_probabilities, tf.where(pad_labels))
+        # pa_probabilities = tf.gather(pad_probabilities, tf.where(tf.logical_not(pad_labels)))
+
+        # eer = py_eer(pa_probabilities, bf_probabilities)
+        # acc = 1 - eer
+
+        # pad_threshold = (tf.reduce_max(pa_probabilities) + tf.reduce_min(bf_probabilities)) / 2
+        # true_positives = tf.reduce_sum(tf.to_int32(bf_probabilities >= pad_threshold))
+        # true_negatives = tf.reduce_sum(tf.to_int32(pa_probabilities < pad_threshold))
+        # # pad_accuracy = metric_variable([], tf.float32, name='pad_accuracy')
+        # acc = (true_positives + true_negatives) / batch_size
+
+        # pad_accuracy, pad_update_ops = tf.metrics.mean(acc)
+
+        # print_ops = [
+        #     tf.print(pad_probabilities),
+        #     tf.print(bf_probabilities, pa_probabilities),
+        #     tf.print(pad_threshold),
+        #     tf.print(true_positives, true_negatives),
+        #     tf.print(pad_probabilities.shape[0]),
+        #     tf.print(acc),
+        # ]
+        # update_op = tf.assign_add(pad_accuracy, tf.cast(acc, tf.float32))
+        # update_op = tf.group([update_op] + print_ops)
+
+        tp = tf.metrics.true_positives_at_thresholds(
+            pad_labels, pad_probabilities, [pad_threshold]
+        )
+        fp = tf.metrics.false_positives_at_thresholds(
+            pad_labels, pad_probabilities, [pad_threshold]
+        )
+        tn = tf.metrics.true_negatives_at_thresholds(
+            pad_labels, pad_probabilities, [pad_threshold]
+        )
+        fn = tf.metrics.false_negatives_at_thresholds(
+            pad_labels, pad_probabilities, [pad_threshold]
+        )
+        pad_accuracy = (tp[0] + tn[0]) / (tp[0] + tn[0] + fp[0] + fn[0])
+        pad_accuracy = tf.reduce_mean(pad_accuracy)
+        pad_update_ops = tf.group([x[1] for x in (tp, tn, fp, fn)])
+
+        eval_metric_ops = {
+            "bio_accuracy": tf.metrics.accuracy(
+                labels=bio_labels, predictions=bio_predictions_op
+            ),
+            "pad_accuracy": (pad_accuracy, pad_update_ops),
+        }
+    return eval_metric_ops
-- 
GitLab