diff --git a/bob/learn/em/test/test_em.py b/bob/learn/em/test/test_em.py
index 9acd5ce0c1ee3edd749cd360f345f069c03e883f..0b524f21a3b22c761de9549ac1114aab666845da 100644
--- a/bob/learn/em/test/test_em.py
+++ b/bob/learn/em/test/test_em.py
@@ -55,6 +55,12 @@ def test_gmm_ML_1():
   ar = bob.io.base.load(datafile("faithful.torch3_f64.hdf5", __name__, path="../data/"))
   gmm = loadGMM()
 
+  # test rng handling
+  ml_gmmtrainer = ML_GMMTrainer(True, True, True)
+  rng = bob.core.random.mt19937(12345)
+  bob.learn.em.train(ml_gmmtrainer, gmm, ar, convergence_threshold=0.001, rng=rng)
+
+  gmm = loadGMM()
   ml_gmmtrainer = ML_GMMTrainer(True, True, True)
   #ml_gmmtrainer.train(gmm, ar)
   bob.learn.em.train(ml_gmmtrainer, gmm, ar, convergence_threshold=0.001)
@@ -114,6 +120,13 @@ def test_gmm_MAP_1():
 
   ar = bob.io.base.load(datafile('faithful.torch3_f64.hdf5', __name__, path="../data/"))
 
+  # test with rng
+  rng = bob.core.random.mt19937(12345)
+  gmm = GMMMachine(bob.io.base.HDF5File(datafile("gmm_ML.hdf5", __name__, path="../data/")))
+  gmmprior = GMMMachine(bob.io.base.HDF5File(datafile("gmm_ML.hdf5", __name__, path="../data/")))
+  map_gmmtrainer = MAP_GMMTrainer(update_means=True, update_variances=False, update_weights=False, prior_gmm=gmmprior, relevance_factor=4.)
+  bob.learn.em.train(map_gmmtrainer, gmm, ar, rng=rng)
+
   gmm = GMMMachine(bob.io.base.HDF5File(datafile("gmm_ML.hdf5", __name__, path="../data/")))
   gmmprior = GMMMachine(bob.io.base.HDF5File(datafile("gmm_ML.hdf5", __name__, path="../data/")))
 
@@ -253,9 +266,9 @@ def test_custom_trainer():
 
   for i in range(0, 2):
     assert (ar[i+1] == machine.means[i, :]).all()
-    
-    
-    
+
+
+
 def test_EMPCA():
 
   # Tests our Probabilistic PCA trainer for linear machines for a simple
@@ -294,5 +307,5 @@ def test_EMPCA():
   T.e_step(m, ar)
   T.m_step(m, ar)
   llh2 = T.compute_likelihood(m)
-  assert abs(exp_llh2 - llh2) < 2e-4    
-    
+  assert abs(exp_llh2 - llh2) < 2e-4
+
diff --git a/bob/learn/em/train.py b/bob/learn/em/train.py
index beb93373d7aeff03594ee81dd37fca6cf505f1b2..41114cca50bc21f43085a65dee3595f7332d2dac 100644
--- a/bob/learn/em/train.py
+++ b/bob/learn/em/train.py
@@ -45,7 +45,9 @@ def train(trainer, machine, data, max_iterations=50, convergence_threshold=None,
 
     # Initialization
     if initialize:
-        if rng is not None:
+        if rng is not None and \
+           (not isinstance(trainer, (bob.learn.em.ML_GMMTrainer,
+                                     bob.learn.em.MAP_GMMTrainer))):
             trainer.initialize(machine, data, rng)
         else:
             trainer.initialize(machine, data)