diff --git a/bob/learn/em/script/train.py b/bob/learn/em/script/train.py index 37f2233043d325e6b4cb93e852e39153ac044c0b..040a76ec4cfbe9285deafe0fa74830a229a70c53 100644 --- a/bob/learn/em/script/train.py +++ b/bob/learn/em/script/train.py @@ -32,7 +32,7 @@ SLEEP = 5 cls=ConfigCommand, epilog="""\b Examples: - $ bob em train -vvv config.py -o /tmp/gmm -- --array 64 -q q1d -m 4G ... + $ bob em train -vvv config.py -o /tmp/gmm -- --array 64 --jman-options '-q q1d -i -m 4G' ... Note: samples must be sorted! """, @@ -49,8 +49,9 @@ Note: samples must be sorted! required=True, cls=ResourceOption, help="A list of samples to be loaded with reader. The samples must be stable. " - "The script will be called several times in separate " - "processes. Each time the samples should be the same! It's best to sort them!", + "The script will be called several times in separate processes. Each time the " + "config file is loaded the samples should have the same order and must be exactly " + "the same! It's best to sort them!", ) @click.option( "--output-dir", @@ -79,23 +80,40 @@ Note: samples must be sorted! type=click.INT, default=50, cls=ResourceOption, + show_default=True, help="The maximum number of iterations to train a machine.", ) @click.option( "--convergence-threshold", type=click.FLOAT, + default=4e-5, + show_default=True, cls=ResourceOption, help="The convergence threshold to train a machine. If None, the training " "procedure will stop with the iterations criteria.", ) +@click.option( + "--initialization-stride", + type=click.INT, + default=1, + show_default=True, + cls=ResourceOption, + help="The stride to use for selecting a subset of samples to initialize the " + "machine. Must be 1 or greater.", +) @click.option( "--jman-options", default=" ", + show_default=True, cls=ResourceOption, help="Additional options to be given to jman", ) @click.option( - "--jman", default="jman", cls=ResourceOption, help="Path to the jman script." + "--jman", + default="jman", + show_default=True, + cls=ResourceOption, + help="Path to the jman script.", ) @click.option( "--step", @@ -114,6 +132,7 @@ def train( machine, max_iterations, convergence_threshold, + initialization_stride, jman_options, jman, step, @@ -145,15 +164,20 @@ def train( ) raise click.Abort - # sanity check - assert len(samples) // array > machine.shape[0], "Please reduce array number!" n_samples = len(samples) - n_jobs = array + # some array jobs may not get any samples + # for example if n_samples is 241 and array is 64, + # each worker gets 4 samples and that means only 61 workers would get samples to + # work with + n_jobs = int(np.ceil(n_samples / np.ceil(n_samples / array))) # initialize if trainer_type in ("KMeansTrainer", "ML_GMMTrainer"): - logger.info("Loading %d samples to initialize the machine", len(samples)) - data = read_samples(reader, samples) + initilization_samples = samples[::initialization_stride] + logger.info( + "Loading %d samples to initialize the machine", len(initilization_samples) + ) + data = read_samples(reader, initilization_samples) logger.info("Initializing the trainer (and maybe machine)") trainer.initialize(machine, data) @@ -349,7 +373,8 @@ def load_statistics(trainer, machine, path): with HDF5File(path, "r") as f: if trainer_type == "KMeansTrainer": - trainer.zeroeth_order_statistics = f["zeroeth_order_statistics"] + zeros = f["zeroeth_order_statistics"] + trainer.zeroeth_order_statistics = np.array(zeros).reshape((-1,)) trainer.first_order_statistics = f["first_order_statistics"] trainer.average_min_distance = f["average_min_distance"] @@ -366,6 +391,9 @@ def load_statistics(trainer, machine, path): def e_step(samples, reader, output_dir, trainer, machine): + if len(samples) == 0: + print("This worker did not get any samples.") + return logger.info("Loading %d samples", len(samples)) data = read_samples(reader, samples) logger.info("Loaded all samples") @@ -394,7 +422,7 @@ def read_samples(reader, samples): # read one sample to see if data is numpy arrays data = reader(samples[0]) if isinstance(data, np.ndarray): - samples = vstack_features(reader, samples, same_size=True) + samples = vstack_features(reader, samples, same_size=False) else: samples = [reader(s) for s in samples] return samples