From 27e73aec2fc4a0a5b5e3d249c0a1050011906f9a Mon Sep 17 00:00:00 2001
From: Amir MOHAMMADI <amir.mohammadi@idiap.ch>
Date: Fri, 7 Feb 2020 15:24:31 +0100
Subject: [PATCH] cast labels to required format

---
 bob/learn/tensorflow/network/PatchCNN.py  | 9 +++++----
 bob/learn/tensorflow/network/SimpleCNN.py | 3 +++
 2 files changed, 8 insertions(+), 4 deletions(-)

diff --git a/bob/learn/tensorflow/network/PatchCNN.py b/bob/learn/tensorflow/network/PatchCNN.py
index fc15a118..98db7e52 100644
--- a/bob/learn/tensorflow/network/PatchCNN.py
+++ b/bob/learn/tensorflow/network/PatchCNN.py
@@ -44,22 +44,22 @@ patch = Sequential([
     Activation('relu'),
     MaxPool2D(padding='same'),
 
-    Conv2D(100, (3, 3), padding='same', use_bias=False, input_shape=(96,96,3)),
+    Conv2D(100, (3, 3), padding='same', use_bias=False),
     BatchNormalization(scale=False),
     Activation('relu'),
     MaxPool2D(padding='same'),
 
-    Conv2D(150, (3, 3), padding='same', use_bias=False, input_shape=(96,96,3)),
+    Conv2D(150, (3, 3), padding='same', use_bias=False),
     BatchNormalization(scale=False),
     Activation('relu'),
     MaxPool2D(pool_size=3, strides=2, padding='same'),
 
-    Conv2D(200, (3, 3), padding='same', use_bias=False, input_shape=(96,96,3)),
+    Conv2D(200, (3, 3), padding='same', use_bias=False),
     BatchNormalization(scale=False),
     Activation('relu'),
     MaxPool2D(padding='same'),
 
-    Conv2D(250, (3, 3), padding='same', use_bias=False, input_shape=(96,96,3)),
+    Conv2D(250, (3, 3), padding='same', use_bias=False),
     BatchNormalization(scale=False),
     Activation('relu'),
     MaxPool2D(padding='same'),
@@ -388,6 +388,7 @@ def model_fn(features, labels, mode, params=None, config=None):
         return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
 
     # Calculate Loss (for both TRAIN and EVAL modes)
+    labels = tf.cast(labels, dtype="int32")
     loss = tf.losses.sparse_softmax_cross_entropy(logits=logits, labels=labels)
     # Add the regularization terms to the loss
     if regularization_rate:
diff --git a/bob/learn/tensorflow/network/SimpleCNN.py b/bob/learn/tensorflow/network/SimpleCNN.py
index e27d4378..1a3b2a08 100644
--- a/bob/learn/tensorflow/network/SimpleCNN.py
+++ b/bob/learn/tensorflow/network/SimpleCNN.py
@@ -401,6 +401,9 @@ def model_fn(features, labels, mode, params=None, config=None):
     if mode == tf.estimator.ModeKeys.PREDICT:
         return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
 
+    # convert labels to the expected int32 format
+    labels = tf.cast(labels, dtype="int32")
+
     accuracy = tf.metrics.accuracy(
         labels=labels, predictions=predictions["classes"])
     metrics = {'accuracy': accuracy}
-- 
GitLab