Regressor.py 6.61 KB
Newer Older
1
from . import check_features, get_trainable_variables
2
from .Logits import moving_average_scaffold
3 4 5 6 7 8 9 10 11 12 13 14
from bob.learn.tensorflow.network.utils import append_logits
from tensorflow.python.estimator import estimator
import tensorflow as tf
import logging

logger = logging.getLogger(__name__)


class Regressor(estimator.Estimator):
    """An estimator for regression problems"""

    def __init__(
15 16 17 18 19 20 21 22 23 24 25 26 27
        self,
        architecture,
        optimizer=tf.train.AdamOptimizer(),
        loss_op=tf.losses.mean_squared_error,
        label_dimension=1,
        config=None,
        model_dir=None,
        apply_moving_averages=True,
        add_regularization_losses=True,
        extra_checkpoint=None,
        add_histograms=None,
        optimize_loss=tf.contrib.layers.optimize_loss,
        optimize_loss_learning_rate=None,
28
        architecture_has_logits=False,
29 30 31 32 33
    ):
        self.architecture = architecture
        self.label_dimension = label_dimension
        self.loss_op = loss_op
        self.add_regularization_losses = add_regularization_losses
34
        self.apply_moving_averages = apply_moving_averages
35

36 37 38 39
        if self.apply_moving_averages and isinstance(optimizer, tf.train.Optimizer):
            logger.info(
                "Encapsulating the optimizer with " "the MovingAverageOptimizer"
            )
40 41 42
            optimizer = tf.contrib.opt.MovingAverageOptimizer(optimizer)

        self.optimizer = optimizer
43 44
        self.optimize_loss = optimize_loss
        self.optimize_loss_learning_rate = optimize_loss_learning_rate
45

46 47 48
        def _model_fn(features, labels, mode, config):

            check_features(features)
49 50
            data = features["data"]
            key = features["key"]
51 52 53

            # Checking if we have some variables/scope that we may want to shut
            # down
54
            trainable_variables = get_trainable_variables(extra_checkpoint, mode=mode)
55
            prelogits, end_points = self.architecture(
56 57
                data, mode=mode, trainable_variables=trainable_variables
            )
58 59 60 61 62 63
            if architecture_has_logits:
                logits, prelogits = prelogits, end_points["prelogits"]
            else:
                logits = append_logits(
                    prelogits, label_dimension, trainable_variables=trainable_variables
                )
64

65
            predictions = {"predictions": logits, "key": key}
66 67

            if mode == tf.estimator.ModeKeys.PREDICT:
68
                return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
69 70 71

            # in PREDICT mode logits rank must be 2 but in EVAL and TRAIN the
            # rank should be 1 for the loss function!
72
            predictions["predictions"] = tf.squeeze(logits)
73 74 75 76 77

            predictions_op = predictions["predictions"]

            # Calculate root mean squared error
            rmse = tf.metrics.root_mean_squared_error(labels, predictions_op)
78
            metrics = {"rmse": rmse}
79 80

            if mode == tf.estimator.ModeKeys.EVAL:
81
                self.loss = self._get_loss(predictions=predictions_op, labels=labels)
82 83 84 85 86
                return tf.estimator.EstimatorSpec(
                    mode=mode,
                    predictions=predictions,
                    loss=self.loss,
                    train_op=None,
87 88
                    eval_metric_ops=metrics,
                )
89 90 91

            # restore the model from an extra_checkpoint
            if extra_checkpoint is not None:
92
                if "Logits/" not in extra_checkpoint["scopes"]:
93 94
                    logger.warning(
                        '"Logits/" (which are automatically added by this '
95 96 97 98
                        "Regressor class are not in the scopes of "
                        "extra_checkpoint). Did you mean to restore the "
                        "Logits variables as well?"
                    )
99 100 101 102 103
                tf.train.init_from_checkpoint(
                    ckpt_dir_or_file=extra_checkpoint["checkpoint_path"],
                    assignment_map=extra_checkpoint["scopes"],
                )

104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
            # Calculate Loss
            self.loss = self._get_loss(predictions=predictions_op, labels=labels)

            # Compute the moving average of all individual losses
            # and the total loss.
            loss_averages = tf.train.ExponentialMovingAverage(0.9, name="avg")
            loss_averages_op = loss_averages.apply(
                tf.get_collection(tf.GraphKeys.LOSSES)
            )

            train_op = tf.group(
                self.optimize_loss(
                    loss=self.loss,
                    global_step=tf.train.get_or_create_global_step(),
                    optimizer=self.optimizer,
                    learning_rate=self.optimize_loss_learning_rate,
                ),
                loss_averages_op,
            )

            # Get the moving average saver after optimizer.minimize is called
            if self.apply_moving_averages:
                self.saver, self.scaffold = moving_average_scaffold(
                    self.optimizer.optimizer
                    if hasattr(self.optimizer, "optimizer")
                    else self.optimizer,
                    config,
                )
            else:
                self.saver, self.scaffold = None, None

            # Log rmse and loss
            with tf.name_scope("train_metrics"):
                tf.summary.scalar("rmse", rmse[1])
                for l in tf.get_collection(tf.GraphKeys.LOSSES):
                    tf.summary.scalar(l.op.name + "_averaged", loss_averages.average(l))

            # add histograms summaries
            if add_histograms == "all":
                for v in tf.all_variables():
                    tf.summary.histogram(v.name, v)
            elif add_histograms == "train":
                for v in tf.trainable_variables():
                    tf.summary.histogram(v.name, v)
148 149 150 151 152 153

            return tf.estimator.EstimatorSpec(
                mode=mode,
                predictions=predictions,
                loss=self.loss,
                train_op=train_op,
154
                eval_metric_ops=metrics,
155 156
                scaffold=self.scaffold,
            )
157 158

        super(Regressor, self).__init__(
159 160
            model_fn=_model_fn, model_dir=model_dir, config=config
        )
161 162

    def _get_loss(self, predictions, labels):
163
        main_loss = self.loss_op(predictions=predictions, labels=labels)
164 165
        if not self.add_regularization_losses:
            return main_loss
166 167 168 169 170
        regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
        regularization_losses = [
            tf.cast(l, main_loss.dtype) for l in regularization_losses
        ]
        total_loss = tf.add_n([main_loss] + regularization_losses, name="total_loss")
171
        return total_loss