From 6e804b8a871be6da89856ba8043efeb06b5b5ef0 Mon Sep 17 00:00:00 2001
From: Amir MOHAMMADI <amir.mohammadi@idiap.ch>
Date: Fri, 23 Aug 2019 11:38:56 +0200
Subject: [PATCH] Improvements to the em train script

---
 bob/learn/em/script/train.py | 50 ++++++++++++++++++++++++++++--------
 1 file changed, 39 insertions(+), 11 deletions(-)

diff --git a/bob/learn/em/script/train.py b/bob/learn/em/script/train.py
index 37f2233..040a76e 100644
--- a/bob/learn/em/script/train.py
+++ b/bob/learn/em/script/train.py
@@ -32,7 +32,7 @@ SLEEP = 5
     cls=ConfigCommand,
     epilog="""\b
 Examples:
-  $ bob em train -vvv config.py -o /tmp/gmm -- --array 64 -q q1d -m 4G ...
+  $ bob em train -vvv config.py -o /tmp/gmm -- --array 64 --jman-options '-q q1d -i -m 4G' ...
 
 Note: samples must be sorted!
 """,
@@ -49,8 +49,9 @@ Note: samples must be sorted!
     required=True,
     cls=ResourceOption,
     help="A list of samples to be loaded with reader. The samples must be stable. "
-    "The script will be called several times in separate "
-    "processes. Each time the samples should be the same! It's best to sort them!",
+    "The script will be called several times in separate processes. Each time the "
+    "config file is loaded the samples should have the same order and must be exactly "
+    "the same! It's best to sort them!",
 )
 @click.option(
     "--output-dir",
@@ -79,23 +80,40 @@ Note: samples must be sorted!
     type=click.INT,
     default=50,
     cls=ResourceOption,
+    show_default=True,
     help="The maximum number of iterations to train a machine.",
 )
 @click.option(
     "--convergence-threshold",
     type=click.FLOAT,
+    default=4e-5,
+    show_default=True,
     cls=ResourceOption,
     help="The convergence threshold to train a machine. If None, the training "
     "procedure will stop with the iterations criteria.",
 )
+@click.option(
+    "--initialization-stride",
+    type=click.INT,
+    default=1,
+    show_default=True,
+    cls=ResourceOption,
+    help="The stride to use for selecting a subset of samples to initialize the "
+    "machine. Must be 1 or greater.",
+)
 @click.option(
     "--jman-options",
     default=" ",
+    show_default=True,
     cls=ResourceOption,
     help="Additional options to be given to jman",
 )
 @click.option(
-    "--jman", default="jman", cls=ResourceOption, help="Path to the jman script."
+    "--jman",
+    default="jman",
+    show_default=True,
+    cls=ResourceOption,
+    help="Path to the jman script.",
 )
 @click.option(
     "--step",
@@ -114,6 +132,7 @@ def train(
     machine,
     max_iterations,
     convergence_threshold,
+    initialization_stride,
     jman_options,
     jman,
     step,
@@ -145,15 +164,20 @@ def train(
         )
         raise click.Abort
 
-    # sanity check
-    assert len(samples) // array > machine.shape[0], "Please reduce array number!"
     n_samples = len(samples)
-    n_jobs = array
+    # some array jobs may not get any samples
+    # for example if n_samples is 241 and array is 64,
+    # each worker gets 4 samples and that means only 61 workers would get samples to
+    # work with
+    n_jobs = int(np.ceil(n_samples / np.ceil(n_samples / array)))
 
     # initialize
     if trainer_type in ("KMeansTrainer", "ML_GMMTrainer"):
-        logger.info("Loading %d samples to initialize the machine", len(samples))
-        data = read_samples(reader, samples)
+        initilization_samples = samples[::initialization_stride]
+        logger.info(
+            "Loading %d samples to initialize the machine", len(initilization_samples)
+        )
+        data = read_samples(reader, initilization_samples)
 
         logger.info("Initializing the trainer (and maybe machine)")
         trainer.initialize(machine, data)
@@ -349,7 +373,8 @@ def load_statistics(trainer, machine, path):
     with HDF5File(path, "r") as f:
 
         if trainer_type == "KMeansTrainer":
-            trainer.zeroeth_order_statistics = f["zeroeth_order_statistics"]
+            zeros = f["zeroeth_order_statistics"]
+            trainer.zeroeth_order_statistics = np.array(zeros).reshape((-1,))
             trainer.first_order_statistics = f["first_order_statistics"]
             trainer.average_min_distance = f["average_min_distance"]
 
@@ -366,6 +391,9 @@ def load_statistics(trainer, machine, path):
 
 
 def e_step(samples, reader, output_dir, trainer, machine):
+    if len(samples) == 0:
+        print("This worker did not get any samples.")
+        return
     logger.info("Loading %d samples", len(samples))
     data = read_samples(reader, samples)
     logger.info("Loaded all samples")
@@ -394,7 +422,7 @@ def read_samples(reader, samples):
     # read one sample to see if data is numpy arrays
     data = reader(samples[0])
     if isinstance(data, np.ndarray):
-        samples = vstack_features(reader, samples, same_size=True)
+        samples = vstack_features(reader, samples, same_size=False)
     else:
         samples = [reader(s) for s in samples]
     return samples
-- 
GitLab