diff --git a/bob/learn/tensorflow/trainers/SiameseTrainer.py b/bob/learn/tensorflow/trainers/SiameseTrainer.py
index 791c317f229d39e06ebe2a9e97bf732592412228..6f512eb5717ed3dba05c5c4b12579ad6a0faccb8 100644
--- a/bob/learn/tensorflow/trainers/SiameseTrainer.py
+++ b/bob/learn/tensorflow/trainers/SiameseTrainer.py
@@ -18,10 +18,10 @@ logger = logging.getLogger("bob.learn")
 class SiameseTrainer(Trainer):
     """
     Trainer for siamese networks:
-     
+
     Chopra, Sumit, Raia Hadsell, and Yann LeCun. "Learning a similarity metric discriminatively, with application to
     face verification." 2005 IEEE Computer Society Conference on Computer Vision and Pattern Recognition (CVPR'05). Vol. 1. IEEE, 2005.
-    
+
 
     **Parameters**
 
@@ -30,10 +30,10 @@ class SiameseTrainer(Trainer):
 
     iterations:
       Maximum number of iterations
-      
+
     snapshot:
       Will take a snapshot of the network at every `n` iterations
-      
+
     validation_snapshot:
       Test with validation each `n` iterations
 
@@ -127,8 +127,7 @@ class SiameseTrainer(Trainer):
         self.optimizer_class = optimizer
         self.learning_rate = learning_rate
 
-        # TODO: find an elegant way to provide this as a parameter of the trainer
-        self.global_step = tf.Variable(0, trainable=False, name="global_step")
+        self.global_step = tf.contrib.framework.get_or_create_global_step()
 
         # Saving all the variables
         self.saver = tf.train.Saver(var_list=tf.global_variables())
diff --git a/bob/learn/tensorflow/trainers/Trainer.py b/bob/learn/tensorflow/trainers/Trainer.py
index 4f26b1d7572acf203bedbf25db4e919f0bd01015..e3a5c6d7d6bfe1d6946d2b8b70dbadcae3ccb61a 100644
--- a/bob/learn/tensorflow/trainers/Trainer.py
+++ b/bob/learn/tensorflow/trainers/Trainer.py
@@ -25,7 +25,7 @@ logger = logging.getLogger("bob.learn")
 class Trainer(object):
     """
     One graph trainer.
-    
+
     Use this trainer when your CNN is composed by one graph
 
     **Parameters**
@@ -35,10 +35,10 @@ class Trainer(object):
 
     iterations:
       Maximum number of iterations
-      
+
     snapshot:
       Will take a snapshot of the network at every `n` iterations
-      
+
     validation_snapshot:
       Test with validation each `n` iterations
 
@@ -118,15 +118,15 @@ class Trainer(object):
 
         """
         Prepare all the tensorflow variables before training.
-        
+
         **Parameters**
-        
+
             graph: Input graph for training
-            
+
             optimizer: Solver
-            
+
             loss: Loss function
-            
+
             learning_rate: Learning rate
         """
 
@@ -139,8 +139,7 @@ class Trainer(object):
         self.optimizer_class = optimizer
         self.learning_rate = learning_rate
 
-        # TODO: find an elegant way to provide this as a parameter of the trainer
-        self.global_step = tf.Variable(0, trainable=False, name="global_step")
+        self.global_step = tf.contrib.framework.get_or_create_global_step()
 
         # Saving all the variables
         self.saver = tf.train.Saver(var_list=tf.global_variables())
@@ -202,8 +201,8 @@ class Trainer(object):
         Given a data shuffler prepared the dictionary to be injected in the graph
 
         ** Parameters **
-            
-            data_shuffler: Data shuffler :py:class:`bob.learn.tensorflow.datashuffler.Base` 
+
+            data_shuffler: Data shuffler :py:class:`bob.learn.tensorflow.datashuffler.Base`
 
         """
         [data, labels] = data_shuffler.get_batch()
diff --git a/bob/learn/tensorflow/trainers/TripletTrainer.py b/bob/learn/tensorflow/trainers/TripletTrainer.py
index 651baa11fd7cc3d2248849d99185f3bf5c3f2e7c..e15e90731023dded1b3841632dc6712d2c9e038d 100644
--- a/bob/learn/tensorflow/trainers/TripletTrainer.py
+++ b/bob/learn/tensorflow/trainers/TripletTrainer.py
@@ -19,8 +19,8 @@ logger = logging.getLogger("bob.learn")
 class TripletTrainer(Trainer):
     """
     Trainer for Triple networks:
-    
-    Schroff, Florian, Dmitry Kalenichenko, and James Philbin. 
+
+    Schroff, Florian, Dmitry Kalenichenko, and James Philbin.
     "Facenet: A unified embedding for face recognition and clustering." Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2015.
 
     **Parameters**
@@ -30,10 +30,10 @@ class TripletTrainer(Trainer):
 
     iterations:
       Maximum number of iterations
-      
+
     snapshot:
       Will take a snapshot of the network at every `n` iterations
-      
+
     validation_snapshot:
       Test with validation each `n` iterations
 
@@ -143,8 +143,7 @@ class TripletTrainer(Trainer):
         self.optimizer_class = optimizer
         self.learning_rate = learning_rate
 
-        # TODO: find an elegant way to provide this as a parameter of the trainer
-        self.global_step = tf.Variable(0, trainable=False, name="global_step")
+        self.global_step = tf.contrib.framework.get_or_create_global_step()
 
         # Saving all the variables
         self.saver = tf.train.Saver(var_list=tf.global_variables())
diff --git a/bob/learn/tensorflow/trainers/learning_rate.py b/bob/learn/tensorflow/trainers/learning_rate.py
index 9a206f75c1b0adaec02ec7f1b7b37c5fd265dd85..a8b0c46725259112a1882691934216256dafb5e1 100644
--- a/bob/learn/tensorflow/trainers/learning_rate.py
+++ b/bob/learn/tensorflow/trainers/learning_rate.py
@@ -18,8 +18,8 @@ def exponential_decay(base_learning_rate=0.05,
     staircase: Boolean. It True decay the learning rate at discrete intervals
     """
 
-    global_step = tf.Variable(0, trainable=False)
-    return tf.train.exponential_decay(base_learning_rate=base_learning_rate,
+    global_step = tf.contrib.framework.get_or_create_global_step()
+    return tf.train.exponential_decay(learning_rate=base_learning_rate,
                                       global_step=global_step,
                                       decay_steps=decay_steps,
                                       decay_rate=weight_decay,