From 27e73aec2fc4a0a5b5e3d249c0a1050011906f9a Mon Sep 17 00:00:00 2001 From: Amir MOHAMMADI <amir.mohammadi@idiap.ch> Date: Fri, 7 Feb 2020 15:24:31 +0100 Subject: [PATCH] cast labels to required format --- bob/learn/tensorflow/network/PatchCNN.py | 9 +++++---- bob/learn/tensorflow/network/SimpleCNN.py | 3 +++ 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/bob/learn/tensorflow/network/PatchCNN.py b/bob/learn/tensorflow/network/PatchCNN.py index fc15a118..98db7e52 100644 --- a/bob/learn/tensorflow/network/PatchCNN.py +++ b/bob/learn/tensorflow/network/PatchCNN.py @@ -44,22 +44,22 @@ patch = Sequential([ Activation('relu'), MaxPool2D(padding='same'), - Conv2D(100, (3, 3), padding='same', use_bias=False, input_shape=(96,96,3)), + Conv2D(100, (3, 3), padding='same', use_bias=False), BatchNormalization(scale=False), Activation('relu'), MaxPool2D(padding='same'), - Conv2D(150, (3, 3), padding='same', use_bias=False, input_shape=(96,96,3)), + Conv2D(150, (3, 3), padding='same', use_bias=False), BatchNormalization(scale=False), Activation('relu'), MaxPool2D(pool_size=3, strides=2, padding='same'), - Conv2D(200, (3, 3), padding='same', use_bias=False, input_shape=(96,96,3)), + Conv2D(200, (3, 3), padding='same', use_bias=False), BatchNormalization(scale=False), Activation('relu'), MaxPool2D(padding='same'), - Conv2D(250, (3, 3), padding='same', use_bias=False, input_shape=(96,96,3)), + Conv2D(250, (3, 3), padding='same', use_bias=False), BatchNormalization(scale=False), Activation('relu'), MaxPool2D(padding='same'), @@ -388,6 +388,7 @@ def model_fn(features, labels, mode, params=None, config=None): return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) # Calculate Loss (for both TRAIN and EVAL modes) + labels = tf.cast(labels, dtype="int32") loss = tf.losses.sparse_softmax_cross_entropy(logits=logits, labels=labels) # Add the regularization terms to the loss if regularization_rate: diff --git a/bob/learn/tensorflow/network/SimpleCNN.py b/bob/learn/tensorflow/network/SimpleCNN.py index e27d4378..1a3b2a08 100644 --- a/bob/learn/tensorflow/network/SimpleCNN.py +++ b/bob/learn/tensorflow/network/SimpleCNN.py @@ -401,6 +401,9 @@ def model_fn(features, labels, mode, params=None, config=None): if mode == tf.estimator.ModeKeys.PREDICT: return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) + # convert labels to the expected int32 format + labels = tf.cast(labels, dtype="int32") + accuracy = tf.metrics.accuracy( labels=labels, predictions=predictions["classes"]) metrics = {'accuracy': accuracy} -- GitLab